forked from BeanieODM/beanie
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: issue#728 pydantic_core._pydantic_core.Url object is not iterable (
BeanieODM#730) * pre-commit * fix: change test for pydantic v1, v2 * fix: add encoder itself test, conditional default encoder
- Loading branch information
1 parent
d9eb71d
commit 0db4dd4
Showing
4 changed files
with
203 additions
and
155 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,155 +1,169 @@ | ||
import dataclasses as dc | ||
import datetime | ||
import decimal | ||
import enum | ||
import ipaddress | ||
import operator | ||
import pathlib | ||
import re | ||
import uuid | ||
from typing import Any, Callable, Container, Iterable, Mapping, Optional, Tuple | ||
|
||
import bson | ||
import pydantic | ||
|
||
import beanie | ||
from beanie.odm.fields import Link, LinkTypes | ||
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2, get_model_fields | ||
|
||
SingleArgCallable = Callable[[Any], Any] | ||
DEFAULT_CUSTOM_ENCODERS: Mapping[type, SingleArgCallable] = { | ||
ipaddress.IPv4Address: str, | ||
ipaddress.IPv4Interface: str, | ||
ipaddress.IPv4Network: str, | ||
ipaddress.IPv6Address: str, | ||
ipaddress.IPv6Interface: str, | ||
ipaddress.IPv6Network: str, | ||
pathlib.PurePath: str, | ||
pydantic.SecretBytes: pydantic.SecretBytes.get_secret_value, | ||
pydantic.SecretStr: pydantic.SecretStr.get_secret_value, | ||
datetime.timedelta: operator.methodcaller("total_seconds"), | ||
enum.Enum: operator.attrgetter("value"), | ||
Link: operator.attrgetter("ref"), | ||
bytes: bson.Binary, | ||
decimal.Decimal: bson.Decimal128, | ||
uuid.UUID: bson.Binary.from_uuid, | ||
re.Pattern: bson.Regex.from_native, | ||
} | ||
BSON_SCALAR_TYPES = ( | ||
type(None), | ||
str, | ||
int, | ||
float, | ||
datetime.datetime, | ||
bson.Binary, | ||
bson.DBRef, | ||
bson.Decimal128, | ||
bson.ObjectId, | ||
) | ||
|
||
|
||
@dc.dataclass | ||
class Encoder: | ||
""" | ||
BSON encoding class | ||
""" | ||
|
||
exclude: Container[str] = frozenset() | ||
custom_encoders: Mapping[type, SingleArgCallable] = dc.field( | ||
default_factory=dict | ||
) | ||
to_db: bool = False | ||
keep_nulls: bool = True | ||
|
||
def _encode_document(self, obj: "beanie.Document") -> Mapping[str, Any]: | ||
obj.parse_store() | ||
settings = obj.get_settings() | ||
obj_dict = {} | ||
if settings.union_doc is not None: | ||
obj_dict[settings.class_id] = ( | ||
settings.union_doc_alias or obj.__class__.__name__ | ||
) | ||
if obj._class_id: | ||
obj_dict[settings.class_id] = obj._class_id | ||
|
||
link_fields = obj.get_link_fields() or {} | ||
sub_encoder = Encoder( | ||
# don't propagate self.exclude to subdocuments | ||
custom_encoders=settings.bson_encoders, | ||
to_db=self.to_db, | ||
keep_nulls=self.keep_nulls, | ||
) | ||
for key, value in self._iter_model_items(obj): | ||
if key in link_fields: | ||
link_type = link_fields[key].link_type | ||
if link_type in (LinkTypes.DIRECT, LinkTypes.OPTIONAL_DIRECT): | ||
if value is not None: | ||
value = value.to_ref() | ||
elif link_type in (LinkTypes.LIST, LinkTypes.OPTIONAL_LIST): | ||
if value is not None: | ||
value = [link.to_ref() for link in value] | ||
elif self.to_db: | ||
continue | ||
obj_dict[key] = sub_encoder.encode(value) | ||
return obj_dict | ||
|
||
def encode(self, obj: Any) -> Any: | ||
if self.custom_encoders: | ||
encoder = _get_encoder(obj, self.custom_encoders) | ||
if encoder is not None: | ||
return encoder(obj) | ||
|
||
if isinstance(obj, BSON_SCALAR_TYPES): | ||
return obj | ||
|
||
encoder = _get_encoder(obj, DEFAULT_CUSTOM_ENCODERS) | ||
if encoder is not None: | ||
return encoder(obj) | ||
|
||
if isinstance(obj, beanie.Document): | ||
return self._encode_document(obj) | ||
if IS_PYDANTIC_V2 and isinstance(obj, pydantic.RootModel): | ||
return self.encode(obj.root) | ||
if isinstance(obj, pydantic.BaseModel): | ||
items = self._iter_model_items(obj) | ||
return {key: self.encode(value) for key, value in items} | ||
if isinstance(obj, Mapping): | ||
return {key: self.encode(value) for key, value in obj.items()} | ||
if isinstance(obj, Iterable): | ||
return [self.encode(value) for value in obj] | ||
|
||
errors = [] | ||
try: | ||
data = dict(obj) | ||
except Exception as e: | ||
errors.append(e) | ||
try: | ||
data = vars(obj) | ||
except Exception as e: | ||
errors.append(e) | ||
raise ValueError(errors) | ||
return self.encode(data) | ||
|
||
def _iter_model_items( | ||
self, obj: pydantic.BaseModel | ||
) -> Iterable[Tuple[str, Any]]: | ||
exclude, keep_nulls = self.exclude, self.keep_nulls | ||
get_model_field = get_model_fields(obj).get | ||
for key, value in obj.__iter__(): | ||
field_info = get_model_field(key) | ||
if field_info is not None: | ||
key = field_info.alias or key | ||
if key not in exclude and (value is not None or keep_nulls): | ||
yield key, value | ||
|
||
|
||
def _get_encoder( | ||
obj: Any, custom_encoders: Mapping[type, SingleArgCallable] | ||
) -> Optional[SingleArgCallable]: | ||
encoder = custom_encoders.get(type(obj)) | ||
if encoder is not None: | ||
return encoder | ||
for cls, encoder in custom_encoders.items(): | ||
if isinstance(obj, cls): | ||
return encoder | ||
return None | ||
import dataclasses as dc | ||
import datetime | ||
import decimal | ||
import enum | ||
import ipaddress | ||
import operator | ||
import pathlib | ||
import re | ||
import uuid | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Container, | ||
Iterable, | ||
Mapping, | ||
MutableMapping, | ||
Optional, | ||
Tuple, | ||
) | ||
|
||
import bson | ||
import pydantic | ||
|
||
import beanie | ||
from beanie.odm.fields import Link, LinkTypes | ||
from beanie.odm.utils.pydantic import IS_PYDANTIC_V2, get_model_fields | ||
|
||
SingleArgCallable = Callable[[Any], Any] | ||
DEFAULT_CUSTOM_ENCODERS: MutableMapping[type, SingleArgCallable] = { | ||
ipaddress.IPv4Address: str, | ||
ipaddress.IPv4Interface: str, | ||
ipaddress.IPv4Network: str, | ||
ipaddress.IPv6Address: str, | ||
ipaddress.IPv6Interface: str, | ||
ipaddress.IPv6Network: str, | ||
pathlib.PurePath: str, | ||
pydantic.SecretBytes: pydantic.SecretBytes.get_secret_value, | ||
pydantic.SecretStr: pydantic.SecretStr.get_secret_value, | ||
datetime.timedelta: operator.methodcaller("total_seconds"), | ||
enum.Enum: operator.attrgetter("value"), | ||
Link: operator.attrgetter("ref"), | ||
bytes: bson.Binary, | ||
decimal.Decimal: bson.Decimal128, | ||
uuid.UUID: bson.Binary.from_uuid, | ||
re.Pattern: bson.Regex.from_native, | ||
} | ||
if IS_PYDANTIC_V2: | ||
from pydantic_core import Url | ||
|
||
DEFAULT_CUSTOM_ENCODERS[Url] = str | ||
|
||
BSON_SCALAR_TYPES = ( | ||
type(None), | ||
str, | ||
int, | ||
float, | ||
datetime.datetime, | ||
bson.Binary, | ||
bson.DBRef, | ||
bson.Decimal128, | ||
bson.ObjectId, | ||
) | ||
|
||
|
||
@dc.dataclass | ||
class Encoder: | ||
""" | ||
BSON encoding class | ||
""" | ||
|
||
exclude: Container[str] = frozenset() | ||
custom_encoders: Mapping[type, SingleArgCallable] = dc.field( | ||
default_factory=dict | ||
) | ||
to_db: bool = False | ||
keep_nulls: bool = True | ||
|
||
def _encode_document(self, obj: "beanie.Document") -> Mapping[str, Any]: | ||
obj.parse_store() | ||
settings = obj.get_settings() | ||
obj_dict = {} | ||
if settings.union_doc is not None: | ||
obj_dict[settings.class_id] = ( | ||
settings.union_doc_alias or obj.__class__.__name__ | ||
) | ||
if obj._class_id: | ||
obj_dict[settings.class_id] = obj._class_id | ||
|
||
link_fields = obj.get_link_fields() or {} | ||
sub_encoder = Encoder( | ||
# don't propagate self.exclude to subdocuments | ||
custom_encoders=settings.bson_encoders, | ||
to_db=self.to_db, | ||
keep_nulls=self.keep_nulls, | ||
) | ||
for key, value in self._iter_model_items(obj): | ||
if key in link_fields: | ||
link_type = link_fields[key].link_type | ||
if link_type in (LinkTypes.DIRECT, LinkTypes.OPTIONAL_DIRECT): | ||
if value is not None: | ||
value = value.to_ref() | ||
elif link_type in (LinkTypes.LIST, LinkTypes.OPTIONAL_LIST): | ||
if value is not None: | ||
value = [link.to_ref() for link in value] | ||
elif self.to_db: | ||
continue | ||
obj_dict[key] = sub_encoder.encode(value) | ||
return obj_dict | ||
|
||
def encode(self, obj: Any) -> Any: | ||
if self.custom_encoders: | ||
encoder = _get_encoder(obj, self.custom_encoders) | ||
if encoder is not None: | ||
return encoder(obj) | ||
|
||
if isinstance(obj, BSON_SCALAR_TYPES): | ||
return obj | ||
|
||
encoder = _get_encoder(obj, DEFAULT_CUSTOM_ENCODERS) | ||
if encoder is not None: | ||
return encoder(obj) | ||
|
||
if isinstance(obj, beanie.Document): | ||
return self._encode_document(obj) | ||
if IS_PYDANTIC_V2 and isinstance(obj, pydantic.RootModel): | ||
return self.encode(obj.root) | ||
if isinstance(obj, pydantic.BaseModel): | ||
items = self._iter_model_items(obj) | ||
return {key: self.encode(value) for key, value in items} | ||
if isinstance(obj, Mapping): | ||
return {key: self.encode(value) for key, value in obj.items()} | ||
if isinstance(obj, Iterable): | ||
return [self.encode(value) for value in obj] | ||
|
||
errors = [] | ||
try: | ||
data = dict(obj) | ||
except Exception as e: | ||
errors.append(e) | ||
try: | ||
data = vars(obj) | ||
except Exception as e: | ||
errors.append(e) | ||
raise ValueError(errors) | ||
return self.encode(data) | ||
|
||
def _iter_model_items( | ||
self, obj: pydantic.BaseModel | ||
) -> Iterable[Tuple[str, Any]]: | ||
exclude, keep_nulls = self.exclude, self.keep_nulls | ||
get_model_field = get_model_fields(obj).get | ||
for key, value in obj.__iter__(): | ||
field_info = get_model_field(key) | ||
if field_info is not None: | ||
key = field_info.alias or key | ||
if key not in exclude and (value is not None or keep_nulls): | ||
yield key, value | ||
|
||
|
||
def _get_encoder( | ||
obj: Any, custom_encoders: Mapping[type, SingleArgCallable] | ||
) -> Optional[SingleArgCallable]: | ||
encoder = custom_encoders.get(type(obj)) | ||
if encoder is not None: | ||
return encoder | ||
for cls, encoder in custom_encoders.items(): | ||
if isinstance(obj, cls): | ||
return encoder | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.