diff --git a/src/crud.py b/src/crud.py index 6e1e59f..1321cb2 100644 --- a/src/crud.py +++ b/src/crud.py @@ -1,42 +1,96 @@ from fastapi import HTTPException from sqlalchemy import desc -from sqlalchemy.orm import Session +from datetime import datetime, timedelta -from enums import MinigameEnum, CourseEnum +from sqlalchemy.orm import Session +import jwt + + +from enums import CourseEnum, MinigameEnum from models import CourseProgress, HighScore, User +from schemas.courseprogress import CourseProgressBase from schemas.highscores import HighScoreCreate from schemas.users import UserCreate, UserHighScore -from schemas.courseprogress import CourseProgressBase +from passlib.context import CryptContext + DEFAULT_NR_HIGH_SCORES = 10 +# JWT authentication setup +jwt_secret = "secret_key" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + def get_user_by_id(db: Session, user_id: int): - """ Fetches a User from the database by their id """ + """Fetches a User from the database by their id""" return db.query(User).filter(User.user_id == user_id).first() def get_user_by_username(db: Session, username: str): - """ Fetches a User from the database by their username """ + """Fetches a User from the database by their username""" return db.query(User).filter(User.username == username).first() def get_users(db: Session): - """ Fetch a list of all users """ + """Fetch a list of all users""" return db.query(User).all() -def create_user(db: Session, username: str, hashed_password: str): - """ Create a new user """ - db_user = User(username=username, hashed_password=hashed_password) +def authenticate_user(db: Session, user: UserCreate): + db_user = get_user_by_username(db=db, username=user.username) + if not db_user: + return False + hashed_password = db_user.hashed_password + if not hashed_password or not pwd_context.verify(user.password, hashed_password): + return False + return db_user + + +def register(db: Session, username: str, password: str): + """Register a new user""" + db_user = get_user_by_username(db, username) + if db_user: + raise HTTPException(status_code=400, detail="Username already registered") + db_user = User(username=username, hashed_password=pwd_context.hash(password)) db.add(db_user) db.commit() db.refresh(db_user) return db_user +def login(db: Session, user: UserCreate): + user = authenticate_user(db, user) + if not user: + raise HTTPException(status_code=401, detail="Invalid username or password") + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_payload = { + "sub": user.username, + "exp": datetime.utcnow() + access_token_expires, + } + access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) + return {"access_token": access_token} + + +def patch_user(db: Session, username: str, user: UserCreate): + db_user = get_user_by_username(db, username) + potential_duplicate = get_user_by_username(db, user.username) + if potential_duplicate: + if potential_duplicate.user_id != db_user.user_id: + raise HTTPException(status_code=400, detail="Username already registered") + + db_user.username = user.username + db_user.hashed_password = pwd_context.hash(user.password) + db.commit() + + def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): - """ Get the n highest scores of a given minigame """ + """Get the n highest scores of a given minigame""" + if nr_highest < 1: + raise HTTPException(status_code=400, detail="Invalid number of high scores") + user_high_scores = [] if not nr_highest: @@ -61,7 +115,7 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int): def create_high_score(db: Session, high_score: HighScoreCreate): - """ Create a new high score for a given minigame """ + """Create a new high score for a given minigame""" owner = db.query(User).filter(User.user_id == high_score.owner_id).first() if not owner: raise HTTPException(status_code=400, detail="User does not exist") @@ -98,12 +152,25 @@ def create_high_score(db: Session, high_score: HighScoreCreate): db.refresh(db_high_score) return db_high_score -def get_course_progress(db: Session, user: User, course: CourseEnum): - """ Get the progress a user has for a certain course """ + +def get_course_progress(db: Session, username: str, course: CourseEnum): + """Get the progress a user has for a certain course""" + user = get_user_by_username(db, username) if course != CourseEnum.All: - course_progress = db.query(CourseProgress).filter(CourseProgress.owner_id == user.user_id, CourseProgress.course == course).first() + course_progress = ( + db.query(CourseProgress) + .filter( + CourseProgress.owner_id == user.user_id, CourseProgress.course == course + ) + .first() + ) if course_progress: - return [CourseProgressBase(progress_value = course_progress.progress_value, course = course_progress.course)] + return [ + CourseProgressBase( + progress_value=course_progress.progress_value, + course=course_progress.course, + ) + ] else: - return [CourseProgressBase(progress_value = 0, course = course)] + return [CourseProgressBase(progress_value=0, course=course)] return [] diff --git a/src/main.py b/src/main.py index bb88ecd..fb0d013 100644 --- a/src/main.py +++ b/src/main.py @@ -9,22 +9,15 @@ from sqlalchemy.orm import Session import crud from database import SessionLocal, engine -from enums import MinigameEnum, CourseEnum +from enums import CourseEnum, MinigameEnum from models import Base -from schemas import highscores, users, courseprogress +from schemas import courseprogress, highscores, users app = FastAPI() + Base.metadata.create_all(bind=engine) - -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -# JWT authentication setup -jwt_secret = "secret_key" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month - bearer_scheme = HTTPBearer() @@ -35,12 +28,11 @@ def get_db(): finally: db.close() - -def get_current_user( +def get_current_user_name( token: HTTPAuthorizationCredentials = Depends(bearer_scheme), ): try: - payload = jwt.decode(token.credentials, jwt_secret, algorithms=[ALGORITHM]) + payload = jwt.decode(token.credentials, crud.jwt_secret, algorithms=[crud.ALGORITHM]) username = payload.get("sub") if username is None: raise HTTPException(status_code=401, detail="Invalid JWT token") @@ -60,45 +52,33 @@ async def read_users(db: Session = Depends(get_db)): @app.patch("/users") -async def patch_current_user(user: users.UserCreate, current_user = Depends(get_current_user), db: Session = Depends(get_db)): - db_user = crud.get_us - return users +async def patch_current_user( + user: users.UserCreate, + current_user_name = Depends(get_current_user_name), + db: Session = Depends(get_db), +): + crud.patch_user(db, current_user_name, user) @app.post("/register", response_model=users.User) async def register(user: users.UserCreate, db: Session = Depends(get_db)): - db_user = crud.get_user_by_username(db, username=user.username) - if db_user: - raise HTTPException(status_code=400, detail="Username already registered") - return crud.create_user( - db=db, username=user.username, hashed_password=pwd_context.hash(user.password) + return crud.register( + db, user.username, user.password ) @app.post("/login") async def login(user: users.UserCreate, db: Session = Depends(get_db)): - user = authenticate_user(user, db) - if not user: - raise HTTPException(status_code=401, detail="Invalid username or password") - access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) - access_token_payload = { - "sub": user.username, - "exp": datetime.utcnow() + access_token_expires, - } - access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM) - return {"access_token": access_token} + return crud.login(db, user) @app.get("/highscores", response_model=List[users.UserHighScore]) async def read_high_scores( db: Session = Depends(get_db), minigame: Optional[MinigameEnum] = None, - n_highest: Optional[int] = None, + nr_highest: Optional[int] = None, ): - if n_highest < 1: - raise HTTPException(status_code=400, detail="Invalid number of high scores") - high_scores = crud.get_high_scores(db, minigame, n_highest) - return high_scores + return crud.get_high_scores(db, minigame, nr_highest) @app.post("/highscores", response_model=highscores.HighScore) @@ -111,22 +91,9 @@ async def create_high_score( #### TESTING!! DELETE LATER -async def get_current_user( - token: HTTPAuthorizationCredentials = Depends(bearer_scheme), -): - try: - payload = jwt.decode(token.credentials, jwt_secret, algorithms=[ALGORITHM]) - username = payload.get("sub") - if username is None: - raise HTTPException(status_code=401, detail="Invalid JWT token") - return username - except jwt.exceptions.DecodeError: - raise HTTPException(status_code=401, detail="Invalid JWT token") - - @app.get("/protected") -async def protected_route(current_user = Depends(get_current_user)): - return {"message": f"Hello, {current_user}!"} +async def protected_route(current_user_name=Depends(get_current_user_name)): + return {"message": f"Hello, {current_user_name}!"} def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)): @@ -139,10 +106,10 @@ def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)): return db_user - - - @app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) -async def get_course_progress(course: Optional[CourseEnum] = CourseEnum.All, current_user = Depends(get_current_user), db: Session = Depends(get_db)): - user = crud.get_user_by_username(db, current_user) - return crud.get_course_progress(db = db, user = user, course = course) +async def get_course_progress( + course: Optional[CourseEnum] = CourseEnum.All, + current_user_name=Depends(get_current_user_name), + db: Session = Depends(get_db), +): + return crud.get_course_progress(db, current_user_name, course) diff --git a/src/schemas/courseprogress.py b/src/schemas/courseprogress.py index d46486e..1f071de 100644 --- a/src/schemas/courseprogress.py +++ b/src/schemas/courseprogress.py @@ -2,10 +2,12 @@ from pydantic import BaseModel from enums import CourseEnum + class CourseProgressBase(BaseModel): progress_value: float course: CourseEnum + class CourseProgress(CourseProgressBase): course_progress_id: int owner_id: int