Minor stuff

This commit is contained in:
lvrossem
2023-04-01 10:03:18 -06:00
parent 65d1a2a6e4
commit d2933a95ba
9 changed files with 91 additions and 54 deletions

View File

@@ -1,7 +1,8 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import jwt import jwt
from fastapi import HTTPException from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from crud.users import get_user_by_username, pwd_context from crud.users import get_user_by_username, pwd_context
@@ -15,6 +16,25 @@ jwt_secret = "secret_key"
ALGORITHM = "HS256" ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month
bearer_scheme = HTTPBearer()
def get_current_user_name(
token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
):
try:
payload = jwt.decode(
token.credentials,
jwt_secret,
algorithms=[ALGORITHM],
)
username = payload.get("sub")
if username is None:
raise HTTPException(status_code=401, detail="Invalid JWT token")
return username
except jwt.exceptions.DecodeError:
raise HTTPException(status_code=401, detail="Invalid JWT token")
def authenticate_user(db: Session, username: str, password: str): def authenticate_user(db: Session, username: str, password: str):
"""Checks whether the provided credentials match with an existing User""" """Checks whether the provided credentials match with an existing User"""
@@ -29,10 +49,14 @@ def authenticate_user(db: Session, username: str, password: str):
def register(db: Session, username: str, password: str, avatar: str): def register(db: Session, username: str, password: str, avatar: str):
"""Register a new user""" """Register a new user"""
if avatar == "":
raise HTTPException(status_code=400, detail="No avatar was provided")
db_user = get_user_by_username(db, username) db_user = get_user_by_username(db, username)
if db_user: if db_user:
raise HTTPException(status_code=400, detail="Username already registered") raise HTTPException(status_code=400, detail="Username already registered")
db_user = User(username = username, hashed_password = pwd_context.hash(password), avatar = avatar) db_user = User(
username=username, hashed_password=pwd_context.hash(password), avatar=avatar
)
db.add(db_user) db.add(db_user)
db.commit() db.commit()
db.refresh(db_user) db.refresh(db_user)

View File

@@ -2,7 +2,8 @@ from fastapi import HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from enums import CourseEnum, course_enum_list from enums import CourseEnum, course_enum_list
from models import User, CourseProgress from models import CourseProgress, User
from schemas.courseprogress import CourseProgressBase
def get_course_progress(db: Session, user: User, course: CourseEnum): def get_course_progress(db: Session, user: User, course: CourseEnum):
@@ -15,21 +16,25 @@ def get_course_progress(db: Session, user: User, course: CourseEnum):
) )
.first() .first()
) )
if course_progress: if course_progress:
return [ return [
CourseProgressBase( CourseProgressBase(
progress_value = course_progress.progress_value, progress_value=course_progress.progress_value,
course = course_progress.course, course=course_progress.course,
) )
] ]
else: else:
return [CourseProgressBase(progress_value = 0, course=course)] db.add(
CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id)
)
db.commit()
return [CourseProgressBase(progress_value=0.0, course=course)]
return [] return []
def initialize_user(db: Session, user: User): def initialize_user(db: Session, user: User):
"""Create CourseProgress records with a value of 0 for a new user"""
for course in course_enum_list: for course in course_enum_list:
db.add(CourseProgress(progress_value = 0.0, course = course, owner_id = user.user_id)) db.add(CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id))
db.commit() db.commit()

View File

@@ -31,19 +31,24 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int):
for high_score in high_scores: for high_score in high_scores:
owner = db.query(User).filter(User.user_id == high_score.owner_id).first() owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
user_high_scores.append( user_high_scores.append(
UserHighScore(username = owner.username, score_value = high_score.score_value, avatar = owner.avatar) UserHighScore(
username=owner.username,
score_value=high_score.score_value,
avatar=owner.avatar,
)
) )
return user_high_scores return user_high_scores
def create_high_score(db: Session, user: User, high_score: HighScoreBase): def create_high_score(db: Session, user: User, high_score: HighScoreBase):
"""Create a new high score for a given minigame""" """Create a new high score for a given minigame"""
def add_to_db(): def add_to_db():
"""Helper function that adds new score to database; prevents code duplication""" """Helper function that adds new score to database; prevents code duplication"""
db_high_score = HighScore( db_high_score = HighScore(
score_value = high_score.score_value, score_value=high_score.score_value,
minigame = high_score.minigame, minigame=high_score.minigame,
owner_id = user.user_id, owner_id=user.user_id,
) )
db.add(db_high_score) db.add(db_high_score)
db.commit() db.commit()

