Skip to content

Commit

Permalink
refactor: rename statement variable
Browse files Browse the repository at this point in the history
  • Loading branch information
turisesonia committed May 31, 2024
1 parent 3620158 commit 52b5a0e
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 132 deletions.
106 changes: 2 additions & 104 deletions fluent_alchemy/builders/base.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,15 @@
from typing import Any, Optional, Generic, Callable, Union
from typing import Optional, Generic, Callable

from sqlalchemy import Executable, Select, Insert, Delete, Update
from sqlalchemy.orm import Session
from sqlalchemy.orm.strategy_options import Load
from sqlalchemy.engine.result import Result
from sqlalchemy.sql import Select
from sqlalchemy.sql.dml import Delete
from sqlalchemy.sql.elements import BinaryExpression, UnaryExpression
from sqlalchemy.sql.elements import BinaryExpression

from . import _M


# class BaseBuilder(Generic[_M]):
# _model: _M

# def __init__(self, session: Session, model: _M):
# self._session: Session = session
# self._model: _M = model
# self._select_entities = []
# self._where_clauses = []
# self._group_clauses = []
# self._having_clauses = []
# self._order_clauses = []
# self._offset: Optional[int] = None
# self._limit: Optional[int] = None
# self._options = []
# self._execution_options: Optional[dict] = None
# self._returnings = []
# self._scopes = {}
# self._macros = {}

# def select(self, *entities):
# self._select_entities.extend(entities)

# return self

# def where(self, *express: BinaryExpression):
# self._where_clauses.extend(express)

# return self

# def offset(self, offset: int):
# self._offset = offset

# return self

# def limit(self, limit: int):
# self._limit = limit

# return self

# def group_by(self, *entities):
# self._group_clauses.extend(entities)

# return self

# def having(self, *express: BinaryExpression):
# self._having_clauses.extend(express)

# return self

# def order_by(self, *express: UnaryExpression):
# self._order_clauses.extend(express)

# return self

# def returning(self, *entities):
# self._returnings.extend(entities)

# return self

# def options(self, *options: Load):
# self._options.extend(options)

# return self

# def execution_options(self, **options):
# if self._execution_options is None:
# self._execution_options = {}

# self._execution_options.update(options)

# return self

# def execute(self, stmt: Union[Select, Delete], *args, **kwargs) -> Result[Any]:
# return self._session.execute(stmt, *args, **kwargs)

# def get_model_class(self):
# return self._model.__class__

# def apply_scopes(self, scopes: dict = {}):
# self._scopes = scopes
# self._on_delete: Optional[Callable] = None

# for _, scope in self._scopes.items():
# scope.boot(self)

# return self

# def macro(self, name: str, callable_: Callable):
# if callable(callable_):
# self._macros[name] = callable_

# return self


class BaseBuilder(Generic[_M]):
_model: _M

def __init__(self, session: Session, model: _M):
self._session: Session = session
self._model: _M = model
self._stmt: Union[Select, Delete, Insert, Update] = None

self._scopes = {}
self._macros = {}
Expand Down
22 changes: 14 additions & 8 deletions fluent_alchemy/builders/insert.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,49 @@
from typing import Any

from sqlalchemy import insert
from sqlalchemy import Insert, insert
from sqlalchemy.engine.result import Result

from .base import BaseBuilder


class InsertBuilder(BaseBuilder):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._insert_stmt: Insert = None

def _initial(self):
if self._stmt is None:
self._stmt = insert(self.get_model_class())
if self._insert_stmt is None:
self._insert_stmt = insert(self.get_model_class())

def returning(self, *entities):
self._initial()

self._stmt = self._stmt.returning(*entities)
self._insert_stmt = self._insert_stmt.returning(*entities)

return self

def execution_options(self, **options):
self._initial()

self._stmt = self._stmt.execution_options(**options)
self._insert_stmt = self._insert_stmt.execution_options(**options)

return self

def values(self, *args, **kwargs):
self._initial()

self._stmt = self._stmt.values(*args, **kwargs)
self._insert_stmt = self._insert_stmt.values(*args, **kwargs)

return self

def execute(self, autocommit: bool = True, *args, **kwargs) -> Result[Any]:
if self._stmt is None:
if self._insert_stmt is None:
# todo error message
raise ValueError("")

result = self._session.execute(self._stmt, *args, **kwargs)
result = self._session.execute(self._insert_stmt, *args, **kwargs)

if autocommit:
self._commit()
Expand Down
22 changes: 14 additions & 8 deletions fluent_alchemy/builders/update.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,50 @@
from typing import Any

from sqlalchemy import update
from sqlalchemy import Update, update
from sqlalchemy.engine.result import Result
from sqlalchemy.sql.elements import BinaryExpression

from .base import BaseBuilder


class UpdateBuilder(BaseBuilder):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._update_stmt: Update = None

