diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 93c55353c..a8ab3a218 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,7 @@ Changelog 0.17.4 ------ - Fix `update_or_create`. (#782) +- Add `contains`, `contained_by` and `filter` to `JSONField` 0.17.3 ------ diff --git a/docs/query.rst b/docs/query.rst index 860347279..c41345878 100644 --- a/docs/query.rst +++ b/docs/query.rst @@ -286,6 +286,46 @@ Specially, you can filter date part with one of following, note that current onl teams = await Team.filter(created_at__month=12) teams = await Team.filter(created_at__day=5) +In PostgreSQL and MYSQL, you can use the ``contains``, ``contained_by`` and ``filter`` options in ``JSONField``: + +.. code-block:: python3 + + class JSONModel: + data = fields.JSONField() + + await JSONModel.create(data=["text", 3, {"msg": "msg2"}]) + obj = await JSONModel.filter(data__contains=[{"msg": "msg2"}]).first() + + await JSONModel.create(data=["text"]) + await JSONModel.create(data=["tortoise", "msg"]) + await JSONModel.create(data=["tortoise"]) + + objects = await JSONModel.filter(data__contained_by=["text", "tortoise", "msg"]) + +.. code-block:: python3 + + class JSONModel: + data = fields.JSONField() + + await JSONModel.create(data={"breed": "labrador", + "owner": { + "name": "Boby", + "last": None, + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + }) + + obj1 = await JSONModel.filter(data__filter={"breed": "labrador"}).first() + obj2 = await JSONModel.filter(data__filter={"owner__name": "Boby"}).first() + obj3 = await JSONModel.filter(data__filter={"owner__other_pets__0__name": "Fishy"}).first() + obj4 = await JSONModel.filter(data__filter={"breed__not": "a"}).first() + obj5 = await JSONModel.filter(data__filter={"owner__name__isnull": True}).first() + obj6 = await JSONModel.filter(data__filter={"owner__last__not_isnull": False}).first() + Complex prefetch ================ diff --git a/tests/fields/test_json.py b/tests/fields/test_json.py index 4118c3008..81efc30b5 100644 --- a/tests/fields/test_json.py +++ b/tests/fields/test_json.py @@ -65,6 +65,182 @@ async def test_list(self): obj2 = await testmodels.JSONFields.get(id=obj.id) self.assertEqual(obj, obj2) + @test.requireCapability(dialect="mysql") + @test.requireCapability(dialect="postgres") + async def test_list_contains(self): + await testmodels.JSONFields.create(data=["text", 3, {"msg": "msg2"}]) + obj = await testmodels.JSONFields.filter(data__contains=[{"msg": "msg2"}]).first() + self.assertEqual(obj.data, ["text", 3, {"msg": "msg2"}]) + await obj.save() + obj2 = await testmodels.JSONFields.get(id=obj.id) + self.assertEqual(obj, obj2) + + @test.requireCapability(dialect="mysql") + @test.requireCapability(dialect="postgres") + async def test_list_contained_by(self): + obj0 = await testmodels.JSONFields.create(data=["text"]) + obj1 = await testmodels.JSONFields.create(data=["tortoise", "msg"]) + obj2 = await testmodels.JSONFields.create(data=["tortoise"]) + obj3 = await testmodels.JSONFields.create(data=["new_message", "some_message"]) + objs = set( + await testmodels.JSONFields.filter(data__contained_by=["text", "tortoise", "msg"]) + ) + created_objs = {obj0, obj1, obj2} + self.assertSetEqual(created_objs, objs) + self.assertTrue(obj3 not in objs) + + @test.requireCapability(dialect="mysql") + @test.requireCapability(dialect="postgres") + async def test_filter(self): + obj0 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": "Bob", + "last": None, + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + obj1 = await testmodels.JSONFields.create( + data={ + "breed": "husky", + "owner": { + "name": "Goldast", + "last": None, + "other_pets": [ + { + "name": None, + } + ], + }, + } + ) + obj = await testmodels.JSONFields.get(data__filter={"breed": "labrador"}) + obj2 = await testmodels.JSONFields.get(data__filter={"owner__name": "Goldast"}) + obj3 = await testmodels.JSONFields.get(data__filter={"owner__other_pets__0__name": "Fishy"}) + + self.assertEqual(obj0, obj) + self.assertEqual(obj1, obj2) + self.assertEqual(obj0, obj3) + + @test.requireCapability(dialect="mysql") + @test.requireCapability(dialect="postgres") + async def test_filter_not_condition(self): + obj0 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": "Bob", + "last": None, + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + obj1 = await testmodels.JSONFields.create( + data={ + "breed": "husky", + "owner": { + "name": "Goldast", + "last": None, + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + + obj2 = await testmodels.JSONFields.get(data__filter={"breed__not": "husky"}) + obj3 = await testmodels.JSONFields.get(data__filter={"breed__not": "labrador"}) + self.assertEqual(obj0, obj2) + self.assertEqual(obj1, obj3) + + @test.requireCapability(dialect="mysql") + @test.requireCapability(dialect="postgres") + async def test_filter_is_null_condition(self): + obj0 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": "Boby", + "last": "Cloud", + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + + obj1 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": None, + "last": "Cloud", + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + + obj2 = await testmodels.JSONFields.get(data__filter={"owner__name__isnull": False}) + obj3 = await testmodels.JSONFields.get(data__filter={"owner__name__isnull": True}) + self.assertEqual(obj0, obj2) + self.assertEqual(obj1, obj3) + + @test.requireCapability(dialect="mysql") + @test.requireCapability(dialect="postgres") + async def test_filter_not_is_null_condition(self): + obj0 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": "Boby", + "last": "Cloud", + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + + obj1 = await testmodels.JSONFields.create( + data={ + "breed": "labrador", + "owner": { + "name": None, + "last": "Cloud", + "other_pets": [ + { + "name": "Fishy", + } + ], + }, + } + ) + + obj2 = await testmodels.JSONFields.get(data__filter={"owner__name__not_isnull": True}) + obj3 = await testmodels.JSONFields.get(data__filter={"owner__name__not_isnull": False}) + self.assertEqual(obj0, obj2) + self.assertEqual(obj1, obj3) + async def test_values(self): obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]}) values = await testmodels.JSONFields.filter(id=obj0.id).values("data") diff --git a/tortoise/backends/asyncpg/executor.py b/tortoise/backends/asyncpg/executor.py index 00f2ecf3e..0313dd2ab 100644 --- a/tortoise/backends/asyncpg/executor.py +++ b/tortoise/backends/asyncpg/executor.py @@ -7,8 +7,13 @@ from tortoise import Model from tortoise.backends.base.executor import BaseExecutor +from tortoise.contrib.postgres.json_functions import ( + postgres_json_contained_by, + postgres_json_contains, + postgres_json_filter, +) from tortoise.contrib.postgres.search import SearchCriterion -from tortoise.filters import search +from tortoise.filters import json_contained_by, json_contains, json_filter, search def postgres_search(field: Term, value: Term): @@ -18,7 +23,12 @@ def postgres_search(field: Term, value: Term): class AsyncpgExecutor(BaseExecutor): EXPLAIN_PREFIX = "EXPLAIN (FORMAT JSON, VERBOSE)" DB_NATIVE = BaseExecutor.DB_NATIVE | {bool, uuid.UUID} - FILTER_FUNC_OVERRIDE = {search: postgres_search} + FILTER_FUNC_OVERRIDE = { + search: postgres_search, + json_contains: postgres_json_contains, + json_contained_by: postgres_json_contained_by, + json_filter: postgres_json_filter, + } def parameter(self, pos: int) -> Parameter: return Parameter("$%d" % (pos + 1,)) diff --git a/tortoise/backends/mysql/executor.py b/tortoise/backends/mysql/executor.py index 76f83e2f5..b998db966 100644 --- a/tortoise/backends/mysql/executor.py +++ b/tortoise/backends/mysql/executor.py @@ -4,6 +4,11 @@ from tortoise import Model from tortoise.backends.base.executor import BaseExecutor +from tortoise.contrib.mysql.json_functions import ( + mysql_json_contained_by, + mysql_json_contains, + mysql_json_filter, +) from tortoise.contrib.mysql.search import SearchCriterion from tortoise.fields import BigIntField, IntField, SmallIntField from tortoise.filters import ( @@ -17,6 +22,9 @@ insensitive_ends_with, insensitive_exact, insensitive_starts_with, + json_contained_by, + json_contains, + json_filter, search, starts_with, ) @@ -97,6 +105,9 @@ class MySQLExecutor(BaseExecutor): insensitive_starts_with: mysql_insensitive_starts_with, insensitive_ends_with: mysql_insensitive_ends_with, search: mysql_search, + json_contains: mysql_json_contains, + json_contained_by: mysql_json_contained_by, + json_filter: mysql_json_filter, } EXPLAIN_PREFIX = "EXPLAIN FORMAT=JSON" diff --git a/tortoise/contrib/mysql/json_functions.py b/tortoise/contrib/mysql/json_functions.py new file mode 100644 index 000000000..dd2b93034 --- /dev/null +++ b/tortoise/contrib/mysql/json_functions.py @@ -0,0 +1,85 @@ +import json +import operator +from typing import Any, Dict, List + +from pypika.functions import Cast +from pypika.terms import Criterion +from pypika.terms import Function as PypikaFunction +from pypika.terms import Term, ValueWrapper + +from tortoise.filters import not_equal + + +class JSONContains(PypikaFunction): # type: ignore + def __init__(self, column_name: Term, target_list: Term): + super(JSONContains, self).__init__("JSON_CONTAINS", column_name, target_list) + + +class JSONExtract(PypikaFunction): # type: ignore + def __init__(self, column_name: Term, query_list: List[Term]): + query = self.make_query(query_list) + super(JSONExtract, self).__init__("JSON_EXTRACT", column_name, query) + + @classmethod + def serialize_value(cls, value: Any): + if isinstance(value, int): + return f"[{value}]" + if isinstance(value, str): + return f".{value}" + + def make_query(self, query_list: List[Term]): + query = ["$"] + for value in query_list: + query.append(self.serialize_value(value)) + + return "".join(query) + + +def mysql_json_contains(field: Term, value: str) -> Criterion: + return JSONContains(field, ValueWrapper(value)) + + +def mysql_json_contained_by(field: Term, value_str: str) -> Criterion: + values = json.loads(value_str) + contained_by = None + for value in values: + if contained_by is None: + contained_by = JSONContains(field, ValueWrapper(json.dumps([value]))) + else: + contained_by |= JSONContains(field, ValueWrapper(json.dumps([value]))) # type: ignore + return contained_by + + +def _mysql_json_is_null(left: Term, is_null: bool): + if is_null is True: + return operator.eq(left, Cast("null", "JSON")) + else: + return not_equal(left, Cast("null", "JSON")) + + +def _mysql_json_not_is_null(left: Term, is_null: bool): + return _mysql_json_is_null(left, not is_null) + + +operator_keywords = { + "not": not_equal, + "isnull": _mysql_json_is_null, + "not_isnull": _mysql_json_not_is_null, +} + + +def _serialize_value(value: Any): + if type(value) in [dict, list]: + return json.dumps(value) + return value + + +def mysql_json_filter(field: Term, value: Dict) -> Criterion: + ((key, filter_value),) = value.items() + filter_value = _serialize_value(filter_value) + key_parts = list(map(lambda item: int(item) if item.isdigit() else str(item), key.split("__"))) + operator_ = operator.eq + if key_parts[-1] in operator_keywords: + operator_ = operator_keywords[str(key_parts.pop(-1))] + + return operator_(JSONExtract(field, key_parts), filter_value) diff --git a/tortoise/contrib/postgres/json_functions.py b/tortoise/contrib/postgres/json_functions.py new file mode 100644 index 000000000..513a50a30 --- /dev/null +++ b/tortoise/contrib/postgres/json_functions.py @@ -0,0 +1,64 @@ +import json +import operator +from typing import Any, Callable, Dict, List + +from pypika.enums import JSONOperators +from pypika.terms import BasicCriterion, Criterion, Term, ValueWrapper + +from tortoise.filters import is_null, not_equal, not_null + + +def postgres_json_contains(field: Term, value: str) -> Criterion: + return BasicCriterion(JSONOperators.CONTAINS, field, ValueWrapper(value)) + + +def postgres_json_contained_by(field: Term, value: str) -> Criterion: + return BasicCriterion(JSONOperators.CONTAINED_BY, field, ValueWrapper(value)) + + +operator_keywords = { + "not": not_equal, + "isnull": is_null, + "not_isnull": not_null, +} + + +def _get_json_criterion(items: List): + if len(items) == 2: + left = items.pop(0) + right = items.pop(0) + return BasicCriterion(JSONOperators.GET_TEXT_VALUE, ValueWrapper(left), ValueWrapper(right)) + + left = items.pop(0) + return BasicCriterion( + JSONOperators.GET_JSON_VALUE, ValueWrapper(left), _get_json_criterion(items) + ) + + +def _create_json_criterion(items: List, field_term: Term, operator_: Callable, value: str): + if len(items) == 1: + term = items.pop(0) + return operator_( + BasicCriterion(JSONOperators.GET_TEXT_VALUE, field_term, ValueWrapper(term)), value + ) + + return operator_( + BasicCriterion(JSONOperators.GET_JSON_VALUE, field_term, _get_json_criterion(items)), value + ) + + +def _serialize_value(value: Any): + if type(value) in [dict, list]: + return json.dumps(value) + return value + + +def postgres_json_filter(field: Term, value: Dict) -> Criterion: + ((key, filter_value),) = value.items() + filter_value = _serialize_value(filter_value) + key_parts = list(map(lambda item: int(item) if item.isdigit() else str(item), key.split("__"))) + operator_ = operator.eq + if key_parts[-1] in operator_keywords: + operator_ = operator_keywords[str(key_parts.pop(-1))] + + return _create_json_criterion(key_parts, field, operator_, filter_value) diff --git a/tortoise/filters.py b/tortoise/filters.py index bd51441b8..639239ab6 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -16,7 +16,7 @@ format_quotes, ) -from tortoise.fields import Field +from tortoise.fields import Field, JSONField from tortoise.fields.relational import BackwardFKRelation, ManyToManyFieldInstance if TYPE_CHECKING: # pragma: nocoverage @@ -109,6 +109,10 @@ def string_encoder(value: Any, instance: "Model", field: Field) -> str: return str(value) +def json_encoder(value: Any, instance: "Model", field: Field) -> Dict: + return value + + ############################################################################## # Operators # Should be type: (field: Term, value: Any) -> Criterion: @@ -224,6 +228,21 @@ def extract_microsecond_equal(field: Term, value: int) -> Criterion: return Extract(DatePart.microsecond, field).eq(value) +def json_contains(field: Term, value: str) -> Criterion: + # will be override in each executor + pass + + +def json_contained_by(field: Term, value: str) -> Criterion: + # will be override in each executor + pass + + +def json_filter(field: Term, value: Dict) -> Criterion: + # will be override in each executor + pass + + ############################################################################## # Filter resolvers ############################################################################## @@ -309,6 +328,50 @@ def get_backward_fk_filters(field_name: str, field: BackwardFKRelation) -> Dict[ } +def get_json_filter(field_name: str, source_field: str): + actual_field_name = field_name + return { + field_name: { + "field": actual_field_name, + "source_field": source_field, + "operator": operator.eq, + }, + f"{field_name}__not": { + "field": actual_field_name, + "source_field": source_field, + "operator": not_equal, + }, + f"{field_name}__isnull": { + "field": actual_field_name, + "source_field": source_field, + "operator": is_null, + "value_encoder": bool_encoder, + }, + f"{field_name}__not_isnull": { + "field": actual_field_name, + "source_field": source_field, + "operator": not_null, + "value_encoder": bool_encoder, + }, + f"{field_name}__contains": { + "field": actual_field_name, + "source_field": source_field, + "operator": json_contains, + }, + f"{field_name}__contained_by": { + "field": actual_field_name, + "source_field": source_field, + "operator": json_contained_by, + }, + f"{field_name}__filter": { + "field": actual_field_name, + "source_field": source_field, + "operator": json_filter, + "value_encoder": json_encoder, + }, + } + + def get_filters_for_field( field_name: str, field: Optional[Field], source_field: str ) -> Dict[str, dict]: @@ -316,6 +379,9 @@ def get_filters_for_field( return get_m2m_filters(field_name, field) if isinstance(field, BackwardFKRelation): return get_backward_fk_filters(field_name, field) + if isinstance(field, JSONField): + return get_json_filter(field_name, source_field) + actual_field_name = field_name if field_name == "pk" and field: actual_field_name = field.model_field_name