Skip to content

Commit

Permalink
Upgrade vkit.
Browse files Browse the repository at this point in the history
  • Loading branch information
huntzhan committed Aug 16, 2022
1 parent b608db1 commit 3f7ce7c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
2 changes: 1 addition & 1 deletion tests/test_adaptive_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def sample_adaptive_scaling_dataset(
dataset = AdaptiveScalingIterableDataset(
AdaptiveScalingIterableDatasetConfig(
steps_json=(
'$VKIT_ARTIFACT_PACK/pipeline/text_detection/dev_adaptive_scaling_dataset_steps.json'
'$VKIT_ARTIFACT_PACK/pipeline/text_detection/dev_adaptive_scaling_dataset_steps.json' # noqa
),
num_samples=num_samples,
rng_seed=13,
Expand Down
31 changes: 23 additions & 8 deletions vkit_open_model/dataset/adaptive_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,46 @@
from vkit.element import Image, Mask, ScoreMap, Box
from vkit.utility import PathType
from vkit.pipeline import (
PipelineState,
PageCroppingStep,
NoneTypePipelinePostProcessorConfig,
pipeline_step_collection_factory,
PageCroppingStepOutput,
PipelinePostProcessor,
PipelinePostProcessorFactory,
PipelineRunRngStateOutput,
Pipeline,
PipelinePool,
pipeline_step_collection_factory,
)

logger = logging.getLogger(__name__)

Sample = Tuple[Image, Tuple[int, int], Box, Mask, ScoreMap, Mapping]


@attrs.define
class AdaptiveScalingPipelinePostProcessorConfig:
pass


@attrs.define
class AdaptiveScalingPipelinePostProcessorInput:
pipeline_run_rng_state_output: PipelineRunRngStateOutput
page_cropping_step_output: PageCroppingStepOutput


class AdaptiveScalingPipelinePostProcessor(
PipelinePostProcessor[
NoneTypePipelinePostProcessorConfig,
AdaptiveScalingPipelinePostProcessorConfig,
AdaptiveScalingPipelinePostProcessorInput,
Sequence[Sample],
]
): # yapf: disable

def generate_output(self, state: PipelineState, rng: RandomGenerator):
rng_state = state.get_value('_rng_state', Mapping)
page_cropping_step_output = state.get_pipeline_step_output(PageCroppingStep)
def generate_output(
self,
input: AdaptiveScalingPipelinePostProcessorInput,
rng: RandomGenerator,
):
rng_state = input.pipeline_run_rng_state_output.rng_state
page_cropping_step_output = input.page_cropping_step_output
samples: List[Sample] = []
for cropped_page in page_cropping_step_output.cropped_pages:
downsampled_label = cropped_page.downsampled_label
Expand Down

0 comments on commit 3f7ce7c

Please sign in to comment.