def _initial(self):
if self._stmt is None:
self._stmt = update(self.get_model_class())
if self._update_stmt is None:
self._update_stmt = update(self.get_model_class())

def where(self, *express: BinaryExpression):
self._initial()

self._stmt = self._stmt.where(*express)
self._update_stmt = self._update_stmt.where(*express)

return self

def returning(self, *entities):
self._initial()

self._stmt = self._stmt.returning(*entities)
self._update_stmt = self._update_stmt.returning(*entities)

return self

def values(self, *args, **kwargs):
self._initial()

self._stmt = self._stmt.values(*args, **kwargs)
self._update_stmt = self._update_stmt.values(*args, **kwargs)

return self

def execute(self, autocommit: bool = True, *args, **kwargs) -> Result[Any]:
if self._stmt is None:
if self._update_stmt is None:
# todo error message
raise ValueError("")

result = self._session.execute(self._stmt, *args, **kwargs)
result = self._session.execute(self._update_stmt, *args, **kwargs)

if autocommit:
self._commit()
Expand Down
10 changes: 5 additions & 5 deletions tests/test_builders_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_insert_values_stmt(faker, builder: InsertBuilder):

builder.values(**values)

stmt = builder._stmt
stmt = builder._insert_stmt

assert isinstance(stmt, Insert)
assert stmt.is_dml
Expand All @@ -44,7 +44,7 @@ def test_insert_values_stmt(faker, builder: InsertBuilder):
def test_insert_returning_all(builder: InsertBuilder):
builder.returning(User)

stmt = builder._stmt
stmt = builder._insert_stmt

assert len(stmt._returning) > 0

Expand All @@ -55,7 +55,7 @@ def test_insert_returning_all(builder: InsertBuilder):
def test_insert_returning_specific_fields(builder: InsertBuilder):
builder.returning(User.id, User.email)

stmt = builder._stmt
stmt = builder._insert_stmt

assert len(stmt._returning) > 0

Expand All @@ -68,7 +68,7 @@ def test_insert_returning_specific_fields(builder: InsertBuilder):
def test_insert_execution_options(builder: InsertBuilder):
builder.execution_options(render_nulls=True)

stmt = builder._stmt
stmt = builder._insert_stmt

execution_options = stmt.get_execution_options()

Expand Down Expand Up @@ -97,5 +97,5 @@ def test_insert_execute(mocker, faker, session: Session, builder: InsertBuilder)
builder.values(values)
builder.execute()

mock_execute.assert_called_once_with(builder._stmt)
mock_execute.assert_called_once_with(builder._insert_stmt)
mock_commit.assert_called_once()
14 changes: 7 additions & 7 deletions tests/test_builders_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,23 @@ def test_update_builder_initial(builder: UpdateBuilder):

def test_build_update_stmt(builder: UpdateBuilder):
builder._initial()
stmt = builder._stmt
stmt = builder._update_stmt

assert isinstance(stmt, Update)
assert stmt.is_dml


def test_build_update_stmt_with_single_where_clause(faker, builder: UpdateBuilder):
builder.where(User.email == faker.email())
stmt = builder._stmt
stmt = builder._update_stmt

assert isinstance(stmt.whereclause, BinaryExpression)
assert stmt.whereclause.left.name == "email"


def test_build_update_stmt_with_multiple_where_clauses(faker, builder: UpdateBuilder):
builder.where(User.email == faker.email()).where(User.state.is_(False))
stmt = builder._stmt
stmt = builder._update_stmt

assert isinstance(stmt.whereclause, BooleanClauseList)

Expand All @@ -59,7 +59,7 @@ def test_build_update_stmt_with_multiple_where_clauses(faker, builder: UpdateBui
def test_build_update_stmt_with_returning(builder: UpdateBuilder):
builder.returning(User.id, User.email)

stmt = builder._stmt
stmt = builder._update_stmt

assert len(stmt._returning) > 0

Expand All @@ -72,7 +72,7 @@ def test_build_update_stmt_with_returning(builder: UpdateBuilder):
def test_build_update_stmt_with_returning_all(builder: UpdateBuilder):
builder.returning(User)

stmt = builder._stmt
stmt = builder._update_stmt

assert len(stmt._returning) > 0

Expand All @@ -88,7 +88,7 @@ def test_build_update_stmt_with_values(faker, builder: UpdateBuilder):

builder.values(**values)

stmt = builder._stmt
stmt = builder._update_stmt

for column, param in stmt._values.items():
assert values[column.name] == param.value
Expand All @@ -107,5 +107,5 @@ def test_update_execute(mocker, faker, session: Session, builder: UpdateBuilder)
builder.values(values)
builder.execute()

mock_execute.assert_called_once_with(builder._stmt)
mock_execute.assert_called_once_with(builder._update_stmt)
mock_commit.assert_called_once()

0 comments on commit 52b5a0e

Please sign in to comment.