BIG refactors

This commit is contained in:
lvrossem
2023-03-31 06:16:40 -06:00
parent 5fe168937f
commit 49f8d7d713
3 changed files with 109 additions and 73 deletions

View File

@@ -1,42 +1,96 @@
from fastapi import HTTPException
from sqlalchemy import desc
from sqlalchemy.orm import Session
from datetime import datetime, timedelta
from enums import MinigameEnum, CourseEnum
from sqlalchemy.orm import Session
import jwt
from enums import CourseEnum, MinigameEnum
from models import CourseProgress, HighScore, User
from schemas.courseprogress import CourseProgressBase
from schemas.highscores import HighScoreCreate
from schemas.users import UserCreate, UserHighScore
from schemas.courseprogress import CourseProgressBase
from passlib.context import CryptContext
DEFAULT_NR_HIGH_SCORES = 10
# JWT authentication setup
jwt_secret = "secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def get_user_by_id(db: Session, user_id: int):
""" Fetches a User from the database by their id """
"""Fetches a User from the database by their id"""
return db.query(User).filter(User.user_id == user_id).first()
def get_user_by_username(db: Session, username: str):
""" Fetches a User from the database by their username """
"""Fetches a User from the database by their username"""
return db.query(User).filter(User.username == username).first()
def get_users(db: Session):
""" Fetch a list of all users """
"""Fetch a list of all users"""
return db.query(User).all()
def create_user(db: Session, username: str, hashed_password: str):
""" Create a new user """
db_user = User(username=username, hashed_password=hashed_password)
def authenticate_user(db: Session, user: UserCreate):
db_user = get_user_by_username(db=db, username=user.username)
if not db_user:
return False
hashed_password = db_user.hashed_password
if not hashed_password or not pwd_context.verify(user.password, hashed_password):
return False
return db_user
def register(db: Session, username: str, password: str):
"""Register a new user"""
db_user = get_user_by_username(db, username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
db_user = User(username=username, hashed_password=pwd_context.hash(password))
db.add(db_user)
db.commit()
db.refresh(db_user)
return db_user
def login(db: Session, user: UserCreate):
user = authenticate_user(db, user)
if not user:
raise HTTPException(status_code=401, detail="Invalid username or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token_payload = {
"sub": user.username,
"exp": datetime.utcnow() + access_token_expires,
}
access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM)
return {"access_token": access_token}
def patch_user(db: Session, username: str, user: UserCreate):
db_user = get_user_by_username(db, username)
potential_duplicate = get_user_by_username(db, user.username)
if potential_duplicate:
if potential_duplicate.user_id != db_user.user_id:
raise HTTPException(status_code=400, detail="Username already registered")
db_user.username = user.username
db_user.hashed_password = pwd_context.hash(user.password)
db.commit()
def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int):
""" Get the n highest scores of a given minigame """
"""Get the n highest scores of a given minigame"""
if nr_highest < 1:
raise HTTPException(status_code=400, detail="Invalid number of high scores")
user_high_scores = []
if not nr_highest:
@@ -61,7 +115,7 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int):
def create_high_score(db: Session, high_score: HighScoreCreate):
""" Create a new high score for a given minigame """
"""Create a new high score for a given minigame"""
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
if not owner:
raise HTTPException(status_code=400, detail="User does not exist")
@@ -98,12 +152,25 @@ def create_high_score(db: Session, high_score: HighScoreCreate):
db.refresh(db_high_score)
return db_high_score
def get_course_progress(db: Session, user: User, course: CourseEnum):
""" Get the progress a user has for a certain course """
def get_course_progress(db: Session, username: str, course: CourseEnum):
"""Get the progress a user has for a certain course"""
user = get_user_by_username(db, username)
if course != CourseEnum.All:
course_progress = db.query(CourseProgress).filter(CourseProgress.owner_id == user.user_id, CourseProgress.course == course).first()
course_progress = (
db.query(CourseProgress)
.filter(
CourseProgress.owner_id == user.user_id, CourseProgress.course == course
)
.first()
)
if course_progress:
return [CourseProgressBase(progress_value = course_progress.progress_value, course = course_progress.course)]
return [
CourseProgressBase(
progress_value=course_progress.progress_value,
course=course_progress.course,
)
]
else:
return [CourseProgressBase(progress_value = 0, course = course)]
return [CourseProgressBase(progress_value=0, course=course)]
return []

View File

