Refactoring: auth tests pass

This commit is contained in:
lvrossem
2023-04-16 07:15:03 -06:00
parent d38d4d5c16
commit 0bf764a0f4
14 changed files with 116 additions and 93 deletions

View File

@@ -47,15 +47,15 @@ def authenticate_user(db: Session, username: str, password: str):
return db_user return db_user
def register(db: Session, username: str, password: str, avatar: str): def register(db: Session, username: str, password: str, avatar_index: int):
"""Register a new user""" """Register a new user"""
check_empty_fields(username, password, avatar) check_empty_fields(username, password, avatar_index)
db_user = get_user_by_username(db, username) db_user = get_user_by_username(db, username)
if db_user: if db_user:
raise HTTPException(status_code=400, detail="Username already registered") raise HTTPException(status_code=400, detail="Username already registered")
db_user = User( db_user = User(
username=username, hashed_password=pwd_context.hash(password), avatar=avatar username=username, hashed_password=pwd_context.hash(password), avatar_index=avatar_index, playtime=0.0
) )
db.add(db_user) db.add(db_user)
db.commit() db.commit()

View File

@@ -27,15 +27,15 @@ def get_course_progress(db: Session, user: User, course: CourseEnum):
if course_progress: if course_progress:
result.append( result.append(
CourseProgressParent( CourseProgressParent(
progress_value=course_progress.progress_value, course=course progress=course_progress.progress, course=course
) )
) )
else: else:
db.add( db.add(
CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id) CourseProgress(progress=0.0, course=course, owner_id=user.user_id)
) )
db.commit() db.commit()
result.append(CourseProgressParent(progress_value=0.0, course=course)) result.append(CourseProgressParent(progress=0.0, course=course))
return result return result
@@ -45,7 +45,7 @@ def initialize_user(db: Session, user: User):
for course in CourseEnum: for course in CourseEnum:
if course != CourseEnum.All: if course != CourseEnum.All:
db.add( db.add(
CourseProgress(progress_value=0.0, course=course, owner_id=user.user_id) CourseProgress(progress=0.0, course=course, owner_id=user.user_id)
) )
db.commit() db.commit()
@@ -54,7 +54,7 @@ def patch_course_progress(
db: Session, user: User, course: CourseEnum, course_progress: CourseProgressBase db: Session, user: User, course: CourseEnum, course_progress: CourseProgressBase
): ):
"""Change the progress value for a given course""" """Change the progress value for a given course"""
if course_progress.progress_value > 1 or course_progress.progress_value < 0: if course_progress.progress > 1 or course_progress.progress < 0:
raise HTTPException(status_code=400, detail="Invalid progress value") raise HTTPException(status_code=400, detail="Invalid progress value")
db_course_progress_list = [] db_course_progress_list = []
@@ -75,10 +75,10 @@ def patch_course_progress(
) )
for db_course_progress in db_course_progress_list: for db_course_progress in db_course_progress_list:
db_course_progress.progress_value = course_progress.progress_value db_course_progress.progress = course_progress.progress
db.commit() db.commit()
return [ return [
CourseProgressParent(course=db_cp.course, progress_value=db_cp.progress_value) CourseProgressParent(course=db_cp.course, progress=db_cp.progress)
for db_cp in db_course_progress_list for db_cp in db_course_progress_list
] ]

View File

@@ -34,7 +34,7 @@ def get_high_scores(
UserHighScore( UserHighScore(
username=user.username, username=user.username,
score_value=high_score.score_value, score_value=high_score.score_value,
avatar=user.avatar, avatar_index=user.avatar_index,
) )
] ]
else: else:
@@ -58,7 +58,7 @@ def get_high_scores(
UserHighScore( UserHighScore(
username=owner.username, username=owner.username,
score_value=high_score.score_value, score_value=high_score.score_value,
avatar=owner.avatar, avatar_index=owner.avatar_index,
) )
) )
return user_high_scores return user_high_scores

View File

