diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index d4860d5..332f3ca 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,2 +1 @@ -# These are supported funding model platforms -custom: https://www.buymeacoffee.com/frankie567 +github: frankie567 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 827440e..ae04238 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,29 +8,35 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python_version: [3.7, 3.8, 3.9] + python_version: [3.7, 3.8, 3.9, '3.10', '3.11'] steps: - - uses: actions/checkout@v1 - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python_version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.dev.txt - - name: Test with pytest - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - run: | - pytest --cov=fastapi_users_db_sqlmodel/ - codecov - - name: Build and install it on system host - run: | - flit build - flit install --python $(which python) - python test_build.py + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + hatch env create + - name: Lint and typecheck + run: | + hatch run lint-check + - name: Test + run: | + hatch run test-cov-xml + - uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: true + verbose: true + - name: Build and install it on system host + run: | + hatch build + pip install dist/fastapi_users_db_sqlmodel-*.whl + python test_build.py release: runs-on: ubuntu-latest @@ -38,18 +44,26 @@ jobs: if: startsWith(github.ref, 'refs/tags/') steps: - - uses: actions/checkout@v1 - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: 3.7 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.dev.txt - - name: Release on PyPI - env: - FLIT_USERNAME: ${{ secrets.FLIT_USERNAME }} - FLIT_PASSWORD: ${{ secrets.FLIT_PASSWORD }} - run: | - flit publish + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.7 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + - name: Build and publish on PyPI + env: + HATCH_INDEX_USER: ${{ secrets.HATCH_INDEX_USER }} + HATCH_INDEX_AUTH: ${{ secrets.HATCH_INDEX_AUTH }} + run: | + hatch build + hatch publish + - name: Create release + uses: ncipollo/release-action@v1 + with: + draft: true + body: ${{ github.event.head_commit.message }} + artifacts: dist/*.whl,dist/*.tar.gz + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/Makefile b/Makefile deleted file mode 100644 index 91608b0..0000000 --- a/Makefile +++ /dev/null @@ -1,17 +0,0 @@ -isort: - isort ./fastapi_users_db_sqlmodel ./tests - -format: isort - black . - -test: - pytest --cov=fastapi_users_db_sqlmodel/ --cov-report=term-missing --cov-fail-under=100 - -bumpversion-major: - bumpversion major - -bumpversion-minor: - bumpversion minor - -bumpversion-patch: - bumpversion patch diff --git a/README.md b/README.md index ef102de..fe00d26 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ [![PyPI version](https://badge.fury.io/py/fastapi-users-db-sqlmodel.svg)](https://badge.fury.io/py/fastapi-users-db-sqlmodel) [![Downloads](https://pepy.tech/badge/fastapi-users-db-sqlmodel)](https://pepy.tech/project/fastapi-users-db-sqlmodel)

- +

--- diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 1623191..695c5e2 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -1,212 +1,243 @@ """FastAPI Users database adapter for SQLModel.""" import uuid -from typing import Callable, Generic, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type from fastapi_users.db.base import BaseUserDatabase -from fastapi_users.models import BaseOAuthAccount, BaseUserDB +from fastapi_users.models import ID, OAP, UP from pydantic import UUID4, EmailStr -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession -from sqlalchemy.future import Engine -from sqlalchemy.orm import selectinload, sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select -__version__ = "0.0.3" +__version__ = "0.3.0" -class SQLModelBaseUserDB(BaseUserDB, SQLModel): - id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True) - email: EmailStr = Field(sa_column_kwargs={"unique": True, "index": True}) +class SQLModelBaseUserDB(SQLModel): + __tablename__ = "user" - class Config: - orm_mode = True + id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True, nullable=False) + if TYPE_CHECKING: # pragma: no cover + email: str + else: + email: EmailStr = Field( + sa_column_kwargs={"unique": True, "index": True}, nullable=False + ) + hashed_password: str - -class SQLModelBaseOAuthAccount(BaseOAuthAccount, SQLModel): - id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True) + is_active: bool = Field(True, nullable=False) + is_superuser: bool = Field(False, nullable=False) + is_verified: bool = Field(False, nullable=False) class Config: orm_mode = True -UD = TypeVar("UD", bound=SQLModelBaseUserDB) -OA = TypeVar("OA", bound=SQLModelBaseOAuthAccount) - - -class NotSetOAuthAccountTableError(Exception): - """ - OAuth table was not set in DB adapter but was needed. +class SQLModelBaseOAuthAccount(SQLModel): + __tablename__ = "oauthaccount" - Raised when trying to create/update a user with OAuth accounts set - but no table were specified in the DB adapter. - """ + id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True) + user_id: UUID4 = Field(foreign_key="user.id", nullable=False) + oauth_name: str = Field(index=True, nullable=False) + access_token: str = Field(nullable=False) + expires_at: Optional[int] = Field(nullable=True) + refresh_token: Optional[str] = Field(nullable=True) + account_id: str = Field(index=True, nullable=False) + account_email: str = Field(nullable=False) - pass + class Config: + orm_mode = True -class SQLModelUserDatabase(Generic[UD, OA], BaseUserDatabase[UD]): +class SQLModelUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]): """ Database adapter for SQLModel. - :param user_db_model: SQLModel model of a DB representation of a user. - :param engine: SQLAlchemy engine. + :param session: SQLAlchemy session. """ - engine: Engine - oauth_account_model: Optional[Type[OA]] + session: Session + user_model: Type[UP] + oauth_account_model: Optional[Type[SQLModelBaseOAuthAccount]] def __init__( self, - user_db_model: Type[UD], - engine: Engine, - oauth_account_model: Optional[Type[OA]] = None, + session: Session, + user_model: Type[UP], + oauth_account_model: Optional[Type[SQLModelBaseOAuthAccount]] = None, ): - super().__init__(user_db_model) - self.engine = engine + self.session = session + self.user_model = user_model self.oauth_account_model = oauth_account_model - async def get(self, id: UUID4) -> Optional[UD]: + async def get(self, id: ID) -> Optional[UP]: """Get a single user by id.""" - with Session(self.engine) as session: - return session.get(self.user_db_model, id) + return self.session.get(self.user_model, id) - async def get_by_email(self, email: str) -> Optional[UD]: + async def get_by_email(self, email: str) -> Optional[UP]: """Get a single user by email.""" - with Session(self.engine) as session: - statement = select(self.user_db_model).where( - func.lower(self.user_db_model.email) == func.lower(email) - ) - results = session.exec(statement) - return results.first() - - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: + statement = select(self.user_model).where( # type: ignore + func.lower(self.user_model.email) == func.lower(email) + ) + results = self.session.exec(statement) + return results.first() + + async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: """Get a single user by OAuth account id.""" - if not self.oauth_account_model: - raise NotSetOAuthAccountTableError() - with Session(self.engine) as session: - statement = ( - select(self.oauth_account_model) - .where(self.oauth_account_model.oauth_name == oauth) - .where(self.oauth_account_model.account_id == account_id) - ) - results = session.exec(statement) - oauth_account = results.first() - if oauth_account: - user = oauth_account.user # type: ignore - return user - return None + if self.oauth_account_model is None: + raise NotImplementedError() + statement = ( + select(self.oauth_account_model) + .where(self.oauth_account_model.oauth_name == oauth) + .where(self.oauth_account_model.account_id == account_id) + ) + results = self.session.exec(statement) + oauth_account = results.first() + if oauth_account: + user = oauth_account.user # type: ignore + return user + return None - async def create(self, user: UD) -> UD: + async def create(self, create_dict: Dict[str, Any]) -> UP: """Create a user.""" - with Session(self.engine) as session: - session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - session.add(oauth_account) - session.commit() - session.refresh(user) - return user + user = self.user_model(**create_dict) + self.session.add(user) + self.session.commit() + self.session.refresh(user) + return user - async def update(self, user: UD) -> UD: - """Update a user.""" - with Session(self.engine) as session: - session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - session.add(oauth_account) - session.commit() - session.refresh(user) - return user + async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP: + for key, value in update_dict.items(): + setattr(user, key, value) + self.session.add(user) + self.session.commit() + self.session.refresh(user) + return user + + async def delete(self, user: UP) -> None: + self.session.delete(user) + self.session.commit() + + async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP: + if self.oauth_account_model is None: + raise NotImplementedError() + + oauth_account = self.oauth_account_model(**create_dict) + user.oauth_accounts.append(oauth_account) # type: ignore + self.session.add(user) + + self.session.commit() + + return user + + async def update_oauth_account( + self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any] + ) -> UP: + if self.oauth_account_model is None: + raise NotImplementedError() + + for key, value in update_dict.items(): + setattr(oauth_account, key, value) + self.session.add(oauth_account) + self.session.commit() - async def delete(self, user: UD) -> None: - """Delete a user.""" - with Session(self.engine) as session: - session.delete(user) - session.commit() + return user -class SQLModelUserDatabaseAsync(Generic[UD, OA], BaseUserDatabase[UD]): +class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]): """ Database adapter for SQLModel working purely asynchronously. - :param user_db_model: SQLModel model of a DB representation of a user. - :param engine: SQLAlchemy async engine. + :param user_model: SQLModel model of a DB representation of a user. + :param session: SQLAlchemy async session. """ - engine: AsyncEngine - oauth_account_model: Optional[Type[OA]] + session: AsyncSession + user_model: Type[UP] + oauth_account_model: Optional[Type[SQLModelBaseOAuthAccount]] def __init__( self, - user_db_model: Type[UD], - engine: AsyncEngine, - oauth_account_model: Optional[Type[OA]] = None, + session: AsyncSession, + user_model: Type[UP], + oauth_account_model: Optional[Type[SQLModelBaseOAuthAccount]] = None, ): - super().__init__(user_db_model) - self.engine = engine + self.session = session + self.user_model = user_model self.oauth_account_model = oauth_account_model - self.session_maker: Callable[[], AsyncSession] = sessionmaker( - self.engine, class_=AsyncSession, expire_on_commit=False - ) - async def get(self, id: UUID4) -> Optional[UD]: + async def get(self, id: ID) -> Optional[UP]: """Get a single user by id.""" - async with self.session_maker() as session: - return await session.get(self.user_db_model, id) + return await self.session.get(self.user_model, id) - async def get_by_email(self, email: str) -> Optional[UD]: + async def get_by_email(self, email: str) -> Optional[UP]: """Get a single user by email.""" - async with self.session_maker() as session: - statement = select(self.user_db_model).where( - func.lower(self.user_db_model.email) == func.lower(email) - ) - results = await session.execute(statement) - object = results.first() - if object is None: - return None - return object[0] - - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: - """Get a single user by OAuth account id.""" - if not self.oauth_account_model: - raise NotSetOAuthAccountTableError() - async with self.session_maker() as session: - statement = ( - select(self.oauth_account_model) - .where(self.oauth_account_model.oauth_name == oauth) - .where(self.oauth_account_model.account_id == account_id) - .options(selectinload(self.oauth_account_model.user)) # type: ignore - ) - results = await session.execute(statement) - oauth_account = results.first() - if oauth_account: - user = oauth_account[0].user - return user + statement = select(self.user_model).where( # type: ignore + func.lower(self.user_model.email) == func.lower(email) + ) + results = await self.session.execute(statement) + object = results.first() + if object is None: return None + return object[0] - async def create(self, user: UD) -> UD: - """Create a user.""" - async with self.session_maker() as session: - session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - session.add(oauth_account) - await session.commit() - await session.refresh(user) - return user - - async def update(self, user: UD) -> UD: - """Update a user.""" - async with self.session_maker() as session: - session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - session.add(oauth_account) - await session.commit() - await session.refresh(user) + async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: + """Get a single user by OAuth account id.""" + if self.oauth_account_model is None: + raise NotImplementedError() + statement = ( + select(self.oauth_account_model) + .where(self.oauth_account_model.oauth_name == oauth) + .where(self.oauth_account_model.account_id == account_id) + .options(selectinload(self.oauth_account_model.user)) # type: ignore + ) + results = await self.session.execute(statement) + oauth_account = results.first() + if oauth_account: + user = oauth_account[0].user # type: ignore return user + return None - async def delete(self, user: UD) -> None: - """Delete a user.""" - async with self.session_maker() as session: - await session.delete(user) - await session.commit() + async def create(self, create_dict: Dict[str, Any]) -> UP: + """Create a user.""" + user = self.user_model(**create_dict) + self.session.add(user) + await self.session.commit() + await self.session.refresh(user) + return user + + async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP: + for key, value in update_dict.items(): + setattr(user, key, value) + self.session.add(user) + await self.session.commit() + await self.session.refresh(user) + return user + + async def delete(self, user: UP) -> None: + await self.session.delete(user) + await self.session.commit() + + async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP: + if self.oauth_account_model is None: + raise NotImplementedError() + + oauth_account = self.oauth_account_model(**create_dict) + user.oauth_accounts.append(oauth_account) # type: ignore + self.session.add(user) + + await self.session.commit() + + return user + + async def update_oauth_account( + self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any] + ) -> UP: + if self.oauth_account_model is None: + raise NotImplementedError() + + for key, value in update_dict.items(): + setattr(oauth_account, key, value) + self.session.add(oauth_account) + await self.session.commit() + + return user diff --git a/fastapi_users_db_sqlmodel/access_token.py b/fastapi_users_db_sqlmodel/access_token.py new file mode 100644 index 0000000..8a4519e --- /dev/null +++ b/fastapi_users_db_sqlmodel/access_token.py @@ -0,0 +1,122 @@ +from datetime import datetime +from typing import Any, Dict, Generic, Optional, Type + +from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase +from pydantic import UUID4 +from sqlalchemy import Column, types +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import Field, Session, SQLModel, select + +from fastapi_users_db_sqlmodel.generics import TIMESTAMPAware, now_utc + + +class SQLModelBaseAccessToken(SQLModel): + __tablename__ = "accesstoken" + + token: str = Field( + sa_column=Column("token", types.String(length=43), primary_key=True) + ) + created_at: datetime = Field( + default_factory=now_utc, + sa_column=Column( + "created_at", TIMESTAMPAware(timezone=True), nullable=False, index=True + ), + ) + user_id: UUID4 = Field(foreign_key="user.id", nullable=False) + + class Config: + orm_mode = True + + +class SQLModelAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]): + """ + Access token database adapter for SQLModel. + + :param session: SQLAlchemy session. + :param access_token_model: SQLModel access token model. + """ + + def __init__(self, session: Session, access_token_model: Type[AP]): + self.session = session + self.access_token_model = access_token_model + + async def get_by_token( + self, token: str, max_age: Optional[datetime] = None + ) -> Optional[AP]: + statement = select(self.access_token_model).where( # type: ignore + self.access_token_model.token == token + ) + if max_age is not None: + statement = statement.where(self.access_token_model.created_at >= max_age) + + results = self.session.execute(statement) + access_token = results.first() + if access_token is None: + return None + return access_token[0] + + async def create(self, create_dict: Dict[str, Any]) -> AP: + access_token = self.access_token_model(**create_dict) + self.session.add(access_token) + self.session.commit() + self.session.refresh(access_token) + return access_token + + async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP: + for key, value in update_dict.items(): + setattr(access_token, key, value) + self.session.add(access_token) + self.session.commit() + self.session.refresh(access_token) + return access_token + + async def delete(self, access_token: AP) -> None: + self.session.delete(access_token) + self.session.commit() + + +class SQLModelAccessTokenDatabaseAsync(Generic[AP], AccessTokenDatabase[AP]): + """ + Access token database adapter for SQLModel working purely asynchronously. + + :param session: SQLAlchemy async session. + :param access_token_model: SQLModel access token model. + """ + + def __init__(self, session: AsyncSession, access_token_model: Type[AP]): + self.session = session + self.access_token_model = access_token_model + + async def get_by_token( + self, token: str, max_age: Optional[datetime] = None + ) -> Optional[AP]: + statement = select(self.access_token_model).where( # type: ignore + self.access_token_model.token == token + ) + if max_age is not None: + statement = statement.where(self.access_token_model.created_at >= max_age) + + results = await self.session.execute(statement) + access_token = results.first() + if access_token is None: + return None + return access_token[0] + + async def create(self, create_dict: Dict[str, Any]) -> AP: + access_token = self.access_token_model(**create_dict) + self.session.add(access_token) + await self.session.commit() + await self.session.refresh(access_token) + return access_token + + async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP: + for key, value in update_dict.items(): + setattr(access_token, key, value) + self.session.add(access_token) + await self.session.commit() + await self.session.refresh(access_token) + return access_token + + async def delete(self, access_token: AP) -> None: + await self.session.delete(access_token) + await self.session.commit() diff --git a/fastapi_users_db_sqlmodel/generics.py b/fastapi_users_db_sqlmodel/generics.py new file mode 100644 index 0000000..bfe2fda --- /dev/null +++ b/fastapi_users_db_sqlmodel/generics.py @@ -0,0 +1,24 @@ +from datetime import datetime, timezone + +from sqlalchemy import TIMESTAMP, TypeDecorator + + +def now_utc(): + return datetime.now(timezone.utc) + + +class TIMESTAMPAware(TypeDecorator): # pragma: no cover + """ + MySQL and SQLite will always return naive-Python datetimes. + + We store everything as UTC, but we want to have + only offset-aware Python datetimes, even with MySQL and SQLite. + """ + + impl = TIMESTAMP + cache_ok = True + + def process_result_value(self, value: datetime, dialect): + if dialect.name != "postgresql": + return value.replace(tzinfo=timezone.utc) + return value diff --git a/pyproject.toml b/pyproject.toml index c85388d..971fac3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,30 +1,84 @@ +[tool.pytest.ini_options] +asyncio_mode = "auto" +addopts = "--ignore=test_build.py" + +[tool.ruff] +extend-select = ["I"] + +[tool.hatch] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.version] +source = "regex_commit" +commit_extra_args = ["-e"] +path = "fastapi_users_db_sqlmodel/__init__.py" + +[tool.hatch.envs.default] +dependencies = [ + "aiosqlite", + "pytest", + "pytest-asyncio", + "black", + "mypy", + "pytest-cov", + "pytest-mock", + "httpx", + "asgi_lifespan", + "ruff", +] + +[tool.hatch.envs.default.scripts] +test = "pytest --cov=fastapi_users_db_sqlmodel/ --cov-report=term-missing --cov-fail-under=100" +test-cov-xml = "pytest --cov=fastapi_users_db_sqlmodel/ --cov-report=xml --cov-fail-under=100" +lint = [ + "black . ", + "ruff --fix .", + "mypy fastapi_users_db_sqlmodel/", +] +lint-check = [ + "black --check .", + "ruff .", + "mypy fastapi_users_db_sqlmodel/", +] + +[tool.hatch.build.targets.sdist] +support-legacy = true # Create setup.py + [build-system] -requires = ["flit_core >=2,<3"] -build-backend = "flit_core.buildapi" - -[tool.flit.metadata] -module = "fastapi_users_db_sqlmodel" -dist-name = "fastapi-users-db-sqlmodel" -author = "François Voron" -author-email = "fvoron@gmail.com" -home-page = "https://github.com/fastapi-users/fastapi-users-db-sqlmodel" +requires = ["hatchling", "hatch-regex-commit"] +build-backend = "hatchling.build" + +[project] +name = "fastapi-users-db-sqlmodel" +authors = [ + { name = "François Voron", email = "fvoron@gmail.com" }, +] +description = "FastAPI Users database adapter for SQLModel" +readme = "README.md" +dynamic = ["version"] classifiers = [ "License :: OSI Approved :: MIT License", "Development Status :: 5 - Production/Stable", + "Framework :: FastAPI", "Framework :: AsyncIO", "Intended Audience :: Developers", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3 :: Only", "Topic :: Internet :: WWW/HTTP :: Session", ] -description-file = "README.md" requires-python = ">=3.7" -requires = [ - "fastapi-users >= 7.0.0", - "sqlmodel >=0.0.4,<0.1.0", +dependencies = [ + "fastapi-users >= 10.0.2", + "greenlet", + "sqlmodel", ] -[tool.flit.metadata.urls] +[project.urls] Documentation = "https://fastapi-users.github.io/fastapi-users" +Source = "https://github.com/fastapi-users/fastapi-users-db-sqlmodel" diff --git a/requirements.dev.txt b/requirements.dev.txt deleted file mode 100644 index e8ebaaf..0000000 --- a/requirements.dev.txt +++ /dev/null @@ -1,18 +0,0 @@ --r requirements.txt - -flake8 -pytest -requests -isort -pytest-asyncio -flake8-docstrings -black -mypy -codecov -pytest-cov -pytest-mock -asynctest -flit -bumpversion -httpx -asgi_lifespan diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 6669104..0000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -aiosqlite >= 0.17.0 -fastapi-users >= 6.1.2 -sqlmodel >=0.0.4,<0.1.0 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index dad3708..0000000 --- a/setup.cfg +++ /dev/null @@ -1,26 +0,0 @@ -[bumpversion] -current_version = 0.0.3 -commit = True -tag = True - -[bumpversion:file:fastapi_users_db_sqlmodel/__init__.py] -search = __version__ = "{current_version}" -replace = __version__ = "{new_version}" - -[flake8] -exclude = docs -max-line-length = 88 -docstring-convention = numpy -ignore = D1 - -[isort] -profile = black - -[tool:pytest] -addopts = --ignore=test_build.py -markers = - authentication - db - fastapi_users - oauth - router diff --git a/tests/conftest.py b/tests/conftest.py index f20fa05..ee12b2c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,72 +1,55 @@ import asyncio -from typing import List, Optional +from typing import Any, Dict, List, Optional import pytest -from fastapi_users import models from pydantic import UUID4 from sqlmodel import Field, Relationship from fastapi_users_db_sqlmodel import SQLModelBaseOAuthAccount, SQLModelBaseUserDB -class User(models.BaseUser): +class User(SQLModelBaseUserDB, table=True): first_name: Optional[str] -class UserCreate(models.BaseUserCreate): - first_name: Optional[str] - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserDB(SQLModelBaseUserDB, User, table=True): - class Config: - orm_mode = True - - -class UserOAuth(User): - pass - - -class UserDBOAuth(SQLModelBaseUserDB, table=True): - __tablename__ = "user" +class UserOAuth(SQLModelBaseUserDB, table=True): + __tablename__ = "user_oauth" oauth_accounts: List["OAuthAccount"] = Relationship( back_populates="user", - sa_relationship_kwargs={"lazy": "joined", "cascade": "all, delete"}, + sa_relationship_kwargs={"lazy": "selectin", "cascade": "all, delete"}, ) class OAuthAccount(SQLModelBaseOAuthAccount, table=True): - user_id: UUID4 = Field(foreign_key="user.id") - user: Optional[UserDBOAuth] = Relationship(back_populates="oauth_accounts") + user_id: UUID4 = Field(foreign_key="user_oauth.id") + user: Optional[UserOAuth] = Relationship(back_populates="oauth_accounts") @pytest.fixture(scope="session") def event_loop(): """Force the pytest-asyncio loop to be the main one.""" - loop = asyncio.get_event_loop() + loop = asyncio.new_event_loop() yield loop + loop.close() @pytest.fixture -def oauth_account1() -> OAuthAccount: - return OAuthAccount( - oauth_name="service1", - access_token="TOKEN", - expires_at=1579000751, - account_id="user_oauth1", - account_email="king.arthur@camelot.bt", - ) +def oauth_account1() -> Dict[str, Any]: + return { + "oauth_name": "service1", + "access_token": "TOKEN", + "expires_at": 1579000751, + "account_id": "user_oauth1", + "account_email": "king.arthur@camelot.bt", + } @pytest.fixture -def oauth_account2() -> OAuthAccount: - return OAuthAccount( - oauth_name="service2", - access_token="TOKEN", - expires_at=1579000751, - account_id="user_oauth2", - account_email="king.arthur@camelot.bt", - ) +def oauth_account2() -> Dict[str, Any]: + return { + "oauth_name": "service2", + "access_token": "TOKEN", + "expires_at": 1579000751, + "account_id": "user_oauth2", + "account_email": "king.arthur@camelot.bt", + } diff --git a/tests/test_access_token.py b/tests/test_access_token.py new file mode 100644 index 0000000..ee93e04 --- /dev/null +++ b/tests/test_access_token.py @@ -0,0 +1,146 @@ +import uuid +from datetime import datetime, timedelta, timezone +from typing import AsyncGenerator + +import pytest +import pytest_asyncio +from pydantic import UUID4 +from sqlalchemy import exc +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel import Session, SQLModel, create_engine + +from fastapi_users_db_sqlmodel import SQLModelUserDatabase, SQLModelUserDatabaseAsync +from fastapi_users_db_sqlmodel.access_token import ( + SQLModelAccessTokenDatabase, + SQLModelAccessTokenDatabaseAsync, + SQLModelBaseAccessToken, +) +from tests.conftest import User + + +class AccessToken(SQLModelBaseAccessToken, table=True): + pass + + +@pytest.fixture +def user_id() -> UUID4: + return uuid.UUID("a9089e5d-2642-406d-a7c0-cbc641aca0ec") + + +async def init_sync_session(url: str) -> AsyncGenerator[Session, None]: + engine = create_engine(url, connect_args={"check_same_thread": False}) + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session + SQLModel.metadata.drop_all(engine) + + +async def init_async_session(url: str) -> AsyncGenerator[AsyncSession, None]: + engine = create_async_engine(url, connect_args={"check_same_thread": False}) + make_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + async with make_session() as session: + yield session + await conn.run_sync(SQLModel.metadata.drop_all) + + +@pytest_asyncio.fixture( + params=[ + ( + init_sync_session, + "sqlite:///./test-sqlmodel-access-token.db", + SQLModelAccessTokenDatabase, + SQLModelUserDatabase, + ), + ( + init_async_session, + "sqlite+aiosqlite:///./test-sqlmodel-access-token.db", + SQLModelAccessTokenDatabaseAsync, + SQLModelUserDatabaseAsync, + ), + ], + ids=["sync", "async"], +) +async def sqlmodel_access_token_db( + request, user_id: UUID4 +) -> AsyncGenerator[SQLModelAccessTokenDatabase, None]: + create_session = request.param[0] + database_url = request.param[1] + access_token_database_class = request.param[2] + user_database_class = request.param[3] + async for session in create_session(database_url): + user_db = user_database_class(session, User) + await user_db.create( + { + "id": user_id, + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } + ) + yield access_token_database_class(session, AccessToken) + + +@pytest.mark.asyncio +async def test_queries( + sqlmodel_access_token_db: SQLModelAccessTokenDatabase[AccessToken], + user_id: UUID4, +): + access_token_create = {"token": "TOKEN", "user_id": user_id} + + # Create + access_token = await sqlmodel_access_token_db.create(access_token_create) + assert access_token.token == "TOKEN" + assert access_token.user_id == user_id + + # Update + update_dict = {"created_at": datetime.now(timezone.utc)} + updated_access_token = await sqlmodel_access_token_db.update( + access_token, update_dict + ) + assert updated_access_token.created_at.replace(microsecond=0) == update_dict[ + "created_at" + ].replace(microsecond=0) + + # Get by token + access_token_by_token = await sqlmodel_access_token_db.get_by_token( + access_token.token + ) + assert access_token_by_token is not None + + # Get by token expired + access_token_by_token = await sqlmodel_access_token_db.get_by_token( + access_token.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) + ) + assert access_token_by_token is None + + # Get by token not expired + access_token_by_token = await sqlmodel_access_token_db.get_by_token( + access_token.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) + ) + assert access_token_by_token is not None + + # Get by token unknown + access_token_by_token = await sqlmodel_access_token_db.get_by_token( + "NOT_EXISTING_TOKEN" + ) + assert access_token_by_token is None + + # Delete token + await sqlmodel_access_token_db.delete(access_token) + deleted_access_token = await sqlmodel_access_token_db.get_by_token( + access_token.token + ) + assert deleted_access_token is None + + +@pytest.mark.asyncio +async def test_insert_existing_token( + sqlmodel_access_token_db: SQLModelAccessTokenDatabase[AccessToken], user_id: UUID4 +): + access_token_create = {"token": "TOKEN", "user_id": user_id} + await sqlmodel_access_token_db.create(access_token_create) + + with pytest.raises(exc.IntegrityError): + await sqlmodel_access_token_db.create(access_token_create) diff --git a/tests/test_fastapi_users_db_sqlmodel.py b/tests/test_fastapi_users_db_sqlmodel.py deleted file mode 100644 index 5998fcf..0000000 --- a/tests/test_fastapi_users_db_sqlmodel.py +++ /dev/null @@ -1,203 +0,0 @@ -import uuid -from typing import AsyncGenerator - -import pytest -from sqlalchemy import exc -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from sqlalchemy.future import Engine -from sqlmodel import SQLModel, create_engine - -from fastapi_users_db_sqlmodel import ( - NotSetOAuthAccountTableError, - SQLModelUserDatabase, - SQLModelUserDatabaseAsync, -) -from tests.conftest import OAuthAccount, UserDB, UserDBOAuth - - -safe_uuid = uuid.UUID("a9089e5d-2642-406d-a7c0-cbc641aca0ec") - -async def init_sync_engine(url: str) -> AsyncGenerator[Engine, None]: - engine = create_engine(url, connect_args={"check_same_thread": False}) - SQLModel.metadata.create_all(engine) - yield engine - SQLModel.metadata.drop_all(engine) - - -async def init_async_engine(url: str) -> AsyncGenerator[AsyncEngine, None]: - engine = create_async_engine(url, connect_args={"check_same_thread": False}) - async with engine.begin() as conn: - await conn.run_sync(SQLModel.metadata.create_all) - yield engine - await conn.run_sync(SQLModel.metadata.drop_all) - - -@pytest.fixture( - params=[ - (init_sync_engine, "sqlite:///./test-sqlmodel-user.db", SQLModelUserDatabase), - ( - init_async_engine, - "sqlite+aiosqlite:///./test-sqlmodel-user.db", - SQLModelUserDatabaseAsync, - ), - ] -) -async def sqlmodel_user_db(request) -> AsyncGenerator[SQLModelUserDatabase, None]: - create_engine = request.param[0] - database_url = request.param[1] - database_class = request.param[2] - async for engine in create_engine(database_url): - yield database_class(UserDB, engine) - - -@pytest.fixture( - params=[ - ( - init_sync_engine, - "sqlite:///./test-sqlmodel-user-oauth.db", - SQLModelUserDatabase, - ), - ( - init_async_engine, - "sqlite+aiosqlite:///./test-sqlmodel-user-oauth.db", - SQLModelUserDatabaseAsync, - ), - ] -) -async def sqlmodel_user_db_oauth(request) -> AsyncGenerator[SQLModelUserDatabase, None]: - create_engine = request.param[0] - database_url = request.param[1] - database_class = request.param[2] - async for engine in create_engine(database_url): - yield database_class(UserDBOAuth, engine, OAuthAccount) - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries(sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccount]): - user = UserDB( - id=safe_uuid, - email="lancelot@camelot.bt", - hashed_password="guinevere", - ) - - # Create - user_db = await sqlmodel_user_db.create(user) - assert user_db.id is not None - assert user_db.is_active is True - assert user_db.is_superuser is False - assert user_db.email == user.email - - # Update - user_db.is_superuser = True - await sqlmodel_user_db.update(user_db) - - # Get by id - id_user = await sqlmodel_user_db.get(user.id) - assert id_user is not None - assert id_user.id == user_db.id - assert id_user.is_superuser is True - - # Get by email - email_user = await sqlmodel_user_db.get_by_email(str(user.email)) - assert email_user is not None - assert email_user.id == user_db.id - - # Get by uppercased email - email_user = await sqlmodel_user_db.get_by_email("Lancelot@camelot.bt") - assert email_user is not None - assert email_user.id == user_db.id - - # Exception when inserting existing email - with pytest.raises(exc.IntegrityError): - await sqlmodel_user_db.create( - UserDB(id=safe_uuid, email=user_db.email, hashed_password="guinevere") - ) - - # Exception when inserting non-nullable fields - with pytest.raises(exc.IntegrityError): - wrong_user = UserDB(id=safe_uuid, email="lancelot@camelot.bt", hashed_password="aaa") - wrong_user.email = None # type: ignore - await sqlmodel_user_db.create(wrong_user) - - # Unknown user - unknown_user = await sqlmodel_user_db.get_by_email("galahad@camelot.bt") - assert unknown_user is None - - # Delete user - await sqlmodel_user_db.delete(user) - deleted_user = await sqlmodel_user_db.get(user.id) - assert deleted_user is None - - # Exception when trying to get by OAuth account - with pytest.raises(NotSetOAuthAccountTableError): - await sqlmodel_user_db.get_by_oauth_account("foo", "bar") - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries_custom_fields( - sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccount], -): - """It should output custom fields in query result.""" - user = UserDB( - id=safe_uuid, - email="lancelot@camelot.bt", - hashed_password="guinevere", - first_name="Lancelot", - ) - await sqlmodel_user_db.create(user) - - id_user = await sqlmodel_user_db.get(user.id) - assert id_user is not None - assert id_user.id == user.id - assert id_user.first_name == user.first_name - - -@pytest.mark.asyncio -@pytest.mark.db -async def test_queries_oauth( - sqlmodel_user_db_oauth: SQLModelUserDatabase[UserDBOAuth, OAuthAccount], - oauth_account1, - oauth_account2, -): - user = UserDBOAuth( - id=safe_uuid, - email="lancelot@camelot.bt", - hashed_password="guinevere", - oauth_accounts=[oauth_account1, oauth_account2], - ) - - # Create - user_db = await sqlmodel_user_db_oauth.create(user) - assert user_db.id is not None - assert hasattr(user_db, "oauth_accounts") - assert len(user_db.oauth_accounts) == 2 - - # Update - user_db.oauth_accounts[0].access_token = "NEW_TOKEN" - await sqlmodel_user_db_oauth.update(user_db) - - # Get by id - id_user = await sqlmodel_user_db_oauth.get(user.id) - assert id_user is not None - assert id_user.id == user_db.id - assert id_user.oauth_accounts[0].access_token == "NEW_TOKEN" - - # Get by email - email_user = await sqlmodel_user_db_oauth.get_by_email(str(user.email)) - assert email_user is not None - assert email_user.id == user_db.id - assert len(email_user.oauth_accounts) == 2 - - # Get by OAuth account - oauth_user = await sqlmodel_user_db_oauth.get_by_oauth_account( - oauth_account1.oauth_name, oauth_account1.account_id - ) - assert oauth_user is not None - assert oauth_user.id == user.id - assert len(oauth_user.oauth_accounts) == 2 - - # Unknown OAuth account - unknown_oauth_user = await sqlmodel_user_db_oauth.get_by_oauth_account("foo", "bar") - assert unknown_oauth_user is None diff --git a/tests/test_users.py b/tests/test_users.py new file mode 100644 index 0000000..dc32aab --- /dev/null +++ b/tests/test_users.py @@ -0,0 +1,212 @@ +import uuid +from typing import Any, AsyncGenerator, Dict + +import pytest +import pytest_asyncio +from pydantic import UUID4 +from sqlalchemy import exc +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel import Session, SQLModel, create_engine + +from fastapi_users_db_sqlmodel import SQLModelUserDatabase, SQLModelUserDatabaseAsync +from tests.conftest import OAuthAccount, User, UserOAuth + +safe_uuid = uuid.UUID("a9089e5d-2642-406d-a7c0-cbc641aca0ec") + + +async def init_sync_session(url: str) -> AsyncGenerator[Session, None]: + engine = create_engine(url, connect_args={"check_same_thread": False}) + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session + SQLModel.metadata.drop_all(engine) + + +async def init_async_session(url: str) -> AsyncGenerator[AsyncSession, None]: + engine = create_async_engine(url, connect_args={"check_same_thread": False}) + make_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + async with make_session() as session: + yield session + await conn.run_sync(SQLModel.metadata.drop_all) + + +@pytest_asyncio.fixture( + params=[ + (init_sync_session, "sqlite:///./test-sqlmodel-user.db", SQLModelUserDatabase), + ( + init_async_session, + "sqlite+aiosqlite:///./test-sqlmodel-user.db", + SQLModelUserDatabaseAsync, + ), + ], + ids=["sync", "async"], +) +async def sqlmodel_user_db(request) -> AsyncGenerator[SQLModelUserDatabase, None]: + create_session = request.param[0] + database_url = request.param[1] + database_class = request.param[2] + async for session in create_session(database_url): + yield database_class(session, User) + + +@pytest_asyncio.fixture( + params=[ + ( + init_sync_session, + "sqlite:///./test-sqlmodel-user-oauth.db", + SQLModelUserDatabase, + ), + ( + init_async_session, + "sqlite+aiosqlite:///./test-sqlmodel-user-oauth.db", + SQLModelUserDatabaseAsync, + ), + ], + ids=["sync", "async"], +) +async def sqlmodel_user_db_oauth(request) -> AsyncGenerator[SQLModelUserDatabase, None]: + create_session = request.param[0] + database_url = request.param[1] + database_class = request.param[2] + async for session in create_session(database_url): + yield database_class(session, UserOAuth, OAuthAccount) + + +@pytest.mark.asyncio +async def test_queries(sqlmodel_user_db: SQLModelUserDatabase[User, UUID4]): + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } + + # Create + user = await sqlmodel_user_db.create(user_create) + assert user.id is not None + assert user.is_active is True + assert user.is_superuser is False + assert user.email == user_create["email"] + + # Update + updated_user = await sqlmodel_user_db.update(user, {"is_superuser": True}) + assert updated_user.is_superuser is True + + # Get by id + id_user = await sqlmodel_user_db.get(user.id) + assert id_user is not None + assert id_user.id == user.id + assert id_user.is_superuser is True + + # Get by email + email_user = await sqlmodel_user_db.get_by_email(str(user_create["email"])) + assert email_user is not None + assert email_user.id == user.id + + # Get by uppercased email + email_user = await sqlmodel_user_db.get_by_email("Lancelot@camelot.bt") + assert email_user is not None + assert email_user.id == user.id + + # Unknown user + unknown_user = await sqlmodel_user_db.get_by_email("galahad@camelot.bt") + assert unknown_user is None + + # Delete user + await sqlmodel_user_db.delete(user) + deleted_user = await sqlmodel_user_db.get(user.id) + assert deleted_user is None + + # OAuth without defined table + with pytest.raises(NotImplementedError): + await sqlmodel_user_db.get_by_oauth_account("foo", "bar") + with pytest.raises(NotImplementedError): + await sqlmodel_user_db.add_oauth_account(user, {}) + with pytest.raises(NotImplementedError): + oauth_account = OAuthAccount() + await sqlmodel_user_db.update_oauth_account(user, oauth_account, {}) + + +@pytest.mark.asyncio +async def test_insert_existing_email( + sqlmodel_user_db: SQLModelUserDatabase[User, UUID4] +): + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } + await sqlmodel_user_db.create(user_create) + + with pytest.raises(exc.IntegrityError): + await sqlmodel_user_db.create(user_create) + + +@pytest.mark.asyncio +async def test_queries_custom_fields( + sqlmodel_user_db: SQLModelUserDatabase[User, UUID4], +): + """It should output custom fields in query result.""" + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + "first_name": "Lancelot", + } + user = await sqlmodel_user_db.create(user_create) + + id_user = await sqlmodel_user_db.get(user.id) + assert id_user is not None + assert id_user.id == user.id + assert id_user.first_name == user.first_name + + +@pytest.mark.asyncio +async def test_queries_oauth( + sqlmodel_user_db_oauth: SQLModelUserDatabase[UserOAuth, UUID4], + oauth_account1: Dict[str, Any], + oauth_account2: Dict[str, Any], +): + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } + + # Create + user = await sqlmodel_user_db_oauth.create(user_create) + assert user.id is not None + + # Add OAuth account + user = await sqlmodel_user_db_oauth.add_oauth_account(user, oauth_account1) + user = await sqlmodel_user_db_oauth.add_oauth_account(user, oauth_account2) + assert len(user.oauth_accounts) == 2 + assert user.oauth_accounts[1].account_id == oauth_account2["account_id"] + assert user.oauth_accounts[0].account_id == oauth_account1["account_id"] + + # Update + user = await sqlmodel_user_db_oauth.update_oauth_account( + user, user.oauth_accounts[0], {"access_token": "NEW_TOKEN"} + ) + assert user.oauth_accounts[0].access_token == "NEW_TOKEN" + + # Get by id + id_user = await sqlmodel_user_db_oauth.get(user.id) + assert id_user is not None + assert id_user.id == user.id + assert id_user.oauth_accounts[0].access_token == "NEW_TOKEN" + + # Get by email + email_user = await sqlmodel_user_db_oauth.get_by_email(user_create["email"]) + assert email_user is not None + assert email_user.id == user.id + assert len(email_user.oauth_accounts) == 2 + + # Get by OAuth account + oauth_user = await sqlmodel_user_db_oauth.get_by_oauth_account( + oauth_account1["oauth_name"], oauth_account1["account_id"] + ) + assert oauth_user is not None + assert oauth_user.id == user.id + + # Unknown OAuth account + unknown_oauth_user = await sqlmodel_user_db_oauth.get_by_oauth_account("foo", "bar") + assert unknown_oauth_user is None