Skip to content

Commit

Permalink
support semi incremental by adding extractor record filter (airbytehq…
Browse files Browse the repository at this point in the history
…#13520)

* support semi incremental by adding extractor record filter

* refactor extractor into a record_selector that supports extraction and filtering of response records
  • Loading branch information
brianjlai authored Jun 23, 2022
1 parent c6d83b3 commit a612248
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 29 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from abc import ABC, abstractmethod
from typing import Any, List, Mapping

import requests
from airbyte_cdk.sources.declarative.types import Record


class HttpSelector(ABC):
@abstractmethod
def select_records(
self,
response: requests.Response,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> List[Record]:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

import requests
from airbyte_cdk.sources.declarative.decoders.decoder import Decoder
from airbyte_cdk.sources.declarative.extractors.http_extractor import HttpExtractor
from airbyte_cdk.sources.declarative.interpolation.jinja import JinjaInterpolation
from airbyte_cdk.sources.declarative.types import Record
from jello import lib as jello_lib


class JelloExtractor(HttpExtractor):
class JelloExtractor:
default_transform = "."

def __init__(self, transform: str, decoder: Decoder, config, kwargs=None):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from typing import Any, List, Mapping

from airbyte_cdk.sources.declarative.interpolation.interpolated_boolean import InterpolatedBoolean
from airbyte_cdk.sources.declarative.types import Record


class RecordFilter:
def __init__(self, config, condition: str = None):
self._config = config
self._filter_interpolator = InterpolatedBoolean(condition)

def filter_records(
self,
records: List[Record],
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> List[Record]:
kwargs = {"stream_state": stream_state, "stream_slice": stream_slice, "next_page_token": next_page_token}
return [record for record in records if self._filter_interpolator.eval(self._config, record=record, **kwargs)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from typing import Any, List, Mapping

import requests
from airbyte_cdk.sources.declarative.extractors.http_selector import HttpSelector
from airbyte_cdk.sources.declarative.extractors.jello import JelloExtractor
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
from airbyte_cdk.sources.declarative.types import Record


class RecordSelector(HttpSelector):
"""
Responsible for translating an HTTP response into a list of records by extracting records from the response and optionally filtering
records based on a heuristic.
"""

def __init__(self, extractor: JelloExtractor, record_filter: RecordFilter = None):
self._extractor = extractor
self._record_filter = record_filter

def select_records(
self,
response: requests.Response,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> List[Record]:
all_records = self._extractor.extract_records(response)
if self._record_filter:
return self._record_filter.filter_records(
all_records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
return all_records
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@ class ConditionalPaginator:
A paginator that performs pagination by incrementing a page number and stops based on a provided stop condition.
"""

def __init__(self, stop_condition_template: str, state: DictState, decoder: Decoder, config):
self._stop_condition_template = InterpolatedBoolean(stop_condition_template)
def __init__(self, stop_condition: str, state: DictState, decoder: Decoder, config):
self._stop_condition_interpolator = InterpolatedBoolean(stop_condition)
self._state: DictState = state
self._decoder = decoder
self._config = config

def next_page_token(self, response: requests.Response, last_records: List[Mapping[str, Any]]) -> Optional[Mapping[str, Any]]:
decoded_response = self._decoder.decode(response)
headers = response.headers
should_stop = self._stop_condition_template.eval(
should_stop = self._stop_condition_interpolator.eval(
self._config, decoded_response=decoded_response, headers=headers, last_records=last_records
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import requests
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.declarative.extractors.http_extractor import HttpExtractor
from airbyte_cdk.sources.declarative.extractors.http_selector import HttpSelector
from airbyte_cdk.sources.declarative.requesters.paginators.paginator import Paginator
from airbyte_cdk.sources.declarative.requesters.requester import Requester
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
Expand All @@ -22,15 +22,15 @@ def __init__(
primary_key,
requester: Requester,
paginator: Paginator,
extractor: HttpExtractor,
record_selector: HttpSelector,
stream_slicer: StreamSlicer,
state: State,
):
self._name = name
self._primary_key = primary_key
self._paginator = paginator
self._requester = requester
self._extractor = extractor
self._record_selector = record_selector
super().__init__(self._requester.get_authenticator())
self._iterator: StreamSlicer = stream_slicer
self._state: State = state.deep_copy()
Expand Down Expand Up @@ -190,7 +190,9 @@ def parse_response(
next_page_token: Mapping[str, Any] = None,
) -> Iterable[Mapping]:
self._last_response = response
records = self._extractor.extract_records(response)
records = self._record_selector.select_records(
response=response, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
self._last_records = records
return records

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

import pytest
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter


@pytest.mark.parametrize(
"test_name, filter_template, records, expected_records",
[
(
"test_using_state_filter",
"{{ record['created_at'] > stream_state['created_at'] }}",
[{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}],
[{"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}],
),
(
"test_with_slice_filter",
"{{ record['last_seen'] >= stream_slice['last_seen'] }}",
[{"id": 1, "last_seen": "06-06-21"}, {"id": 2, "last_seen": "06-07-21"}, {"id": 3, "last_seen": "06-10-21"}],
[{"id": 3, "last_seen": "06-10-21"}],
),
(
"test_with_next_page_token_filter",
"{{ record['id'] >= next_page_token['last_seen_id'] }}",
[{"id": 11}, {"id": 12}, {"id": 13}, {"id": 14}, {"id": 15}],
[{"id": 14}, {"id": 15}],
),
(
"test_missing_filter_fields_return_no_results",
"{{ record['id'] >= next_page_token['path_to_nowhere'] }}",
[{"id": 11}, {"id": 12}, {"id": 13}, {"id": 14}, {"id": 15}],
[],
),
],
)
def test_record_filter(test_name, filter_template, records, expected_records):
config = {"response_override": "stop_if_you_see_me"}
stream_state = {"created_at": "06-06-21"}
stream_slice = {"last_seen": "06-10-21"}
next_page_token = {"last_seen_id": 14}
record_filter = RecordFilter(config=config, condition=filter_template)

actual_records = record_filter.filter_records(
records, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
assert actual_records == expected_records
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

import json

import pytest
import requests
from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder
from airbyte_cdk.sources.declarative.extractors.jello import JelloExtractor
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector


@pytest.mark.parametrize(
"test_name, transform_template, filter_template, body, expected_records",
[
(
"test_with_extractor_and_filter",
"_.data",
"{{ record['created_at'] > stream_state['created_at'] }}",
{"data": [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}]},
[{"id": 2, "created_at": "06-07-21"}, {"id": 3, "created_at": "06-08-21"}],
),
(
"test_no_record_filter_returns_all_records",
"_.data",
None,
{"data": [{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}]},
[{"id": 1, "created_at": "06-06-21"}, {"id": 2, "created_at": "06-07-21"}],
),
],
)
def test_record_filter(test_name, transform_template, filter_template, body, expected_records):
config = {"response_override": "stop_if_you_see_me"}
stream_state = {"created_at": "06-06-21"}
stream_slice = {"last_seen": "06-10-21"}
next_page_token = {"last_seen_id": 14}

response = create_response(body)
decoder = JsonDecoder()
extractor = JelloExtractor(transform=transform_template, decoder=decoder, config=config, kwargs={})
if filter_template is None:
record_filter = None
else:
record_filter = RecordFilter(config=config, condition=filter_template)
record_selector = RecordSelector(extractor=extractor, record_filter=record_filter)

actual_records = record_selector.select_records(
response=response, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
)
assert actual_records == expected_records


def create_response(body):
response = requests.Response()
response._content = json.dumps(body).encode("utf-8")
return response
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test():
next_page_token = {"cursor": "cursor_value"}
paginator.next_page_token.return_value = next_page_token

extractor = MagicMock()
extractor.extract_records.return_value = records
record_selector = MagicMock()
record_selector.select_records.return_value = records

iterator = MagicMock()
stream_slices = [{"date": "2022-01-01"}, {"date": "2022-01-02"}]
Expand Down Expand Up @@ -62,7 +62,7 @@ def test():
use_cache = True
requester.use_cache = use_cache

retriever = SimpleRetriever("stream_name", primary_key, requester, paginator, extractor, iterator, state)
retriever = SimpleRetriever("stream_name", primary_key, requester, paginator, record_selector, iterator, state)

# hack because we clone the state...
retriever._state = state
Expand Down
18 changes: 16 additions & 2 deletions airbyte-cdk/python/unit_tests/sources/declarative/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.decoders.json_decoder import JsonDecoder
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
from airbyte_cdk.sources.declarative.parsers.factory import DeclarativeComponentFactory
from airbyte_cdk.sources.declarative.parsers.yaml_parser import YamlParser
from airbyte_cdk.sources.declarative.requesters.request_options.interpolated_request_options_provider import (
Expand Down Expand Up @@ -86,6 +88,13 @@ def test_full_config():
extractor:
class_name: airbyte_cdk.sources.declarative.extractors.jello.JelloExtractor
decoder: "*ref(decoder)"
selector:
class_name: airbyte_cdk.sources.declarative.extractors.record_selector.RecordSelector
extractor:
decoder: "*ref(decoder)"
record_filter:
class_name: airbyte_cdk.sources.declarative.extractors.record_filter.RecordFilter
condition: "{{ record['id'] > stream_state['id'] }}"
metadata_paginator:
class_name: "airbyte_cdk.sources.declarative.requesters.paginators.next_page_url_paginator.NextPageUrlPaginator"
next_page_token_template:
Expand Down Expand Up @@ -139,6 +148,8 @@ def test_full_config():
default: "marketing/lists"
paginator:
ref: "*ref(metadata_paginator)"
record_selector:
ref: "*ref(selector)"
check:
class_name: airbyte_cdk.sources.declarative.checks.check_stream.CheckStream
stream_names: ["list_stream"]
Expand All @@ -156,8 +167,11 @@ def test_full_config():
assert type(stream._retriever) == SimpleRetriever
assert stream._retriever._requester._method == HttpMethod.GET
assert stream._retriever._requester._authenticator._tokens == ["verysecrettoken"]
assert type(stream._retriever._extractor._decoder) == JsonDecoder
assert stream._retriever._extractor._transform == ".result[]"
assert type(stream._retriever._record_selector) == RecordSelector
assert type(stream._retriever._record_selector._extractor._decoder) == JsonDecoder
assert stream._retriever._record_selector._extractor._transform == ".result[]"
assert type(stream._retriever._record_selector._record_filter) == RecordFilter
assert stream._retriever._record_selector._record_filter._filter_interpolator._condition == "{{ record['id'] > stream_state['id'] }}"
assert stream._schema_loader._file_path._string == "./source_sendgrid/schemas/lists.json"

checker = factory.create_component(config["check"], input_config)()
Expand Down

0 comments on commit a612248

Please sign in to comment.