From 60d7946e9ff8933a8da481814b3ff2c811ba81aa Mon Sep 17 00:00:00 2001 From: isra17 Date: Thu, 18 Sep 2025 11:20:14 -0400 Subject: [PATCH] Object support serializing AttrDict. Fixes #3075 --- elasticsearch/dsl/field.py | 3 ++- test_elasticsearch/test_dsl/test_field.py | 7 +++++++ utils/templates/field.py.tpl | 13 +++++++------ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/elasticsearch/dsl/field.py b/elasticsearch/dsl/field.py index 84dbe0c4a..5acbad984 100644 --- a/elasticsearch/dsl/field.py +++ b/elasticsearch/dsl/field.py @@ -571,7 +571,8 @@ def _serialize( # somebody assigned raw dict to the field, we should tolerate that if isinstance(data, collections.abc.Mapping): return data - + if isinstance(data, AttrDict): + return data.to_dict() return data.to_dict(skip_empty=skip_empty) def clean(self, data: Any) -> Any: diff --git a/test_elasticsearch/test_dsl/test_field.py b/test_elasticsearch/test_dsl/test_field.py index 423936ae3..39ce5efee 100644 --- a/test_elasticsearch/test_dsl/test_field.py +++ b/test_elasticsearch/test_dsl/test_field.py @@ -24,6 +24,7 @@ from dateutil import tz from elasticsearch.dsl import InnerDoc, Range, ValidationException, field +from elasticsearch.dsl.utils import AttrDict def test_date_range_deserialization() -> None: @@ -232,3 +233,9 @@ class Inner(InnerDoc): with pytest.raises(ValidationException): field.Object(doc_class=Inner, dynamic=False) + + +def test_object_with_attrdict() -> None: + f = field.Object(dynamic=True) + assert f.deserialize(AttrDict({"a": "b"})).to_dict() == {"a": "b"} + assert f.serialize(AttrDict({"a": "b"})) == {"a": "b"} diff --git a/utils/templates/field.py.tpl b/utils/templates/field.py.tpl index 8699d852e..41047786f 100644 --- a/utils/templates/field.py.tpl +++ b/utils/templates/field.py.tpl @@ -333,7 +333,8 @@ class {{ k.name }}({{ k.parent }}): # somebody assigned raw dict to the field, we should tolerate that if isinstance(data, collections.abc.Mapping): return data - + if isinstance(data, AttrDict): + return data.to_dict() return data.to_dict(skip_empty=skip_empty) def clean(self, data: Any) -> Any: @@ -388,7 +389,7 @@ class {{ k.name }}({{ k.parent }}): # Divide by a float to preserve milliseconds on the datetime. return datetime.utcfromtimestamp(data / 1000.0) - raise ValidationException(f"Could not parse date from the value ({data!r})") + raise ValidationException(f"Could not parse date from the value ({data!r})") {% elif k.field == "boolean" %} super().__init__(*args, **kwargs) @@ -402,7 +403,7 @@ class {{ k.name }}({{ k.parent }}): data = self.deserialize(data) if data is None and self._required: raise ValidationException("Value required for this field.") - return data # type: ignore[no-any-return] + return data # type: ignore[no-any-return] {% elif k.field == "float" %} super().__init__(*args, **kwargs) @@ -432,7 +433,7 @@ class {{ k.name }}({{ k.parent }}): super().__init__(*args, **kwargs) def _deserialize(self, data: Any) -> int: - return int(data) + return int(data) {% elif k.field == "ip" %} super().__init__(*args, **kwargs) @@ -443,7 +444,7 @@ class {{ k.name }}({{ k.parent }}): def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]: if data is None: return None - return str(data) + return str(data) {% elif k.field == "binary" %} super().__init__(*args, **kwargs) @@ -458,7 +459,7 @@ class {{ k.name }}({{ k.parent }}): def _serialize(self, data: Any, skip_empty: bool) -> Optional[str]: if data is None: return None - return base64.b64encode(data).decode() + return base64.b64encode(data).decode() {% elif k.field == "percolator" %} super().__init__(*args, **kwargs)