Skip to content

Commit

Permalink
Improved allowed_objects and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed Sep 22, 2020
1 parent 0e87fa0 commit c9d1fd4
Show file tree
Hide file tree
Showing 12 changed files with 134 additions and 30 deletions.
1 change: 1 addition & 0 deletions news/382.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for objects in config
37 changes: 28 additions & 9 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,11 @@ def _raise_missing_error(obj: Any, name: str) -> None:
)


def get_attr_data(obj: Any) -> Dict[str, Any]:
from omegaconf.omegaconf import _maybe_wrap
def get_attr_data(obj: Any, allow_objects: Optional[bool] = None) -> Dict[str, Any]:
from omegaconf.omegaconf import OmegaConf, _maybe_wrap

flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
dummy_parent = OmegaConf.create(flags=flags)

d = {}
is_type = isinstance(obj, type)
Expand All @@ -215,14 +218,23 @@ def get_attr_data(obj: Any) -> Dict[str, Any]:
format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))

d[name] = _maybe_wrap(
ref_type=type_, is_optional=is_optional, key=name, value=value, parent=None
ref_type=type_,
is_optional=is_optional,
key=name,
value=value,
parent=dummy_parent,
)
d[name]._set_parent(None)
return d


def get_dataclass_data(obj: Any) -> Dict[str, Any]:
from omegaconf.omegaconf import _maybe_wrap
def get_dataclass_data(
obj: Any, allow_objects: Optional[bool] = None
) -> Dict[str, Any]:
from omegaconf.omegaconf import OmegaConf, _maybe_wrap

flags = {"allow_objects": allow_objects} if allow_objects is not None else {}
dummy_parent = OmegaConf.create(flags=flags)
d = {}
for field in dataclasses.fields(obj):
name = field.name
Expand Down Expand Up @@ -251,8 +263,13 @@ def get_dataclass_data(obj: Any) -> Dict[str, Any]:
)
format_and_raise(node=None, key=None, value=value, cause=e, msg=str(e))
d[name] = _maybe_wrap(
ref_type=type_, is_optional=is_optional, key=name, value=value, parent=None
ref_type=type_,
is_optional=is_optional,
key=name,
value=value,
parent=dummy_parent,
)
d[name]._set_parent(None)
return d


Expand Down Expand Up @@ -305,11 +322,13 @@ def is_structured_config_frozen(obj: Any) -> bool:
return False


def get_structured_config_data(obj: Any) -> Dict[str, Any]:
def get_structured_config_data(
obj: Any, allow_objects: Optional[bool] = None
) -> Dict[str, Any]:
if is_dataclass(obj):
return get_dataclass_data(obj)
return get_dataclass_data(obj, allow_objects=allow_objects)
elif is_attr_class(obj):
return get_attr_data(obj)
return get_attr_data(obj, allow_objects=allow_objects)
else:
raise ValueError(f"Unsupported type: {type(obj).__name__}")

Expand Down
7 changes: 5 additions & 2 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def __eq__(self, other: Any) -> bool:
if other is None:
return self.__dict__["_content"] is None
if is_primitive_dict(other) or is_structured_config(other):
other = DictConfig(other)
other = DictConfig(other, flags={"allow_objects": True})
return DictConfig._dict_conf_eq(self, other)
if isinstance(other, DictConfig):
return DictConfig._dict_conf_eq(self, other)
Expand Down Expand Up @@ -548,7 +548,10 @@ def _set_value(self, value: Any) -> None:
self.__dict__["_content"] = {}
if is_structured_config(value):
self._metadata.object_type = None
data = get_structured_config_data(value)
data = get_structured_config_data(
value,
allow_objects=self._get_flag("allow_objects"),
)
for k, v in data.items():
self.__setitem__(k, v)
self._metadata.object_type = get_type_of(value)
Expand Down
2 changes: 1 addition & 1 deletion omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def key1(x: Any) -> Any:

