Skip to content

Commit

Permalink
Add missing typing.Optional type annotations to function parameters.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 376287121
  • Loading branch information
rchen152 authored and SeqIO committed May 28, 2021
1 parent 5b6447a commit 97273a4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
25 changes: 12 additions & 13 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,8 @@ def __init__(self,
Sequence[Tuple[str, Union[int, float,
Callable[[Task],
float]]]]],
default_rate: Union[float, Callable[[Task], float]] = None):
default_rate: Optional[Union[float, Callable[[Task],
float]]] = None):
"""Mixture constructor.
A mixture specifies a set of tasks with associated mixing rates.
Expand Down Expand Up @@ -1428,18 +1429,16 @@ def get_subtasks(task_or_mixture):
return task_or_mixture.tasks


def get_dataset(
mixture_or_task_name: str,
task_feature_lengths: Mapping[str, int],
feature_converter: FeatureConverter,
dataset_split: str = "train",
use_cached: bool = False,
shuffle: bool = False,
num_epochs: Optional[int] = 1,
shard_info: ShardInfo = None,
verbose: bool = True,
seed: Optional[int] = None
) -> tf.data.Dataset:
def get_dataset(mixture_or_task_name: str,
task_feature_lengths: Mapping[str, int],
feature_converter: FeatureConverter,
dataset_split: str = "train",
use_cached: bool = False,
shuffle: bool = False,
num_epochs: Optional[int] = 1,
shard_info: Optional[ShardInfo] = None,
verbose: bool = True,
seed: Optional[int] = None) -> tf.data.Dataset:
"""Get processed dataset with the model features.
In order to use options specific to a feature converter, e.g., packing,
Expand Down
8 changes: 4 additions & 4 deletions seqio/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,8 @@ def test_postprocessing(
class PredictCallable(evaluation.PredictFnCallable):

def __call__(self,
dataset: tf.data.Dataset = None,
model_feature_lengths: Mapping[str, int] = None):
dataset: Optional[tf.data.Dataset] = None,
model_feature_lengths: Optional[Mapping[str, int]] = None):
if predict_output is None:
return []
task = dataset_providers.get_mixture_or_task(task_name)
Expand All @@ -646,8 +646,8 @@ class ScoreCallable(evaluation.PredictFnCallable):

def __call__(
self,
dataset: tf.data.Dataset = None,
model_feature_lengths: Mapping[str, int] = None,
dataset: Optional[tf.data.Dataset] = None,
model_feature_lengths: Optional[Mapping[str, int]] = None,
):
if score_output is None:
return []
Expand Down

0 comments on commit 97273a4

Please sign in to comment.