Updating to support categories

This commit is contained in:
2023-03-11 12:32:41 +00:00
parent bdab151dae
commit ec7bd6dde5
27 changed files with 807 additions and 42 deletions

View File

@@ -51,11 +51,13 @@ app.add_middleware(
)
# include the routers
from .routers import auth_router, signs_router, signvideo_router
from .routers import (auth_router, category_router, signs_router,
signvideo_router)
app.include_router(auth_router)
app.include_router(signs_router)
app.include_router(signvideo_router)
app.include_router(category_router)
# Add the exception handlers
app.add_exception_handler(BaseException, base_exception_handler)

View File

@@ -3,6 +3,7 @@ from typing import List, Optional
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from src.database.crud import delete, read_all_where, read_where, update
@@ -34,8 +35,8 @@ class SQLModelExtended(SQLModel):
return res
@classmethod
async def get_all_where(self, *args, session: AsyncSession):
res = await read_all_where(self, *args, session=session)
async def get_all_where(self, *args, select_in_load: List = [], session: AsyncSession):
res = await read_all_where(self, *args, select_in_load=select_in_load, session=session)
return res
async def delete(self, session: AsyncSession) -> None:

5
backend/src/models/__init__.py Executable file
View File

@@ -0,0 +1,5 @@
import glob
from os.path import basename, dirname, isfile, join
modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]

54
backend/src/models/category.py Executable file
View File

@@ -0,0 +1,54 @@
from typing import List
import numpy as np
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import Field, Relationship
from src.models.sign import Sign, SignOut
from src.models.SQLModelExtended import SQLModelExtended
class Category(SQLModelExtended, table=True):
id: int = Field(primary_key=True)
name: str = Field(unique=True)
enabled: bool = Field(default=True)
# list of signs that belong to this category
signs: List[Sign] = Relationship(
back_populates="category",
sa_relationship_kwargs={"lazy": "selectin"},
)
@classmethod
async def get_random_sign(self, session: AsyncSession):
# get all categories
categories = await self.get_all_where(Category.enabled==True, select_in_load=[Category.signs],session=session)
# get all signs in one list with list comprehension
signs = [s for c in categories for s in c.signs]
sign_videos = [len(s.sign_videos) for s in signs]
random_prob = [1 / (x + 1) for x in sign_videos]
random_prob = random_prob / np.sum(random_prob)
# get random sign
sign = np.random.choice(signs, p=random_prob)
return sign
class CategoryPost(BaseModel):
name: str
class CategoryPut(BaseModel):
id: int
name: str
enabled: bool = True
class CategoryOut(BaseModel):
id: int
name: str
enabled: bool

View File

@@ -3,7 +3,6 @@ from typing import List
import numpy as np
import requests
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import Field, Relationship, SQLModel
from src.exceptions.base_exception import BaseException
@@ -25,8 +24,12 @@ class Sign(SQLModelExtended, table=True):
sa_relationship_kwargs={"lazy": "selectin"},
)
def __init__(self, url):
category_id: int = Field(foreign_key="category.id")
category: "Category" = Relationship(back_populates="signs")
def __init__(self, url, category_id):
self.url = url
self.category_id = category_id
# get name and sign id from url
try:
@@ -53,23 +56,6 @@ class Sign(SQLModelExtended, table=True):
if self.video_url is None:
raise BaseException(404, "Video url not found")
@classmethod
async def get_random(self, session: AsyncSession):
signs = await self.get_all(select_in_load=[Sign.sign_videos],session=session)
sign_videos = [len(s.sign_videos) for s in signs]
# get probability based on number of videos, lower must be more likely
# the sum must be 1
random_prob = [1 / (x + 1) for x in sign_videos]
random_prob = random_prob / np.sum(random_prob)
# get random sign
sign = np.random.choice(signs, p=random_prob)
return sign
@@ -79,6 +65,7 @@ class SignOut(BaseModel):
name: str
sign_id: str
video_url: str
category_id: int
sign_videos: List[SignVideo] = []
@@ -87,4 +74,5 @@ class SignOutSimple(BaseModel):
url: str
name: str
sign_id: str
video_url: str
video_url: str
category_id: int

View File

@@ -1,3 +1,4 @@
from .auth import router as auth_router
from .category import router as category_router
from .signs import router as signs_router
from .signvideo import router as signvideo_router

106
backend/src/routers/category.py Executable file
View File

