Skip to content

Commit

Permalink
🐛 CDK: fix bug with limit parameter for incremental stream (airbytehq…
Browse files Browse the repository at this point in the history
…#5833)

* CDK: fix bug with limit parameter for incremental stream

Co-authored-by: Dmytro Rezchykov <[email protected]>
  • Loading branch information
avida and Dmytro Rezchykov authored Sep 9, 2021
1 parent 70513bc commit 6041f3d
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 27 deletions.
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.1.18
Fix incremental stream not saved state when internal limit config set.

## 0.1.17
Fix mismatching between number of records actually read and number of records in logs by 1: https://github.com/airbytehq/airbyte/pull/5767

Expand Down
44 changes: 35 additions & 9 deletions airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,29 +137,40 @@ def _read_stream(

use_incremental = configured_stream.sync_mode == SyncMode.incremental and stream_instance.supports_incremental
if use_incremental:
record_iterator = self._read_incremental(logger, stream_instance, configured_stream, connector_state)
record_iterator = self._read_incremental(logger, stream_instance, configured_stream, connector_state, internal_config)
else:
record_iterator = self._read_full_refresh(stream_instance, configured_stream)
record_iterator = self._read_full_refresh(stream_instance, configured_stream, internal_config)

record_counter = 0
stream_name = configured_stream.stream.name
logger.info(f"Syncing stream: {stream_name} ")
for record in record_iterator:
if record.type == MessageType.RECORD:
if internal_config.limit and record_counter >= internal_config.limit:
logger.info(f"Reached limit defined by internal config ({internal_config.limit}), stop reading")
break
record_counter += 1
yield record

logger.info(f"Read {record_counter} records from {stream_name} stream")

@staticmethod
def _limit_reached(internal_config: InternalConfig, records_counter: int) -> bool:
"""
Check if record count reached liimt set by internal config.
:param internal_config - internal CDK configuration separated from user defined config
:records_counter - number of records already red
:return True if limit reached, False otherwise
"""
if internal_config.limit:
if records_counter >= internal_config.limit:
return True
return False

def _read_incremental(
self,
logger: AirbyteLogger,
stream_instance: Stream,
configured_stream: ConfiguredAirbyteStream,
connector_state: MutableMapping[str, Any],
internal_config: InternalConfig,
) -> Iterator[AirbyteMessage]:
stream_name = configured_stream.stream.name
stream_state = connector_state.get(stream_name, {})
Expand All @@ -170,31 +181,46 @@ def _read_incremental(
slices = stream_instance.stream_slices(
cursor_field=configured_stream.cursor_field, sync_mode=SyncMode.incremental, stream_state=stream_state
)
total_records_counter = 0
for slice in slices:
record_counter = 0
records = stream_instance.read_records(
sync_mode=SyncMode.incremental,
stream_slice=slice,
stream_state=stream_state,
cursor_field=configured_stream.cursor_field or None,
)
for record_data in records:
record_counter += 1
for record_counter, record_data in enumerate(records, start=1):
yield self._as_airbyte_record(stream_name, record_data)
stream_state = stream_instance.get_updated_state(stream_state, record_data)
if checkpoint_interval and record_counter % checkpoint_interval == 0:
yield self._checkpoint_state(stream_name, stream_state, connector_state, logger)

total_records_counter += 1
# This functionality should ideally live outside of this method
# but since state is managed inside this method, we keep track
# of it here.
if self._limit_reached(internal_config, total_records_counter):
# Break from slice loop to save state and exit from _read_incremental function.
break

yield self._checkpoint_state(stream_name, stream_state, connector_state, logger)
if self._limit_reached(internal_config, total_records_counter):
return

def _read_full_refresh(self, stream_instance: Stream, configured_stream: ConfiguredAirbyteStream) -> Iterator[AirbyteMessage]:
def _read_full_refresh(
self, stream_instance: Stream, configured_stream: ConfiguredAirbyteStream, internal_config: InternalConfig
) -> Iterator[AirbyteMessage]:
slices = stream_instance.stream_slices(sync_mode=SyncMode.full_refresh, cursor_field=configured_stream.cursor_field)
total_records_counter = 0
for slice in slices:
records = stream_instance.read_records(
stream_slice=slice, sync_mode=SyncMode.full_refresh, cursor_field=configured_stream.cursor_field
)
for record in records:
yield self._as_airbyte_record(configured_stream.stream.name, record)
total_records_counter += 1
if self._limit_reached(internal_config, total_records_counter):
return

def _checkpoint_state(self, stream_name, stream_state, connector_state, logger):
logger.info(f"Setting state of {stream_name} stream to {stream_state}")
Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

setup(
name="airbyte-cdk",
version="0.1.17",
version="0.1.18",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
75 changes: 58 additions & 17 deletions airbyte-cdk/python/unit_tests/sources/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import pytest
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import ConfiguredAirbyteCatalog
from airbyte_cdk.models import ConfiguredAirbyteCatalog, SyncMode, Type
from airbyte_cdk.sources import AbstractSource, Source
from airbyte_cdk.sources.streams.core import Stream
from airbyte_cdk.sources.streams.http.http import HttpStream
Expand All @@ -54,6 +54,25 @@ def source():
return MockSource()


@pytest.fixture
def catalog():
configured_catalog = {
"streams": [
{
"stream": {"name": "mock_http_stream", "json_schema": {}},
"destination_sync_mode": "overwrite",
"sync_mode": "full_refresh",
},
{
"stream": {"name": "mock_stream", "json_schema": {}},
"destination_sync_mode": "overwrite",
"sync_mode": "full_refresh",
},
]
}
return ConfiguredAirbyteCatalog.parse_obj(configured_catalog)


@pytest.fixture
def abstract_source(mocker):
mocker.patch.multiple(HttpStream, __abstractmethods__=set())
Expand All @@ -63,6 +82,9 @@ class MockHttpStream(MagicMock, HttpStream):
url_base = "http://example.com"
path = "/dummy/path"

def supports_incremental(self):
return True

def __init__(self, *args, **kvargs):
MagicMock.__init__(self)
HttpStream.__init__(self, *args, kvargs)
Expand Down Expand Up @@ -120,22 +142,7 @@ def test_read_catalog(source):
assert actual == expected


def test_internal_config(abstract_source):
configured_catalog = {
"streams": [
{
"stream": {"name": "mock_http_stream", "json_schema": {}},
"destination_sync_mode": "overwrite",
"sync_mode": "full_refresh",
},
{
"stream": {"name": "mock_stream", "json_schema": {}},
"destination_sync_mode": "overwrite",
"sync_mode": "full_refresh",
},
]
}
catalog = ConfiguredAirbyteCatalog.parse_obj(configured_catalog)
def test_internal_config(abstract_source, catalog):
streams = abstract_source.streams(None)
assert len(streams) == 2
http_stream = streams[0]
Expand Down Expand Up @@ -175,3 +182,37 @@ def test_internal_config(abstract_source):
assert http_stream.page_size == 2
# Make sure page_size havent been set for non http streams
assert not non_http_stream.page_size


def test_internal_config_limit(abstract_source, catalog):
logger_mock = MagicMock()
del catalog.streams[1]
STREAM_LIMIT = 2
FULL_RECORDS_NUMBER = 3
streams = abstract_source.streams(None)
http_stream = streams[0]
http_stream.read_records.return_value = [{}] * FULL_RECORDS_NUMBER
internal_config = {"some_config": 100, "_limit": STREAM_LIMIT}

catalog.streams[0].sync_mode = SyncMode.full_refresh
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
# Check if log line matches number of limit
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
assert read_log_record[0].startswith(f"Read {STREAM_LIMIT} ")

# No limit, check if state record produced for incremental stream
catalog.streams[0].sync_mode = SyncMode.incremental
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == FULL_RECORDS_NUMBER + 1
assert records[-1].type == Type.STATE

# Set limit and check if state is produced when limit is set for incremental stream
logger_mock.reset_mock()
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT + 1
assert records[-1].type == Type.STATE
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
assert read_log_record[0].startswith(f"Read {STREAM_LIMIT} ")

0 comments on commit 6041f3d

Please sign in to comment.