diff --git a/src/crud/courseprogress.py b/src/crud/courseprogress.py index b7d587e..e848004 100644 --- a/src/crud/courseprogress.py +++ b/src/crud/courseprogress.py @@ -1,11 +1,25 @@ from fastapi import HTTPException from sqlalchemy.orm import Session +from typing import List from src.enums import CourseEnum from src.models import CourseProgress, User from src.schemas.courseprogress import CourseProgressBase, CourseProgressParent, SavedCourseProgress +from src.schemas.learnableprogress import SavedLearnableProgress from src.crud.learnableprogress import get_learnables +def get_learnable_values(learnables: List[SavedLearnableProgress]): + completed_learnables = sum( + [1 if learnable.progress == 5.0 else 0 for learnable in learnables] + ) + + in_use_learnables = sum( + [1 if learnable.in_use else 0 for learnable in learnables] + ) + + total_learnables = len(learnables) + + return completed_learnables, in_use_learnables, total_learnables def get_course_progress(db: Session, user: User, course: CourseEnum): """Get the progress a user has for a certain course""" @@ -31,15 +45,7 @@ def get_course_progress(db: Session, user: User, course: CourseEnum): print("CURRENT COURSE: " + course_progress.course) learnables = get_learnables(db, user, course) - completed_learnables = sum( - [1 if learnable.progress == 5.0 else 0 for learnable in learnables] - ) - - in_use_learnables = sum( - [1 if learnable.in_use else 0 for learnable in learnables] - ) - - total_learnables = len(learnables) + completed_learnables, in_use_learnables, total_learnables = get_learnable_values(learnables) result.append(SavedCourseProgress( course_index=course_progress.course, @@ -104,14 +110,20 @@ def patch_course_progress( db_course_progress.progress = course_progress.progress db.commit() - return [ - SavedCourseProgress( + result = [] + + for db_cp in db_course_progress_list: + learnables = get_learnables(db, user, db_cp.course) + + completed_learnables, in_use_learnables, total_learnables = get_learnable_values(learnables) + + result.append(SavedCourseProgress( course_index=db_cp.course, progress=db_cp.progress, - completed_learnables=db, - in_use_learnables=0, - total_learnables=0, - learnables=[], - ) - for db_cp in db_course_progress_list - ] + completed_learnables=completed_learnables, + in_use_learnables=in_use_learnables, + total_learnables=total_learnables, + learnables=learnables, + )) + + return result diff --git a/tests/test_courseprogress.py b/tests/test_courseprogress.py index 8345369..8c105c4 100644 --- a/tests/test_courseprogress.py +++ b/tests/test_courseprogress.py @@ -117,7 +117,7 @@ async def test_patch_all_should_patch_all_courses(): for course in CourseEnum: if course != CourseEnum.All: - assert {"progress": 0.0, "course_index": course, "completed_learnables": 0, "in_use_learnables": 0, "total_learnables": 0, "learnables": []} in response + assert {"progress": progress, "course_index": course, "completed_learnables": 0, "in_use_learnables": 0, "total_learnables": 0, "learnables": []} in response @pytest.mark.asyncio