View File

@@ -18,6 +18,7 @@ def patch_user(db: Session, username: str, user: UserCreate):
db_user.username = user.username db_user.username = user.username
db_user.hashed_password = pwd_context.hash(user.password) db_user.hashed_password = pwd_context.hash(user.password)
db_user.avatar = user.avatar
db.commit() db.commit()

View File

@@ -9,3 +9,11 @@ engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -28,20 +28,21 @@ class MinigameEnum(StrEnum):
class CourseEnum(StrEnum): class CourseEnum(StrEnum):
Fingerspelling = "Fingerspelling" Fingerspelling = "Fingerspelling"
#Basics = "Basics" Basics = "Basics"
Hobbies = "Hobbies" Hobbies = "Hobbies"
Animals = "Animals" Animals = "Animals"
Colors = "Colors" Colors = "Colors"
FruitsVegetables = "FruitsVegetables" FruitsVegetables = "FruitsVegetables"
All = "All" All = "All"
# This is needed because for some reason iterating over an enum doesn't work properly... # This is needed because for some reason iterating over an enum doesn't work properly...
course_enum_list = [ course_enum_list = [
CourseEnum.Fingerspelling, CourseEnum.Fingerspelling,
#CourseEnum.Basics, # CourseEnum.Basics,
CourseEnum.Hobbies, CourseEnum.Hobbies,
CourseEnum.Animals, CourseEnum.Animals,
CourseEnum.Colors, CourseEnum.Colors,
CourseEnum.FruitsVegetables, CourseEnum.FruitsVegetables,
CourseEnum.All CourseEnum.All,
] ]

View File

