Minor stuff
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from crud.users import get_user_by_username, pwd_context
|
from crud.users import get_user_by_username, pwd_context
|
||||||
@@ -15,6 +16,25 @@ jwt_secret = "secret_key"
|
|||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month
|
ACCESS_TOKEN_EXPIRE_MINUTES = 44640 # One month
|
||||||
|
|
||||||
|
bearer_scheme = HTTPBearer()
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_name(
|
||||||
|
token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token.credentials,
|
||||||
|
jwt_secret,
|
||||||
|
algorithms=[ALGORITHM],
|
||||||
|
)
|
||||||
|
username = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid JWT token")
|
||||||
|
return username
|
||||||
|
except jwt.exceptions.DecodeError:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid JWT token")
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(db: Session, username: str, password: str):
|
def authenticate_user(db: Session, username: str, password: str):
|
||||||
"""Checks whether the provided credentials match with an existing User"""
|
"""Checks whether the provided credentials match with an existing User"""
|
||||||
@@ -29,10 +49,14 @@ def authenticate_user(db: Session, username: str, password: str):
|
|||||||
|
|
||||||
def register(db: Session, username: str, password: str, avatar: str):
|
def register(db: Session, username: str, password: str, avatar: str):
|
||||||
"""Register a new user"""
|
"""Register a new user"""
|
||||||
|
if avatar == "":
|
||||||
|
raise HTTPException(status_code=400, detail="No avatar was provided")
|
||||||
db_user = get_user_by_username(db, username)
|
db_user = get_user_by_username(db, username)
|
||||||
if db_user:
|
if db_user:
|
||||||
raise HTTPException(status_code=400, detail="Username already registered")
|
raise HTTPException(status_code=400, detail="Username already registered")
|
||||||
db_user = User(username = username, hashed_password = pwd_context.hash(password), avatar = avatar)
|
db_user = User(
|
||||||
|
username=username, hashed_password=pwd_context.hash(password), avatar=avatar
|
||||||
|
)
|
||||||
db.add(db_user)
|
db.add(db_user)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_user)
|
db.refresh(db_user)
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ from fastapi import HTTPException
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from enums import CourseEnum, course_enum_list
|
from enums import CourseEnum, course_enum_list
|
||||||
from models import User, CourseProgress
|
from models import CourseProgress, User
|
||||||
|
from schemas.courseprogress import CourseProgressBase
|
||||||
|
|
||||||
|
|
||||||
def get_course_progress(db: Session, user: User, course: CourseEnum):
|
def get_course_progress(db: Session, user: User, course: CourseEnum):
|
||||||
@@ -15,21 +16,25 @@ def get_course_progress(db: Session, user: User, course: CourseEnum):
|
|||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
if course_progress:
|
if course_progress:
|
||||||
return [
|
return [
|
||||||
CourseProgressBase(
|
CourseProgressBase(
|
||||||
progress_value = course_progress.progress_value,
|
progress_value=course_progress.progress_value,
|
||||||
course = course_progress.course,
|
course=course_progress.course,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
return [CourseProgressBase(progress_value = 0, course=course)]
|
db.add(
|
||||||
|
CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id)
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
return [CourseProgressBase(progress_value=0.0, course=course)]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def initialize_user(db: Session, user: User):
|
def initialize_user(db: Session, user: User):
|
||||||
|
"""Create CourseProgress records with a value of 0 for a new user"""
|
||||||
for course in course_enum_list:
|
for course in course_enum_list:
|
||||||
db.add(CourseProgress(progress_value = 0.0, course = course, owner_id = user.user_id))
|
db.add(CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id))
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,19 +31,24 @@ def get_high_scores(db: Session, minigame: MinigameEnum, nr_highest: int):
|
|||||||
for high_score in high_scores:
|
for high_score in high_scores:
|
||||||
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
|
owner = db.query(User).filter(User.user_id == high_score.owner_id).first()
|
||||||
user_high_scores.append(
|
user_high_scores.append(
|
||||||
UserHighScore(username = owner.username, score_value = high_score.score_value, avatar = owner.avatar)
|
UserHighScore(
|
||||||
|
username=owner.username,
|
||||||
|
score_value=high_score.score_value,
|
||||||
|
avatar=owner.avatar,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return user_high_scores
|
return user_high_scores
|
||||||
|
|
||||||
|
|
||||||
def create_high_score(db: Session, user: User, high_score: HighScoreBase):
|
def create_high_score(db: Session, user: User, high_score: HighScoreBase):
|
||||||
"""Create a new high score for a given minigame"""
|
"""Create a new high score for a given minigame"""
|
||||||
|
|
||||||
def add_to_db():
|
def add_to_db():
|
||||||
"""Helper function that adds new score to database; prevents code duplication"""
|
"""Helper function that adds new score to database; prevents code duplication"""
|
||||||
db_high_score = HighScore(
|
db_high_score = HighScore(
|
||||||
score_value = high_score.score_value,
|
score_value=high_score.score_value,
|
||||||
minigame = high_score.minigame,
|
minigame=high_score.minigame,
|
||||||
owner_id = user.user_id,
|
owner_id=user.user_id,
|
||||||
)
|
)
|
||||||
db.add(db_high_score)
|
db.add(db_high_score)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ def patch_user(db: Session, username: str, user: UserCreate):
|
|||||||
|
|
||||||
db_user.username = user.username
|
db_user.username = user.username
|
||||||
db_user.hashed_password = pwd_context.hash(user.password)
|
db_user.hashed_password = pwd_context.hash(user.password)
|
||||||
|
db_user.avatar = user.avatar
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,3 +9,11 @@ engine = create_engine(SQLALCHEMY_DATABASE_URL)
|
|||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|||||||
@@ -28,20 +28,21 @@ class MinigameEnum(StrEnum):
|
|||||||
|
|
||||||
class CourseEnum(StrEnum):
|
class CourseEnum(StrEnum):
|
||||||
Fingerspelling = "Fingerspelling"
|
Fingerspelling = "Fingerspelling"
|
||||||
#Basics = "Basics"
|
Basics = "Basics"
|
||||||
Hobbies = "Hobbies"
|
Hobbies = "Hobbies"
|
||||||
Animals = "Animals"
|
Animals = "Animals"
|
||||||
Colors = "Colors"
|
Colors = "Colors"
|
||||||
FruitsVegetables = "FruitsVegetables"
|
FruitsVegetables = "FruitsVegetables"
|
||||||
All = "All"
|
All = "All"
|
||||||
|
|
||||||
|
|
||||||
# This is needed because for some reason iterating over an enum doesn't work properly...
|
# This is needed because for some reason iterating over an enum doesn't work properly...
|
||||||
course_enum_list = [
|
course_enum_list = [
|
||||||
CourseEnum.Fingerspelling,
|
CourseEnum.Fingerspelling,
|
||||||
#CourseEnum.Basics,
|
# CourseEnum.Basics,
|
||||||
CourseEnum.Hobbies,
|
CourseEnum.Hobbies,
|
||||||
CourseEnum.Animals,
|
CourseEnum.Animals,
|
||||||
CourseEnum.Colors,
|
CourseEnum.Colors,
|
||||||
CourseEnum.FruitsVegetables,
|
CourseEnum.FruitsVegetables,
|
||||||
CourseEnum.All
|
CourseEnum.All,
|
||||||
]
|
]
|
||||||
|
|||||||
61
src/main.py
61
src/main.py
@@ -2,46 +2,19 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from fastapi import Depends, FastAPI, HTTPException
|
from fastapi import Depends, FastAPI, HTTPException
|
||||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from crud import authentication as crud_authentication
|
from crud import authentication as crud_authentication
|
||||||
from crud import courseprogress as crud_courseprogress
|
from crud import courseprogress as crud_courseprogress
|
||||||
from crud import highscores as crud_highscores
|
from crud import highscores as crud_highscores
|
||||||
from crud import users as crud_users
|
from crud import users as crud_users
|
||||||
from database import SessionLocal, engine
|
from database import SessionLocal, engine, get_db
|
||||||
from enums import CourseEnum, MinigameEnum
|
from enums import CourseEnum, MinigameEnum
|
||||||
from models import Base
|
from models import Base
|
||||||
from schemas import courseprogress, highscores, users
|
from schemas import courseprogress, highscores, users
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
bearer_scheme = HTTPBearer()
|
|
||||||
|
|
||||||
|
|
||||||
def get_db():
|
|
||||||
db = SessionLocal()
|
|
||||||
try:
|
|
||||||
yield db
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_user_name(
|
|
||||||
token: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token.credentials,
|
|
||||||
crud_authentication.jwt_secret,
|
|
||||||
algorithms=[crud_authentication.ALGORITHM],
|
|
||||||
)
|
|
||||||
username = payload.get("sub")
|
|
||||||
if username is None:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid JWT token")
|
|
||||||
return username
|
|
||||||
except jwt.exceptions.DecodeError:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid JWT token")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
@@ -57,15 +30,17 @@ async def read_users(db: Session = Depends(get_db)):
|
|||||||
@app.patch("/users")
|
@app.patch("/users")
|
||||||
async def patch_current_user(
|
async def patch_current_user(
|
||||||
user: users.UserCreate,
|
user: users.UserCreate,
|
||||||
current_user_name=Depends(get_current_user_name),
|
current_user_name=Depends(crud_authentication.get_current_user_name),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
crud_users.patch_user(db, current_user_name, user)
|
crud_users.patch_user(db, current_user_name, user)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/register")
|
@app.post("/register")
|
||||||
async def register(user: users.UserCreate, db: Session = Depends(get_db)):
|
async def register(user: users.UserCreate, db: Session = Depends(get_db)):
|
||||||
access_token = crud_authentication.register(db, user.username, user.password, user.avatar)
|
access_token = crud_authentication.register(
|
||||||
|
db, user.username, user.password, user.avatar
|
||||||
|
)
|
||||||
user = crud_users.get_user_by_username(db, user.username)
|
user = crud_users.get_user_by_username(db, user.username)
|
||||||
crud_courseprogress.initialize_user(db, user)
|
crud_courseprogress.initialize_user(db, user)
|
||||||
return access_token
|
return access_token
|
||||||
@@ -80,7 +55,7 @@ async def login(user: users.UserCreate, db: Session = Depends(get_db)):
|
|||||||
async def get_high_scores(
|
async def get_high_scores(
|
||||||
minigame: Optional[MinigameEnum] = None,
|
minigame: Optional[MinigameEnum] = None,
|
||||||
nr_highest: Optional[int] = None,
|
nr_highest: Optional[int] = None,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
return crud_highscores.get_high_scores(db, minigame, nr_highest)
|
return crud_highscores.get_high_scores(db, minigame, nr_highest)
|
||||||
|
|
||||||
@@ -88,8 +63,8 @@ async def get_high_scores(
|
|||||||
@app.post("/highscores", response_model=highscores.HighScore)
|
@app.post("/highscores", response_model=highscores.HighScore)
|
||||||
async def create_high_score(
|
async def create_high_score(
|
||||||
high_score: highscores.HighScoreBase,
|
high_score: highscores.HighScoreBase,
|
||||||
current_user_name = Depends(get_current_user_name),
|
current_user_name=Depends(crud_authentication.get_current_user_name),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
current_user = crud_users.get_user_by_username(db, current_user_name)
|
current_user = crud_users.get_user_by_username(db, current_user_name)
|
||||||
return crud_highscores.create_high_score(db, current_user, high_score)
|
return crud_highscores.create_high_score(db, current_user, high_score)
|
||||||
@@ -99,15 +74,27 @@ async def create_high_score(
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/protected")
|
@app.get("/protected")
|
||||||
async def protected_route(current_user_name=Depends(get_current_user_name)):
|
async def protected_route(
|
||||||
|
current_user_name: str = Depends(crud_authentication.get_current_user_name),
|
||||||
|
):
|
||||||
return {"message": f"Hello, {current_user_name}!"}
|
return {"message": f"Hello, {current_user_name}!"}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase])
|
@app.get("/courseprogress", response_model=List[courseprogress.CourseProgressBase])
|
||||||
async def get_course_progress(
|
async def get_course_progress(
|
||||||
course: Optional[CourseEnum] = CourseEnum.All,
|
course: Optional[CourseEnum] = CourseEnum.All,
|
||||||
current_user_name=Depends(get_current_user_name),
|
current_user_name: str = Depends(crud_authentication.get_current_user_name),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
current_user = crud_users.get_user_by_username(db, current_user_name)
|
||||||
|
return crud_courseprogress.get_course_progress(db, current_user, course)
|
||||||
|
|
||||||
|
|
||||||
|
@app.patch("/courseprogress")
|
||||||
|
async def get_course_progress(
|
||||||
|
current_user_name: str = Depends(crud_authentication.get_current_user_name),
|
||||||
|
course: Optional[CourseEnum] = CourseEnum.All,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
current_user = crud_users.get_user_by_username(db, current_user_name)
|
current_user = crud_users.get_user_by_username(db, current_user_name)
|
||||||
return crud_courseprogress.get_course_progress(db, current_user, course)
|
return crud_courseprogress.get_course_progress(db, current_user, course)
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ from enums import CourseEnum, MinigameEnum, StrEnumType
|
|||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
|
"""The database model for users"""
|
||||||
|
|
||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
|
|
||||||
user_id = Column(Integer, primary_key=True, index=True)
|
user_id = Column(Integer, primary_key=True, index=True)
|
||||||
@@ -23,6 +25,8 @@ class User(Base):
|
|||||||
|
|
||||||
|
|
||||||
class HighScore(Base):
|
class HighScore(Base):
|
||||||
|
"""The database model for high scores"""
|
||||||
|
|
||||||
__tablename__ = "high_scores"
|
__tablename__ = "high_scores"
|
||||||
|
|
||||||
high_score_id = Column(Integer, primary_key=True, index=True)
|
high_score_id = Column(Integer, primary_key=True, index=True)
|
||||||
@@ -33,6 +37,8 @@ class HighScore(Base):
|
|||||||
|
|
||||||
|
|
||||||
class CourseProgress(Base):
|
class CourseProgress(Base):
|
||||||
|
"""The database model for course progress"""
|
||||||
|
|
||||||
__tablename__ = "course_progress"
|
__tablename__ = "course_progress"
|
||||||
|
|
||||||
course_progress_id = Column(Integer, primary_key=True, index=True)
|
course_progress_id = Column(Integer, primary_key=True, index=True)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
class UserBase(BaseModel):
|
class UserBase(BaseModel):
|
||||||
username: str
|
username: str
|
||||||
avatar: str
|
avatar: str = ""
|
||||||
|
|
||||||
|
|
||||||
class User(UserBase):
|
class User(UserBase):
|
||||||
|
|||||||
Reference in New Issue
Block a user