Updating to support categories
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from src.database.crud import delete, read_all_where, read_where, update
|
||||
|
||||
|
||||
@@ -34,8 +35,8 @@ class SQLModelExtended(SQLModel):
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
async def get_all_where(self, *args, session: AsyncSession):
|
||||
res = await read_all_where(self, *args, session=session)
|
||||
async def get_all_where(self, *args, select_in_load: List = [], session: AsyncSession):
|
||||
res = await read_all_where(self, *args, select_in_load=select_in_load, session=session)
|
||||
return res
|
||||
|
||||
async def delete(self, session: AsyncSession) -> None:
|
||||
|
||||
5
backend/src/models/__init__.py
Executable file
5
backend/src/models/__init__.py
Executable file
@@ -0,0 +1,5 @@
|
||||
import glob
|
||||
from os.path import basename, dirname, isfile, join
|
||||
|
||||
modules = glob.glob(join(dirname(__file__), "*.py"))
|
||||
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
|
||||
54
backend/src/models/category.py
Executable file
54
backend/src/models/category.py
Executable file
@@ -0,0 +1,54 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from src.models.sign import Sign, SignOut
|
||||
from src.models.SQLModelExtended import SQLModelExtended
|
||||
|
||||
|
||||
class Category(SQLModelExtended, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
|
||||
name: str = Field(unique=True)
|
||||
enabled: bool = Field(default=True)
|
||||
|
||||
# list of signs that belong to this category
|
||||
signs: List[Sign] = Relationship(
|
||||
back_populates="category",
|
||||
sa_relationship_kwargs={"lazy": "selectin"},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_random_sign(self, session: AsyncSession):
|
||||
# get all categories
|
||||
|
||||
categories = await self.get_all_where(Category.enabled==True, select_in_load=[Category.signs],session=session)
|
||||
|
||||
# get all signs in one list with list comprehension
|
||||
signs = [s for c in categories for s in c.signs]
|
||||
|
||||
sign_videos = [len(s.sign_videos) for s in signs]
|
||||
|
||||
random_prob = [1 / (x + 1) for x in sign_videos]
|
||||
random_prob = random_prob / np.sum(random_prob)
|
||||
|
||||
# get random sign
|
||||
sign = np.random.choice(signs, p=random_prob)
|
||||
|
||||
return sign
|
||||
|
||||
class CategoryPost(BaseModel):
|
||||
name: str
|
||||
|
||||
class CategoryPut(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
enabled: bool = True
|
||||
|
||||
class CategoryOut(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
enabled: bool
|
||||
@@ -3,7 +3,6 @@ from typing import List
|
||||
import numpy as np
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlmodel import Field, Relationship, SQLModel
|
||||
|
||||
from src.exceptions.base_exception import BaseException
|
||||
@@ -25,8 +24,12 @@ class Sign(SQLModelExtended, table=True):
|
||||
sa_relationship_kwargs={"lazy": "selectin"},
|
||||
)
|
||||
|
||||
def __init__(self, url):
|
||||
category_id: int = Field(foreign_key="category.id")
|
||||
category: "Category" = Relationship(back_populates="signs")
|
||||
|
||||
def __init__(self, url, category_id):
|
||||
self.url = url
|
||||
self.category_id = category_id
|
||||
|
||||
# get name and sign id from url
|
||||
try:
|
||||
@@ -53,23 +56,6 @@ class Sign(SQLModelExtended, table=True):
|
||||
if self.video_url is None:
|
||||
raise BaseException(404, "Video url not found")
|
||||
|
||||
@classmethod
|
||||
async def get_random(self, session: AsyncSession):
|
||||
signs = await self.get_all(select_in_load=[Sign.sign_videos],session=session)
|
||||
|
||||
sign_videos = [len(s.sign_videos) for s in signs]
|
||||
|
||||
# get probability based on number of videos, lower must be more likely
|
||||
# the sum must be 1
|
||||
|
||||
random_prob = [1 / (x + 1) for x in sign_videos]
|
||||
random_prob = random_prob / np.sum(random_prob)
|
||||
|
||||
# get random sign
|
||||
sign = np.random.choice(signs, p=random_prob)
|
||||
|
||||
return sign
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -79,6 +65,7 @@ class SignOut(BaseModel):
|
||||
name: str
|
||||
sign_id: str
|
||||
video_url: str
|
||||
category_id: int
|
||||
|
||||
sign_videos: List[SignVideo] = []
|
||||
|
||||
@@ -87,4 +74,5 @@ class SignOutSimple(BaseModel):
|
||||
url: str
|
||||
name: str
|
||||
sign_id: str
|
||||
video_url: str
|
||||
video_url: str
|
||||
category_id: int
|
||||
Reference in New Issue
Block a user