Skip to content

Commit

Permalink
refactor: select builder rework
Browse files Browse the repository at this point in the history
  • Loading branch information
turisesonia committed May 23, 2024
1 parent fbb6a2b commit dbe547f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 25 deletions.
10 changes: 10 additions & 0 deletions fluent_alchemy/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,13 @@ def macro(self, name: str, callable_: Callable):
self._macros[name] = callable_

return self


class WhereBase:
def __init__(self):
self._where_clauses = ()

def where(self, *express: BinaryExpression):
self._where_clauses += (*express,)

return self
59 changes: 34 additions & 25 deletions fluent_alchemy/builders/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,94 +6,100 @@
from sqlalchemy.sql.elements import BinaryExpression, UnaryExpression
from sqlalchemy.engine.result import Result
from sqlalchemy.orm.strategy_options import Load
from sqlalchemy.sql import Select

from . import _M
from .base import BaseBuilder
from .base import BaseBuilder, WhereBase


class SelectBuilder(BaseBuilder):
class SelectBuilder(BaseBuilder, WhereBase):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._select_stmt: Select = None

def _initial(self):
if self._stmt is None:
self._stmt = select(self.get_model_class())
if self._select_stmt is None:
self._select_stmt = select(self.get_model_class())

def select(self, *entities):
if self._stmt is not None:
if self._select_stmt is not None:
# TODO select statement is already initial
raise Exception("")

if not entities:
self._stmt = select(self.get_model_class())
self._select_stmt = select(self.get_model_class())
else:
self._stmt = select(*entities)

return self

def where(self, *express: BinaryExpression):
self._initial()

self._stmt = self._stmt.where(*express)
self._select_stmt = select(*entities)

return self

def offset(self, offset: int):
self._initial()

self._stmt = self._stmt.offset(offset)
self._select_stmt = self._select_stmt.offset(offset)

return self

def limit(self, limit: int):
self._initial()

self._stmt = self._stmt.limit(limit)
self._select_stmt = self._select_stmt.limit(limit)

return self

def group_by(self, *entities):
self._initial()

self._stmt = self._stmt.group_by(*entities)
self._select_stmt = self._select_stmt.group_by(*entities)

return self

def having(self, *express: BinaryExpression):
self._initial()

self._stmt = self._stmt.having(*express)
self._select_stmt = self._select_stmt.having(*express)

return self

def order_by(self, *express: UnaryExpression):
self._initial()

self._stmt = self._stmt.order_by(*express)
self._select_stmt = self._select_stmt.order_by(*express)

return self

def options(self, *options: Load):
self._initial()

self._stmt = self._stmt.options(*options)
self._select_stmt = self._select_stmt.options(*options)

return self

def execute(
self, stmt: Optional[Executable] = None, *args, **kwargs
) -> Result[Any]:
stmt = stmt if stmt is not None else self._stmt
stmt = stmt if stmt is not None else self._select_stmt

return self._session.execute(stmt, *args, **kwargs)

def first(self, specific_fields: bool = False) -> Optional[_M]:
result = self.execute(self._stmt)
if self._where_clauses:
self._select_stmt.where(*self._where_clauses)

result = self.execute(self._select_stmt)

if specific_fields:
return result.first()

return result.scalars().first()

def get(self, specific_fields: bool = False) -> Iterable[_M]:
result = self.execute(self._stmt)
if self._where_clauses:
self._select_stmt.where(*self._where_clauses)

result = self.execute(self._select_stmt)

if specific_fields:
return result.all()
Expand All @@ -104,9 +110,12 @@ def paginate(self, page: int = 1, per_page: int = 30) -> dict:
self.offset((page - 1) * per_page)
self.limit(per_page)

if self._where_clauses:
self._select_stmt.where(*self._where_clauses)

total_stmt = select(func.count()).select_from(self.get_model_class())
if self._stmt.whereclause is not None:
total_stmt = total_stmt.where(self._stmt.whereclause)
if self._select_stmt.whereclause is not None:
total_stmt = total_stmt.where(self._select_stmt.whereclause)

total_rows = self.execute(total_stmt).scalars().first()

Expand Down

0 comments on commit dbe547f

Please sign in to comment.