Skip to content

Commit d8877a6

Browse files
authored
Merge branch 'master' into include-source-in-failed-stories
2 parents 75d73d2 + ee88904 commit d8877a6

File tree

4 files changed

+35
-12
lines changed

4 files changed

+35
-12
lines changed
File renamed without changes.

changelog/5646.improvement.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
``DIETClassifier`` only counts as extractor in ``rasa test`` if it was actually trained for entity recognition.
2+

rasa/nlu/test.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from rasa.nlu.model import Interpreter, Trainer, TrainingData
3838
from rasa.nlu.components import Component
3939
from rasa.nlu.tokenizers.tokenizer import Token
40+
from rasa.utils.tensorflow.constants import ENTITY_RECOGNITION
4041

4142
logger = logging.getLogger(__name__)
4243

@@ -1022,12 +1023,18 @@ def get_entity_extractors(interpreter: Interpreter) -> Set[Text]:
10221023
10231024
Processors are removed since they do not detect the boundaries themselves.
10241025
"""
1025-
10261026
from rasa.nlu.extractors.extractor import EntityExtractor
1027+
from rasa.nlu.classifiers.diet_classifier import DIETClassifier
1028+
1029+
extractors = set()
1030+
for c in interpreter.pipeline:
1031+
if isinstance(c, EntityExtractor):
1032+
if isinstance(c, DIETClassifier):
1033+
if c.component_config[ENTITY_RECOGNITION]:
1034+
extractors.add(c.name)
1035+
else:
1036+
extractors.add(c.name)
10271037

1028-
extractors = {
1029-
c.name for c in interpreter.pipeline if isinstance(c, EntityExtractor)
1030-
}
10311038
return extractors - ENTITY_PROCESSORS
10321039

10331040

tests/nlu/test_evaluation.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from _pytest.tmpdir import TempdirFactory
88

99
import rasa.utils.io
10+
from rasa.nlu.classifiers.diet_classifier import DIETClassifier
1011
from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor
1112
from rasa.test import compare_nlu_models
1213
from rasa.nlu.extractors.extractor import EntityExtractor
@@ -50,7 +51,7 @@
5051
from tests.nlu.conftest import DEFAULT_DATA_PATH
5152
from rasa.nlu.selectors.response_selector import ResponseSelector
5253
from rasa.nlu.test import is_response_selector_present
53-
from rasa.utils.tensorflow.constants import EPOCHS
54+
from rasa.utils.tensorflow.constants import EPOCHS, ENTITY_RECOGNITION
5455

5556

5657
# https://github.com/pytest-dev/pytest-asyncio/issues/68
@@ -510,6 +511,26 @@ def test_response_evaluation_report(tmpdir_factory):
510511
assert result["predictions"][1] == prediction
511512

512513

514+
@pytest.mark.parametrize(
515+
"components, expected_extractors",
516+
[
517+
([DIETClassifier({ENTITY_RECOGNITION: False})], set()),
518+
([DIETClassifier({ENTITY_RECOGNITION: True})], {"DIETClassifier"}),
519+
([CRFEntityExtractor()], {"CRFEntityExtractor"}),
520+
(
521+
[SpacyEntityExtractor(), CRFEntityExtractor()],
522+
{"SpacyEntityExtractor", "CRFEntityExtractor"},
523+
),
524+
([ResponseSelector()], set()),
525+
],
526+
)
527+
def test_get_entity_extractors(components, expected_extractors):
528+
mock_interpreter = Interpreter(components, None)
529+
extractors = get_entity_extractors(mock_interpreter)
530+
531+
assert extractors == expected_extractors
532+
533+
513534
def test_entity_evaluation_report(tmpdir_factory):
514535
class EntityExtractorA(EntityExtractor):
515536

@@ -653,13 +674,6 @@ def test_evaluate_entities_cv():
653674
}, "Wrong entity prediction alignment"
654675

655676

656-
def test_get_entity_extractors(pretrained_interpreter):
657-
assert get_entity_extractors(pretrained_interpreter) == {
658-
"SpacyEntityExtractor",
659-
"DucklingHTTPExtractor",
660-
}
661-
662-
663677
def test_remove_pretrained_extractors(pretrained_interpreter):
664678
target_components_names = ["SpacyNLP"]
665679
filtered_pipeline = remove_pretrained_extractors(pretrained_interpreter.pipeline)

0 commit comments

Comments
 (0)