Updating to support categories

This commit is contained in:
2023-03-11 12:32:41 +00:00
parent bdab151dae
commit ec7bd6dde5
27 changed files with 807 additions and 42 deletions

View File

@@ -3,6 +3,7 @@ from typing import List, Optional
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import SQLModel
from src.database.crud import delete, read_all_where, read_where, update
@@ -34,8 +35,8 @@ class SQLModelExtended(SQLModel):
return res
@classmethod
async def get_all_where(self, *args, session: AsyncSession):
res = await read_all_where(self, *args, session=session)
async def get_all_where(self, *args, select_in_load: List = [], session: AsyncSession):
res = await read_all_where(self, *args, select_in_load=select_in_load, session=session)
return res
async def delete(self, session: AsyncSession) -> None:

5
backend/src/models/__init__.py Executable file
View File

@@ -0,0 +1,5 @@
import glob
from os.path import basename, dirname, isfile, join
modules = glob.glob(join(dirname(__file__), "*.py"))
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]

54
backend/src/models/category.py Executable file
View File

@@ -0,0 +1,54 @@
from typing import List
import numpy as np
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import Field, Relationship
from src.models.sign import Sign, SignOut
from src.models.SQLModelExtended import SQLModelExtended
class Category(SQLModelExtended, table=True):
id: int = Field(primary_key=True)
name: str = Field(unique=True)
enabled: bool = Field(default=True)
# list of signs that belong to this category
signs: List[Sign] = Relationship(
back_populates="category",
sa_relationship_kwargs={"lazy": "selectin"},
)
@classmethod
async def get_random_sign(self, session: AsyncSession):
# get all categories
categories = await self.get_all_where(Category.enabled==True, select_in_load=[Category.signs],session=session)
# get all signs in one list with list comprehension
signs = [s for c in categories for s in c.signs]
sign_videos = [len(s.sign_videos) for s in signs]
random_prob = [1 / (x + 1) for x in sign_videos]
random_prob = random_prob / np.sum(random_prob)
# get random sign
sign = np.random.choice(signs, p=random_prob)
return sign
class CategoryPost(BaseModel):
name: str
class CategoryPut(BaseModel):
id: int
name: str
enabled: bool = True
class CategoryOut(BaseModel):
id: int
name: str
enabled: bool

View File

@@ -3,7 +3,6 @@ from typing import List
import numpy as np
import requests
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import Field, Relationship, SQLModel
from src.exceptions.base_exception import BaseException
@@ -25,8 +24,12 @@ class Sign(SQLModelExtended, table=True):
sa_relationship_kwargs={"lazy": "selectin"},
)
def __init__(self, url):
category_id: int = Field(foreign_key="category.id")
category: "Category" = Relationship(back_populates="signs")
def __init__(self, url, category_id):
self.url = url
self.category_id = category_id
# get name and sign id from url
try:
@@ -53,23 +56,6 @@ class Sign(SQLModelExtended, table=True):
if self.video_url is None:
raise BaseException(404, "Video url not found")
@classmethod
async def get_random(self, session: AsyncSession):
signs = await self.get_all(select_in_load=[Sign.sign_videos],session=session)
sign_videos = [len(s.sign_videos) for s in signs]
# get probability based on number of videos, lower must be more likely
# the sum must be 1
random_prob = [1 / (x + 1) for x in sign_videos]
random_prob = random_prob / np.sum(random_prob)
# get random sign
sign = np.random.choice(signs, p=random_prob)
return sign
@@ -79,6 +65,7 @@ class SignOut(BaseModel):
name: str
sign_id: str
video_url: str
category_id: int
sign_videos: List[SignVideo] = []
@@ -87,4 +74,5 @@ class SignOutSimple(BaseModel):
url: str
name: str
sign_id: str
video_url: str
video_url: str
category_id: int