@@ -0,0 +1,106 @@
from typing import List
from fastapi import APIRouter, BackgroundTasks, Depends, status
from fastapi_jwt_auth import AuthJWT
from sqlalchemy.ext.asyncio import AsyncSession
from src.database.database import get_session
from src.exceptions.base_exception import BaseException
from src.exceptions.login_exception import LoginException
from src.models.auth import User
from src.models.category import (Category, CategoryOut, CategoryPost,
CategoryPut)
from src.models.sign import Sign, SignOut
router = APIRouter(prefix="/categories")
@router.get("/", response_model=List[Category])
async def get_categories(Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
Authorize.jwt_required()
user = Authorize.get_jwt_subject()
user = await User.get_by_id(id=user, session=session)
if not user:
raise LoginException("User not found")
categories = await Category.get_all(session=session, select_in_load=[Category.signs])
return categories
@router.get("/{category_id}/signs", response_model=List[SignOut])
async def get_signs_by_category(category_id: int, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
Authorize.jwt_required()
user = Authorize.get_jwt_subject()
user = await User.get_by_id(id=user, session=session)
if not user:
raise LoginException("User not found")
signs = await Sign.get_all_where(Sign.category_id==category_id, select_in_load=[Sign.sign_videos], session=session)
return signs
@router.post("/", status_code=status.HTTP_201_CREATED, response_model=CategoryOut)
async def create_category(category: CategoryPost, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
Authorize.jwt_required()
user = Authorize.get_jwt_subject()
user = await User.get_by_id(id=user, session=session)
if not user:
raise LoginException("User not found")
# Category name cannot be empty or only exist of spaces
if not category.name or category.name.isspace():
raise BaseException(message="Category name cannot be empty", status_code=status.HTTP_400_BAD_REQUEST)
try:
c = Category(name=category.name)
await c.save(session=session)
except Exception as e:
raise BaseException(message="Category already exists", status_code=status.HTTP_400_BAD_REQUEST)
return c
@router.put("/", response_model=CategoryOut)
async def update_category(category: CategoryPut, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
Authorize.jwt_required()
user = Authorize.get_jwt_subject()
user = await User.get_by_id(id=user, session=session)
if not user:
raise LoginException("User not found")
# Category name cannot be empty or only exist of spaces
if not category.name or category.name.isspace():
raise BaseException(message="Category name cannot be empty", status_code=status.HTTP_400_BAD_REQUEST)
c = await Category.get_by_id(id=category.id, session=session)
if not c:
raise BaseException(message="Category not found", status_code=status.HTTP_404_NOT_FOUND)
c.name = category.name
c.enabled = category.enabled
await c.save(session=session)
return c
@router.delete("/{category_id}")
async def delete_category(category_id: int, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
Authorize.jwt_required()
user = Authorize.get_jwt_subject()
user = await User.get_by_id(id=user, session=session)
if not user:
raise LoginException("User not found")
c = await Category.get_by_id(id=category_id, session=session)
print(c)
if not c:
raise BaseException(message="Category not found", status_code=status.HTTP_404_NOT_FOUND)
if (len(c.signs) > 0):
raise BaseException(message="Category cannot be deleted because it contains signs", status_code=status.HTTP_400_BAD_REQUEST)
await c.delete(session=session)
return {"message": "Category deleted successfully"}

View File

@@ -12,24 +12,25 @@ from sqlalchemy.ext.asyncio import AsyncSession
import src.settings as settings
from src.database.database import get_session
from src.exceptions.login_exception import LoginException
from src.models.auth import Login, User
from src.models.auth import User
from src.models.category import Category
from src.models.sign import Sign, SignOut, SignOutSimple
from src.models.signvideo import SignVideo
from src.models.token import TokenExtended
router = APIRouter(prefix="/signs")
class SignUrl(BaseModel):
class SignIn(BaseModel):
category: int
url: str
@router.get("/random", status_code=status.HTTP_200_OK, response_model=SignOutSimple)
async def get_random_sign(session: AsyncSession = Depends(get_session)):
# get a random sign where there is not much data from yet
sign = await Sign.get_random(session=session)
sign = await Category.get_random_sign(session=session)
return sign
@router.post("/", status_code=status.HTTP_201_CREATED, response_model=SignOut)
async def add_sign(url: SignUrl, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
async def add_sign(sign: SignIn, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
Authorize.jwt_required()
user = Authorize.get_jwt_subject()
@@ -38,7 +39,7 @@ async def add_sign(url: SignUrl, Authorize: AuthJWT = Depends(), session: AsyncS
if not user:
raise LoginException("User not found")
sign = Sign(url.url)
sign = Sign(url=sign.url, category_id=sign.category)
# check if the sign already exists
signs = await Sign.get_all_where(Sign.url == sign.url, session=session)