Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
[low-code connectors] Bugfix transformations (airbytehq#14810)
Browse files Browse the repository at this point in the history
  • Loading branch information
sherifnada authored Jul 18, 2022
1 parent 2628fd3 commit 52e3755
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.declarative.types import Config, StreamSlice
from airbyte_cdk.sources.streams.core import Stream


Expand All @@ -23,6 +24,7 @@ def __init__(
primary_key,
schema_loader: SchemaLoader,
retriever: Retriever,
config: Config,
cursor_field: Optional[List[str]] = None,
transformations: List[RecordTransformation] = None,
checkpoint_interval: Optional[int] = None,
Expand All @@ -38,6 +40,7 @@ def __init__(
in the order in which they are defined.
"""
self._name = name
self._config = config
self._primary_key = primary_key
self._cursor_field = cursor_field or []
self._schema_loader = schema_loader
Expand Down Expand Up @@ -98,12 +101,12 @@ def read_records(
stream_state: Mapping[str, Any] = None,
) -> Iterable[Mapping[str, Any]]:
for record in self._retriever.read_records(sync_mode, cursor_field, stream_slice, stream_state):
yield self._apply_transformations(record)
yield self._apply_transformations(record, self._config, stream_slice)

def _apply_transformations(self, record: Mapping[str, Any]):
def _apply_transformations(self, record: Mapping[str, Any], config: Config, stream_slice: StreamSlice):
output_record = record
for transformation in self._transformations:
output_record = transformation.transform(record)
output_record = transformation.transform(record, config=config, stream_state=self.state, stream_slice=stream_slice)

return output_record

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def build(self, class_or_class_name: Union[str, Type], config, **kwargs):
kwargs["options"] = {k: self._create_subcomponent(k, v, kwargs, config, class_) for k, v in kwargs["options"].items()}

updated_kwargs = {k: self._create_subcomponent(k, v, kwargs, config, class_) for k, v in kwargs.items()}

return create(class_, config=config, **updated_kwargs)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class RecordTransformation(ABC):

@abstractmethod
def transform(
self, record: Mapping[str, Any], config: Config = None, state: StreamState = None, slice: StreamSlice = None
self, record: Mapping[str, Any], config: Config = None, stream_state: StreamState = None, stream_slice: StreamSlice = None
) -> Mapping[str, Any]:
"""
:param record: the input record to be transformed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,32 @@ def test():
retriever.stream_slices.return_value = stream_slices

no_op_transform = mock.create_autospec(spec=RecordTransformation)
no_op_transform.transform = MagicMock(side_effect=lambda x: x)
no_op_transform.transform = MagicMock(side_effect=lambda record, config, stream_slice, stream_state: record)
transformations = [no_op_transform]

config = {"api_key": "open_sesame"}

stream = DeclarativeStream(
name=name,
primary_key=primary_key,
cursor_field=cursor_field,
schema_loader=schema_loader,
retriever=retriever,
config=config,
transformations=transformations,
checkpoint_interval=checkpoint_interval,
)

assert stream.name == name
assert stream.get_json_schema() == json_schema
assert stream.state == state
assert list(stream.read_records(SyncMode.full_refresh, cursor_field, None, None)) == records
input_slice = stream_slices[0]
assert list(stream.read_records(SyncMode.full_refresh, cursor_field, input_slice, state)) == records
assert stream.primary_key == primary_key
assert stream.cursor_field == cursor_field
assert stream.stream_slices(sync_mode=SyncMode.incremental, cursor_field=cursor_field, stream_state=None) == stream_slices
assert stream.state_checkpoint_interval == checkpoint_interval
for transformation in transformations:
assert len(transformation.transform.call_args_list) == len(records)
expected_calls = [call(record) for record in records]
expected_calls = [call(record, config=config, stream_slice=input_slice, stream_state=state) for record in records]
transformation.transform.assert_has_calls(expected_calls, any_order=False)

0 comments on commit 52e3755

Please sign in to comment.