Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Shell-style pipes #359

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Make all popular transformers pipes
  • Loading branch information
Dmitriy Serdyuk authored and dmitriy-serdyuk committed Jul 28, 2016
commit 9eb1327f93edbcccda1e266358710490d54d69cb
82 changes: 51 additions & 31 deletions fuel/transformers/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fuel.streams import AbstractDataStream
from fuel.schemes import BatchSizeScheme
from .base import Transformer
from .pipes import Pipe

log = logging.getLogger(__name__)

Expand All @@ -34,7 +35,7 @@ def transform_batch(self, batch):
return self.transform_any(batch)


class Mapping(Transformer):
class Mapping(Pipe):
"""Applies a mapping to the data of the wrapped data stream.

Parameters
Expand Down Expand Up @@ -78,7 +79,7 @@ def transform_example(self, example):


@add_metaclass(ABCMeta)
class SourcewiseTransformer(Transformer):
class SourcewiseTransformer(Pipe):
"""Applies a transformation sourcewise.

Subclasses must define `transform_source_example` (to transform
Expand All @@ -94,8 +95,8 @@ class SourcewiseTransformer(Transformer):
which case the mapping is applied to all sources.

"""
def __init__(self, data_stream, produces_examples, which_sources=None,
**kwargs):
def __init__(self, data_stream=None, produces_examples=None,
which_sources=None, **kwargs):
if which_sources is None:
which_sources = data_stream.sources
self.which_sources = which_sources
Expand Down Expand Up @@ -178,6 +179,10 @@ def transform_any_source(self, source_data, source_name):

"""

@property
def produces_examples(self):
return self.data_stream.produces_examples


class Flatten(SourcewiseTransformer):
"""Flattens selected sources.
Expand All @@ -189,7 +194,7 @@ class Flatten(SourcewiseTransformer):
`numpy.asarray`).

"""
def __init__(self, data_stream, **kwargs):
def __init__(self, data_stream=None, **kwargs):
# Modify the axis_labels dict to reflect the fact that all non-batch
# axes will be grouped together under the same 'feature' axis.
if data_stream.axis_labels:
Expand All @@ -198,7 +203,7 @@ def __init__(self, data_stream, **kwargs):
'axis_labels',
self._infer_axis_labels(data_stream, which_sources))
super(Flatten, self).__init__(
data_stream, data_stream.produces_examples, **kwargs)
data_stream, **kwargs)

def _infer_axis_labels(self, data_stream, which_sources):
axis_labels = {}
Expand Down Expand Up @@ -235,13 +240,15 @@ class ScaleAndShift(AgnosticSourcewiseTransformer):
Shifting factor.

"""
def __init__(self, data_stream, scale, shift, **kwargs):
def __init__(self, data_stream=None, scale=None, shift=None, **kwargs):
if scale is None or shift is None:
raise ValueError('scale or shift cannot be None')
self.scale = scale
self.shift = shift
if data_stream.axis_labels:
kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
super(ScaleAndShift, self).__init__(
data_stream, data_stream.produces_examples, **kwargs)
data_stream, **kwargs)

def transform_any_source(self, source_data, _):
return numpy.asarray(source_data) * self.scale + self.shift
Expand All @@ -260,26 +267,28 @@ class Cast(AgnosticSourcewiseTransformer):
in which case ``fuel.config.floatX`` is used.

"""
def __init__(self, data_stream, dtype, **kwargs):
def __init__(self, data_stream=None, dtype=None, **kwargs):
if dtype is None:
raise ValueError('dtype cannot be None')
if dtype == 'floatX':
dtype = config.floatX
self.dtype = dtype
if data_stream.axis_labels:
kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
super(Cast, self).__init__(
data_stream, data_stream.produces_examples, **kwargs)
data_stream, **kwargs)

def transform_any_source(self, source_data, _):
return numpy.asarray(source_data, dtype=self.dtype)


class ForceFloatX(AgnosticSourcewiseTransformer):
"""Force all floating point numpy arrays to be floatX."""
def __init__(self, data_stream, **kwargs):
def __init__(self, data_stream=None, **kwargs):
if data_stream.axis_labels:
kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
super(ForceFloatX, self).__init__(
data_stream, data_stream.produces_examples, **kwargs)
data_stream, **kwargs)

def transform_any_source(self, source_data, _):
source_needs_casting = (isinstance(source_data, numpy.ndarray) and
Expand All @@ -290,7 +299,7 @@ def transform_any_source(self, source_data, _):
return source_data


class Filter(Transformer):
class Filter(Pipe):
"""Filters samples that meet a predicate.

Parameters
Expand All @@ -301,19 +310,21 @@ class Filter(Transformer):
Should return ``True`` for the samples to be kept.

"""
def __init__(self, data_stream, predicate, **kwargs):
def __init__(self, data_stream=None, predicate=None, **kwargs):
if predicate is None:
raise ValueError('predicate cannot be None')
if data_stream.axis_labels:
kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
super(Filter, self).__init__(
data_stream, data_stream.produces_examples, **kwargs)
data_stream, **kwargs)
self.predicate = predicate

