From e4ec43e45805af9aba1e4f2e4278b76bd49fbabb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 20 Sep 2021 08:23:25 +0200 Subject: [PATCH 01/22] Inject a session instead of an engine --- fastapi_users_db_sqlmodel/__init__.py | 176 +++++++++++------------- requirements.txt | 2 +- tests/test_fastapi_users_db_sqlmodel.py | 87 +++++++----- 3 files changed, 137 insertions(+), 128 deletions(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 1623191..aa8a5b7 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -1,13 +1,12 @@ """FastAPI Users database adapter for SQLModel.""" import uuid -from typing import Callable, Generic, Optional, Type, TypeVar +from typing import Generic, Optional, Type, TypeVar from fastapi_users.db.base import BaseUserDatabase from fastapi_users.models import BaseOAuthAccount, BaseUserDB from pydantic import UUID4, EmailStr -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession -from sqlalchemy.future import Engine -from sqlalchemy.orm import selectinload, sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select __version__ = "0.0.3" @@ -48,80 +47,74 @@ class SQLModelUserDatabase(Generic[UD, OA], BaseUserDatabase[UD]): Database adapter for SQLModel. :param user_db_model: SQLModel model of a DB representation of a user. - :param engine: SQLAlchemy engine. + :param session: SQLAlchemy session. """ - engine: Engine + session: Session oauth_account_model: Optional[Type[OA]] def __init__( self, user_db_model: Type[UD], - engine: Engine, + session: Session, oauth_account_model: Optional[Type[OA]] = None, ): super().__init__(user_db_model) - self.engine = engine + self.session = session self.oauth_account_model = oauth_account_model async def get(self, id: UUID4) -> Optional[UD]: """Get a single user by id.""" - with Session(self.engine) as session: - return session.get(self.user_db_model, id) + return self.session.get(self.user_db_model, id) async def get_by_email(self, email: str) -> Optional[UD]: """Get a single user by email.""" - with Session(self.engine) as session: - statement = select(self.user_db_model).where( - func.lower(self.user_db_model.email) == func.lower(email) - ) - results = session.exec(statement) - return results.first() + statement = select(self.user_db_model).where( + func.lower(self.user_db_model.email) == func.lower(email) + ) + results = self.session.exec(statement) + return results.first() async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: """Get a single user by OAuth account id.""" if not self.oauth_account_model: raise NotSetOAuthAccountTableError() - with Session(self.engine) as session: - statement = ( - select(self.oauth_account_model) - .where(self.oauth_account_model.oauth_name == oauth) - .where(self.oauth_account_model.account_id == account_id) - ) - results = session.exec(statement) - oauth_account = results.first() - if oauth_account: - user = oauth_account.user # type: ignore - return user - return None + statement = ( + select(self.oauth_account_model) + .where(self.oauth_account_model.oauth_name == oauth) + .where(self.oauth_account_model.account_id == account_id) + ) + results = self.session.exec(statement) + oauth_account = results.first() + if oauth_account: + user = oauth_account.user # type: ignore + return user + return None async def create(self, user: UD) -> UD: """Create a user.""" - with Session(self.engine) as session: - session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - session.add(oauth_account) - session.commit() - session.refresh(user) - return user + self.session.add(user) + if self.oauth_account_model is not None: + for oauth_account in user.oauth_accounts: # type: ignore + self.session.add(oauth_account) + self.session.commit() + self.session.refresh(user) + return user async def update(self, user: UD) -> UD: """Update a user.""" - with Session(self.engine) as session: - session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - session.add(oauth_account) - session.commit() - session.refresh(user) - return user + self.session.add(user) + if self.oauth_account_model is not None: + for oauth_account in user.oauth_accounts: # type: ignore + self.session.add(oauth_account) + self.session.commit() + self.session.refresh(user) + return user async def delete(self, user: UD) -> None: """Delete a user.""" - with Session(self.engine) as session: - session.delete(user) - session.commit() + self.session.delete(user) + self.session.commit() class SQLModelUserDatabaseAsync(Generic[UD, OA], BaseUserDatabase[UD]): @@ -132,81 +125,72 @@ class SQLModelUserDatabaseAsync(Generic[UD, OA], BaseUserDatabase[UD]): :param engine: SQLAlchemy async engine. """ - engine: AsyncEngine + session: AsyncSession oauth_account_model: Optional[Type[OA]] def __init__( self, user_db_model: Type[UD], - engine: AsyncEngine, + session: AsyncSession, oauth_account_model: Optional[Type[OA]] = None, ): super().__init__(user_db_model) - self.engine = engine + self.session = session self.oauth_account_model = oauth_account_model - self.session_maker: Callable[[], AsyncSession] = sessionmaker( - self.engine, class_=AsyncSession, expire_on_commit=False - ) async def get(self, id: UUID4) -> Optional[UD]: """Get a single user by id.""" - async with self.session_maker() as session: - return await session.get(self.user_db_model, id) + return await self.session.get(self.user_db_model, id) async def get_by_email(self, email: str) -> Optional[UD]: """Get a single user by email.""" - async with self.session_maker() as session: - statement = select(self.user_db_model).where( - func.lower(self.user_db_model.email) == func.lower(email) - ) - results = await session.execute(statement) - object = results.first() - if object is None: - return None - return object[0] + statement = select(self.user_db_model).where( + func.lower(self.user_db_model.email) == func.lower(email) + ) + results = await self.session.execute(statement) + object = results.first() + if object is None: + return None + return object[0] async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: """Get a single user by OAuth account id.""" if not self.oauth_account_model: raise NotSetOAuthAccountTableError() - async with self.session_maker() as session: - statement = ( - select(self.oauth_account_model) - .where(self.oauth_account_model.oauth_name == oauth) - .where(self.oauth_account_model.account_id == account_id) - .options(selectinload(self.oauth_account_model.user)) # type: ignore - ) - results = await session.execute(statement) - oauth_account = results.first() - if oauth_account: - user = oauth_account[0].user - return user - return None + statement = ( + select(self.oauth_account_model) + .where(self.oauth_account_model.oauth_name == oauth) + .where(self.oauth_account_model.account_id == account_id) + .options(selectinload(self.oauth_account_model.user)) # type: ignore + ) + results = await self.session.execute(statement) + oauth_account = results.first() + if oauth_account: + user = oauth_account[0].user + return user + return None async def create(self, user: UD) -> UD: """Create a user.""" - async with self.session_maker() as session: - session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - session.add(oauth_account) - await session.commit() - await session.refresh(user) - return user + self.session.add(user) + if self.oauth_account_model is not None: + for oauth_account in user.oauth_accounts: # type: ignore + self.session.add(oauth_account) + await self.session.commit() + await self.session.refresh(user) + return user async def update(self, user: UD) -> UD: """Update a user.""" - async with self.session_maker() as session: - session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - session.add(oauth_account) - await session.commit() - await session.refresh(user) - return user + self.session.add(user) + if self.oauth_account_model is not None: + for oauth_account in user.oauth_accounts: # type: ignore + self.session.add(oauth_account) + await self.session.commit() + await self.session.refresh(user) + return user async def delete(self, user: UD) -> None: """Delete a user.""" - async with self.session_maker() as session: - await session.delete(user) - await session.commit() + await self.session.delete(user) + await self.session.commit() diff --git a/requirements.txt b/requirements.txt index 6669104..ece98a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ aiosqlite >= 0.17.0 -fastapi-users >= 6.1.2 +fastapi-users >= 8.0.0b3 sqlmodel >=0.0.4,<0.1.0 diff --git a/tests/test_fastapi_users_db_sqlmodel.py b/tests/test_fastapi_users_db_sqlmodel.py index 5998fcf..9a9107b 100644 --- a/tests/test_fastapi_users_db_sqlmodel.py +++ b/tests/test_fastapi_users_db_sqlmodel.py @@ -3,9 +3,9 @@ import pytest from sqlalchemy import exc -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from sqlalchemy.future import Engine -from sqlmodel import SQLModel, create_engine +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker +from sqlmodel import SQLModel, create_engine, Session from fastapi_users_db_sqlmodel import ( NotSetOAuthAccountTableError, @@ -17,59 +17,65 @@ safe_uuid = uuid.UUID("a9089e5d-2642-406d-a7c0-cbc641aca0ec") -async def init_sync_engine(url: str) -> AsyncGenerator[Engine, None]: + +async def init_sync_session(url: str) -> AsyncGenerator[Session, None]: engine = create_engine(url, connect_args={"check_same_thread": False}) SQLModel.metadata.create_all(engine) - yield engine + with Session(engine) as session: + yield session SQLModel.metadata.drop_all(engine) -async def init_async_engine(url: str) -> AsyncGenerator[AsyncEngine, None]: +async def init_async_session(url: str) -> AsyncGenerator[AsyncSession, None]: engine = create_async_engine(url, connect_args={"check_same_thread": False}) + make_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) async with engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) - yield engine + async with make_session() as session: + yield session await conn.run_sync(SQLModel.metadata.drop_all) @pytest.fixture( params=[ - (init_sync_engine, "sqlite:///./test-sqlmodel-user.db", SQLModelUserDatabase), + (init_sync_session, "sqlite:///./test-sqlmodel-user.db", SQLModelUserDatabase), ( - init_async_engine, + init_async_session, "sqlite+aiosqlite:///./test-sqlmodel-user.db", SQLModelUserDatabaseAsync, ), - ] + ], + ids=["sync", "async"], ) async def sqlmodel_user_db(request) -> AsyncGenerator[SQLModelUserDatabase, None]: - create_engine = request.param[0] + create_session = request.param[0] database_url = request.param[1] database_class = request.param[2] - async for engine in create_engine(database_url): - yield database_class(UserDB, engine) + async for session in create_session(database_url): + yield database_class(UserDB, session) @pytest.fixture( params=[ ( - init_sync_engine, + init_sync_session, "sqlite:///./test-sqlmodel-user-oauth.db", SQLModelUserDatabase, ), ( - init_async_engine, + init_async_session, "sqlite+aiosqlite:///./test-sqlmodel-user-oauth.db", SQLModelUserDatabaseAsync, ), - ] + ], + ids=["sync", "async"], ) async def sqlmodel_user_db_oauth(request) -> AsyncGenerator[SQLModelUserDatabase, None]: - create_engine = request.param[0] + create_session = request.param[0] database_url = request.param[1] database_class = request.param[2] - async for engine in create_engine(database_url): - yield database_class(UserDBOAuth, engine, OAuthAccount) + async for session in create_session(database_url): + yield database_class(UserDBOAuth, session, OAuthAccount) @pytest.mark.asyncio @@ -108,18 +114,6 @@ async def test_queries(sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccou assert email_user is not None assert email_user.id == user_db.id - # Exception when inserting existing email - with pytest.raises(exc.IntegrityError): - await sqlmodel_user_db.create( - UserDB(id=safe_uuid, email=user_db.email, hashed_password="guinevere") - ) - - # Exception when inserting non-nullable fields - with pytest.raises(exc.IntegrityError): - wrong_user = UserDB(id=safe_uuid, email="lancelot@camelot.bt", hashed_password="aaa") - wrong_user.email = None # type: ignore - await sqlmodel_user_db.create(wrong_user) - # Unknown user unknown_user = await sqlmodel_user_db.get_by_email("galahad@camelot.bt") assert unknown_user is None @@ -134,6 +128,37 @@ async def test_queries(sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccou await sqlmodel_user_db.get_by_oauth_account("foo", "bar") +@pytest.mark.asyncio +@pytest.mark.db +async def test_insert_existing_email( + sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccount] +): + user = UserDB( + id=safe_uuid, + email="lancelot@camelot.bt", + hashed_password="guinevere", + ) + await sqlmodel_user_db.create(user) + + with pytest.raises(exc.IntegrityError): + await sqlmodel_user_db.create( + UserDB(id=safe_uuid, email=user.email, hashed_password="guinevere") + ) + + +@pytest.mark.asyncio +@pytest.mark.db +async def test_insert_non_nullable_fields( + sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccount] +): + with pytest.raises(exc.IntegrityError): + wrong_user = UserDB( + id=safe_uuid, email="lancelot@camelot.bt", hashed_password="aaa" + ) + wrong_user.email = None # type: ignore + await sqlmodel_user_db.create(wrong_user) + + @pytest.mark.asyncio @pytest.mark.db async def test_queries_custom_fields( From bb9207966d0424ebcace03bfd47ce8d63b3befc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 20 Sep 2021 10:05:10 +0200 Subject: [PATCH 02/22] =?UTF-8?q?Bump=20version:=200.0.3=20=E2=86=92=200.0?= =?UTF-8?q?.4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi_users_db_sqlmodel/__init__.py | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index aa8a5b7..052d739 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select -__version__ = "0.0.3" +__version__ = "0.0.4" class SQLModelBaseUserDB(BaseUserDB, SQLModel): diff --git a/setup.cfg b/setup.cfg index dad3708..7620d81 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.0.3 +current_version = 0.0.4 commit = True tag = True From c5cf0eeda16baa8b8efec52a5f00b407da4c11eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 13 Oct 2021 13:23:49 +0200 Subject: [PATCH 03/22] Enforce is_active/is_superuser/is_verified to be not nullable --- fastapi_users_db_sqlmodel/__init__.py | 10 ++++++++-- requirements.txt | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 052d739..19da878 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -13,8 +13,14 @@ class SQLModelBaseUserDB(BaseUserDB, SQLModel): - id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True) - email: EmailStr = Field(sa_column_kwargs={"unique": True, "index": True}) + id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True, nullable=False) + email: EmailStr = Field( + sa_column_kwargs={"unique": True, "index": True}, nullable=False + ) + + is_active: bool = Field(True, nullable=False) + is_superuser: bool = Field(False, nullable=False) + is_verified: bool = Field(False, nullable=False) class Config: orm_mode = True diff --git a/requirements.txt b/requirements.txt index ece98a3..c0c3bbb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ aiosqlite >= 0.17.0 -fastapi-users >= 8.0.0b3 +fastapi-users >= 8.1.1 sqlmodel >=0.0.4,<0.1.0 From ce228ed7be73e0af901952371e765e233b8e0920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 13 Oct 2021 13:25:10 +0200 Subject: [PATCH 04/22] =?UTF-8?q?Bump=20version:=200.0.4=20=E2=86=92=200.0?= =?UTF-8?q?.5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi_users_db_sqlmodel/__init__.py | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 19da878..4aed3d6 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select -__version__ = "0.0.4" +__version__ = "0.0.5" class SQLModelBaseUserDB(BaseUserDB, SQLModel): diff --git a/setup.cfg b/setup.cfg index 7620d81..3b88c62 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.0.4 +current_version = 0.0.5 commit = True tag = True From 4eac6f9dfb2722c725cb386172c6e5de59f57af8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 24 Nov 2021 17:40:55 +0100 Subject: [PATCH 05/22] Loosen version constraint for SQLModel dependency --- pyproject.toml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c85388d..18dfb27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ description-file = "README.md" requires-python = ">=3.7" requires = [ "fastapi-users >= 7.0.0", - "sqlmodel >=0.0.4,<0.1.0", + "sqlmodel", ] [tool.flit.metadata.urls] diff --git a/requirements.txt b/requirements.txt index c0c3bbb..3a5c620 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ aiosqlite >= 0.17.0 fastapi-users >= 8.1.1 -sqlmodel >=0.0.4,<0.1.0 +sqlmodel From d61ee21007c070b296af9e0f82492caf7ff7a052 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 24 Nov 2021 17:41:13 +0100 Subject: [PATCH 06/22] =?UTF-8?q?Bump=20version:=200.0.5=20=E2=86=92=200.0?= =?UTF-8?q?.6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi_users_db_sqlmodel/__init__.py | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 4aed3d6..94192a1 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select -__version__ = "0.0.5" +__version__ = "0.0.6" class SQLModelBaseUserDB(BaseUserDB, SQLModel): diff --git a/setup.cfg b/setup.cfg index 3b88c62..0161b38 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.0.5 +current_version = 0.0.6 commit = True tag = True From f3a73eeeca57237daa3b515952be47ea4096c9e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 31 Dec 2021 14:50:02 +0100 Subject: [PATCH 07/22] Implement access token strategy db adapter --- fastapi_users_db_sqlmodel/__init__.py | 7 +- fastapi_users_db_sqlmodel/access_token.py | 114 ++++++++++++++ pyproject.toml | 2 +- requirements.txt | 2 +- tests/conftest.py | 7 +- tests/test_access_token.py | 141 ++++++++++++++++++ ...api_users_db_sqlmodel.py => test_users.py} | 5 +- 7 files changed, 270 insertions(+), 8 deletions(-) create mode 100644 fastapi_users_db_sqlmodel/access_token.py create mode 100644 tests/test_access_token.py rename tests/{test_fastapi_users_db_sqlmodel.py => test_users.py} (98%) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 94192a1..02ef030 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -13,6 +13,8 @@ class SQLModelBaseUserDB(BaseUserDB, SQLModel): + __tablename__ = "user" + id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True, nullable=False) email: EmailStr = Field( sa_column_kwargs={"unique": True, "index": True}, nullable=False @@ -27,7 +29,10 @@ class Config: class SQLModelBaseOAuthAccount(BaseOAuthAccount, SQLModel): + __tablename__ = "oauthaccount" + id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True) + user_id: UUID4 = Field(foreign_key="user.id", nullable=False) class Config: orm_mode = True @@ -128,7 +133,7 @@ class SQLModelUserDatabaseAsync(Generic[UD, OA], BaseUserDatabase[UD]): Database adapter for SQLModel working purely asynchronously. :param user_db_model: SQLModel model of a DB representation of a user. - :param engine: SQLAlchemy async engine. + :param session: SQLAlchemy async session. """ session: AsyncSession diff --git a/fastapi_users_db_sqlmodel/access_token.py b/fastapi_users_db_sqlmodel/access_token.py new file mode 100644 index 0000000..a836ed8 --- /dev/null +++ b/fastapi_users_db_sqlmodel/access_token.py @@ -0,0 +1,114 @@ +from datetime import datetime, timezone +from typing import Generic, Optional, Type, TypeVar + +from fastapi_users.authentication.strategy.db import AccessTokenDatabase +from fastapi_users.authentication.strategy.db.models import BaseAccessToken +from pydantic import UUID4 +from sqlalchemy import Column, types +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import Field, Session, SQLModel, select + + +def now_utc(): + return datetime.now(timezone.utc) + + +class SQLModelBaseAccessToken(BaseAccessToken, SQLModel): + __tablename__ = "accesstoken" + + token: str = Field( + sa_column=Column("token", types.String(length=43), primary_key=True) + ) + created_at: datetime = Field(default_factory=now_utc, nullable=False) + user_id: UUID4 = Field(foreign_key="user.id", nullable=False) + + class Config: + orm_mode = True + + +A = TypeVar("A", bound=SQLModelBaseAccessToken) + + +class SQLModelAccessTokenDatabase(Generic[A], AccessTokenDatabase[A]): + """ + Access token database adapter for SQLModel. + + :param user_db_model: SQLModel model of a DB representation of an access token. + :param session: SQLAlchemy session. + """ + + def __init__(self, access_token_model: Type[A], session: Session): + self.access_token_model = access_token_model + self.session = session + + async def get_by_token( + self, token: str, max_age: Optional[datetime] = None + ) -> Optional[A]: + statement = select(self.access_token_model).where( + self.access_token_model.token == token + ) + if max_age is not None: + statement = statement.where(self.access_token_model.created_at >= max_age) + + results = self.session.exec(statement) + return results.first() + + async def create(self, access_token: A) -> A: + self.session.add(access_token) + self.session.commit() + self.session.refresh(access_token) + return access_token + + async def update(self, access_token: A) -> A: + self.session.add(access_token) + self.session.commit() + self.session.refresh(access_token) + return access_token + + async def delete(self, access_token: A) -> None: + self.session.delete(access_token) + self.session.commit() + + +class SQLModelAccessTokenDatabaseAsync(Generic[A], AccessTokenDatabase[A]): + """ + Access token database adapter for SQLModel working purely asynchronously. + + :param user_db_model: SQLModel model of a DB representation of an access token. + :param session: SQLAlchemy async session. + """ + + def __init__(self, access_token_model: Type[A], session: AsyncSession): + self.access_token_model = access_token_model + self.session = session + + async def get_by_token( + self, token: str, max_age: Optional[datetime] = None + ) -> Optional[A]: + statement = select(self.access_token_model).where( + self.access_token_model.token == token + ) + if max_age is not None: + statement = statement.where(self.access_token_model.created_at >= max_age) + + results = await self.session.execute(statement) + object = results.first() + if object is None: + return None + return object[0] + + async def create(self, access_token: A) -> A: + self.session.add(access_token) + await self.session.commit() + await self.session.refresh(access_token) + return access_token + + async def update(self, access_token: A) -> A: + self.session.add(access_token) + await self.session.commit() + await self.session.refresh(access_token) + return access_token + + async def delete(self, access_token: A) -> None: + await self.session.delete(access_token) + await self.session.commit() diff --git a/pyproject.toml b/pyproject.toml index 18dfb27..ed04e58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ description-file = "README.md" requires-python = ">=3.7" requires = [ - "fastapi-users >= 7.0.0", + "fastapi-users >= 9.1.0", "sqlmodel", ] diff --git a/requirements.txt b/requirements.txt index 3a5c620..3c9bc3d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ aiosqlite >= 0.17.0 -fastapi-users >= 8.1.1 +fastapi-users >= 9.1.0 sqlmodel diff --git a/tests/conftest.py b/tests/conftest.py index f20fa05..8c42dda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import asyncio +import uuid from typing import List, Optional import pytest @@ -31,7 +32,7 @@ class UserOAuth(User): class UserDBOAuth(SQLModelBaseUserDB, table=True): - __tablename__ = "user" + __tablename__ = "user_oauth" oauth_accounts: List["OAuthAccount"] = Relationship( back_populates="user", sa_relationship_kwargs={"lazy": "joined", "cascade": "all, delete"}, @@ -39,7 +40,7 @@ class UserDBOAuth(SQLModelBaseUserDB, table=True): class OAuthAccount(SQLModelBaseOAuthAccount, table=True): - user_id: UUID4 = Field(foreign_key="user.id") + user_id: UUID4 = Field(foreign_key="user_oauth.id") user: Optional[UserDBOAuth] = Relationship(back_populates="oauth_accounts") @@ -53,6 +54,7 @@ def event_loop(): @pytest.fixture def oauth_account1() -> OAuthAccount: return OAuthAccount( + id=uuid.UUID("b9089e5d-2642-406d-a7c0-cbc641aca0ec"), oauth_name="service1", access_token="TOKEN", expires_at=1579000751, @@ -64,6 +66,7 @@ def oauth_account1() -> OAuthAccount: @pytest.fixture def oauth_account2() -> OAuthAccount: return OAuthAccount( + id=uuid.UUID("c9089e5d-2642-406d-a7c0-cbc641aca0ec"), oauth_name="service2", access_token="TOKEN", expires_at=1579000751, diff --git a/tests/test_access_token.py b/tests/test_access_token.py new file mode 100644 index 0000000..822fac3 --- /dev/null +++ b/tests/test_access_token.py @@ -0,0 +1,141 @@ +import uuid +from datetime import datetime, timedelta, timezone +from typing import AsyncGenerator + +import pytest +from pydantic import UUID4 +from sqlalchemy import exc +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel import Session, SQLModel, create_engine + +from fastapi_users_db_sqlmodel import SQLModelUserDatabase, SQLModelUserDatabaseAsync +from fastapi_users_db_sqlmodel.access_token import ( + SQLModelAccessTokenDatabase, + SQLModelAccessTokenDatabaseAsync, + SQLModelBaseAccessToken, +) +from tests.conftest import UserDB + + +class AccessToken(SQLModelBaseAccessToken, table=True): + pass + + +@pytest.fixture +def user_id() -> UUID4: + return uuid.UUID("a9089e5d-2642-406d-a7c0-cbc641aca0ec") + + +async def init_sync_session(url: str) -> AsyncGenerator[Session, None]: + engine = create_engine(url, connect_args={"check_same_thread": False}) + SQLModel.metadata.create_all(engine) + with Session(engine) as session: + yield session + SQLModel.metadata.drop_all(engine) + + +async def init_async_session(url: str) -> AsyncGenerator[AsyncSession, None]: + engine = create_async_engine(url, connect_args={"check_same_thread": False}) + make_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + async with make_session() as session: + yield session + await conn.run_sync(SQLModel.metadata.drop_all) + + +@pytest.fixture( + params=[ + ( + init_sync_session, + "sqlite:///./test-sqlmodel-access-token.db", + SQLModelAccessTokenDatabase, + SQLModelUserDatabase, + ), + ( + init_async_session, + "sqlite+aiosqlite:///./test-sqlmodel-access-token.db", + SQLModelAccessTokenDatabaseAsync, + SQLModelUserDatabaseAsync, + ), + ], + ids=["sync", "async"], +) +async def sqlmodel_access_token_db( + request, user_id: UUID4 +) -> AsyncGenerator[SQLModelAccessTokenDatabase, None]: + create_session = request.param[0] + database_url = request.param[1] + access_token_database_class = request.param[2] + user_database_class = request.param[3] + async for session in create_session(database_url): + user = UserDB( + id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere" + ) + user_db = user_database_class(UserDB, session) + await user_db.create(user) + yield access_token_database_class(AccessToken, session) + + +@pytest.mark.asyncio +@pytest.mark.db +async def test_queries( + sqlmodel_access_token_db: SQLModelAccessTokenDatabase[AccessToken], + user_id: UUID4, +): + access_token = AccessToken(token="TOKEN", user_id=user_id) + + # Create + access_token_db = await sqlmodel_access_token_db.create(access_token) + assert access_token_db.token == "TOKEN" + assert access_token_db.user_id == user_id + + # Update + access_token_db.created_at = datetime.now(timezone.utc) + await sqlmodel_access_token_db.update(access_token_db) + + # Get by token + access_token_by_token = await sqlmodel_access_token_db.get_by_token( + access_token_db.token + ) + assert access_token_by_token is not None + + # Get by token expired + access_token_by_token = await sqlmodel_access_token_db.get_by_token( + access_token_db.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) + ) + assert access_token_by_token is None + + # Get by token not expired + access_token_by_token = await sqlmodel_access_token_db.get_by_token( + access_token_db.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) + ) + assert access_token_by_token is not None + + # Get by token unknown + access_token_by_token = await sqlmodel_access_token_db.get_by_token( + "NOT_EXISTING_TOKEN" + ) + assert access_token_by_token is None + + # Delete token + await sqlmodel_access_token_db.delete(access_token_db) + deleted_access_token = await sqlmodel_access_token_db.get_by_token( + access_token_db.token + ) + assert deleted_access_token is None + + +@pytest.mark.asyncio +@pytest.mark.db +async def test_insert_existing_token( + sqlmodel_access_token_db: SQLModelAccessTokenDatabase[AccessToken], user_id: UUID4 +): + access_token = AccessToken(token="TOKEN", user_id=user_id) + await sqlmodel_access_token_db.create(access_token) + + with pytest.raises(exc.IntegrityError): + await sqlmodel_access_token_db.create( + AccessToken(token="TOKEN", user_id=user_id) + ) diff --git a/tests/test_fastapi_users_db_sqlmodel.py b/tests/test_users.py similarity index 98% rename from tests/test_fastapi_users_db_sqlmodel.py rename to tests/test_users.py index 9a9107b..4c63d03 100644 --- a/tests/test_fastapi_users_db_sqlmodel.py +++ b/tests/test_users.py @@ -3,9 +3,9 @@ import pytest from sqlalchemy import exc -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker -from sqlmodel import SQLModel, create_engine, Session +from sqlmodel import Session, SQLModel, create_engine from fastapi_users_db_sqlmodel import ( NotSetOAuthAccountTableError, @@ -14,7 +14,6 @@ ) from tests.conftest import OAuthAccount, UserDB, UserDBOAuth - safe_uuid = uuid.UUID("a9089e5d-2642-406d-a7c0-cbc641aca0ec") From e2e4c740f0054f4ed166c601c6a07bfb1b8148ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 31 Dec 2021 14:57:25 +0100 Subject: [PATCH 08/22] =?UTF-8?q?Bump=20version:=200.0.6=20=E2=86=92=200.1?= =?UTF-8?q?.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi_users_db_sqlmodel/__init__.py | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 02ef030..f8c9e3b 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select -__version__ = "0.0.6" +__version__ = "0.1.0" class SQLModelBaseUserDB(BaseUserDB, SQLModel): diff --git a/setup.cfg b/setup.cfg index 0161b38..0890b39 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.0.6 +current_version = 0.1.0 commit = True tag = True From 702c618bb9fac418c2b4fbc17dcf5b9a49b39856 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 31 Dec 2021 15:15:45 +0100 Subject: [PATCH 09/22] Fix AccessToken.created_at column definition to support timezone --- fastapi_users_db_sqlmodel/access_token.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/fastapi_users_db_sqlmodel/access_token.py b/fastapi_users_db_sqlmodel/access_token.py index a836ed8..2616c2c 100644 --- a/fastapi_users_db_sqlmodel/access_token.py +++ b/fastapi_users_db_sqlmodel/access_token.py @@ -19,7 +19,12 @@ class SQLModelBaseAccessToken(BaseAccessToken, SQLModel): token: str = Field( sa_column=Column("token", types.String(length=43), primary_key=True) ) - created_at: datetime = Field(default_factory=now_utc, nullable=False) + created_at: datetime = Field( + default_factory=now_utc, + sa_column=Column( + "created_at", types.DateTime(timezone=True), nullable=False, index=True + ), + ) user_id: UUID4 = Field(foreign_key="user.id", nullable=False) class Config: From 3a46b80399f129aa07a834a1b40bf49d08c37be1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 31 Dec 2021 15:15:52 +0100 Subject: [PATCH 10/22] =?UTF-8?q?Bump=20version:=200.1.0=20=E2=86=92=200.1?= =?UTF-8?q?.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi_users_db_sqlmodel/__init__.py | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index f8c9e3b..e11fe8f 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select -__version__ = "0.1.0" +__version__ = "0.1.1" class SQLModelBaseUserDB(BaseUserDB, SQLModel): diff --git a/setup.cfg b/setup.cfg index 0890b39..3542fc1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.1.0 +current_version = 0.1.1 commit = True tag = True From 6ff379c362e59ce723641c7f8deaf6a9c864f0d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 28 Apr 2022 09:21:14 +0200 Subject: [PATCH 11/22] Pin SQLAlchemy to prevent bug https://github.com/tiangolo/sqlmodel/issues/315 --- pyproject.toml | 1 + requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index ed04e58..118892c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ requires-python = ">=3.7" requires = [ "fastapi-users >= 9.1.0", "sqlmodel", + "sqlalchemy[asyncio] >=1.4,<1.4.36", # Pin SQLAlchemy to prevent bug https://github.com/tiangolo/sqlmodel/issues/315 ] [tool.flit.metadata.urls] diff --git a/requirements.txt b/requirements.txt index 3c9bc3d..6a5054d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ aiosqlite >= 0.17.0 fastapi-users >= 9.1.0 sqlmodel +sqlalchemy[asyncio] >=1.4,<1.4.36 From 94c9090a109cb93fd48950df4f451ec058572229 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 28 Apr 2022 09:48:17 +0200 Subject: [PATCH 12/22] =?UTF-8?q?Bump=20version:=200.1.1=20=E2=86=92=200.1?= =?UTF-8?q?.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi_users_db_sqlmodel/__init__.py | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index e11fe8f..ded583c 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select -__version__ = "0.1.1" +__version__ = "0.1.2" class SQLModelBaseUserDB(BaseUserDB, SQLModel): diff --git a/setup.cfg b/setup.cfg index 3542fc1..9f90210 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.1.1 +current_version = 0.1.2 commit = True tag = True From 508a00eb458a329b98765e85d72911da5f8ef42a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 9 May 2022 10:57:28 +0200 Subject: [PATCH 13/22] Upgrade for FastAPI Users V10 --- fastapi_users_db_sqlmodel/__init__.py | 182 +++++++++++++--------- fastapi_users_db_sqlmodel/access_token.py | 69 ++++---- fastapi_users_db_sqlmodel/generics.py | 24 +++ pyproject.toml | 2 +- requirements.txt | 2 +- setup.cfg | 7 +- tests/conftest.py | 66 +++----- tests/test_access_token.py | 55 ++++--- tests/test_users.py | 155 +++++++++--------- 9 files changed, 295 insertions(+), 267 deletions(-) create mode 100644 fastapi_users_db_sqlmodel/generics.py diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index ded583c..2441e15 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -1,9 +1,9 @@ """FastAPI Users database adapter for SQLModel.""" import uuid -from typing import Generic, Optional, Type, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type from fastapi_users.db.base import BaseUserDatabase -from fastapi_users.models import BaseOAuthAccount, BaseUserDB +from fastapi_users.models import ID, OAP, UP from pydantic import UUID4, EmailStr from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -12,13 +12,17 @@ __version__ = "0.1.2" -class SQLModelBaseUserDB(BaseUserDB, SQLModel): +class SQLModelBaseUserDB(SQLModel): __tablename__ = "user" id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True, nullable=False) - email: EmailStr = Field( - sa_column_kwargs={"unique": True, "index": True}, nullable=False - ) + if TYPE_CHECKING: # pragma: no cover + email: str + else: + email: EmailStr = Field( + sa_column_kwargs={"unique": True, "index": True}, nullable=False + ) + hashed_password: str is_active: bool = Field(True, nullable=False) is_superuser: bool = Field(False, nullable=False) @@ -28,68 +32,59 @@ class Config: orm_mode = True -class SQLModelBaseOAuthAccount(BaseOAuthAccount, SQLModel): +class SQLModelBaseOAuthAccount(SQLModel): __tablename__ = "oauthaccount" id: UUID4 = Field(default_factory=uuid.uuid4, primary_key=True) user_id: UUID4 = Field(foreign_key="user.id", nullable=False) + oauth_name: str = Field(index=True, nullable=False) + access_token: str = Field(nullable=False) + expires_at: Optional[int] = Field(nullable=True) + refresh_token: Optional[str] = Field(nullable=True) + account_id: str = Field(index=True, nullable=False) + account_email: str = Field(nullable=False) class Config: orm_mode = True -UD = TypeVar("UD", bound=SQLModelBaseUserDB) -OA = TypeVar("OA", bound=SQLModelBaseOAuthAccount) - - -class NotSetOAuthAccountTableError(Exception): - """ - OAuth table was not set in DB adapter but was needed. - - Raised when trying to create/update a user with OAuth accounts set - but no table were specified in the DB adapter. - """ - - pass - - -class SQLModelUserDatabase(Generic[UD, OA], BaseUserDatabase[UD]): +class SQLModelUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]): """ Database adapter for SQLModel. - :param user_db_model: SQLModel model of a DB representation of a user. :param session: SQLAlchemy session. """ session: Session - oauth_account_model: Optional[Type[OA]] + user_model: Type[UP] + oauth_account_model: Optional[Type[SQLModelBaseOAuthAccount]] def __init__( self, - user_db_model: Type[UD], session: Session, - oauth_account_model: Optional[Type[OA]] = None, + user_model: Type[UP], + oauth_account_model: Optional[Type[SQLModelBaseOAuthAccount]] = None, ): - super().__init__(user_db_model) self.session = session + self.user_model = user_model self.oauth_account_model = oauth_account_model - async def get(self, id: UUID4) -> Optional[UD]: + async def get(self, id: ID) -> Optional[UP]: """Get a single user by id.""" - return self.session.get(self.user_db_model, id) + return self.session.get(self.user_model, id) - async def get_by_email(self, email: str) -> Optional[UD]: + async def get_by_email(self, email: str) -> Optional[UP]: """Get a single user by email.""" - statement = select(self.user_db_model).where( - func.lower(self.user_db_model.email) == func.lower(email) + statement = select(self.user_model).where( + func.lower(self.user_model.email) == func.lower(email) ) results = self.session.exec(statement) return results.first() - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: + async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: """Get a single user by OAuth account id.""" - if not self.oauth_account_model: - raise NotSetOAuthAccountTableError() + if self.oauth_account_model is None: + raise NotImplementedError() statement = ( select(self.oauth_account_model) .where(self.oauth_account_model.oauth_name == oauth) @@ -102,61 +97,82 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD return user return None - async def create(self, user: UD) -> UD: + async def create(self, create_dict: Dict[str, Any]) -> UP: """Create a user.""" + user = self.user_model(**create_dict) self.session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - self.session.add(oauth_account) self.session.commit() self.session.refresh(user) return user - async def update(self, user: UD) -> UD: - """Update a user.""" + async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP: + for key, value in update_dict.items(): + setattr(user, key, value) self.session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - self.session.add(oauth_account) self.session.commit() self.session.refresh(user) return user - async def delete(self, user: UD) -> None: - """Delete a user.""" + async def delete(self, user: UP) -> None: self.session.delete(user) self.session.commit() + async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP: + if self.oauth_account_model is None: + raise NotImplementedError() -class SQLModelUserDatabaseAsync(Generic[UD, OA], BaseUserDatabase[UD]): + oauth_account = self.oauth_account_model(**create_dict) + user.oauth_accounts.append(oauth_account) # type: ignore + self.session.add(user) + + self.session.commit() + + return user + + async def update_oauth_account( + self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any] + ) -> UP: + if self.oauth_account_model is None: + raise NotImplementedError() + + for key, value in update_dict.items(): + setattr(oauth_account, key, value) + self.session.add(oauth_account) + self.session.commit() + + return user + + +class SQLModelUserDatabaseAsync(Generic[UP, ID], BaseUserDatabase[UP, ID]): """ Database adapter for SQLModel working purely asynchronously. - :param user_db_model: SQLModel model of a DB representation of a user. + :param user_model: SQLModel model of a DB representation of a user. :param session: SQLAlchemy async session. """ session: AsyncSession - oauth_account_model: Optional[Type[OA]] + user_model: Type[UP] + oauth_account_model: Optional[Type[SQLModelBaseOAuthAccount]] def __init__( self, - user_db_model: Type[UD], session: AsyncSession, - oauth_account_model: Optional[Type[OA]] = None, + user_model: Type[UP], + oauth_account_model: Optional[Type[SQLModelBaseOAuthAccount]] = None, ): - super().__init__(user_db_model) self.session = session + self.user_model = user_model self.oauth_account_model = oauth_account_model - async def get(self, id: UUID4) -> Optional[UD]: + async def get(self, id: ID) -> Optional[UP]: """Get a single user by id.""" - return await self.session.get(self.user_db_model, id) + return await self.session.get(self.user_model, id) - async def get_by_email(self, email: str) -> Optional[UD]: + async def get_by_email(self, email: str) -> Optional[UP]: """Get a single user by email.""" - statement = select(self.user_db_model).where( - func.lower(self.user_db_model.email) == func.lower(email) + statement = select(self.user_model).where( + func.lower(self.user_model.email) == func.lower(email) ) results = await self.session.execute(statement) object = results.first() @@ -164,10 +180,10 @@ async def get_by_email(self, email: str) -> Optional[UD]: return None return object[0] - async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD]: + async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: """Get a single user by OAuth account id.""" - if not self.oauth_account_model: - raise NotSetOAuthAccountTableError() + if self.oauth_account_model is None: + raise NotImplementedError() statement = ( select(self.oauth_account_model) .where(self.oauth_account_model.oauth_name == oauth) @@ -177,31 +193,51 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UD results = await self.session.execute(statement) oauth_account = results.first() if oauth_account: - user = oauth_account[0].user + user = oauth_account[0].user # type: ignore return user return None - async def create(self, user: UD) -> UD: + async def create(self, create_dict: Dict[str, Any]) -> UP: """Create a user.""" + user = self.user_model(**create_dict) self.session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - self.session.add(oauth_account) await self.session.commit() await self.session.refresh(user) return user - async def update(self, user: UD) -> UD: - """Update a user.""" + async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP: + for key, value in update_dict.items(): + setattr(user, key, value) self.session.add(user) - if self.oauth_account_model is not None: - for oauth_account in user.oauth_accounts: # type: ignore - self.session.add(oauth_account) await self.session.commit() await self.session.refresh(user) return user - async def delete(self, user: UD) -> None: - """Delete a user.""" + async def delete(self, user: UP) -> None: await self.session.delete(user) await self.session.commit() + + async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP: + if self.oauth_account_model is None: + raise NotImplementedError() + + oauth_account = self.oauth_account_model(**create_dict) + user.oauth_accounts.append(oauth_account) # type: ignore + self.session.add(user) + + await self.session.commit() + + return user + + async def update_oauth_account( + self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any] + ) -> UP: + if self.oauth_account_model is None: + raise NotImplementedError() + + for key, value in update_dict.items(): + setattr(oauth_account, key, value) + self.session.add(oauth_account) + await self.session.commit() + + return user diff --git a/fastapi_users_db_sqlmodel/access_token.py b/fastapi_users_db_sqlmodel/access_token.py index 2616c2c..7395a43 100644 --- a/fastapi_users_db_sqlmodel/access_token.py +++ b/fastapi_users_db_sqlmodel/access_token.py @@ -1,19 +1,16 @@ -from datetime import datetime, timezone -from typing import Generic, Optional, Type, TypeVar +from datetime import datetime +from typing import Any, Dict, Generic, Optional, Type -from fastapi_users.authentication.strategy.db import AccessTokenDatabase -from fastapi_users.authentication.strategy.db.models import BaseAccessToken +from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase from pydantic import UUID4 from sqlalchemy import Column, types from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import Field, Session, SQLModel, select +from fastapi_users_db_sqlmodel.generics import TIMESTAMPAware, now_utc -def now_utc(): - return datetime.now(timezone.utc) - -class SQLModelBaseAccessToken(BaseAccessToken, SQLModel): +class SQLModelBaseAccessToken(SQLModel): __tablename__ = "accesstoken" token: str = Field( @@ -22,7 +19,7 @@ class SQLModelBaseAccessToken(BaseAccessToken, SQLModel): created_at: datetime = Field( default_factory=now_utc, sa_column=Column( - "created_at", types.DateTime(timezone=True), nullable=False, index=True + "created_at", TIMESTAMPAware(timezone=True), nullable=False, index=True ), ) user_id: UUID4 = Field(foreign_key="user.id", nullable=False) @@ -31,65 +28,68 @@ class Config: orm_mode = True -A = TypeVar("A", bound=SQLModelBaseAccessToken) - - -class SQLModelAccessTokenDatabase(Generic[A], AccessTokenDatabase[A]): +class SQLModelAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]): """ Access token database adapter for SQLModel. - :param user_db_model: SQLModel model of a DB representation of an access token. :param session: SQLAlchemy session. + :param access_token_model: SQLModel access token model. """ - def __init__(self, access_token_model: Type[A], session: Session): - self.access_token_model = access_token_model + def __init__(self, session: Session, access_token_model: Type[AP]): self.session = session + self.access_token_model = access_token_model async def get_by_token( self, token: str, max_age: Optional[datetime] = None - ) -> Optional[A]: + ) -> Optional[AP]: statement = select(self.access_token_model).where( self.access_token_model.token == token ) if max_age is not None: statement = statement.where(self.access_token_model.created_at >= max_age) - results = self.session.exec(statement) - return results.first() + results = self.session.execute(statement) + access_token = results.first() + if access_token is None: + return None + return access_token[0] - async def create(self, access_token: A) -> A: + async def create(self, create_dict: Dict[str, Any]) -> AP: + access_token = self.access_token_model(**create_dict) self.session.add(access_token) self.session.commit() self.session.refresh(access_token) return access_token - async def update(self, access_token: A) -> A: + async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP: + for key, value in update_dict.items(): + setattr(access_token, key, value) self.session.add(access_token) self.session.commit() self.session.refresh(access_token) return access_token - async def delete(self, access_token: A) -> None: + async def delete(self, access_token: AP) -> None: self.session.delete(access_token) self.session.commit() -class SQLModelAccessTokenDatabaseAsync(Generic[A], AccessTokenDatabase[A]): +class SQLModelAccessTokenDatabaseAsync(Generic[AP], AccessTokenDatabase[AP]): """ Access token database adapter for SQLModel working purely asynchronously. - :param user_db_model: SQLModel model of a DB representation of an access token. :param session: SQLAlchemy async session. + :param access_token_model: SQLModel access token model. """ - def __init__(self, access_token_model: Type[A], session: AsyncSession): - self.access_token_model = access_token_model + def __init__(self, session: AsyncSession, access_token_model: Type[AP]): self.session = session + self.access_token_model = access_token_model async def get_by_token( self, token: str, max_age: Optional[datetime] = None - ) -> Optional[A]: + ) -> Optional[AP]: statement = select(self.access_token_model).where( self.access_token_model.token == token ) @@ -97,23 +97,26 @@ async def get_by_token( statement = statement.where(self.access_token_model.created_at >= max_age) results = await self.session.execute(statement) - object = results.first() - if object is None: + access_token = results.first() + if access_token is None: return None - return object[0] + return access_token[0] - async def create(self, access_token: A) -> A: + async def create(self, create_dict: Dict[str, Any]) -> AP: + access_token = self.access_token_model(**create_dict) self.session.add(access_token) await self.session.commit() await self.session.refresh(access_token) return access_token - async def update(self, access_token: A) -> A: + async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP: + for key, value in update_dict.items(): + setattr(access_token, key, value) self.session.add(access_token) await self.session.commit() await self.session.refresh(access_token) return access_token - async def delete(self, access_token: A) -> None: + async def delete(self, access_token: AP) -> None: await self.session.delete(access_token) await self.session.commit() diff --git a/fastapi_users_db_sqlmodel/generics.py b/fastapi_users_db_sqlmodel/generics.py new file mode 100644 index 0000000..bfe2fda --- /dev/null +++ b/fastapi_users_db_sqlmodel/generics.py @@ -0,0 +1,24 @@ +from datetime import datetime, timezone + +from sqlalchemy import TIMESTAMP, TypeDecorator + + +def now_utc(): + return datetime.now(timezone.utc) + + +class TIMESTAMPAware(TypeDecorator): # pragma: no cover + """ + MySQL and SQLite will always return naive-Python datetimes. + + We store everything as UTC, but we want to have + only offset-aware Python datetimes, even with MySQL and SQLite. + """ + + impl = TIMESTAMP + cache_ok = True + + def process_result_value(self, value: datetime, dialect): + if dialect.name != "postgresql": + return value.replace(tzinfo=timezone.utc) + return value diff --git a/pyproject.toml b/pyproject.toml index 118892c..e9d555a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ description-file = "README.md" requires-python = ">=3.7" requires = [ - "fastapi-users >= 9.1.0", + "fastapi-users >= 10.0.2", "sqlmodel", "sqlalchemy[asyncio] >=1.4,<1.4.36", # Pin SQLAlchemy to prevent bug https://github.com/tiangolo/sqlmodel/issues/315 ] diff --git a/requirements.txt b/requirements.txt index 6a5054d..5663706 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ aiosqlite >= 0.17.0 fastapi-users >= 9.1.0 -sqlmodel +git+https://github.com/andrewbolster/sqlmodel.git@patch-1 sqlalchemy[asyncio] >=1.4,<1.4.36 diff --git a/setup.cfg b/setup.cfg index 9f90210..2827268 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,9 +18,4 @@ profile = black [tool:pytest] addopts = --ignore=test_build.py -markers = - authentication - db - fastapi_users - oauth - router +asyncio_mode = strict diff --git a/tests/conftest.py b/tests/conftest.py index 8c42dda..ee12b2c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,75 +1,55 @@ import asyncio -import uuid -from typing import List, Optional +from typing import Any, Dict, List, Optional import pytest -from fastapi_users import models from pydantic import UUID4 from sqlmodel import Field, Relationship from fastapi_users_db_sqlmodel import SQLModelBaseOAuthAccount, SQLModelBaseUserDB -class User(models.BaseUser): +class User(SQLModelBaseUserDB, table=True): first_name: Optional[str] -class UserCreate(models.BaseUserCreate): - first_name: Optional[str] - - -class UserUpdate(models.BaseUserUpdate): - pass - - -class UserDB(SQLModelBaseUserDB, User, table=True): - class Config: - orm_mode = True - - -class UserOAuth(User): - pass - - -class UserDBOAuth(SQLModelBaseUserDB, table=True): +class UserOAuth(SQLModelBaseUserDB, table=True): __tablename__ = "user_oauth" oauth_accounts: List["OAuthAccount"] = Relationship( back_populates="user", - sa_relationship_kwargs={"lazy": "joined", "cascade": "all, delete"}, + sa_relationship_kwargs={"lazy": "selectin", "cascade": "all, delete"}, ) class OAuthAccount(SQLModelBaseOAuthAccount, table=True): user_id: UUID4 = Field(foreign_key="user_oauth.id") - user: Optional[UserDBOAuth] = Relationship(back_populates="oauth_accounts") + user: Optional[UserOAuth] = Relationship(back_populates="oauth_accounts") @pytest.fixture(scope="session") def event_loop(): """Force the pytest-asyncio loop to be the main one.""" - loop = asyncio.get_event_loop() + loop = asyncio.new_event_loop() yield loop + loop.close() @pytest.fixture -def oauth_account1() -> OAuthAccount: - return OAuthAccount( - id=uuid.UUID("b9089e5d-2642-406d-a7c0-cbc641aca0ec"), - oauth_name="service1", - access_token="TOKEN", - expires_at=1579000751, - account_id="user_oauth1", - account_email="king.arthur@camelot.bt", - ) +def oauth_account1() -> Dict[str, Any]: + return { + "oauth_name": "service1", + "access_token": "TOKEN", + "expires_at": 1579000751, + "account_id": "user_oauth1", + "account_email": "king.arthur@camelot.bt", + } @pytest.fixture -def oauth_account2() -> OAuthAccount: - return OAuthAccount( - id=uuid.UUID("c9089e5d-2642-406d-a7c0-cbc641aca0ec"), - oauth_name="service2", - access_token="TOKEN", - expires_at=1579000751, - account_id="user_oauth2", - account_email="king.arthur@camelot.bt", - ) +def oauth_account2() -> Dict[str, Any]: + return { + "oauth_name": "service2", + "access_token": "TOKEN", + "expires_at": 1579000751, + "account_id": "user_oauth2", + "account_email": "king.arthur@camelot.bt", + } diff --git a/tests/test_access_token.py b/tests/test_access_token.py index 822fac3..ee93e04 100644 --- a/tests/test_access_token.py +++ b/tests/test_access_token.py @@ -3,6 +3,7 @@ from typing import AsyncGenerator import pytest +import pytest_asyncio from pydantic import UUID4 from sqlalchemy import exc from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @@ -15,7 +16,7 @@ SQLModelAccessTokenDatabaseAsync, SQLModelBaseAccessToken, ) -from tests.conftest import UserDB +from tests.conftest import User class AccessToken(SQLModelBaseAccessToken, table=True): @@ -45,7 +46,7 @@ async def init_async_session(url: str) -> AsyncGenerator[AsyncSession, None]: await conn.run_sync(SQLModel.metadata.drop_all) -@pytest.fixture( +@pytest_asyncio.fixture( params=[ ( init_sync_session, @@ -70,46 +71,53 @@ async def sqlmodel_access_token_db( access_token_database_class = request.param[2] user_database_class = request.param[3] async for session in create_session(database_url): - user = UserDB( - id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere" + user_db = user_database_class(session, User) + await user_db.create( + { + "id": user_id, + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } ) - user_db = user_database_class(UserDB, session) - await user_db.create(user) - yield access_token_database_class(AccessToken, session) + yield access_token_database_class(session, AccessToken) @pytest.mark.asyncio -@pytest.mark.db async def test_queries( sqlmodel_access_token_db: SQLModelAccessTokenDatabase[AccessToken], user_id: UUID4, ): - access_token = AccessToken(token="TOKEN", user_id=user_id) + access_token_create = {"token": "TOKEN", "user_id": user_id} # Create - access_token_db = await sqlmodel_access_token_db.create(access_token) - assert access_token_db.token == "TOKEN" - assert access_token_db.user_id == user_id + access_token = await sqlmodel_access_token_db.create(access_token_create) + assert access_token.token == "TOKEN" + assert access_token.user_id == user_id # Update - access_token_db.created_at = datetime.now(timezone.utc) - await sqlmodel_access_token_db.update(access_token_db) + update_dict = {"created_at": datetime.now(timezone.utc)} + updated_access_token = await sqlmodel_access_token_db.update( + access_token, update_dict + ) + assert updated_access_token.created_at.replace(microsecond=0) == update_dict[ + "created_at" + ].replace(microsecond=0) # Get by token access_token_by_token = await sqlmodel_access_token_db.get_by_token( - access_token_db.token + access_token.token ) assert access_token_by_token is not None # Get by token expired access_token_by_token = await sqlmodel_access_token_db.get_by_token( - access_token_db.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) + access_token.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) ) assert access_token_by_token is None # Get by token not expired access_token_by_token = await sqlmodel_access_token_db.get_by_token( - access_token_db.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) + access_token.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) ) assert access_token_by_token is not None @@ -120,22 +128,19 @@ async def test_queries( assert access_token_by_token is None # Delete token - await sqlmodel_access_token_db.delete(access_token_db) + await sqlmodel_access_token_db.delete(access_token) deleted_access_token = await sqlmodel_access_token_db.get_by_token( - access_token_db.token + access_token.token ) assert deleted_access_token is None @pytest.mark.asyncio -@pytest.mark.db async def test_insert_existing_token( sqlmodel_access_token_db: SQLModelAccessTokenDatabase[AccessToken], user_id: UUID4 ): - access_token = AccessToken(token="TOKEN", user_id=user_id) - await sqlmodel_access_token_db.create(access_token) + access_token_create = {"token": "TOKEN", "user_id": user_id} + await sqlmodel_access_token_db.create(access_token_create) with pytest.raises(exc.IntegrityError): - await sqlmodel_access_token_db.create( - AccessToken(token="TOKEN", user_id=user_id) - ) + await sqlmodel_access_token_db.create(access_token_create) diff --git a/tests/test_users.py b/tests/test_users.py index 4c63d03..dc32aab 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -1,18 +1,16 @@ import uuid -from typing import AsyncGenerator +from typing import Any, AsyncGenerator, Dict import pytest +import pytest_asyncio +from pydantic import UUID4 from sqlalchemy import exc from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from sqlmodel import Session, SQLModel, create_engine -from fastapi_users_db_sqlmodel import ( - NotSetOAuthAccountTableError, - SQLModelUserDatabase, - SQLModelUserDatabaseAsync, -) -from tests.conftest import OAuthAccount, UserDB, UserDBOAuth +from fastapi_users_db_sqlmodel import SQLModelUserDatabase, SQLModelUserDatabaseAsync +from tests.conftest import OAuthAccount, User, UserOAuth safe_uuid = uuid.UUID("a9089e5d-2642-406d-a7c0-cbc641aca0ec") @@ -35,7 +33,7 @@ async def init_async_session(url: str) -> AsyncGenerator[AsyncSession, None]: await conn.run_sync(SQLModel.metadata.drop_all) -@pytest.fixture( +@pytest_asyncio.fixture( params=[ (init_sync_session, "sqlite:///./test-sqlmodel-user.db", SQLModelUserDatabase), ( @@ -51,10 +49,10 @@ async def sqlmodel_user_db(request) -> AsyncGenerator[SQLModelUserDatabase, None database_url = request.param[1] database_class = request.param[2] async for session in create_session(database_url): - yield database_class(UserDB, session) + yield database_class(session, User) -@pytest.fixture( +@pytest_asyncio.fixture( params=[ ( init_sync_session, @@ -74,44 +72,42 @@ async def sqlmodel_user_db_oauth(request) -> AsyncGenerator[SQLModelUserDatabase database_url = request.param[1] database_class = request.param[2] async for session in create_session(database_url): - yield database_class(UserDBOAuth, session, OAuthAccount) + yield database_class(session, UserOAuth, OAuthAccount) @pytest.mark.asyncio -@pytest.mark.db -async def test_queries(sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccount]): - user = UserDB( - id=safe_uuid, - email="lancelot@camelot.bt", - hashed_password="guinevere", - ) +async def test_queries(sqlmodel_user_db: SQLModelUserDatabase[User, UUID4]): + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } # Create - user_db = await sqlmodel_user_db.create(user) - assert user_db.id is not None - assert user_db.is_active is True - assert user_db.is_superuser is False - assert user_db.email == user.email + user = await sqlmodel_user_db.create(user_create) + assert user.id is not None + assert user.is_active is True + assert user.is_superuser is False + assert user.email == user_create["email"] # Update - user_db.is_superuser = True - await sqlmodel_user_db.update(user_db) + updated_user = await sqlmodel_user_db.update(user, {"is_superuser": True}) + assert updated_user.is_superuser is True # Get by id id_user = await sqlmodel_user_db.get(user.id) assert id_user is not None - assert id_user.id == user_db.id + assert id_user.id == user.id assert id_user.is_superuser is True # Get by email - email_user = await sqlmodel_user_db.get_by_email(str(user.email)) + email_user = await sqlmodel_user_db.get_by_email(str(user_create["email"])) assert email_user is not None - assert email_user.id == user_db.id + assert email_user.id == user.id # Get by uppercased email email_user = await sqlmodel_user_db.get_by_email("Lancelot@camelot.bt") assert email_user is not None - assert email_user.id == user_db.id + assert email_user.id == user.id # Unknown user unknown_user = await sqlmodel_user_db.get_by_email("galahad@camelot.bt") @@ -122,55 +118,41 @@ async def test_queries(sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccou deleted_user = await sqlmodel_user_db.get(user.id) assert deleted_user is None - # Exception when trying to get by OAuth account - with pytest.raises(NotSetOAuthAccountTableError): + # OAuth without defined table + with pytest.raises(NotImplementedError): await sqlmodel_user_db.get_by_oauth_account("foo", "bar") + with pytest.raises(NotImplementedError): + await sqlmodel_user_db.add_oauth_account(user, {}) + with pytest.raises(NotImplementedError): + oauth_account = OAuthAccount() + await sqlmodel_user_db.update_oauth_account(user, oauth_account, {}) @pytest.mark.asyncio -@pytest.mark.db async def test_insert_existing_email( - sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccount] + sqlmodel_user_db: SQLModelUserDatabase[User, UUID4] ): - user = UserDB( - id=safe_uuid, - email="lancelot@camelot.bt", - hashed_password="guinevere", - ) - await sqlmodel_user_db.create(user) - - with pytest.raises(exc.IntegrityError): - await sqlmodel_user_db.create( - UserDB(id=safe_uuid, email=user.email, hashed_password="guinevere") - ) - + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } + await sqlmodel_user_db.create(user_create) -@pytest.mark.asyncio -@pytest.mark.db -async def test_insert_non_nullable_fields( - sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccount] -): with pytest.raises(exc.IntegrityError): - wrong_user = UserDB( - id=safe_uuid, email="lancelot@camelot.bt", hashed_password="aaa" - ) - wrong_user.email = None # type: ignore - await sqlmodel_user_db.create(wrong_user) + await sqlmodel_user_db.create(user_create) @pytest.mark.asyncio -@pytest.mark.db async def test_queries_custom_fields( - sqlmodel_user_db: SQLModelUserDatabase[UserDB, OAuthAccount], + sqlmodel_user_db: SQLModelUserDatabase[User, UUID4], ): """It should output custom fields in query result.""" - user = UserDB( - id=safe_uuid, - email="lancelot@camelot.bt", - hashed_password="guinevere", - first_name="Lancelot", - ) - await sqlmodel_user_db.create(user) + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + "first_name": "Lancelot", + } + user = await sqlmodel_user_db.create(user_create) id_user = await sqlmodel_user_db.get(user.id) assert id_user is not None @@ -179,48 +161,51 @@ async def test_queries_custom_fields( @pytest.mark.asyncio -@pytest.mark.db async def test_queries_oauth( - sqlmodel_user_db_oauth: SQLModelUserDatabase[UserDBOAuth, OAuthAccount], - oauth_account1, - oauth_account2, + sqlmodel_user_db_oauth: SQLModelUserDatabase[UserOAuth, UUID4], + oauth_account1: Dict[str, Any], + oauth_account2: Dict[str, Any], ): - user = UserDBOAuth( - id=safe_uuid, - email="lancelot@camelot.bt", - hashed_password="guinevere", - oauth_accounts=[oauth_account1, oauth_account2], - ) + user_create = { + "email": "lancelot@camelot.bt", + "hashed_password": "guinevere", + } # Create - user_db = await sqlmodel_user_db_oauth.create(user) - assert user_db.id is not None - assert hasattr(user_db, "oauth_accounts") - assert len(user_db.oauth_accounts) == 2 + user = await sqlmodel_user_db_oauth.create(user_create) + assert user.id is not None + + # Add OAuth account + user = await sqlmodel_user_db_oauth.add_oauth_account(user, oauth_account1) + user = await sqlmodel_user_db_oauth.add_oauth_account(user, oauth_account2) + assert len(user.oauth_accounts) == 2 + assert user.oauth_accounts[1].account_id == oauth_account2["account_id"] + assert user.oauth_accounts[0].account_id == oauth_account1["account_id"] # Update - user_db.oauth_accounts[0].access_token = "NEW_TOKEN" - await sqlmodel_user_db_oauth.update(user_db) + user = await sqlmodel_user_db_oauth.update_oauth_account( + user, user.oauth_accounts[0], {"access_token": "NEW_TOKEN"} + ) + assert user.oauth_accounts[0].access_token == "NEW_TOKEN" # Get by id id_user = await sqlmodel_user_db_oauth.get(user.id) assert id_user is not None - assert id_user.id == user_db.id + assert id_user.id == user.id assert id_user.oauth_accounts[0].access_token == "NEW_TOKEN" # Get by email - email_user = await sqlmodel_user_db_oauth.get_by_email(str(user.email)) + email_user = await sqlmodel_user_db_oauth.get_by_email(user_create["email"]) assert email_user is not None - assert email_user.id == user_db.id + assert email_user.id == user.id assert len(email_user.oauth_accounts) == 2 # Get by OAuth account oauth_user = await sqlmodel_user_db_oauth.get_by_oauth_account( - oauth_account1.oauth_name, oauth_account1.account_id + oauth_account1["oauth_name"], oauth_account1["account_id"] ) assert oauth_user is not None assert oauth_user.id == user.id - assert len(oauth_user.oauth_accounts) == 2 # Unknown OAuth account unknown_oauth_user = await sqlmodel_user_db_oauth.get_by_oauth_account("foo", "bar") From 1060fd761012c6db1c027f6f71d7577c93fc7ee1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 9 May 2022 10:59:33 +0200 Subject: [PATCH 14/22] =?UTF-8?q?Bump=20version:=200.1.2=20=E2=86=92=200.2?= =?UTF-8?q?.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi_users_db_sqlmodel/__init__.py | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 2441e15..9a7e486 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select -__version__ = "0.1.2" +__version__ = "0.2.0" class SQLModelBaseUserDB(SQLModel): diff --git a/setup.cfg b/setup.cfg index 2827268..03dd6e4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.1.2 +current_version = 0.2.0 commit = True tag = True From 940c7a42788d37ae45c441eda218b5d6e20051f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 18 Jul 2022 11:00:30 +0200 Subject: [PATCH 15/22] Update FUNDING.yml --- .github/FUNDING.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index d4860d5..332f3ca 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,2 +1 @@ -# These are supported funding model platforms -custom: https://www.buymeacoffee.com/frankie567 +github: frankie567 From 342f47a4611b3086928a064d5d04125ef2b95cff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 18 Jul 2022 11:06:10 +0200 Subject: [PATCH 16/22] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ef102de..8c900c2 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ [![PyPI version](https://badge.fury.io/py/fastapi-users-db-sqlmodel.svg)](https://badge.fury.io/py/fastapi-users-db-sqlmodel) [![Downloads](https://pepy.tech/badge/fastapi-users-db-sqlmodel)](https://pepy.tech/project/fastapi-users-db-sqlmodel)

- +

--- From fa7ae4efa811c85c8eeada31cb24c9707c3bdb59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Mon, 18 Jul 2022 18:01:08 +0200 Subject: [PATCH 17/22] Fix md-buttons URL --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8c900c2..4654684 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ [![PyPI version](https://badge.fury.io/py/fastapi-users-db-sqlmodel.svg)](https://badge.fury.io/py/fastapi-users-db-sqlmodel) [![Downloads](https://pepy.tech/badge/fastapi-users-db-sqlmodel)](https://pepy.tech/project/fastapi-users-db-sqlmodel)

- +

--- From 171120000af8119351869ea0fa72357d9ecdd81f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Thu, 11 Aug 2022 17:24:45 +0200 Subject: [PATCH 18/22] Fix BMAC button --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4654684..fe00d26 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ [![PyPI version](https://badge.fury.io/py/fastapi-users-db-sqlmodel.svg)](https://badge.fury.io/py/fastapi-users-db-sqlmodel) [![Downloads](https://pepy.tech/badge/fastapi-users-db-sqlmodel)](https://pepy.tech/project/fastapi-users-db-sqlmodel)

- +

--- From f4beec6fcf99f028f278e8d12cca2c481fa2f795 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 3 May 2023 09:05:08 +0200 Subject: [PATCH 19/22] Remove SQLAlchemy dependency lock --- pyproject.toml | 1 - requirements.txt | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e9d555a..1455110 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ requires-python = ">=3.7" requires = [ "fastapi-users >= 10.0.2", "sqlmodel", - "sqlalchemy[asyncio] >=1.4,<1.4.36", # Pin SQLAlchemy to prevent bug https://github.com/tiangolo/sqlmodel/issues/315 ] [tool.flit.metadata.urls] diff --git a/requirements.txt b/requirements.txt index 5663706..df49004 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aiosqlite >= 0.17.0 -fastapi-users >= 9.1.0 -git+https://github.com/andrewbolster/sqlmodel.git@patch-1 -sqlalchemy[asyncio] >=1.4,<1.4.36 +aiosqlite >= 0.19.0 +fastapi-users >= 10 +sqlmodel +greenlet From 4d3f4ae75d5274f4108ac143fc1fc5491809cb5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 3 May 2023 09:06:34 +0200 Subject: [PATCH 20/22] Enable Python 3.10 and 3.11 in CI test matrix --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 827440e..1114219 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python_version: [3.7, 3.8, 3.9] + python_version: [3.7, 3.8, 3.9, '3.10', '3.11'] steps: - uses: actions/checkout@v1 From cf498970db35203c1c3ebbdc8e632bb0f61e2843 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 3 May 2023 09:17:36 +0200 Subject: [PATCH 21/22] Move to Hatch for package management --- .github/workflows/build.yml | 84 +++++++++++++---------- Makefile | 17 ----- fastapi_users_db_sqlmodel/__init__.py | 4 +- fastapi_users_db_sqlmodel/access_token.py | 4 +- pyproject.toml | 78 +++++++++++++++++---- requirements.dev.txt | 18 ----- requirements.txt | 4 -- setup.cfg | 21 ------ 8 files changed, 119 insertions(+), 111 deletions(-) delete mode 100644 Makefile delete mode 100644 requirements.dev.txt delete mode 100644 requirements.txt delete mode 100644 setup.cfg diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1114219..ae04238 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -11,26 +11,32 @@ jobs: python_version: [3.7, 3.8, 3.9, '3.10', '3.11'] steps: - - uses: actions/checkout@v1 - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python_version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.dev.txt - - name: Test with pytest - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - run: | - pytest --cov=fastapi_users_db_sqlmodel/ - codecov - - name: Build and install it on system host - run: | - flit build - flit install --python $(which python) - python test_build.py + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python_version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + hatch env create + - name: Lint and typecheck + run: | + hatch run lint-check + - name: Test + run: | + hatch run test-cov-xml + - uses: codecov/codecov-action@v3 + with: + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: true + verbose: true + - name: Build and install it on system host + run: | + hatch build + pip install dist/fastapi_users_db_sqlmodel-*.whl + python test_build.py release: runs-on: ubuntu-latest @@ -38,18 +44,26 @@ jobs: if: startsWith(github.ref, 'refs/tags/') steps: - - uses: actions/checkout@v1 - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: 3.7 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.dev.txt - - name: Release on PyPI - env: - FLIT_USERNAME: ${{ secrets.FLIT_USERNAME }} - FLIT_PASSWORD: ${{ secrets.FLIT_PASSWORD }} - run: | - flit publish + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.7 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch + - name: Build and publish on PyPI + env: + HATCH_INDEX_USER: ${{ secrets.HATCH_INDEX_USER }} + HATCH_INDEX_AUTH: ${{ secrets.HATCH_INDEX_AUTH }} + run: | + hatch build + hatch publish + - name: Create release + uses: ncipollo/release-action@v1 + with: + draft: true + body: ${{ github.event.head_commit.message }} + artifacts: dist/*.whl,dist/*.tar.gz + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/Makefile b/Makefile deleted file mode 100644 index 91608b0..0000000 --- a/Makefile +++ /dev/null @@ -1,17 +0,0 @@ -isort: - isort ./fastapi_users_db_sqlmodel ./tests - -format: isort - black . - -test: - pytest --cov=fastapi_users_db_sqlmodel/ --cov-report=term-missing --cov-fail-under=100 - -bumpversion-major: - bumpversion major - -bumpversion-minor: - bumpversion minor - -bumpversion-patch: - bumpversion patch diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 9a7e486..902c368 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -75,7 +75,7 @@ async def get(self, id: ID) -> Optional[UP]: async def get_by_email(self, email: str) -> Optional[UP]: """Get a single user by email.""" - statement = select(self.user_model).where( + statement = select(self.user_model).where( # type: ignore func.lower(self.user_model.email) == func.lower(email) ) results = self.session.exec(statement) @@ -171,7 +171,7 @@ async def get(self, id: ID) -> Optional[UP]: async def get_by_email(self, email: str) -> Optional[UP]: """Get a single user by email.""" - statement = select(self.user_model).where( + statement = select(self.user_model).where( # type: ignore func.lower(self.user_model.email) == func.lower(email) ) results = await self.session.execute(statement) diff --git a/fastapi_users_db_sqlmodel/access_token.py b/fastapi_users_db_sqlmodel/access_token.py index 7395a43..8a4519e 100644 --- a/fastapi_users_db_sqlmodel/access_token.py +++ b/fastapi_users_db_sqlmodel/access_token.py @@ -43,7 +43,7 @@ def __init__(self, session: Session, access_token_model: Type[AP]): async def get_by_token( self, token: str, max_age: Optional[datetime] = None ) -> Optional[AP]: - statement = select(self.access_token_model).where( + statement = select(self.access_token_model).where( # type: ignore self.access_token_model.token == token ) if max_age is not None: @@ -90,7 +90,7 @@ def __init__(self, session: AsyncSession, access_token_model: Type[AP]): async def get_by_token( self, token: str, max_age: Optional[datetime] = None ) -> Optional[AP]: - statement = select(self.access_token_model).where( + statement = select(self.access_token_model).where( # type: ignore self.access_token_model.token == token ) if max_age is not None: diff --git a/pyproject.toml b/pyproject.toml index 1455110..971fac3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,30 +1,84 @@ +[tool.pytest.ini_options] +asyncio_mode = "auto" +addopts = "--ignore=test_build.py" + +[tool.ruff] +extend-select = ["I"] + +[tool.hatch] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.version] +source = "regex_commit" +commit_extra_args = ["-e"] +path = "fastapi_users_db_sqlmodel/__init__.py" + +[tool.hatch.envs.default] +dependencies = [ + "aiosqlite", + "pytest", + "pytest-asyncio", + "black", + "mypy", + "pytest-cov", + "pytest-mock", + "httpx", + "asgi_lifespan", + "ruff", +] + +[tool.hatch.envs.default.scripts] +test = "pytest --cov=fastapi_users_db_sqlmodel/ --cov-report=term-missing --cov-fail-under=100" +test-cov-xml = "pytest --cov=fastapi_users_db_sqlmodel/ --cov-report=xml --cov-fail-under=100" +lint = [ + "black . ", + "ruff --fix .", + "mypy fastapi_users_db_sqlmodel/", +] +lint-check = [ + "black --check .", + "ruff .", + "mypy fastapi_users_db_sqlmodel/", +] + +[tool.hatch.build.targets.sdist] +support-legacy = true # Create setup.py + [build-system] -requires = ["flit_core >=2,<3"] -build-backend = "flit_core.buildapi" - -[tool.flit.metadata] -module = "fastapi_users_db_sqlmodel" -dist-name = "fastapi-users-db-sqlmodel" -author = "François Voron" -author-email = "fvoron@gmail.com" -home-page = "https://github.com/fastapi-users/fastapi-users-db-sqlmodel" +requires = ["hatchling", "hatch-regex-commit"] +build-backend = "hatchling.build" + +[project] +name = "fastapi-users-db-sqlmodel" +authors = [ + { name = "François Voron", email = "fvoron@gmail.com" }, +] +description = "FastAPI Users database adapter for SQLModel" +readme = "README.md" +dynamic = ["version"] classifiers = [ "License :: OSI Approved :: MIT License", "Development Status :: 5 - Production/Stable", + "Framework :: FastAPI", "Framework :: AsyncIO", "Intended Audience :: Developers", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3 :: Only", "Topic :: Internet :: WWW/HTTP :: Session", ] -description-file = "README.md" requires-python = ">=3.7" -requires = [ +dependencies = [ "fastapi-users >= 10.0.2", + "greenlet", "sqlmodel", ] -[tool.flit.metadata.urls] +[project.urls] Documentation = "https://fastapi-users.github.io/fastapi-users" +Source = "https://github.com/fastapi-users/fastapi-users-db-sqlmodel" diff --git a/requirements.dev.txt b/requirements.dev.txt deleted file mode 100644 index e8ebaaf..0000000 --- a/requirements.dev.txt +++ /dev/null @@ -1,18 +0,0 @@ --r requirements.txt - -flake8 -pytest -requests -isort -pytest-asyncio -flake8-docstrings -black -mypy -codecov -pytest-cov -pytest-mock -asynctest -flit -bumpversion -httpx -asgi_lifespan diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index df49004..0000000 --- a/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -aiosqlite >= 0.19.0 -fastapi-users >= 10 -sqlmodel -greenlet diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 03dd6e4..0000000 --- a/setup.cfg +++ /dev/null @@ -1,21 +0,0 @@ -[bumpversion] -current_version = 0.2.0 -commit = True -tag = True - -[bumpversion:file:fastapi_users_db_sqlmodel/__init__.py] -search = __version__ = "{current_version}" -replace = __version__ = "{new_version}" - -[flake8] -exclude = docs -max-line-length = 88 -docstring-convention = numpy -ignore = D1 - -[isort] -profile = black - -[tool:pytest] -addopts = --ignore=test_build.py -asyncio_mode = strict From 83980d7f20886120f4636a102ab1822b4c366f63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Wed, 3 May 2023 09:20:43 +0200 Subject: [PATCH 22/22] =?UTF-8?q?Bump=20version=200.2.0=20=E2=86=92=200.3.?= =?UTF-8?q?0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improvements ------------ * Use latest version of SQLModel * Remove SQLAlchemy dependency pin, it's now in sync with the one asked by SQLModel --- fastapi_users_db_sqlmodel/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi_users_db_sqlmodel/__init__.py b/fastapi_users_db_sqlmodel/__init__.py index 902c368..695c5e2 100644 --- a/fastapi_users_db_sqlmodel/__init__.py +++ b/fastapi_users_db_sqlmodel/__init__.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import selectinload from sqlmodel import Field, Session, SQLModel, func, select -__version__ = "0.2.0" +__version__ = "0.3.0" class SQLModelBaseUserDB(SQLModel):