Skip to content

Commit ca6e0d4

Browse files
committed
exclude DIET from extractors if no entities trained
1 parent d55e868 commit ca6e0d4

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

rasa/nlu/test.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from rasa.nlu.model import Interpreter, Trainer, TrainingData
3838
from rasa.nlu.components import Component
3939
from rasa.nlu.tokenizers.tokenizer import Token
40+
from rasa.utils.tensorflow.constants import ENTITY_RECOGNITION
4041

4142
logger = logging.getLogger(__name__)
4243

@@ -1022,12 +1023,17 @@ def get_entity_extractors(interpreter: Interpreter) -> Set[Text]:
10221023
10231024
Processors are removed since they do not detect the boundaries themselves.
10241025
"""
1025-
10261026
from rasa.nlu.extractors.extractor import EntityExtractor
10271027

1028-
extractors = {
1029-
c.name for c in interpreter.pipeline if isinstance(c, EntityExtractor)
1030-
}
1028+
extractors = set()
1029+
for c in interpreter.pipeline:
1030+
if isinstance(c, EntityExtractor):
1031+
if c.name == "DIETClassifier":
1032+
if c.component_config[ENTITY_RECOGNITION]:
1033+
extractors.add(c.name)
1034+
else:
1035+
extractors.add(c.name)
1036+
10311037
return extractors - ENTITY_PROCESSORS
10321038

10331039

tests/nlu/test_evaluation.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from _pytest.tmpdir import TempdirFactory
88

99
import rasa.utils.io
10+
from rasa.nlu.classifiers.diet_classifier import DIETClassifier
1011
from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor
1112
from rasa.test import compare_nlu_models
1213
from rasa.nlu.extractors.extractor import EntityExtractor
@@ -50,7 +51,7 @@
5051
from tests.nlu.conftest import DEFAULT_DATA_PATH
5152
from rasa.nlu.selectors.response_selector import ResponseSelector
5253
from rasa.nlu.test import is_response_selector_present
53-
from rasa.utils.tensorflow.constants import EPOCHS
54+
from rasa.utils.tensorflow.constants import EPOCHS, ENTITY_RECOGNITION
5455

5556

5657
# https://github.com/pytest-dev/pytest-asyncio/issues/68
@@ -510,6 +511,26 @@ def test_response_evaluation_report(tmpdir_factory):
510511
assert result["predictions"][1] == prediction
511512

512513

514+
@pytest.mark.parametrize(
515+
"components, expected_extractors",
516+
[
517+
([DIETClassifier({ENTITY_RECOGNITION: False})], set()),
518+
([DIETClassifier({ENTITY_RECOGNITION: True})], {"DIETClassifier"}),
519+
([CRFEntityExtractor()], {"CRFEntityExtractor"}),
520+
(
521+
[SpacyEntityExtractor(), CRFEntityExtractor()],
522+
{"SpacyEntityExtractor", "CRFEntityExtractor"},
523+
),
524+
([ResponseSelector()], set()),
525+
],
526+
)
527+
def test_get_entity_extractors(components, expected_extractors):
528+
mock_interpreter = Interpreter(components, None)
529+
extractors = get_entity_extractors(mock_interpreter)
530+
531+
assert extractors == expected_extractors
532+
533+
513534
def test_entity_evaluation_report(tmpdir_factory):
514535
class EntityExtractorA(EntityExtractor):
515536

@@ -653,13 +674,6 @@ def test_evaluate_entities_cv():
653674
}, "Wrong entity prediction alignment"
654675

655676

656-
def test_get_entity_extractors(pretrained_interpreter):
657-
assert get_entity_extractors(pretrained_interpreter) == {
658-
"SpacyEntityExtractor",
659-
"DucklingHTTPExtractor",
660-
}
661-
662-
663677
def test_remove_pretrained_extractors(pretrained_interpreter):
664678
target_components_names = ["SpacyNLP"]
665679
filtered_pipeline = remove_pretrained_extractors(pretrained_interpreter.pipeline)

0 commit comments

Comments
 (0)