Skip to content

Commit

Permalink
Merge pull request #8 from Xenia101/feature/refactor-by-xenia101
Browse files Browse the repository at this point in the history
Refactoring 🦥
  • Loading branch information
jujumilk3 authored Jul 5, 2024
2 parents 07b6ce1 + 990cd7e commit ffcf7fd
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 25 deletions.
7 changes: 6 additions & 1 deletion app/api/v1/endpoints/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from app.core.dependencies import get_current_active_user
from app.model.user import User
from app.schema.base_schema import Blank
from app.schema.post_tag_schema import FindPost, FindPostWithTagsResult, PostWithTags, UpsertPostWithTags
from app.schema.post_tag_schema import (
FindPost,
FindPostWithTagsResult,
PostWithTags,
UpsertPostWithTags,
)
from app.services.post_service import PostService

router = APIRouter(
Expand Down
4 changes: 2 additions & 2 deletions app/core/database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Callable
from typing import Any, Generator

from sqlalchemy import create_engine, orm
from sqlalchemy.ext.declarative import as_declarative, declared_attr
Expand Down Expand Up @@ -32,7 +32,7 @@ def create_database(self) -> None:
BaseModel.metadata.create_all(self._engine)

@contextmanager
def session(self) -> Callable[..., AbstractContextManager[Session]]:
def session(self) -> Generator[Any, Any, AbstractContextManager[Session]]:
session: Session = self._session_factory()
try:
yield session
Expand Down
3 changes: 2 additions & 1 deletion app/core/security.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime, timedelta
from typing import Tuple

from fastapi import Request
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand All @@ -12,7 +13,7 @@
ALGORITHM = "HS256"


def create_access_token(subject: dict, expires_delta: timedelta = None) -> (str, str):
def create_access_token(subject: dict, expires_delta: timedelta = None) -> Tuple[str, str]:
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
Expand Down
23 changes: 13 additions & 10 deletions app/repository/base_repository.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
from contextlib import AbstractContextManager
from typing import Callable
from typing import Any, Callable, Type, TypeVar

from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload

from app.core.config import configs
from app.core.exceptions import DuplicatedError, NotFoundError
from app.model.base_model import BaseModel
from app.util.query_builder import dict_to_sqlalchemy_filter_options

T = TypeVar("T", bound=BaseModel)


class BaseRepository:
def __init__(self, session_factory: Callable[..., AbstractContextManager[Session]], model) -> None:
def __init__(self, session_factory: Callable[..., AbstractContextManager[Session]], model: Type[T]) -> None:
self.session_factory = session_factory
self.model = model

def read_by_options(self, schema, eager=False):
def read_by_options(self, schema: T, eager: bool = False) -> dict:
with self.session_factory() as session:
schema_as_dict = schema.dict(exclude_none=True)
ordering = schema_as_dict.get("ordering", configs.ORDERING)
schema_as_dict: dict = schema.dict(exclude_none=True)
ordering: str = schema_as_dict.get("ordering", configs.ORDERING)
order_query = (
getattr(self.model, ordering[1:]).desc()
if ordering.startswith("-")
Expand Down Expand Up @@ -47,7 +50,7 @@ def read_by_options(self, schema, eager=False):
},
}

def read_by_id(self, id: int, eager=False):
def read_by_id(self, id: int, eager: bool = False):
with self.session_factory() as session:
query = session.query(self.model)
if eager:
Expand All @@ -58,7 +61,7 @@ def read_by_id(self, id: int, eager=False):
raise NotFoundError(detail=f"not found id : {id}")
return query

def create(self, schema):
def create(self, schema: T):
with self.session_factory() as session:
query = self.model(**schema.dict())
try:
Expand All @@ -69,19 +72,19 @@ def create(self, schema):
raise DuplicatedError(detail=str(e.orig))
return query

def update(self, id: int, schema):
def update(self, id: int, schema: T):
with self.session_factory() as session:
session.query(self.model).filter(self.model.id == id).update(schema.dict(exclude_none=True))
session.commit()
return self.read_by_id(id)

def update_attr(self, id: int, column: str, value):
def update_attr(self, id: int, column: str, value: Any):
with self.session_factory() as session:
session.query(self.model).filter(self.model.id == id).update({column: value})
session.commit()
return self.read_by_id(id)

def whole_update(self, id: int, schema):
def whole_update(self, id: int, schema: T):
with self.session_factory() as session:
session.query(self.model).filter(self.model.id == id).update(schema.dict())
session.commit()
Expand Down
35 changes: 27 additions & 8 deletions app/services/base_service.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,43 @@
from typing import Any, Protocol


class RepositoryProtocol(Protocol):
def read_by_options(self, schema: Any) -> Any: ...

def read_by_id(self, id: int) -> Any: ...

def create(self, schema: Any) -> Any: ...

def update(self, id: int, schema: Any) -> Any: ...

def update_attr(self, id: int, attr: str, value: Any) -> Any: ...

def whole_update(self, id: int, schema: Any) -> Any: ...

def delete_by_id(self, id: int) -> Any: ...


class BaseService:
def __init__(self, repository) -> None:
def __init__(self, repository: RepositoryProtocol) -> None:
self._repository = repository

def get_list(self, schema):
def get_list(self, schema: Any) -> Any:
return self._repository.read_by_options(schema)

def get_by_id(self, id: int):
def get_by_id(self, id: int) -> Any:
return self._repository.read_by_id(id)

def add(self, schema):
def add(self, schema: Any) -> Any:
return self._repository.create(schema)

def patch(self, id: int, schema):
def patch(self, id: int, schema: Any) -> Any:
return self._repository.update(id, schema)

def patch_attr(self, id: int, attr: str, value):
def patch_attr(self, id: int, attr: str, value: Any) -> Any:
return self._repository.update_attr(id, attr, value)

def put_update(self, id: int, schema):
def put_update(self, id: int, schema: Any) -> Any:
return self._repository.whole_update(id, schema)

def remove_by_id(self, id):
def remove_by_id(self, id: int) -> Any:
return self._repository.delete_by_id(id)
2 changes: 1 addition & 1 deletion app/util/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
from pytz import timezone


def get_now():
def get_now() -> datetime:
return datetime.now(tz=timezone("UTC"))
2 changes: 1 addition & 1 deletion app/util/hash.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid


def get_rand_hash(length=16):
def get_rand_hash(length: int = 16) -> str:
return uuid.uuid4().hex[:length]
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ alembic==1.7.7
anyio==3.6.1
asgiref==3.5.2
asyncpg==0.26.0
attrs==21.4.0
attrs==22.2.0
bcrypt==3.2.2
certifi==2022.9.24
cffi==1.15.0
Expand Down

0 comments on commit ffcf7fd

Please sign in to comment.