def get_epoch_iterator(self, **kwargs):
super(Filter, self).get_epoch_iterator(**kwargs)
return ifilter(self.predicate, self.child_epoch_iterator)


class Cache(Transformer):
class Cache(Pipe):
"""Cache examples when sequentially reading a dataset.

Given a data stream which reads large chunks of data, this data
Expand All @@ -337,7 +348,7 @@ class Cache(Transformer):
refilled when needed through the :meth:`get_data` method.

"""
def __init__(self, data_stream, iteration_scheme, **kwargs):
def __init__(self, data_stream=None, iteration_scheme=None, **kwargs):
# Note: produces_examples will always be False because of this
# restriction: the only iteration schemes allowed are BatchSizeScheme,
# which produce batches.
Expand Down Expand Up @@ -402,7 +413,7 @@ def __call__(self, batch):
return output


class Batch(Transformer):
class Batch(Pipe):
"""Creates minibatches from data streams providing single examples.

Some datasets only return one example at at time e.g. when reading text
Expand All @@ -425,8 +436,12 @@ class Batch(Transformer):
raised if a batch of the requested size cannot be provided.

"""
def __init__(self, data_stream, iteration_scheme, strictness=0, **kwargs):
if not data_stream.produces_examples:
def __init__(self, data_stream=None, iteration_scheme=None, strictness=0,
**kwargs):
if iteration_scheme is None:
raise ValueError('iteration_scheme cannot be None')
if data_stream and not data_stream.produces_examples:
# TODO: This check is not performed if lazy
raise ValueError('the wrapped data stream must produce examples, '
'not batches of examples.')
# The value for `produces_examples` is inferred from the iteration
Expand Down Expand Up @@ -465,7 +480,7 @@ def get_data(self, request=None):
return tuple(numpy.asarray(source_data) for source_data in data)


class Unpack(Transformer):
class Unpack(Pipe):
"""Unpacks batches to compose a stream of examples.

This class is the inverse of the Batch class: it turns a minibatch into
Expand All @@ -477,8 +492,8 @@ class Unpack(Transformer):
The data stream to unpack

"""
def __init__(self, data_stream, **kwargs):
if data_stream.produces_examples:
def __init__(self, data_stream=None, **kwargs):
if data_stream and data_stream.produces_examples:
raise ValueError('the wrapped data stream must produce batches of '
'examples, not examples')
if data_stream.axis_labels:
Expand All @@ -503,7 +518,7 @@ def get_data(self, request=None):
return self.get_data()


class Padding(Transformer):
class Padding(Pipe):
"""Adds padding to variable-length sequences.

When your batches consist of variable-length sequences, use this class
Expand All @@ -528,9 +543,9 @@ class Padding(Transformer):
be used.

"""
def __init__(self, data_stream, mask_sources=None, mask_dtype=None,
def __init__(self, data_stream=None, mask_sources=None, mask_dtype=None,
**kwargs):
if data_stream.produces_examples:
if data_stream and data_stream.produces_examples:
raise ValueError('the wrapped data stream must produce batches of '
'examples, not examples')
super(Padding, self).__init__(
Expand Down Expand Up @@ -683,7 +698,7 @@ def get_next_data(self):
return self.batches.get()


class MultiProcessing(Transformer):
class MultiProcessing(Pipe):
"""Cache batches from the stream in a separate process.

To speed up training of your model, it can be worthwhile to load and
Expand All @@ -707,8 +722,8 @@ class MultiProcessing(Transformer):
robust approach might need to be considered.

"""
def __init__(self, data_stream, max_store=100, **kwargs):
if data_stream.axis_labels:
def __init__(self, data_stream=None, max_store=100, **kwargs):
if data_stream and data_stream.axis_labels:
kwargs.setdefault('axis_labels', data_stream.axis_labels.copy())
super(MultiProcessing, self).__init__(
data_stream, data_stream.produces_examples, **kwargs)
Expand Down Expand Up @@ -742,7 +757,10 @@ class Rename(AgnosticTransformer):
description of possible values. Default is 'raise'.

"""
def __init__(self, data_stream, names, on_non_existent='raise', **kwargs):
def __init__(self, data_stream=None, names=None, on_non_existent='raise',
**kwargs):
if names is None:
raise ValueError('names cannot be None')
if on_non_existent not in ('raise', 'ignore', 'warn'):
raise ValueError("on_non_existent must be one of 'raise', "
"'ignore', 'warn'")
Expand Down Expand Up @@ -806,7 +824,9 @@ class FilterSources(AgnosticTransformer):
Must be a subset of the sources given by the stream.

"""
def __init__(self, data_stream, sources, **kwargs):
def __init__(self, data_stream=None, sources=None, **kwargs):
if sources is None:
raise ValueError('sources cannot be None')
if any(source not in data_stream.sources for source in sources):
raise ValueError("sources must all be contained in "
"data_stream.sources")
Expand Down