Skip to content
This repository has been archived by the owner. It is now read-only.

Commit

Permalink
Add some features to JSONField (tortoise#667)
Browse files Browse the repository at this point in the history
- contains
 - contained_by
 - filter
  + by key => equal, is_null, is_not_null, not
  + by index => equal, is_null, is_not_null, not
  • Loading branch information
ahmadgh74 authored Jun 7, 2021
1 parent 850a2ca commit d69b9e5
Show file tree
Hide file tree
Showing 8 changed files with 456 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Changelog
0.17.4
------
- Fix `update_or_create`. (#782)
- Add `contains`, `contained_by` and `filter` to `JSONField`

0.17.3
------
Expand Down
40 changes: 40 additions & 0 deletions docs/query.rst
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,46 @@ Specially, you can filter date part with one of following, note that current onl
teams = await Team.filter(created_at__month=12)
teams = await Team.filter(created_at__day=5)
In PostgreSQL and MYSQL, you can use the ``contains``, ``contained_by`` and ``filter`` options in ``JSONField``:

.. code-block:: python3
class JSONModel:
data = fields.JSONField()
await JSONModel.create(data=["text", 3, {"msg": "msg2"}])
obj = await JSONModel.filter(data__contains=[{"msg": "msg2"}]).first()
await JSONModel.create(data=["text"])
await JSONModel.create(data=["tortoise", "msg"])
await JSONModel.create(data=["tortoise"])
objects = await JSONModel.filter(data__contained_by=["text", "tortoise", "msg"])
.. code-block:: python3
class JSONModel:
data = fields.JSONField()
await JSONModel.create(data={"breed": "labrador",
"owner": {
"name": "Boby",
"last": None,
"other_pets": [
{
"name": "Fishy",
}
],
},
})
obj1 = await JSONModel.filter(data__filter={"breed": "labrador"}).first()
obj2 = await JSONModel.filter(data__filter={"owner__name": "Boby"}).first()
obj3 = await JSONModel.filter(data__filter={"owner__other_pets__0__name": "Fishy"}).first()
obj4 = await JSONModel.filter(data__filter={"breed__not": "a"}).first()
obj5 = await JSONModel.filter(data__filter={"owner__name__isnull": True}).first()
obj6 = await JSONModel.filter(data__filter={"owner__last__not_isnull": False}).first()
Complex prefetch
================

Expand Down
176 changes: 176 additions & 0 deletions tests/fields/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,182 @@ async def test_list(self):
obj2 = await testmodels.JSONFields.get(id=obj.id)
self.assertEqual(obj, obj2)

@test.requireCapability(dialect="mysql")
@test.requireCapability(dialect="postgres")
async def test_list_contains(self):
await testmodels.JSONFields.create(data=["text", 3, {"msg": "msg2"}])
obj = await testmodels.JSONFields.filter(data__contains=[{"msg": "msg2"}]).first()
self.assertEqual(obj.data, ["text", 3, {"msg": "msg2"}])
await obj.save()
obj2 = await testmodels.JSONFields.get(id=obj.id)
self.assertEqual(obj, obj2)

@test.requireCapability(dialect="mysql")
@test.requireCapability(dialect="postgres")
async def test_list_contained_by(self):
obj0 = await testmodels.JSONFields.create(data=["text"])
obj1 = await testmodels.JSONFields.create(data=["tortoise", "msg"])
obj2 = await testmodels.JSONFields.create(data=["tortoise"])
obj3 = await testmodels.JSONFields.create(data=["new_message", "some_message"])
objs = set(
await testmodels.JSONFields.filter(data__contained_by=["text", "tortoise", "msg"])
)
created_objs = {obj0, obj1, obj2}
self.assertSetEqual(created_objs, objs)
self.assertTrue(obj3 not in objs)

@test.requireCapability(dialect="mysql")
@test.requireCapability(dialect="postgres")
async def test_filter(self):
obj0 = await testmodels.JSONFields.create(
data={
"breed": "labrador",
"owner": {
"name": "Bob",
"last": None,
"other_pets": [
{
"name": "Fishy",
}
],
},
}
)
obj1 = await testmodels.JSONFields.create(
data={
"breed": "husky",
"owner": {
"name": "Goldast",
"last": None,
"other_pets": [
{
"name": None,
}
],
},
}
)
obj = await testmodels.JSONFields.get(data__filter={"breed": "labrador"})
obj2 = await testmodels.JSONFields.get(data__filter={"owner__name": "Goldast"})
obj3 = await testmodels.JSONFields.get(data__filter={"owner__other_pets__0__name": "Fishy"})

self.assertEqual(obj0, obj)
self.assertEqual(obj1, obj2)
self.assertEqual(obj0, obj3)

@test.requireCapability(dialect="mysql")
@test.requireCapability(dialect="postgres")
async def test_filter_not_condition(self):
obj0 = await testmodels.JSONFields.create(
data={
"breed": "labrador",
"owner": {
"name": "Bob",
"last": None,
"other_pets": [
{
"name": "Fishy",
}
],
},
}
)
obj1 = await testmodels.JSONFields.create(
data={
"breed": "husky",
"owner": {
"name": "Goldast",
"last": None,
"other_pets": [
{
"name": "Fishy",
}
],
},
}
)

obj2 = await testmodels.JSONFields.get(data__filter={"breed__not": "husky"})
obj3 = await testmodels.JSONFields.get(data__filter={"breed__not": "labrador"})
self.assertEqual(obj0, obj2)
self.assertEqual(obj1, obj3)

@test.requireCapability(dialect="mysql")
@test.requireCapability(dialect="postgres")
async def test_filter_is_null_condition(self):
obj0 = await testmodels.JSONFields.create(
data={
"breed": "labrador",
"owner": {
"name": "Boby",
"last": "Cloud",
"other_pets": [
{
"name": "Fishy",
}
],
},
}
)

obj1 = await testmodels.JSONFields.create(
data={
"breed": "labrador",
"owner": {
"name": None,
"last": "Cloud",
"other_pets": [
{
"name": "Fishy",
}
],
},
}
)

obj2 = await testmodels.JSONFields.get(data__filter={"owner__name__isnull": False})
obj3 = await testmodels.JSONFields.get(data__filter={"owner__name__isnull": True})
self.assertEqual(obj0, obj2)
self.assertEqual(obj1, obj3)

@test.requireCapability(dialect="mysql")
@test.requireCapability(dialect="postgres")
async def test_filter_not_is_null_condition(self):
obj0 = await testmodels.JSONFields.create(
data={
"breed": "labrador",
"owner": {
"name": "Boby",
"last": "Cloud",
"other_pets": [
{
"name": "Fishy",
}
],
},
}
)

obj1 = await testmodels.JSONFields.create(
data={
"breed": "labrador",
"owner": {
"name": None,
"last": "Cloud",
"other_pets": [
{
"name": "Fishy",
}
],
},
}
)

obj2 = await testmodels.JSONFields.get(data__filter={"owner__name__not_isnull": True})
obj3 = await testmodels.JSONFields.get(data__filter={"owner__name__not_isnull": False})
self.assertEqual(obj0, obj2)
self.assertEqual(obj1, obj3)

async def test_values(self):
obj0 = await testmodels.JSONFields.create(data={"some": ["text", 3]})
values = await testmodels.JSONFields.filter(id=obj0.id).values("data")
Expand Down
14 changes: 12 additions & 2 deletions tortoise/backends/asyncpg/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@

from tortoise import Model
from tortoise.backends.base.executor import BaseExecutor
from tortoise.contrib.postgres.json_functions import (
postgres_json_contained_by,
postgres_json_contains,
postgres_json_filter,
)
from tortoise.contrib.postgres.search import SearchCriterion
from tortoise.filters import search
from tortoise.filters import json_contained_by, json_contains, json_filter, search


def postgres_search(field: Term, value: Term):
Expand All @@ -18,7 +23,12 @@ def postgres_search(field: Term, value: Term):
class AsyncpgExecutor(BaseExecutor):
EXPLAIN_PREFIX = "EXPLAIN (FORMAT JSON, VERBOSE)"
DB_NATIVE = BaseExecutor.DB_NATIVE | {bool, uuid.UUID}
FILTER_FUNC_OVERRIDE = {search: postgres_search}
FILTER_FUNC_OVERRIDE = {
search: postgres_search,
json_contains: postgres_json_contains,
json_contained_by: postgres_json_contained_by,
json_filter: postgres_json_filter,
}

def parameter(self, pos: int) -> Parameter:
return Parameter("$%d" % (pos + 1,))
Expand Down
11 changes: 11 additions & 0 deletions tortoise/backends/mysql/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

from tortoise import Model
from tortoise.backends.base.executor import BaseExecutor
from tortoise.contrib.mysql.json_functions import (
mysql_json_contained_by,
mysql_json_contains,
mysql_json_filter,
)
from tortoise.contrib.mysql.search import SearchCriterion
from tortoise.fields import BigIntField, IntField, SmallIntField
from tortoise.filters import (
Expand All @@ -17,6 +22,9 @@
insensitive_ends_with,
insensitive_exact,
insensitive_starts_with,
json_contained_by,
json_contains,
json_filter,
search,
starts_with,
)
Expand Down Expand Up @@ -97,6 +105,9 @@ class MySQLExecutor(BaseExecutor):
insensitive_starts_with: mysql_insensitive_starts_with,
insensitive_ends_with: mysql_insensitive_ends_with,
search: mysql_search,
json_contains: mysql_json_contains,
json_contained_by: mysql_json_contained_by,
json_filter: mysql_json_filter,
}
EXPLAIN_PREFIX = "EXPLAIN FORMAT=JSON"

Expand Down
85 changes: 85 additions & 0 deletions tortoise/contrib/mysql/json_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import json
import operator
from typing import Any, Dict, List

from pypika.functions import Cast
from pypika.terms import Criterion
from pypika.terms import Function as PypikaFunction
from pypika.terms import Term, ValueWrapper

from tortoise.filters import not_equal


class JSONContains(PypikaFunction): # type: ignore
def __init__(self, column_name: Term, target_list: Term):
super(JSONContains, self).__init__("JSON_CONTAINS", column_name, target_list)


class JSONExtract(PypikaFunction): # type: ignore
def __init__(self, column_name: Term, query_list: List[Term]):
query = self.make_query(query_list)
super(JSONExtract, self).__init__("JSON_EXTRACT", column_name, query)

@classmethod
def serialize_value(cls, value: Any):
if isinstance(value, int):
return f"[{value}]"
if isinstance(value, str):
return f".{value}"

def make_query(self, query_list: List[Term]):
query = ["$"]
for value in query_list:
query.append(self.serialize_value(value))

return "".join(query)


def mysql_json_contains(field: Term, value: str) -> Criterion:
return JSONContains(field, ValueWrapper(value))


def mysql_json_contained_by(field: Term, value_str: str) -> Criterion:
values = json.loads(value_str)
contained_by = None
for value in values:
if contained_by is None:
contained_by = JSONContains(field, ValueWrapper(json.dumps([value])))
else:
contained_by |= JSONContains(field, ValueWrapper(json.dumps([value]))) # type: ignore
return contained_by


def _mysql_json_is_null(left: Term, is_null: bool):
if is_null is True:
return operator.eq(left, Cast("null", "JSON"))
else:
return not_equal(left, Cast("null", "JSON"))


def _mysql_json_not_is_null(left: Term, is_null: bool):
return _mysql_json_is_null(left, not is_null)


operator_keywords = {
"not": not_equal,
"isnull": _mysql_json_is_null,
"not_isnull": _mysql_json_not_is_null,
}


def _serialize_value(value: Any):
if type(value) in [dict, list]:
return json.dumps(value)
return value


def mysql_json_filter(field: Term, value: Dict) -> Criterion:
((key, filter_value),) = value.items()
filter_value = _serialize_value(filter_value)
key_parts = list(map(lambda item: int(item) if item.isdigit() else str(item), key.split("__")))
operator_ = operator.eq
if key_parts[-1] in operator_keywords:
operator_ = operator_keywords[str(key_parts.pop(-1))]

return operator_(JSONExtract(field, key_parts), filter_value)
Loading

0 comments on commit d69b9e5

Please sign in to comment.