Updating to support categories
This commit is contained in:
@@ -51,11 +51,13 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
# include the routers
|
||||
from .routers import auth_router, signs_router, signvideo_router
|
||||
from .routers import (auth_router, category_router, signs_router,
|
||||
signvideo_router)
|
||||
|
||||
app.include_router(auth_router)
|
||||
app.include_router(signs_router)
|
||||
app.include_router(signvideo_router)
|
||||
app.include_router(category_router)
|
||||
|
||||
# Add the exception handlers
|
||||
app.add_exception_handler(BaseException, base_exception_handler)
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from src.database.crud import delete, read_all_where, read_where, update
|
||||
|
||||
|
||||
@@ -34,8 +35,8 @@ class SQLModelExtended(SQLModel):
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
async def get_all_where(self, *args, session: AsyncSession):
|
||||
res = await read_all_where(self, *args, session=session)
|
||||
async def get_all_where(self, *args, select_in_load: List = [], session: AsyncSession):
|
||||
res = await read_all_where(self, *args, select_in_load=select_in_load, session=session)
|
||||
return res
|
||||
|
||||
async def delete(self, session: AsyncSession) -> None:
|
||||
|
||||
5
backend/src/models/__init__.py
Executable file
5
backend/src/models/__init__.py
Executable file
@@ -0,0 +1,5 @@
|
||||
import glob
|
||||
from os.path import basename, dirname, isfile, join
|
||||
|
||||
modules = glob.glob(join(dirname(__file__), "*.py"))
|
||||
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
|
||||
54
backend/src/models/category.py
Executable file
54
backend/src/models/category.py
Executable file
@@ -0,0 +1,54 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from src.models.sign import Sign, SignOut
|
||||
from src.models.SQLModelExtended import SQLModelExtended
|
||||
|
||||
|
||||
class Category(SQLModelExtended, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
|
||||
name: str = Field(unique=True)
|
||||
enabled: bool = Field(default=True)
|
||||
|
||||
# list of signs that belong to this category
|
||||
signs: List[Sign] = Relationship(
|
||||
back_populates="category",
|
||||
sa_relationship_kwargs={"lazy": "selectin"},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_random_sign(self, session: AsyncSession):
|
||||
# get all categories
|
||||
|
||||
categories = await self.get_all_where(Category.enabled==True, select_in_load=[Category.signs],session=session)
|
||||
|
||||
# get all signs in one list with list comprehension
|
||||
signs = [s for c in categories for s in c.signs]
|
||||
|
||||
sign_videos = [len(s.sign_videos) for s in signs]
|
||||
|
||||
random_prob = [1 / (x + 1) for x in sign_videos]
|
||||
random_prob = random_prob / np.sum(random_prob)
|
||||
|
||||
# get random sign
|
||||
sign = np.random.choice(signs, p=random_prob)
|
||||
|
||||
return sign
|
||||
|
||||
class CategoryPost(BaseModel):
|
||||
name: str
|
||||
|
||||
class CategoryPut(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
enabled: bool = True
|
||||
|
||||
class CategoryOut(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
enabled: bool
|
||||
@@ -3,7 +3,6 @@ from typing import List
|
||||
import numpy as np
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
|
||||
from src.exceptions.base_exception import BaseException
|
||||
@@ -25,8 +24,12 @@ class Sign(SQLModelExtended, table=True):
|
||||
sa_relationship_kwargs={"lazy": "selectin"},
|
||||
)
|
||||
|
||||
def __init__(self, url):
|
||||
category_id: int = Field(foreign_key="category.id")
|
||||
category: "Category" = Relationship(back_populates="signs")
|
||||
|
||||
def __init__(self, url, category_id):
|
||||
self.url = url
|
||||
self.category_id = category_id
|
||||
|
||||
# get name and sign id from url
|
||||
try:
|
||||
@@ -53,23 +56,6 @@ class Sign(SQLModelExtended, table=True):
|
||||
if self.video_url is None:
|
||||
raise BaseException(404, "Video url not found")
|
||||
|
||||
@classmethod
|
||||
async def get_random(self, session: AsyncSession):
|
||||
signs = await self.get_all(select_in_load=[Sign.sign_videos],session=session)
|
||||
|
||||
sign_videos = [len(s.sign_videos) for s in signs]
|
||||
|
||||
# get probability based on number of videos, lower must be more likely
|
||||
# the sum must be 1
|
||||
|
||||
random_prob = [1 / (x + 1) for x in sign_videos]
|
||||
random_prob = random_prob / np.sum(random_prob)
|
||||
|
||||
# get random sign
|
||||
sign = np.random.choice(signs, p=random_prob)
|
||||
|
||||
return sign
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -79,6 +65,7 @@ class SignOut(BaseModel):
|
||||
name: str
|
||||
sign_id: str
|
||||
video_url: str
|
||||
category_id: int
|
||||
|
||||
sign_videos: List[SignVideo] = []
|
||||
|
||||
@@ -87,4 +74,5 @@ class SignOutSimple(BaseModel):
|
||||
url: str
|
||||
name: str
|
||||
sign_id: str
|
||||
video_url: str
|
||||
video_url: str
|
||||
category_id: int
|
||||
@@ -1,3 +1,4 @@
|
||||
from .auth import router as auth_router
|
||||
from .category import router as category_router
|
||||
from .signs import router as signs_router
|
||||
from .signvideo import router as signvideo_router
|
||||
|
||||
106
backend/src/routers/category.py
Executable file
106
backend/src/routers/category.py
Executable file
@@ -0,0 +1,106 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, status
|
||||
from fastapi_jwt_auth import AuthJWT
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.database.database import get_session
|
||||
from src.exceptions.base_exception import BaseException
|
||||
from src.exceptions.login_exception import LoginException
|
||||
from src.models.auth import User
|
||||
from src.models.category import (Category, CategoryOut, CategoryPost,
|
||||
CategoryPut)
|
||||
from src.models.sign import Sign, SignOut
|
||||
|
||||
router = APIRouter(prefix="/categories")
|
||||
|
||||
@router.get("/", response_model=List[Category])
|
||||
async def get_categories(Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
|
||||
Authorize.jwt_required()
|
||||
|
||||
user = Authorize.get_jwt_subject()
|
||||
user = await User.get_by_id(id=user, session=session)
|
||||
|
||||
if not user:
|
||||
raise LoginException("User not found")
|
||||
|
||||
categories = await Category.get_all(session=session, select_in_load=[Category.signs])
|
||||
return categories
|
||||
|
||||
@router.get("/{category_id}/signs", response_model=List[SignOut])
|
||||
async def get_signs_by_category(category_id: int, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
|
||||
Authorize.jwt_required()
|
||||
|
||||
user = Authorize.get_jwt_subject()
|
||||
user = await User.get_by_id(id=user, session=session)
|
||||
|
||||
if not user:
|
||||
raise LoginException("User not found")
|
||||
|
||||
signs = await Sign.get_all_where(Sign.category_id==category_id, select_in_load=[Sign.sign_videos], session=session)
|
||||
return signs
|
||||
|
||||
@router.post("/", status_code=status.HTTP_201_CREATED, response_model=CategoryOut)
|
||||
async def create_category(category: CategoryPost, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
|
||||
Authorize.jwt_required()
|
||||
|
||||
user = Authorize.get_jwt_subject()
|
||||
user = await User.get_by_id(id=user, session=session)
|
||||
|
||||
if not user:
|
||||
raise LoginException("User not found")
|
||||
|
||||
# Category name cannot be empty or only exist of spaces
|
||||
if not category.name or category.name.isspace():
|
||||
raise BaseException(message="Category name cannot be empty", status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
try:
|
||||
c = Category(name=category.name)
|
||||
await c.save(session=session)
|
||||
except Exception as e:
|
||||
raise BaseException(message="Category already exists", status_code=status.HTTP_400_BAD_REQUEST)
|
||||
return c
|
||||
|
||||
@router.put("/", response_model=CategoryOut)
|
||||
async def update_category(category: CategoryPut, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
|
||||
Authorize.jwt_required()
|
||||
|
||||
user = Authorize.get_jwt_subject()
|
||||
user = await User.get_by_id(id=user, session=session)
|
||||
|
||||
if not user:
|
||||
raise LoginException("User not found")
|
||||
|
||||
# Category name cannot be empty or only exist of spaces
|
||||
if not category.name or category.name.isspace():
|
||||
raise BaseException(message="Category name cannot be empty", status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
c = await Category.get_by_id(id=category.id, session=session)
|
||||
if not c:
|
||||
raise BaseException(message="Category not found", status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
c.name = category.name
|
||||
c.enabled = category.enabled
|
||||
await c.save(session=session)
|
||||
return c
|
||||
|
||||
@router.delete("/{category_id}")
|
||||
async def delete_category(category_id: int, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
|
||||
Authorize.jwt_required()
|
||||
|
||||
user = Authorize.get_jwt_subject()
|
||||
user = await User.get_by_id(id=user, session=session)
|
||||
|
||||
if not user:
|
||||
raise LoginException("User not found")
|
||||
|
||||
c = await Category.get_by_id(id=category_id, session=session)
|
||||
print(c)
|
||||
if not c:
|
||||
raise BaseException(message="Category not found", status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
if (len(c.signs) > 0):
|
||||
raise BaseException(message="Category cannot be deleted because it contains signs", status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
await c.delete(session=session)
|
||||
return {"message": "Category deleted successfully"}
|
||||
@@ -12,24 +12,25 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import src.settings as settings
|
||||
from src.database.database import get_session
|
||||
from src.exceptions.login_exception import LoginException
|
||||
from src.models.auth import Login, User
|
||||
from src.models.auth import User
|
||||
from src.models.category import Category
|
||||
from src.models.sign import Sign, SignOut, SignOutSimple
|
||||
from src.models.signvideo import SignVideo
|
||||
from src.models.token import TokenExtended
|
||||
|
||||
router = APIRouter(prefix="/signs")
|
||||
|
||||
class SignUrl(BaseModel):
|
||||
class SignIn(BaseModel):
|
||||
category: int
|
||||
url: str
|
||||
|
||||
@router.get("/random", status_code=status.HTTP_200_OK, response_model=SignOutSimple)
|
||||
async def get_random_sign(session: AsyncSession = Depends(get_session)):
|
||||
# get a random sign where there is not much data from yet
|
||||
sign = await Sign.get_random(session=session)
|
||||
sign = await Category.get_random_sign(session=session)
|
||||
return sign
|
||||
|
||||
@router.post("/", status_code=status.HTTP_201_CREATED, response_model=SignOut)
|
||||
async def add_sign(url: SignUrl, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
|
||||
async def add_sign(sign: SignIn, Authorize: AuthJWT = Depends(), session: AsyncSession = Depends(get_session)):
|
||||
Authorize.jwt_required()
|
||||
|
||||
user = Authorize.get_jwt_subject()
|
||||
@@ -38,7 +39,7 @@ async def add_sign(url: SignUrl, Authorize: AuthJWT = Depends(), session: AsyncS
|
||||
if not user:
|
||||
raise LoginException("User not found")
|
||||
|
||||
sign = Sign(url.url)
|
||||
sign = Sign(url=sign.url, category_id=sign.category)
|
||||
|
||||
# check if the sign already exists
|
||||
signs = await Sign.get_all_where(Sign.url == sign.url, session=session)
|
||||
|
||||
Reference in New Issue
Block a user