@@ -9,22 +9,15 @@ from sqlalchemy.orm import Session
import crud
from database import SessionLocal, engine
from enums import MinigameEnum, CourseEnum
from enums import CourseEnum, MinigameEnum
from models import Base
from schemas import highscores, users, courseprogress
from schemas import courseprogress, highscores, users
app = FastAPI()
Base.metadata.create_all(bind=engine)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
# JWT authentication setup
jwt_secret = "secret_key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month
bearer_scheme = HTTPBearer()
@@ -35,12 +28,11 @@ def get_db():
finally:
db.close()
def get_current_user(
def get_current_user_name(
token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
):
try:
payload = jwt.decode(token.credentials, jwt_secret, algorithms=[ALGORITHM])
payload = jwt.decode(token.credentials, crud.jwt_secret, algorithms=[crud.ALGORITHM])
username = payload.get("sub")
if username is None:
raise HTTPException(status_code=401, detail="Invalid JWT token")
@@ -60,45 +52,33 @@ async def read_users(db: Session = Depends(get_db)):
@app.patch("/users")
async def patch_current_user(user: users.UserCreate, current_user = Depends(get_current_user), db: Session = Depends(get_db)):
db_user = crud.get_us
return users
async def patch_current_user(
user: users.UserCreate,
current_user_name = Depends(get_current_user_name),
db: Session = Depends(get_db),
):
crud.patch_user(db, current_user_name, user)
@app.post("/register", response_model=users.User)
async def register(user: users.UserCreate, db: Session = Depends(get_db)):
db_user = crud.get_user_by_username(db, username=user.username)
if db_user:
raise HTTPException(status_code=400, detail="Username already registered")
return crud.create_user(
db=db, username=user.username, hashed_password=pwd_context.hash(user.password)
return crud.register(
db, user.username, user.password
)
@app.post("/login")
async def login(user: users.UserCreate, db: Session = Depends(get_db)):
user = authenticate_user(user, db)
if not user:
raise HTTPException(status_code=401, detail="Invalid username or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token_payload = {
"sub": user.username,
"exp": datetime.utcnow() + access_token_expires,
}
access_token = jwt.encode(access_token_payload, jwt_secret, algorithm=ALGORITHM)
return {"access_token": access_token}
return crud.login(db, user)
@app.get("/highscores", response_model=List[users.UserHighScore])
async def read_high_scores(
db: Session = Depends(get_db),
minigame: Optional[MinigameEnum] = None,
n_highest: Optional[int] = None,
nr_highest: Optional[int] = None,
):
if n_highest < 1:
raise HTTPException(status_code=400, detail="Invalid number of high scores")
high_scores = crud.get_high_scores(db, minigame, n_highest)
return high_scores
return crud.get_high_scores(db, minigame, nr_highest)
@app.post("/highscores", response_model=highscores.HighScore)
@@ -111,22 +91,9 @@ async def create_high_score(
#### TESTING!! DELETE LATER
async def get_current_user(
token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
):
try:
payload = jwt.decode(token.credentials, jwt_secret, algorithms=[ALGORITHM])
username = payload.get("sub")
if username is None:
raise HTTPException(status_code=401, detail="Invalid JWT token")
return username
except jwt.exceptions.DecodeError:
raise HTTPException(status_code=401, detail="Invalid JWT token")
@app.get("/protected")
async def protected_route(current_user = Depends(get_current_user)):
return {"message": f"Hello, {current_user}!"}
async def protected_route(current_user_name=Depends(get_current_user_name)):
return {"message": f"Hello, {current_user_name}!"}
def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)):
@@ -139,10 +106,10 @@ def authenticate_user(user: users.UserCreate, db: Session = Depends(get_db)):
return db_user
@app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase])
async def get_course_progress(course: Optional[CourseEnum] = CourseEnum.All, current_user = Depends(get_current_user), db: Session = Depends(get_db)):
user = crud.get_user_by_username(db, current_user)
return crud.get_course_progress(db = db, user = user, course = course)
async def get_course_progress(
course: Optional[CourseEnum] = CourseEnum.All,
current_user_name=Depends(get_current_user_name),
db: Session = Depends(get_db),
):
return crud.get_course_progress(db, current_user_name, course)

View File

@@ -2,10 +2,12 @@ from pydantic import BaseModel
from enums import CourseEnum
class CourseProgressBase(BaseModel):
progress_value: float
course: CourseEnum
class CourseProgress(CourseProgressBase):
course_progress_id: int
owner_id: int