@@ -8,9 +8,9 @@ from src.schemas.users import UserCreate
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def check_empty_fields(username: str, password: str, avatar: str): def check_empty_fields(username: str, password: str, avatar_index: int):
"Checks if any user fields are empty" "Checks if any user fields are empty"
if len(avatar) == 0: if avatar_index < 0:
raise HTTPException(status_code=400, detail="No avatar was provided") raise HTTPException(status_code=400, detail="No avatar was provided")
if len(username) == 0: if len(username) == 0:
raise HTTPException(status_code=400, detail="No username was provided") raise HTTPException(status_code=400, detail="No username was provided")
@@ -20,7 +20,7 @@ def check_empty_fields(username: str, password: str, avatar: str):
def patch_user(db: Session, username: str, user: UserCreate): def patch_user(db: Session, username: str, user: UserCreate):
"""Changes the username and/or the password of a User""" """Changes the username and/or the password of a User"""
check_empty_fields(user.username, user.password, user.avatar) check_empty_fields(user.username, user.password, user.avatar_index)
db_user = get_user_by_username(db, username) db_user = get_user_by_username(db, username)
potential_duplicate = get_user_by_username(db, user.username) potential_duplicate = get_user_by_username(db, user.username)
if potential_duplicate: if potential_duplicate:
@@ -28,7 +28,7 @@ def patch_user(db: Session, username: str, user: UserCreate):
raise HTTPException(status_code=400, detail="Username already registered") raise HTTPException(status_code=400, detail="Username already registered")
db_user.username = user.username db_user.username = user.username
db_user.hashed_password = pwd_context.hash(user.password) db_user.hashed_password = pwd_context.hash(user.password)
db_user.avatar = user.avatar db_user.avatar_index = user.avatar_index
db.commit() db.commit()

View File

