|
7 | 7 | from _pytest.tmpdir import TempdirFactory
|
8 | 8 |
|
9 | 9 | import rasa.utils.io
|
| 10 | +from rasa.nlu.classifiers.diet_classifier import DIETClassifier |
10 | 11 | from rasa.nlu.extractors.crf_entity_extractor import CRFEntityExtractor
|
11 | 12 | from rasa.test import compare_nlu_models
|
12 | 13 | from rasa.nlu.extractors.extractor import EntityExtractor
|
|
50 | 51 | from tests.nlu.conftest import DEFAULT_DATA_PATH
|
51 | 52 | from rasa.nlu.selectors.response_selector import ResponseSelector
|
52 | 53 | from rasa.nlu.test import is_response_selector_present
|
53 |
| -from rasa.utils.tensorflow.constants import EPOCHS |
| 54 | +from rasa.utils.tensorflow.constants import EPOCHS, ENTITY_RECOGNITION |
54 | 55 |
|
55 | 56 |
|
56 | 57 | # https://github.com/pytest-dev/pytest-asyncio/issues/68
|
@@ -510,6 +511,26 @@ def test_response_evaluation_report(tmpdir_factory):
|
510 | 511 | assert result["predictions"][1] == prediction
|
511 | 512 |
|
512 | 513 |
|
| 514 | +@pytest.mark.parametrize( |
| 515 | + "components, expected_extractors", |
| 516 | + [ |
| 517 | + ([DIETClassifier({ENTITY_RECOGNITION: False})], set()), |
| 518 | + ([DIETClassifier({ENTITY_RECOGNITION: True})], {"DIETClassifier"}), |
| 519 | + ([CRFEntityExtractor()], {"CRFEntityExtractor"}), |
| 520 | + ( |
| 521 | + [SpacyEntityExtractor(), CRFEntityExtractor()], |
| 522 | + {"SpacyEntityExtractor", "CRFEntityExtractor"}, |
| 523 | + ), |
| 524 | + ([ResponseSelector()], set()), |
| 525 | + ], |
| 526 | +) |
| 527 | +def test_get_entity_extractors(components, expected_extractors): |
| 528 | + mock_interpreter = Interpreter(components, None) |
| 529 | + extractors = get_entity_extractors(mock_interpreter) |
| 530 | + |
| 531 | + assert extractors == expected_extractors |
| 532 | + |
| 533 | + |
513 | 534 | def test_entity_evaluation_report(tmpdir_factory):
|
514 | 535 | class EntityExtractorA(EntityExtractor):
|
515 | 536 |
|
@@ -653,13 +674,6 @@ def test_evaluate_entities_cv():
|
653 | 674 | }, "Wrong entity prediction alignment"
|
654 | 675 |
|
655 | 676 |
|
656 |
| -def test_get_entity_extractors(pretrained_interpreter): |
657 |
| - assert get_entity_extractors(pretrained_interpreter) == { |
658 |
| - "SpacyEntityExtractor", |
659 |
| - "DucklingHTTPExtractor", |
660 |
| - } |
661 |
| - |
662 |
| - |
663 | 677 | def test_remove_pretrained_extractors(pretrained_interpreter):
|
664 | 678 | target_components_names = ["SpacyNLP"]
|
665 | 679 | filtered_pipeline = remove_pretrained_extractors(pretrained_interpreter.pipeline)
|
|
0 commit comments