Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions graphene_sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,28 @@ def construct_fields_and_filters(
return fields, filters


class SQLAlchemyPrimaryKeySerializer(object):
"""
Serializes/unserializes primary keys
"""

DEFAULT = None

def __init__(self, serialize, deserialize):
self.serialize = serialize
self.deserialize = deserialize

@classmethod
def default(cls):
if cls.DEFAULT is None:
cls.DEFAULT = cls(
serialize=lambda keys: str(tuple(keys)) if len(keys) > 1 else keys[0],
deserialize=lambda id: id,
)

return cls.DEFAULT


class SQLAlchemyBase(BaseType):
"""
This class contains initialization code that is common to both ObjectTypes
Expand All @@ -441,6 +463,7 @@ def __init_subclass_with_meta__(
connection_field_factory=None,
_meta=None,
create_filters=True,
serializer=None,
**options,
):
# We always want to bypass this hook unless we're defining a concrete
Expand Down Expand Up @@ -530,6 +553,12 @@ def __init_subclass_with_meta__(

cls.connection = connection # Public way to get the connection

if serializer is None:
cls.serializer = SQLAlchemyPrimaryKeySerializer.default()

else:
cls.serializer = serializer

super(SQLAlchemyBase, cls).__init_subclass_with_meta__(
_meta=_meta, interfaces=interfaces, **options
)
Expand Down Expand Up @@ -557,28 +586,35 @@ def get_query(cls, info):

@classmethod
def get_node(cls, info, id):
key = cls.serializer.deserialize(id)

if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
try:
return cls.get_query(info).get(id)
return cls.get_query(info).get(key)
except NoResultFound:
return None

session = get_session(info.context)
if isinstance(session, AsyncSession):

async def get_result() -> Any:
return await session.get(cls._meta.model, id)
return await session.get(cls._meta.model, key)

return get_result()
try:
return cls.get_query(info).get(id)
return cls.get_query(info).get(key)
except NoResultFound:
return None

def resolve_id(self, info):
# graphene_type = info.parent_type.graphene_type
graphene_type = info.parent_type.graphene_type
keys = self.__mapper__.primary_key_from_instance(self)
return str(tuple(keys)) if len(keys) > 1 else keys[0]

try:
return graphene_type.serializer.serialize(keys)

except Exception as e:
raise ValueError(f"Non-serializable primary key: {e}") from e

@classmethod
def enum_for_field(cls, field_name):
Expand Down