@@ -49,7 +49,7 @@ async def patch_current_user(
@app.post("/register") @app.post("/register")
async def register(user: users.UserCreate, db: Session = Depends(get_db)): async def register(user: users.UserCreate, db: Session = Depends(get_db)):
access_token = crud_authentication.register( access_token = crud_authentication.register(
db, user.username, user.password, user.avatar db, user.username, user.password, user.avatar_index
) )
user = crud_users.get_user_by_username(db, user.username) user = crud_users.get_user_by_username(db, user.username)
crud_courseprogress.initialize_user(db, user) crud_courseprogress.initialize_user(db, user)

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, Float, ForeignKey, Integer, String from sqlalchemy import Column, Float, ForeignKey, Integer, String, Boolean
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from src.database import Base from src.database import Base
@@ -12,7 +12,8 @@ class User(Base):
user_id = Column(Integer, primary_key=True, index=True) user_id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True, nullable=False) username = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False) hashed_password = Column(String, nullable=False)
avatar = Column(String, nullable=False) avatar_index = Column(Integer, nullable=False)
playtime = Column(Float, nullable=False)
high_scores = relationship( high_scores = relationship(
"HighScore", back_populates="owner", cascade="all, delete", lazy="dynamic" "HighScore", back_populates="owner", cascade="all, delete", lazy="dynamic"
@@ -29,6 +30,7 @@ class HighScore(Base):
high_score_id = Column(Integer, primary_key=True, index=True) high_score_id = Column(Integer, primary_key=True, index=True)
score_value = Column(Float, nullable=False) score_value = Column(Float, nullable=False)
time = Column(String, nullable=False)
minigame = Column(String, nullable=False) minigame = Column(String, nullable=False)
owner_id = Column(Integer, ForeignKey("users.user_id")) owner_id = Column(Integer, ForeignKey("users.user_id"))
owner = relationship("User", back_populates="high_scores") owner = relationship("User", back_populates="high_scores")
@@ -40,7 +42,22 @@ class CourseProgress(Base):
__tablename__ = "course_progress" __tablename__ = "course_progress"
course_progress_id = Column(Integer, primary_key=True, index=True) course_progress_id = Column(Integer, primary_key=True, index=True)
progress_value = Column(Float, nullable=False) progress = Column(Float, nullable=False)
course = Column(String, nullable=False) course = Column(String, nullable=False)
owner_id = Column(Integer, ForeignKey("users.user_id")) owner_id = Column(Integer, ForeignKey("users.user_id"))
owner = relationship("User", back_populates="course_progress") owner = relationship("User", back_populates="course_progress")
learnables = relationship("LearnableProgress", back_populates="course")
class LearnableProgress(Base):
"""The database model for learnable progress"""
__tablename__ = "learnable_progress"
learnable_progress_id = Column(Integer, primary_key=True, index=True)
index = Column(Integer, nullable=False)
in_use = Column(Boolean, nullable=False)
name = Column(String, nullable=False)
progress = Column(Float, nullable=False)
course_id = Column(Integer, ForeignKey("course_progress.course_progress_id"))
course = relationship("CourseProgress", back_populates="learnables")

View File

@@ -4,7 +4,7 @@ from src.enums import CourseEnum
class CourseProgressBase(BaseModel): class CourseProgressBase(BaseModel):
progress_value: float progress: float
class CourseProgressParent(CourseProgressBase): class CourseProgressParent(CourseProgressBase):

39
src/schemas/saved_data.py Normal file
View File

@@ -0,0 +1,39 @@
from pydantic import BaseModel
from src.enums import CourseEnum, MinigameEnum
from typing import List
class SavedUser(BaseModel):
username: str
avatar_index: int = -1
playtime: float
minigames: List[SavedMinigameProgress]
courses: List[SavedCourseProgress]
class Score(BaseModel):
score_value: int
time: str
class SavedLearnableProgress(BaseModel):
index: int
in_use: bool
name: str
progress: float
class SavedCourseProgress(BaseModel):
course_index: CourseEnum
progress: float
completed_learnables: int
in_use_learnables: int
total_learnables: int
learnables: List[SavedLearnableProgress]
class SavedMinigameProgress(BaseModel):
minigame_index: MinigameEnum
lastest_scores: List[Score]
highest_scores: List[Score]

View File

@@ -3,7 +3,7 @@ from pydantic import BaseModel
class UserBase(BaseModel): class UserBase(BaseModel):
username: str username: str
avatar: str = "" avatar_index: int = -1
class User(UserBase): class User(UserBase):

View File

@@ -13,4 +13,15 @@ client = TestClient(app)
username = "user1" username = "user1"
password = "password" password = "password"
avatar = "lion" avatar_index = 1
async def register_user():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar_index": avatar_index},
)
assert response.status_code == 200
return response.json()["access_token"]

View File

@@ -2,22 +2,10 @@ import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from src.main import app, get_db from src.main import app, get_db
from tests.base import avatar, client, password, username from tests.base import avatar_index, client, password, username, register_user
from tests.config.database import clear_db, override_get_db from tests.config.database import clear_db, override_get_db
async def register_user():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
return response.json()["access_token"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register(): async def test_register():
"""Test the register endpoint""" """Test the register endpoint"""
@@ -26,7 +14,7 @@ async def test_register():
response = client.post( response = client.post(
"/register", "/register",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar}, json={"username": username, "password": password, "avatar_index": avatar_index},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -42,7 +30,7 @@ async def test_register_duplicate_name_should_fail():
response = client.post( response = client.post(
"/register", "/register",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar}, json={"username": username, "password": password, "avatar_index": avatar_index},
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -57,7 +45,7 @@ async def test_register_without_username_should_fail():
response = client.post( response = client.post(
"/register", "/register",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
json={"password": password, "avatar": avatar}, json={"password": password, "avatar_index": avatar_index},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -72,7 +60,7 @@ async def test_register_without_password_should_fail():
response = client.post( response = client.post(
"/register", "/register",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
json={"username": username, "avatar": avatar}, json={"username": username, "avatar_index": avatar_index},
) )
assert response.status_code == 422 assert response.status_code == 422

View File

@@ -5,22 +5,10 @@ from fastapi.testclient import TestClient
from src.enums import CourseEnum from src.enums import CourseEnum
from src.main import app, get_db from src.main import app, get_db
from tests.base import avatar, client, password, username from tests.base import client, register_user
from tests.config.database import clear_db, override_get_db from tests.config.database import clear_db, override_get_db
async def register_user():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
return response.json()["access_token"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_creates_progress_of_zero(): async def test_register_creates_progress_of_zero():
"""Test whether registering a new user initializes all progress values to 0.0""" """Test whether registering a new user initializes all progress values to 0.0"""
@@ -36,7 +24,7 @@ async def test_register_creates_progress_of_zero():
response = response.json()[0] response = response.json()[0]
assert response["progress_value"] == 0.0 assert response["progress"] == 0.0
assert response["course"] == course assert response["course"] == course
@@ -54,11 +42,11 @@ async def test_get_all_returns_all():
for course in CourseEnum: for course in CourseEnum:
if course != CourseEnum.All: if course != CourseEnum.All:
assert {"progress_value": 0.0, "course": course} in response assert {"progress": 0.0, "course": course} in response
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_course_progress_value_without_auth_should_fail(): async def test_get_course_progress_without_auth_should_fail():
"""Test whether fetching a course progress value without authentication fails""" """Test whether fetching a course progress value without authentication fails"""
clear_db() clear_db()
@@ -94,16 +82,16 @@ async def test_patch_course_progress():
for course in CourseEnum: for course in CourseEnum:
if course != CourseEnum.All: if course != CourseEnum.All:
progress_value = random.uniform(0, 1) progress = random.uniform(0, 1)
response = client.patch( response = client.patch(
f"/courseprogress/{course}", f"/courseprogress/{course}",
headers=headers, headers=headers,
json={"progress_value": progress_value}, json={"progress": progress},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()[0]["progress_value"] == progress_value assert response.json()[0]["progress"] == progress
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -114,12 +102,12 @@ async def test_patch_all_should_patch_all_courses():
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
progress_value = random.uniform(0, 1) progress = random.uniform(0, 1)
response = client.patch( response = client.patch(
"/courseprogress/All", "/courseprogress/All",
headers=headers, headers=headers,
json={"progress_value": progress_value}, json={"progress": progress},
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -131,7 +119,7 @@ async def test_patch_all_should_patch_all_courses():
for course in CourseEnum: for course in CourseEnum:
if course != CourseEnum.All: if course != CourseEnum.All:
assert {"progress_value": progress_value, "course": course} in response assert {"progress": progress, "course": course} in response
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -144,12 +132,12 @@ async def test_patch_nonexisting_course_should_fail():
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
progress_value = random.uniform(0, 1) progress = random.uniform(0, 1)
response = client.patch( response = client.patch(
f"/courseprogress/{fake_course}", f"/courseprogress/{fake_course}",
headers=headers, headers=headers,
json={"progress_value": progress_value}, json={"progress": progress},
) )
assert response.status_code == 422 assert response.status_code == 422
@@ -163,13 +151,13 @@ async def test_patch_course_with_invalid_value_should_fail():
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
too_high_progress_value = random.uniform(0, 1) + 2 too_high_progress = random.uniform(0, 1) + 2
too_low_progress_value = random.uniform(0, 1) - 2 too_low_progress = random.uniform(0, 1) - 2
response = client.patch( response = client.patch(
"/courseprogress/All", "/courseprogress/All",
headers=headers, headers=headers,
json={"progress_value": too_high_progress_value}, json={"progress": too_high_progress},
) )
assert response.status_code == 400 assert response.status_code == 400
@@ -177,14 +165,14 @@ async def test_patch_course_with_invalid_value_should_fail():
response = client.patch( response = client.patch(
"/courseprogress/All", "/courseprogress/All",
headers=headers, headers=headers,
json={"progress_value": too_low_progress_value}, json={"progress": too_low_progress},
) )
assert response.status_code == 400 assert response.status_code == 400
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_patch_course_progress_value_without_auth_should_fail(): async def test_patch_course_progress_without_auth_should_fail():
"""Test whether updating a course progress value without authentication fails""" """Test whether updating a course progress value without authentication fails"""
clear_db() clear_db()
@@ -194,7 +182,7 @@ async def test_patch_course_progress_value_without_auth_should_fail():
response = client.patch( response = client.patch(
f"/courseprogress/{course}", f"/courseprogress/{course}",
headers=headers, headers=headers,
json={"progress_value": random.uniform(0, 1)}, json={"progress": random.uniform(0, 1)},
) )
assert response.status_code == 403 assert response.status_code == 403

View File

@@ -5,22 +5,10 @@ from fastapi.testclient import TestClient
from src.enums import MinigameEnum from src.enums import MinigameEnum
from src.main import app, get_db from src.main import app, get_db
from tests.base import avatar, client, password, username from tests.base import avatar_index, client, password, username, register_user
from tests.config.database import clear_db, override_get_db from tests.config.database import clear_db, override_get_db
async def register_user():
response = client.post(
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
return response.json()["access_token"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_put_highscore(): async def test_put_highscore():
"""Test whether putting a new high score succeeds""" """Test whether putting a new high score succeeds"""

View File

@@ -2,7 +2,7 @@ import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from src.main import app, get_db from src.main import app, get_db
from tests.base import avatar, client, password, username from tests.base import avatar_index, client, password, username, register_user
from tests.config.database import clear_db, override_get_db from tests.config.database import clear_db, override_get_db
patched_username = "New name" patched_username = "New name"
@@ -15,15 +15,7 @@ async def test_get_current_user():
"""Test the GET /users endpoint to get info about the current user""" """Test the GET /users endpoint to get info about the current user"""
clear_db() clear_db()
response = client.post( token = await register_user()
"/register",
headers={"Content-Type": "application/json"},
json={"username": username, "password": password, "avatar": avatar},
)
assert response.status_code == 200
token = response.json()["access_token"]
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
response = client.get("/users", headers=headers) response = client.get("/users", headers=headers)