diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 894ebfd..89f9409 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -418,6 +418,28 @@ def construct_fields_and_filters( return fields, filters +class SQLAlchemyPrimaryKeySerializer(object): + """ + Serializes/unserializes primary keys + """ + + DEFAULT = None + + def __init__(self, serialize, deserialize): + self.serialize = serialize + self.deserialize = deserialize + + @classmethod + def default(cls): + if cls.DEFAULT is None: + cls.DEFAULT = cls( + serialize=lambda keys: str(tuple(keys)) if len(keys) > 1 else keys[0], + deserialize=lambda id: id, + ) + + return cls.DEFAULT + + class SQLAlchemyBase(BaseType): """ This class contains initialization code that is common to both ObjectTypes @@ -441,6 +463,7 @@ def __init_subclass_with_meta__( connection_field_factory=None, _meta=None, create_filters=True, + serializer=None, **options, ): # We always want to bypass this hook unless we're defining a concrete @@ -530,6 +553,12 @@ def __init_subclass_with_meta__( cls.connection = connection # Public way to get the connection + if serializer is None: + cls.serializer = SQLAlchemyPrimaryKeySerializer.default() + + else: + cls.serializer = serializer + super(SQLAlchemyBase, cls).__init_subclass_with_meta__( _meta=_meta, interfaces=interfaces, **options ) @@ -557,9 +586,11 @@ def get_query(cls, info): @classmethod def get_node(cls, info, id): + key = cls.serializer.deserialize(id) + if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4: try: - return cls.get_query(info).get(id) + return cls.get_query(info).get(key) except NoResultFound: return None @@ -567,18 +598,23 @@ def get_node(cls, info, id): if isinstance(session, AsyncSession): async def get_result() -> Any: - return await session.get(cls._meta.model, id) + return await session.get(cls._meta.model, key) return get_result() try: - return cls.get_query(info).get(id) + return cls.get_query(info).get(key) except NoResultFound: return None def resolve_id(self, info): - # graphene_type = info.parent_type.graphene_type + graphene_type = info.parent_type.graphene_type keys = self.__mapper__.primary_key_from_instance(self) - return str(tuple(keys)) if len(keys) > 1 else keys[0] + + try: + return graphene_type.serializer.serialize(keys) + + except Exception as e: + raise ValueError(f"Non-serializable primary key: {e}") from e @classmethod def enum_for_field(cls, field_name):