Skip to content

Commit

Permalink
handle unserializable asset selections on external sensors (dagster-i…
Browse files Browse the repository at this point in the history
…o#18750)

## Summary & Motivation

Nothing prevents users from making their own asset selection subclasses.
We offer one of our own with `DbtManifestAssetSelection`. This handles
these by converting them to `KeysAssetSelection`, instead of erroring.

## How I Tested These Changes
  • Loading branch information
sryza authored Dec 21, 2023
1 parent ff0fb79 commit dcb99a4
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,18 @@ def from_coercible(cls, selection: CoercibleToAssetSelection) -> "AssetSelection
f" {type(selection)}."
)

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return AssetSelection.keys(*self.resolve(asset_graph))


@whitelist_for_serdes
class AllSelection(AssetSelection, NamedTuple("_AllSelection", [])):
def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
return asset_graph.materializable_asset_keys

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self


@whitelist_for_serdes
class AllAssetCheckSelection(AssetSelection, NamedTuple("_AllAssetChecksSelection", [])):
Expand All @@ -388,6 +394,9 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
def resolve_checks_inner(self, asset_graph: InternalAssetGraph) -> AbstractSet[AssetCheckKey]:
return asset_graph.asset_check_keys

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self


@whitelist_for_serdes
class AssetChecksForAssetKeysSelection(
Expand All @@ -404,6 +413,9 @@ def resolve_checks_inner(self, asset_graph: InternalAssetGraph) -> AbstractSet[A
if handle.asset_key in self.selected_asset_keys
}

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self


@whitelist_for_serdes
class AssetCheckKeysSelection(
Expand All @@ -422,6 +434,9 @@ def resolve_checks_inner(self, asset_graph: InternalAssetGraph) -> AbstractSet[A
if handle in self.selected_asset_check_keys
}

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self


@whitelist_for_serdes
class AndAssetSelection(
Expand All @@ -436,6 +451,12 @@ def resolve_checks_inner(self, asset_graph: InternalAssetGraph) -> AbstractSet[A
asset_graph
)

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self._replace(
left=self.left.to_serializable_asset_selection(asset_graph),
right=self.right.to_serializable_asset_selection(asset_graph),
)


@whitelist_for_serdes
class SubtractAssetSelection(
Expand All @@ -450,6 +471,12 @@ def resolve_checks_inner(self, asset_graph: InternalAssetGraph) -> AbstractSet[A
asset_graph
)

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self._replace(
left=self.left.to_serializable_asset_selection(asset_graph),
right=self.right.to_serializable_asset_selection(asset_graph),
)


@whitelist_for_serdes
class SinksAssetSelection(
Expand All @@ -460,6 +487,9 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
selection = self.child.resolve_inner(asset_graph)
return fetch_sinks(asset_graph.asset_dep_graph, selection)

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self._replace(child=self.child.to_serializable_asset_selection(asset_graph))


@whitelist_for_serdes
class RequiredNeighborsAssetSelection(
Expand All @@ -473,6 +503,9 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
output.update(asset_graph.get_required_multi_asset_keys(asset_key))
return output

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self._replace(child=self.child.to_serializable_asset_selection(asset_graph))


@whitelist_for_serdes
class RootsAssetSelection(
Expand All @@ -483,6 +516,9 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
selection = self.child.resolve_inner(asset_graph)
return fetch_sources(asset_graph.asset_dep_graph, selection)

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self._replace(child=self.child.to_serializable_asset_selection(asset_graph))


@whitelist_for_serdes
class DownstreamAssetSelection(
Expand Down Expand Up @@ -515,6 +551,9 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
selection if not self.include_self else set(),
)

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self._replace(child=self.child.to_serializable_asset_selection(asset_graph))


@whitelist_for_serdes
class GroupsAssetSelection(
Expand All @@ -539,6 +578,9 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
if group in self.selected_groups and asset_key in base_set
}

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self


@whitelist_for_serdes
class KeysAssetSelection(
Expand All @@ -556,6 +598,9 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
)
return specified_keys

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self


@whitelist_for_serdes
class KeyPrefixesAssetSelection(
Expand All @@ -577,6 +622,9 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
if any(key.has_prefix(prefix) for prefix in self.selected_key_prefixes)
}

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self


@whitelist_for_serdes
class OrAssetSelection(
Expand All @@ -591,6 +639,12 @@ def resolve_checks_inner(self, asset_graph: InternalAssetGraph) -> AbstractSet[A
asset_graph
)

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self._replace(
left=self.left.to_serializable_asset_selection(asset_graph),
right=self.right.to_serializable_asset_selection(asset_graph),
)


def _fetch_all_upstream(
selection: AbstractSet[AssetKey],
Expand Down Expand Up @@ -636,6 +690,9 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
all_upstream = _fetch_all_upstream(selection, asset_graph, self.depth, self.include_self)
return {key for key in all_upstream if key not in asset_graph.source_asset_keys}

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self._replace(child=self.child.to_serializable_asset_selection(asset_graph))


@whitelist_for_serdes
class ParentSourcesAssetSelection(
Expand All @@ -648,3 +705,6 @@ def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
return selection
all_upstream = _fetch_all_upstream(selection, asset_graph)
return {key for key in all_upstream if key in asset_graph.source_asset_keys}

def to_serializable_asset_selection(self, asset_graph: AssetGraph) -> "AssetSelection":
return self._replace(child=self.child.to_serializable_asset_selection(asset_graph))
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
from dagster._core.snap.mode import ResourceDefSnap, build_resource_def_snap
from dagster._core.storage.io_manager import IOManagerDefinition
from dagster._serdes import whitelist_for_serdes
from dagster._serdes.serdes import is_whitelisted_for_serdes_object
from dagster._utils.error import SerializableErrorInfo

if TYPE_CHECKING:
Expand Down Expand Up @@ -553,6 +554,13 @@ def __new__(
)
}

if asset_selection:
check.opt_inst_param(asset_selection, "asset_selection", AssetSelection)
check.invariant(
is_whitelisted_for_serdes_object(asset_selection),
"asset_selection must be serializable",
)

return super(ExternalSensorData, cls).__new__(
cls,
name=check.str_param(name, "name"),
Expand Down Expand Up @@ -2036,6 +2044,10 @@ def external_sensor_data_from_def(
)
for base_asset_job_name in repository_def.get_implicit_asset_job_names()
}

serializable_asset_selection = sensor_def.asset_selection.to_serializable_asset_selection(
repository_def.asset_graph
)
else:
target_dict = {
target.job_name: ExternalTargetData(
Expand All @@ -2046,6 +2058,8 @@ def external_sensor_data_from_def(
for target in sensor_def.targets
}

serializable_asset_selection = None

return ExternalSensorData(
name=sensor_def.name,
job_name=first_target.job_name if first_target else None,
Expand All @@ -2057,7 +2071,7 @@ def external_sensor_data_from_def(
metadata=ExternalSensorMetadata(asset_keys=asset_keys),
default_status=sensor_def.default_status,
sensor_type=sensor_def.sensor_type,
asset_selection=sensor_def.asset_selection,
asset_selection=serializable_asset_selection,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
SourceAsset,
StaticPartitionsDefinition,
TimeWindowPartitionMapping,
asset_check,
multi_asset,
)
from dagster._core.definitions import AssetSelection, asset
from dagster._core.definitions.asset_graph import AssetGraph
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.events import AssetKey
from dagster._serdes.serdes import _WHITELIST_MAP
Expand Down Expand Up @@ -392,3 +394,100 @@ def test_all_asset_selection_subclasses_serializable():
for asset_selection_subclass in asset_selection_subclasses:
if asset_selection_subclass != AssetSelection:
assert _WHITELIST_MAP.has_object_serializer(asset_selection_subclass.__name__)


def test_to_serializable_asset_selection():
class UnserializableAssetSelection(AssetSelection):
def resolve_inner(self, asset_graph: AssetGraph) -> AbstractSet[AssetKey]:
return asset_graph.materializable_asset_keys - {AssetKey("asset2")}

@asset
def asset1():
...

@asset
def asset2():
...

@asset_check(asset=asset1)
def check1():
...

asset_graph = AssetGraph.from_assets([asset1, asset2], asset_checks=[check1])

def assert_serializable_same(asset_selection: AssetSelection) -> None:
assert asset_selection.to_serializable_asset_selection(asset_graph) == asset_selection

assert_serializable_same(AssetSelection.groups("a"))
assert_serializable_same(AssetSelection.key_prefixes(["foo", "bar"]))
assert_serializable_same(AssetSelection.all())
assert_serializable_same(AssetSelection.all_asset_checks())
assert_serializable_same(AssetSelection.keys("asset1"))
assert_serializable_same(AssetSelection.checks_for_assets(asset1))
assert_serializable_same(AssetSelection.checks(check1))

assert_serializable_same(AssetSelection.sinks(AssetSelection.groups("a")))
assert_serializable_same(AssetSelection.downstream(AssetSelection.groups("a"), depth=1))
assert_serializable_same(AssetSelection.upstream(AssetSelection.groups("a"), depth=1))
assert_serializable_same(
AssetSelection.required_multi_asset_neighbors(AssetSelection.groups("a"))
)
assert_serializable_same(AssetSelection.roots(AssetSelection.groups("a")))
assert_serializable_same(AssetSelection.sources(AssetSelection.groups("a")))
assert_serializable_same(AssetSelection.upstream_source_assets(AssetSelection.groups("a")))

assert_serializable_same(AssetSelection.groups("a") & AssetSelection.groups("b"))
assert_serializable_same(AssetSelection.groups("a") | AssetSelection.groups("b"))
assert_serializable_same(AssetSelection.groups("a") - AssetSelection.groups("b"))

asset1_selection = AssetSelection.keys("asset1")
assert (
UnserializableAssetSelection().to_serializable_asset_selection(asset_graph)
== asset1_selection
)

assert AssetSelection.sinks(UnserializableAssetSelection()).to_serializable_asset_selection(
asset_graph
) == AssetSelection.sinks(asset1_selection)
assert AssetSelection.downstream(
UnserializableAssetSelection(), depth=1
).to_serializable_asset_selection(asset_graph) == AssetSelection.downstream(
asset1_selection, depth=1
)
assert AssetSelection.upstream(
UnserializableAssetSelection(), depth=1
).to_serializable_asset_selection(asset_graph) == AssetSelection.upstream(
asset1_selection, depth=1
)
assert AssetSelection.required_multi_asset_neighbors(
UnserializableAssetSelection()
).to_serializable_asset_selection(asset_graph) == AssetSelection.required_multi_asset_neighbors(
asset1_selection
)
assert AssetSelection.roots(UnserializableAssetSelection()).to_serializable_asset_selection(
asset_graph
) == AssetSelection.roots(asset1_selection)
assert AssetSelection.sources(UnserializableAssetSelection()).to_serializable_asset_selection(
asset_graph
) == AssetSelection.sources(asset1_selection)
assert AssetSelection.upstream_source_assets(
UnserializableAssetSelection()
).to_serializable_asset_selection(asset_graph) == AssetSelection.upstream_source_assets(
asset1_selection
)

assert (
UnserializableAssetSelection() & AssetSelection.groups("b")
).to_serializable_asset_selection(asset_graph) == (
asset1_selection & AssetSelection.groups("b")
)
assert (
UnserializableAssetSelection() | AssetSelection.groups("b")
).to_serializable_asset_selection(asset_graph) == (
asset1_selection | AssetSelection.groups("b")
)
assert (
UnserializableAssetSelection() - AssetSelection.groups("b")
).to_serializable_asset_selection(asset_graph) == (
asset1_selection - AssetSelection.groups("b")
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from dagster import AssetKey, asset, sensor
from dagster._core.definitions.decorators.repository_decorator import repository
from dagster._core.host_representation.external_data import external_sensor_data_from_def


def test_coerce_to_asset_selection():
Expand Down Expand Up @@ -29,31 +27,3 @@ def sensor2():
...

assert sensor2.asset_selection.resolve(assets) == {AssetKey("asset1"), AssetKey("asset2")}


def test_external_sensor_has_asset_selection():
@asset
def asset1():
...

@asset
def asset2():
...

@asset
def asset3():
...

@sensor(asset_selection=["asset1", "asset2"])
def sensor1():
...

@repository
def my_repo():
return [
sensor1,
]

assert (
external_sensor_data_from_def(sensor1, my_repo).asset_selection == sensor1.asset_selection
)
Loading

0 comments on commit dcb99a4

Please sign in to comment.