|
1 | 1 | import os
|
2 | 2 | from pathlib import Path
|
| 3 | +from typing import Any, Text, Dict |
| 4 | + |
| 5 | +import pytest |
3 | 6 |
|
4 | 7 | import rasa.utils.io
|
5 | 8 | from rasa.core.test import (
|
|
10 | 13 | CONFUSION_MATRIX_STORIES_FILE,
|
11 | 14 | REPORT_STORIES_FILE,
|
12 | 15 | SUCCESSFUL_STORIES_FILE,
|
| 16 | + _clean_entity_results, |
13 | 17 | )
|
14 | 18 | from rasa.core.policies.memoization import MemoizationPolicy
|
15 | 19 |
|
@@ -165,3 +169,70 @@ async def test_end_to_evaluation_trips_circuit_breaker():
|
165 | 169 | story_evaluation.evaluation_store.action_predictions == circuit_trip_predicted
|
166 | 170 | )
|
167 | 171 | assert num_stories == 1
|
| 172 | + |
| 173 | + |
| 174 | +@pytest.mark.parametrize( |
| 175 | + "text, entity, expected_entity", |
| 176 | + [ |
| 177 | + ( |
| 178 | + "The first one please.", |
| 179 | + { |
| 180 | + "extractor": "DucklingHTTPExtractor", |
| 181 | + "entity": "ordinal", |
| 182 | + "confidence": 0.87, |
| 183 | + "start": 4, |
| 184 | + "end": 9, |
| 185 | + "value": 1, |
| 186 | + }, |
| 187 | + { |
| 188 | + "text": "The first one please.", |
| 189 | + "entity": "ordinal", |
| 190 | + "start": 4, |
| 191 | + "end": 9, |
| 192 | + "value": "1", |
| 193 | + }, |
| 194 | + ), |
| 195 | + ( |
| 196 | + "The first one please.", |
| 197 | + { |
| 198 | + "extractor": "CRFEntityExtractor", |
| 199 | + "entity": "ordinal", |
| 200 | + "confidence": 0.87, |
| 201 | + "start": 4, |
| 202 | + "end": 9, |
| 203 | + "value": "1", |
| 204 | + }, |
| 205 | + { |
| 206 | + "text": "The first one please.", |
| 207 | + "entity": "ordinal", |
| 208 | + "start": 4, |
| 209 | + "end": 9, |
| 210 | + "value": "1", |
| 211 | + }, |
| 212 | + ), |
| 213 | + ( |
| 214 | + "Italian food", |
| 215 | + { |
| 216 | + "extractor": "DIETClassifier", |
| 217 | + "entity": "cuisine", |
| 218 | + "confidence": 0.99, |
| 219 | + "start": 0, |
| 220 | + "end": 7, |
| 221 | + "value": "Italian", |
| 222 | + }, |
| 223 | + { |
| 224 | + "text": "Italian food", |
| 225 | + "entity": "cuisine", |
| 226 | + "start": 0, |
| 227 | + "end": 7, |
| 228 | + "value": "Italian", |
| 229 | + }, |
| 230 | + ), |
| 231 | + ], |
| 232 | +) |
| 233 | +def test_event_has_proper_implementation( |
| 234 | + text: Text, entity: Dict[Text, Any], expected_entity: Dict[Text, Any] |
| 235 | +): |
| 236 | + actual_entities = _clean_entity_results(text, [entity]) |
| 237 | + |
| 238 | + assert actual_entities[0] == expected_entity |
0 commit comments