Skip to content

Commit

Permalink
A more generic packing op for seqio
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 373368806
  • Loading branch information
0x0539 authored and SeqIO committed May 12, 2021
1 parent 2953766 commit 62a20f2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
7 changes: 4 additions & 3 deletions seqio/feature_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,11 @@ def __call__(self, ds: tf.data.Dataset,
if self.pack:
for k, v in expected_features.items():
# Packing requires rank 1.
if v.rank != 1:
if v.rank != 1 and not self._use_custom_packing_ops:
raise ValueError(
f"When packing is enabled, expected ranks must be 1. Got "
f"expected rank {v.rank} for feature {k}.")
f"When packing is enabled, expected ranks must be 1 or "
f"use_custom_packing_ops must be set. Got expected rank {v.rank} "
f"for feature {k}.")
for k, v in self.PACKING_FEATURE_DTYPES.items():
expected_features[k] = FeatureConverter.FeatureSpec(rank=1, dtype=v)

Expand Down
5 changes: 3 additions & 2 deletions seqio/feature_converters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,9 @@ def test_validate_dataset_rank_2_with_pack(self):
"inputs": FeatureSpec(tf.int64, rank=2),
"targets": FeatureSpec(tf.int64)
}
expected_msg = ("When packing is enabled, expected ranks must be 1. Got "
"expected rank 2 for feature inputs.")
expected_msg = ("When packing is enabled, expected ranks must be 1 or "
"use_custom_packing_ops must be set. Got expected rank 2 "
"for feature inputs.")
with self.assertRaisesRegex(ValueError, expected_msg):
converter(ds, task_feature_lengths)

Expand Down

0 comments on commit 62a20f2

Please sign in to comment.