def __eq__(self, other: Any) -> bool:
if isinstance(other, (list, tuple)) or other is None:
other = ListConfig(other)
other = ListConfig(other, flags={"allow_objects": True})
return ListConfig._list_eq(self, other)
if other is None or isinstance(other, ListConfig):
return ListConfig._list_eq(self, other)
Expand Down
6 changes: 2 additions & 4 deletions omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,10 @@ def __init__(
def validate_and_convert(self, value: Any) -> Any:
from ._utils import is_primitive_type

# _allow_non_primitive_ is internal and not an official API. use at your own risk.
# allow_objects is internal and not an official API. use at your own risk.
# Please be aware that this support is subject to change without notice.
# If this is deemed useful and supportable it may become an official API.
if self._get_flag(
"_allow_non_primitive_"
) is not True and not is_primitive_type(value):
if self._get_flag("allow_objects") is not True and not is_primitive_type(value):
t = get_type_of(value)
raise UnsupportedValueType(
f"Value '{t.__name__}' is not a supported primitive type"
Expand Down
6 changes: 3 additions & 3 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,12 @@ def _create_impl( # noqa F811
if isinstance(obj, str):
obj = yaml.load(obj, Loader=get_yaml_loader())
if obj is None:
return OmegaConf.create({})
return OmegaConf.create({}, flags=flags)
elif isinstance(obj, str):
return OmegaConf.create({obj: None})
return OmegaConf.create({obj: None}, flags=flags)
else:
assert isinstance(obj, (list, dict))
return OmegaConf.create(obj)
return OmegaConf.create(obj, flags=flags)

else:
if (
Expand Down
2 changes: 1 addition & 1 deletion omegaconf/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys # pragma: no cover

__version__ = "2.1.0dev1"
__version__ = "2.1.0dev3"

msg = """OmegaConf 2.0 and above is compatible with Python 3.6 and newer.
You have the following options:
Expand Down
29 changes: 27 additions & 2 deletions tests/structured_conf/test_structured_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import pytest

from omegaconf import OmegaConf, ValidationError, _utils
from omegaconf.errors import ConfigKeyError
from omegaconf import OmegaConf, ValidationError, _utils, flag_override
from omegaconf.errors import ConfigKeyError, UnsupportedValueType
from tests import IllegalType


@pytest.mark.parametrize(
Expand Down Expand Up @@ -199,3 +200,27 @@ def test_native_missing(self, class_type: str) -> None:
),
):
OmegaConf.create(module.WithNativeMISSING)

def test_allow_objects(self, class_type: str) -> None:
module: Any = import_module(class_type)
cfg = OmegaConf.structured(module.Plugin)
iv = IllegalType()
with pytest.raises(UnsupportedValueType):
cfg.params = iv
cfg = OmegaConf.structured(module.Plugin, flags={"allow_objects": True})
cfg.params = iv
assert cfg.params == iv

cfg = OmegaConf.structured(module.Plugin)
with flag_override(cfg, "allow_objects", True):
cfg.params = iv
assert cfg.params == iv

cfg = OmegaConf.structured({"plugin": module.Plugin})
pwo = module.Plugin(name="foo", params=iv)
with pytest.raises(UnsupportedValueType):
cfg.plugin = pwo

with flag_override(cfg, "allow_objects", True):
cfg.plugin = pwo
assert cfg.plugin == pwo
15 changes: 13 additions & 2 deletions tests/test_basic_ops_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
UnsupportedValueType,
ValidationError,
_utils,
flag_override,
open_dict,
)
from omegaconf.basecontainer import BaseContainer
Expand Down Expand Up @@ -429,14 +430,24 @@ def test_dict_len(d: Any, expected: Any) -> None:

def test_dict_assign_illegal_value() -> None:
c = OmegaConf.create()
iv = IllegalType()
with pytest.raises(UnsupportedValueType, match=re.escape("key: a")):
c.a = IllegalType()
c.a = iv

with flag_override(c, "allow_objects", True):
c.a = iv
assert c.a == iv


def test_dict_assign_illegal_value_nested() -> None:
c = OmegaConf.create({"a": {}})
iv = IllegalType()
with pytest.raises(UnsupportedValueType, match=re.escape("key: a.b")):
c.a.b = IllegalType()
c.a.b = iv

with flag_override(c, "allow_objects", True):
c.a.b = iv
assert c.a.b == iv


def test_assign_dict_in_dict() -> None:
Expand Down
17 changes: 13 additions & 4 deletions tests/test_basic_ops_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from omegaconf import AnyNode, ListConfig, OmegaConf
from omegaconf import AnyNode, ListConfig, OmegaConf, flag_override
from omegaconf.errors import (
ConfigKeyError,
ConfigTypeError,
Expand Down Expand Up @@ -461,21 +461,30 @@ def test_sort() -> None:

def test_insert_throws_not_changing_list() -> None:
c = OmegaConf.create([])
iv = IllegalType()
with pytest.raises(ValueError):
c.insert(0, IllegalType())
c.insert(0, iv)
assert len(c) == 0
assert c == []

with flag_override(c, "allow_objects", True):
c.insert(0, iv)
assert c == [iv]


def test_append_throws_not_changing_list() -> None:
c = OmegaConf.create([])
v = IllegalType()
iv = IllegalType()
with pytest.raises(ValueError):
c.append(v)
c.append(iv)
assert len(c) == 0
assert c == []
validate_list_keys(c)

with flag_override(c, "allow_objects", True):
c.append(iv)
assert c == [iv]


def test_hash() -> None:
c1 = OmegaConf.create([10])
Expand Down
12 changes: 11 additions & 1 deletion tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
StringNode,
ValueNode,
)
from omegaconf.errors import ValidationError
from omegaconf.errors import UnsupportedValueType, ValidationError

from . import Color, IllegalType, User

Expand Down Expand Up @@ -550,3 +550,13 @@ def test_set_valuenode() -> None:
assert id(cfg._get_node("age")) == id(a_before)
with pytest.raises(ValidationError):
cfg.age = []


def test_allow_objects() -> None:
c = OmegaConf.create({"foo": AnyNode()})
with pytest.raises(UnsupportedValueType):
c.foo = IllegalType()
c._set_flag("allow_objects", True)
iv = IllegalType()
c.foo = iv
assert c.foo == iv
30 changes: 29 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from omegaconf import DictConfig, ListConfig, Node, OmegaConf, _utils
from omegaconf._utils import is_dict_annotation, is_list_annotation
from omegaconf.errors import ValidationError
from omegaconf.errors import UnsupportedValueType, ValidationError
from omegaconf.nodes import (
AnyNode,
BooleanNode,
Expand Down Expand Up @@ -136,6 +136,16 @@ class _TestAttrsClass:
dict1: Dict[str, int] = {}


@dataclass
class _TestDataclassIllegalValue:
x: Any = IllegalType()


@attr.s(auto_attribs=True)
class _TestAttrllegalValue:
x: Any = IllegalType()


class _TestUserClass:
pass

Expand Down Expand Up @@ -183,6 +193,24 @@ def test_get_structured_config_data(test_cls_or_obj: Any, expectation: Any) -> N
assert d["dict1"] == {}


@pytest.mark.parametrize( # type: ignore
"test_cls",
[
_TestDataclassIllegalValue,
_TestAttrllegalValue,
],
)
def test_get_structured_config_data_illegal_value(test_cls: Any) -> None:
with pytest.raises(UnsupportedValueType):
_utils.get_structured_config_data(test_cls, allow_objects=None)

with pytest.raises(UnsupportedValueType):
_utils.get_structured_config_data(test_cls, allow_objects=False)

d = _utils.get_structured_config_data(test_cls, allow_objects=True)
assert d["x"] == IllegalType()


def test_is_dataclass(mocker: Any) -> None:
@dataclass
class Foo:
Expand Down

0 comments on commit c9d1fd4

Please sign in to comment.