Skip to content

Commit 4477ba8

Browse files
committed
fix tests and types
1 parent da4f26f commit 4477ba8

File tree

5 files changed

+42
-23
lines changed

5 files changed

+42
-23
lines changed

rasa/nlu/classifiers/diet_classifier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,9 +1281,9 @@ def _entity_loss(
12811281

12821282
# should call first to build weights
12831283
pred_ids = self._tf_layers["crf"](logits, sequence_lengths)
1284-
loss = self._tf_layers["crf"].loss(
1285-
logits, c, sequence_lengths
1286-
) # pytype: disable=attribute-error
1284+
# pytype: disable=attribute-error
1285+
loss = self._tf_layers["crf"].loss(logits, c, sequence_lengths)
1286+
# pytype: enable=attribute-error
12871287

12881288
# TODO check that f1 calculation is correct
12891289
# calculate f1 score for train predictions

rasa/nlu/classifiers/embedding_intent_classifier.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
from rasa.constants import DOCS_BASE_URL
55
from rasa.nlu.components import any_of
66
from rasa.nlu.classifiers.diet_classifier import DIETClassifier
7-
from rasa.nlu.constants import (
8-
TEXT,
9-
DENSE_FEATURE_NAMES,
10-
SPARSE_FEATURE_NAMES,
11-
)
7+
from rasa.nlu.constants import TEXT, DENSE_FEATURE_NAMES, SPARSE_FEATURE_NAMES
128
from rasa.utils.tensorflow.constants import (
139
LABEL,
1410
HIDDEN_LAYERS_SIZES,
@@ -125,7 +121,7 @@ def __init__(
125121

126122
component_config = component_config or {}
127123

128-
# the following properties are fixed for the EmbeddingIntentClassifier
124+
# the following properties cannot be adapted for the EmbeddingIntentClassifier
129125
component_config[INTENT_CLASSIFICATION] = True
130126
component_config[ENTITY_RECOGNITION] = False
131127
component_config[MASKED_LM] = False

rasa/nlu/extractors/crf_entity_extractor.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
from rasa.nlu.training_data import TrainingData, Message
1212
from rasa.constants import DOCS_BASE_URL
1313
from rasa.nlu.classifiers.diet_classifier import DIETClassifier
14-
from rasa.nlu.constants import (
15-
TEXT,
16-
ENTITIES,
17-
TOKENS_NAMES,
18-
)
14+
from rasa.nlu.constants import TEXT, ENTITIES, TOKENS_NAMES
1915
from rasa.utils.tensorflow.constants import (
2016
HIDDEN_LAYERS_SIZES,
2117
NUM_TRANSFORMER_LAYERS,
@@ -122,7 +118,7 @@ def __init__(
122118
) -> None:
123119
component_config = component_config or {}
124120

125-
# the following properties are fixed for the CRFEntityExtractor
121+
# the following properties cannot be adapted for the CRFEntityExtractor
126122
component_config[INTENT_CLASSIFICATION] = False
127123
component_config[ENTITY_RECOGNITION] = True
128124
component_config[MASKED_LM] = False

rasa/nlu/selectors/response_selector.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class ResponseSelector(DIETClassifier):
8585

8686
requires = [
8787
any_of(DENSE_FEATURE_NAMES[TEXT], SPARSE_FEATURE_NAMES[TEXT]),
88-
any_of(DENSE_FEATURE_NAMES[RESPONSE], SPARSE_FEATURE_NAMES[RESPONSE],),
88+
any_of(DENSE_FEATURE_NAMES[RESPONSE], SPARSE_FEATURE_NAMES[RESPONSE]),
8989
]
9090

9191
# default properties (DOC MARKER - don't remove)
@@ -175,6 +175,30 @@ class ResponseSelector(DIETClassifier):
175175
}
176176
# end default properties (DOC MARKER - don't remove)
177177

178+
def __init__(
179+
self,
180+
component_config: Optional[Dict[Text, Any]] = None,
181+
inverted_label_dict: Optional[Dict[int, Text]] = None,
182+
inverted_tag_dict: Optional[Dict[int, Text]] = None,
183+
model: Optional[RasaModel] = None,
184+
batch_tuple_sizes: Optional[Dict] = None,
185+
) -> None:
186+
187+
component_config = component_config or {}
188+
189+
# the following properties cannot be adapted for the ResponseSelector
190+
component_config[INTENT_CLASSIFICATION] = True
191+
component_config[ENTITY_RECOGNITION] = False
192+
component_config[BILOU_FLAG] = False
193+
194+
super().__init__(
195+
component_config,
196+
inverted_label_dict,
197+
inverted_tag_dict,
198+
model,
199+
batch_tuple_sizes,
200+
)
201+
178202
@property
179203
def label_key(self) -> Text:
180204
return "label_ids"
@@ -224,7 +248,7 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData:
224248
)
225249

226250
model_data = self._create_model_data(
227-
training_data.intent_examples, label_id_dict, label_attribute=RESPONSE,
251+
training_data.intent_examples, label_id_dict, label_attribute=RESPONSE
228252
)
229253

230254
self.check_input_dimension_consistency(model_data)
@@ -306,7 +330,7 @@ def _create_all_labels(self) -> Tuple[tf.Tensor, tf.Tensor]:
306330
sequence_lengths_label = self._get_sequence_lengths(mask_label)
307331

308332
label_transformed, _, _, _ = self._create_sequence(
309-
self.tf_label_data["label_features"], mask_label, self.label_name,
333+
self.tf_label_data["label_features"], mask_label, self.label_name
310334
)
311335
cls_label = self._last_token(label_transformed, sequence_lengths_label)
312336

@@ -339,7 +363,7 @@ def batch_loss(
339363
sequence_lengths_label = self._get_sequence_lengths(mask_label)
340364

341365
label_transformed, _, _, _ = self._create_sequence(
342-
tf_batch_data["label_features"], mask_label, self.label_name,
366+
tf_batch_data["label_features"], mask_label, self.label_name
343367
)
344368

345369
losses = []

tests/cli/test_rasa_test.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,22 @@ def test_test_nlu_cross_validation(run_in_default_project: Callable[..., RunResu
5858
assert os.path.exists("results/confmat.png")
5959

6060

61-
def test_test_nlu_comparison(run_in_default_project: Callable[..., RunResult]):
61+
def test_test_nlu_comparison(
62+
run_in_default_project_without_models: Callable[..., RunResult]
63+
):
6264
copyfile("config.yml", "config-1.yml")
6365

64-
run_in_default_project(
66+
run_in_default_project_without_models(
6567
"test",
6668
"nlu",
67-
"-config",
69+
"--config",
6870
"config.yml",
6971
"config-1.yml",
7072
"--run",
7173
"2",
72-
"-percentages",
74+
"--percentages",
7375
"75",
76+
"25",
7477
)
7578

7679
assert os.path.exists("results/run_1")

0 commit comments

Comments
 (0)