@@ -85,7 +85,7 @@ class ResponseSelector(DIETClassifier):
85
85
86
86
requires = [
87
87
any_of (DENSE_FEATURE_NAMES [TEXT ], SPARSE_FEATURE_NAMES [TEXT ]),
88
- any_of (DENSE_FEATURE_NAMES [RESPONSE ], SPARSE_FEATURE_NAMES [RESPONSE ], ),
88
+ any_of (DENSE_FEATURE_NAMES [RESPONSE ], SPARSE_FEATURE_NAMES [RESPONSE ]),
89
89
]
90
90
91
91
# default properties (DOC MARKER - don't remove)
@@ -175,6 +175,30 @@ class ResponseSelector(DIETClassifier):
175
175
}
176
176
# end default properties (DOC MARKER - don't remove)
177
177
178
+ def __init__ (
179
+ self ,
180
+ component_config : Optional [Dict [Text , Any ]] = None ,
181
+ inverted_label_dict : Optional [Dict [int , Text ]] = None ,
182
+ inverted_tag_dict : Optional [Dict [int , Text ]] = None ,
183
+ model : Optional [RasaModel ] = None ,
184
+ batch_tuple_sizes : Optional [Dict ] = None ,
185
+ ) -> None :
186
+
187
+ component_config = component_config or {}
188
+
189
+ # the following properties cannot be adapted for the ResponseSelector
190
+ component_config [INTENT_CLASSIFICATION ] = True
191
+ component_config [ENTITY_RECOGNITION ] = False
192
+ component_config [BILOU_FLAG ] = False
193
+
194
+ super ().__init__ (
195
+ component_config ,
196
+ inverted_label_dict ,
197
+ inverted_tag_dict ,
198
+ model ,
199
+ batch_tuple_sizes ,
200
+ )
201
+
178
202
@property
179
203
def label_key (self ) -> Text :
180
204
return "label_ids"
@@ -224,7 +248,7 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
224
248
)
225
249
226
250
model_data = self ._create_model_data (
227
- training_data .intent_examples , label_id_dict , label_attribute = RESPONSE ,
251
+ training_data .intent_examples , label_id_dict , label_attribute = RESPONSE
228
252
)
229
253
230
254
self .check_input_dimension_consistency (model_data )
@@ -306,7 +330,7 @@ def _create_all_labels(self) -> Tuple[tf.Tensor, tf.Tensor]:
306
330
sequence_lengths_label = self ._get_sequence_lengths (mask_label )
307
331
308
332
label_transformed , _ , _ , _ = self ._create_sequence (
309
- self .tf_label_data ["label_features" ], mask_label , self .label_name ,
333
+ self .tf_label_data ["label_features" ], mask_label , self .label_name
310
334
)
311
335
cls_label = self ._last_token (label_transformed , sequence_lengths_label )
312
336
@@ -339,7 +363,7 @@ def batch_loss(
339
363
sequence_lengths_label = self ._get_sequence_lengths (mask_label )
340
364
341
365
label_transformed , _ , _ , _ = self ._create_sequence (
342
- tf_batch_data ["label_features" ], mask_label , self .label_name ,
366
+ tf_batch_data ["label_features" ], mask_label , self .label_name
343
367
)
344
368
345
369
losses = []
0 commit comments