@@ -2,46 +2,19 @@ from typing import List, Optional
import jwt import jwt
from fastapi import Depends, FastAPI, HTTPException from fastapi import Depends, FastAPI, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from crud import authentication as crud_authentication from crud import authentication as crud_authentication
from crud import courseprogress as crud_courseprogress from crud import courseprogress as crud_courseprogress
from crud import highscores as crud_highscores from crud import highscores as crud_highscores
from crud import users as crud_users from crud import users as crud_users
from database import SessionLocal, engine from database import SessionLocal, engine, get_db
from enums import CourseEnum, MinigameEnum from enums import CourseEnum, MinigameEnum
from models import Base from models import Base
from schemas import courseprogress, highscores, users from schemas import courseprogress, highscores, users
app = FastAPI() app = FastAPI()
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
bearer_scheme = HTTPBearer()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
def get_current_user_name(
token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
):
try:
payload = jwt.decode(
token.credentials,
crud_authentication.jwt_secret,
algorithms=[crud_authentication.ALGORITHM],
)
username = payload.get("sub")
if username is None:
raise HTTPException(status_code=401, detail="Invalid JWT token")
return username
except jwt.exceptions.DecodeError:
raise HTTPException(status_code=401, detail="Invalid JWT token")
@app.get("/") @app.get("/")
@@ -57,15 +30,17 @@ async def read_users(db: Session = Depends(get_db)):
@app.patch("/users") @app.patch("/users")
async def patch_current_user( async def patch_current_user(
user: users.UserCreate, user: users.UserCreate,
current_user_name=Depends(get_current_user_name), current_user_name=Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
crud_users.patch_user(db, current_user_name, user) crud_users.patch_user(db, current_user_name, user)
@app.post("/register") @app.post("/register")
async def register(user: users.UserCreate, db: Session = Depends(get_db)): async def register(user: users.UserCreate, db: Session = Depends(get_db)):
access_token = crud_authentication.register(db, user.username, user.password, user.avatar) access_token = crud_authentication.register(
db, user.username, user.password, user.avatar
)
user = crud_users.get_user_by_username(db, user.username) user = crud_users.get_user_by_username(db, user.username)
crud_courseprogress.initialize_user(db, user) crud_courseprogress.initialize_user(db, user)
return access_token return access_token
@@ -80,7 +55,7 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)):
async def get_high_scores( async def get_high_scores(
minigame: Optional[MinigameEnum] = None, minigame: Optional[MinigameEnum] = None,
nr_highest: Optional[int] = None, nr_highest: Optional[int] = None,
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
return crud_highscores.get_high_scores(db, minigame, nr_highest) return crud_highscores.get_high_scores(db, minigame, nr_highest)
@@ -88,8 +63,8 @@ async def get_high_scores(
@app.post("/highscores", response_model=highscores.HighScore) @app.post("/highscores", response_model=highscores.HighScore)
async def create_high_score( async def create_high_score(
high_score: highscores.HighScoreBase, high_score: highscores.HighScoreBase,
current_user_name = Depends(get_current_user_name), current_user_name=Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db) db: Session = Depends(get_db),
): ):
current_user = crud_users.get_user_by_username(db, current_user_name) current_user = crud_users.get_user_by_username(db, current_user_name)
return crud_highscores.create_high_score(db, current_user, high_score) return crud_highscores.create_high_score(db, current_user, high_score)
@@ -99,15 +74,27 @@ async def create_high_score(
@app.get("/protected") @app.get("/protected")
async def protected_route(current_user_name=Depends(get_current_user_name)): async def protected_route(
current_user_name: str = Depends(crud_authentication.get_current_user_name),
):
return {"message": f"Hello, {current_user_name}!"} return {"message": f"Hello, {current_user_name}!"}
@app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase]) @app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase])
async def get_course_progress( async def get_course_progress(
course: Optional[CourseEnum] = CourseEnum.All, course: Optional[CourseEnum] = CourseEnum.All,
current_user_name=Depends(get_current_user_name), current_user_name: str = Depends(crud_authentication.get_current_user_name),
db: Session = Depends(get_db) db: Session = Depends(get_db),
):
current_user = crud_users.get_user_by_username(db, current_user_name)
return crud_courseprogress.get_course_progress(db, current_user, course)
@app.patch("/courseprogress")
async def get_course_progress(
current_user_name: str = Depends(crud_authentication.get_current_user_name),
course: Optional[CourseEnum] = CourseEnum.All,
db: Session = Depends(get_db),
): ):
current_user = crud_users.get_user_by_username(db, current_user_name) current_user = crud_users.get_user_by_username(db, current_user_name)
return crud_courseprogress.get_course_progress(db, current_user, course) return crud_courseprogress.get_course_progress(db, current_user, course)

View File

@@ -7,6 +7,8 @@ from enums import CourseEnum, MinigameEnum, StrEnumType
class User(Base): class User(Base):
"""The database model for users"""
__tablename__ = "users" __tablename__ = "users"
user_id = Column(Integer, primary_key=True, index=True) user_id = Column(Integer, primary_key=True, index=True)
@@ -23,6 +25,8 @@ class User(Base):
class HighScore(Base): class HighScore(Base):
"""The database model for high scores"""
__tablename__ = "high_scores" __tablename__ = "high_scores"
high_score_id = Column(Integer, primary_key=True, index=True) high_score_id = Column(Integer, primary_key=True, index=True)
@@ -33,6 +37,8 @@ class HighScore(Base):
class CourseProgress(Base): class CourseProgress(Base):
"""The database model for course progress"""
__tablename__ = "course_progress" __tablename__ = "course_progress"
course_progress_id = Column(Integer, primary_key=True, index=True) course_progress_id = Column(Integer, primary_key=True, index=True)

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
class UserBase(BaseModel): class UserBase(BaseModel):
username: str username: str
avatar: str avatar: str = ""
class User(UserBase): class User(UserBase):