Skip to content

Commit d32522f

Browse files
authored
Merge pull request RasaHQ#6054 from RasaHQ/share-hidden-layers
Fix option `share_hidden_layers`
2 parents de8e585 + f803204 commit d32522f

File tree

4 files changed

+10
-8
lines changed

4 files changed

+10
-8
lines changed

changelog/6053.bugfix.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Update ``FeatureSignature`` to store just the feature dimension instead of the complete shape. This change fixes the
2+
usage of the option ``share_hidden_layers`` in the ``DIETClassifier``.

rasa/nlu/classifiers/diet_classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,14 +1183,14 @@ def _prepare_sparse_dense_layers(
11831183
) -> None:
11841184
sparse = False
11851185
dense = False
1186-
for is_sparse, shape in feature_signatures:
1186+
for is_sparse, feature_dimension in feature_signatures:
11871187
if is_sparse:
11881188
sparse = True
11891189
else:
11901190
dense = True
11911191
# if dense features are present
11921192
# use the feature dimension of the dense features
1193-
dense_dim = shape[-1]
1193+
dense_dim = feature_dimension
11941194

11951195
if sparse:
11961196
self._tf_layers[f"sparse_to_dense.{name}"] = layers.DenseForSparse(

rasa/utils/tensorflow/model_data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class FeatureSignature(NamedTuple):
2525
"""Stores the shape and the type (sparse vs dense) of features."""
2626

2727
is_sparse: bool
28-
shape: List[int]
28+
feature_dimension: Optional[int]
2929

3030

3131
class RasaModelData:
@@ -210,7 +210,7 @@ def get_signature(self) -> Dict[Text, List[FeatureSignature]]:
210210
key: [
211211
FeatureSignature(
212212
True if isinstance(v[0], scipy.sparse.spmatrix) else False,
213-
v[0].shape,
213+
v[0].shape[-1] if v[0].shape else None,
214214
)
215215
for v in values
216216
]
@@ -357,8 +357,8 @@ def _balanced_data(self, data: Data, batch_size: int, shuffle: bool) -> Data:
357357
if num_data_cycles[index] > 0 and not skipped[index]:
358358
skipped[index] = True
359359
continue
360-
else:
361-
skipped[index] = False
360+
361+
skipped[index] = False
362362

363363
index_batch_size = (
364364
int(counts_label_ids[index] / self.num_examples * batch_size) + 1

rasa/utils/tensorflow/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,15 +417,15 @@ def batch_to_model_data_format(
417417

418418
idx = 0
419419
for k, signature in data_signature.items():
420-
for is_sparse, shape in signature:
420+
for is_sparse, feature_dimension in signature:
421421
if is_sparse:
422422
# explicitly substitute last dimension in shape with known
423423
# static value
424424
batch_data[k].append(
425425
tf.SparseTensor(
426426
batch[idx],
427427
batch[idx + 1],
428-
[batch[idx + 2][0], batch[idx + 2][1], shape[-1]],
428+
[batch[idx + 2][0], batch[idx + 2][1], feature_dimension],
429429
)
430430
)
431431
idx += 3

0 commit comments

Comments
 (0)