Skip to content

Commit f950f7c

Browse files
authored
Merge pull request RasaHQ#4951 from RasaHQ/3549_no_plot
Add a flag to disable plotting in rasa test
2 parents 3072050 + d0c0cbd commit f950f7c

File tree

6 files changed

+104
-25
lines changed

6 files changed

+104
-25
lines changed

changelog/3549.improvement.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added ``--no-plot`` option for ``rasa test`` command, which disables rendering of confusion matrix and histogram. By default plots will be rendered.

rasa/cli/arguments/test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
def set_test_arguments(parser: argparse.ArgumentParser):
1717
add_model_param(parser, add_positional_arg=False)
18+
add_no_plot_param(parser)
1819

1920
core_arguments = parser.add_argument_group("Core Test Arguments")
2021
add_test_core_argument_group(core_arguments)
@@ -79,6 +80,7 @@ def add_test_core_argument_group(
7980
"All models in the provided directory are evaluated "
8081
"and compared against each other.",
8182
)
83+
add_no_plot_param(parser)
8284

8385

8486
def add_test_nlu_argument_group(
@@ -162,6 +164,8 @@ def add_test_nlu_argument_group(
162164
help="Percentages of training data to exclude during comparison.",
163165
)
164166

167+
add_no_plot_param(parser)
168+
165169

166170
def add_test_core_model_param(parser: argparse.ArgumentParser):
167171
default_path = get_latest_model(DEFAULT_MODELS_PATH)
@@ -175,3 +179,16 @@ def add_test_core_model_param(parser: argparse.ArgumentParser):
175179
"will be used (exception: '--evaluate-model-directory' flag is set). If multiple "
176180
"'tar.gz' files are provided, all those models will be compared.",
177181
)
182+
183+
184+
def add_no_plot_param(
185+
parser: argparse.ArgumentParser, default: bool = False, required: bool = False,
186+
) -> None:
187+
parser.add_argument(
188+
"--no-plot",
189+
dest="disable_plotting",
190+
action="store_true",
191+
default=default,
192+
help=f"Don't render evaluation plots",
193+
required=required,
194+
)

rasa/core/test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ async def test(
488488
out_directory: Optional[Text] = None,
489489
fail_on_prediction_errors: bool = False,
490490
e2e: bool = False,
491+
disable_plotting: bool = False,
491492
):
492493
"""Run the evaluation of the stories, optionally plot the results."""
493494
from rasa.nlu.test import get_evaluation_metrics
@@ -518,6 +519,7 @@ async def test(
518519
accuracy,
519520
story_evaluation.in_training_data_fraction,
520521
out_directory,
522+
disable_plotting,
521523
)
522524

523525
log_failed_stories(story_evaluation.failed_stories, out_directory)
@@ -566,6 +568,7 @@ def plot_story_evaluation(
566568
accuracy,
567569
in_training_data_fraction,
568570
out_directory,
571+
disable_plotting,
569572
):
570573
"""Plot the results of story evaluation"""
571574
from sklearn.metrics import confusion_matrix
@@ -584,6 +587,9 @@ def plot_story_evaluation(
584587
include_report=True,
585588
)
586589

590+
if disable_plotting:
591+
return
592+
587593
cnf_matrix = confusion_matrix(test_y, predictions)
588594

589595
plot_confusion_matrix(

rasa/nlu/test.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ def evaluate_intents(
441441
errors: bool,
442442
confmat_filename: Optional[Text],
443443
intent_hist_filename: Optional[Text],
444+
disable_plotting: bool,
444445
) -> Dict: # pragma: no cover
445446
"""Creates a confusion matrix and summary statistics for intent predictions.
446447
@@ -504,26 +505,13 @@ def evaluate_intents(
504505
# log and save misclassified samples to file for debugging
505506
collect_nlu_errors(intent_results, errors_filename)
506507

507-
if confmat_filename:
508-
import matplotlib.pyplot as plt
509-
510-
if output_directory:
511-
confmat_filename = os.path.join(output_directory, confmat_filename)
512-
intent_hist_filename = os.path.join(output_directory, intent_hist_filename)
513-
514-
plot_confusion_matrix(
515-
cnf_matrix,
516-
classes=labels,
517-
title="Intent Confusion matrix",
518-
out=confmat_filename,
519-
)
520-
plt.show(block=False)
521-
522-
plot_attribute_confidences(
523-
intent_results, intent_hist_filename, "intent_target", "intent_prediction"
524-
)
525-
526-
plt.show(block=False)
508+
if not disable_plotting:
509+
if confmat_filename:
510+
_plot_confusion_matrix(
511+
output_directory, confmat_filename, cnf_matrix, labels
512+
)
513+
if intent_hist_filename:
514+
_plot_histogram(output_directory, intent_hist_filename, intent_results)
527515

528516
predictions = [
529517
{
@@ -544,6 +532,35 @@ def evaluate_intents(
544532
}
545533

546534

535+
def _plot_confusion_matrix(
536+
output_directory: Optional[Text],
537+
confmat_filename: Optional[Text],
538+
cnf_matrix: np.array,
539+
labels: Collection[Text],
540+
) -> None:
541+
if output_directory:
542+
confmat_filename = os.path.join(output_directory, confmat_filename)
543+
544+
plot_confusion_matrix(
545+
cnf_matrix,
546+
classes=labels,
547+
title="Intent Confusion matrix",
548+
out=confmat_filename,
549+
)
550+
551+
552+
def _plot_histogram(
553+
output_directory: Optional[Text],
554+
intent_hist_filename: Optional[Text],
555+
intent_results: List[IntentEvaluationResult],
556+
) -> None:
557+
if output_directory:
558+
intent_hist_filename = os.path.join(output_directory, intent_hist_filename)
559+
plot_attribute_confidences(
560+
intent_results, intent_hist_filename, "intent_target", "intent_prediction"
561+
)
562+
563+
547564
def merge_labels(
548565
aligned_predictions: List[Dict], extractor: Optional[Text] = None
549566
) -> np.array:
@@ -1037,6 +1054,7 @@ def run_evaluation(
10371054
confmat: Optional[Text] = None,
10381055
histogram: Optional[Text] = None,
10391056
component_builder: Optional[ComponentBuilder] = None,
1057+
disable_plotting: bool = False,
10401058
) -> Dict: # pragma: no cover
10411059
"""
10421060
Evaluate intent classification, response selection and entity extraction.
@@ -1049,6 +1067,7 @@ def run_evaluation(
10491067
:param confmat: path to file that will show the confusion matrix
10501068
:param histogram: path fo file that will show a histogram
10511069
:param component_builder: component builder
1070+
:param disable_plotting: if true confusion matrix and histogram will not be rendered
10521071
10531072
:return: dictionary containing evaluation results
10541073
"""
@@ -1075,7 +1094,13 @@ def run_evaluation(
10751094
if intent_results:
10761095
logger.info("Intent evaluation results:")
10771096
result["intent_evaluation"] = evaluate_intents(
1078-
intent_results, output_directory, successes, errors, confmat, histogram
1097+
intent_results,
1098+
output_directory,
1099+
successes,
1100+
errors,
1101+
confmat,
1102+
histogram,
1103+
disable_plotting,
10791104
)
10801105

10811106
if response_selection_results:
@@ -1168,6 +1193,7 @@ def cross_validate(
11681193
errors: bool = False,
11691194
confmat: Optional[Text] = None,
11701195
histogram: Optional[Text] = None,
1196+
disable_plotting: bool = False,
11711197
) -> Tuple[CVEvaluationResult, CVEvaluationResult]:
11721198
"""Stratified cross validation on data.
11731199
@@ -1230,7 +1256,13 @@ def cross_validate(
12301256
if intent_classifier_present:
12311257
logger.info("Accumulated test folds intent evaluation results:")
12321258
evaluate_intents(
1233-
intent_test_results, output, successes, errors, confmat, histogram
1259+
intent_test_results,
1260+
output,
1261+
successes,
1262+
errors,
1263+
confmat,
1264+
histogram,
1265+
disable_plotting,
12341266
)
12351267

12361268
if extractors:

tests/cli/test_rasa_test.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ def test_test_core(run_in_default_project: Callable[..., RunResult]):
1212
assert os.path.exists("results")
1313

1414

15+
def test_test_core_no_plot(run_in_default_project: Callable[..., RunResult]):
16+
run_in_default_project("test", "core", "--no-plot")
17+
18+
assert not os.path.exists("results/story_confmat.pdf")
19+
20+
1521
def test_test(run_in_default_project: Callable[..., RunResult]):
1622
run_in_default_project("test")
1723

@@ -20,6 +26,14 @@ def test_test(run_in_default_project: Callable[..., RunResult]):
2026
assert os.path.exists("results/confmat.png")
2127

2228

29+
def test_test_no_plot(run_in_default_project: Callable[..., RunResult]):
30+
run_in_default_project("test", "--no-plot")
31+
32+
assert not os.path.exists("results/hist.png")
33+
assert not os.path.exists("results/confmat.png")
34+
assert not os.path.exists("results/story_confmat.pdf")
35+
36+
2337
def test_test_nlu(run_in_default_project: Callable[..., RunResult]):
2438
run_in_default_project("test", "nlu", "--nlu", "data", "--successes")
2539

@@ -28,6 +42,13 @@ def test_test_nlu(run_in_default_project: Callable[..., RunResult]):
2842
assert os.path.exists("results/intent_successes.json")
2943

3044

45+
def test_test_nlu_no_plot(run_in_default_project: Callable[..., RunResult]):
46+
run_in_default_project("test", "nlu", "--no-plot")
47+
48+
assert not os.path.exists("results/confmat.png")
49+
assert not os.path.exists("results/hist.png")
50+
51+
3152
def test_test_nlu_cross_validation(run_in_default_project: Callable[..., RunResult]):
3253
run_in_default_project(
3354
"test", "nlu", "--cross-validation", "-c", "config.yml", "-f", "2"
@@ -134,7 +155,7 @@ def test_test_help(run: Callable[..., RunResult]):
134155
[--successes] [--no-errors] [--histogram HISTOGRAM]
135156
[--confmat CONFMAT] [-c CONFIG [CONFIG ...]]
136157
[--cross-validation] [-f FOLDS] [-r RUNS]
137-
[-p PERCENTAGES [PERCENTAGES ...]]
158+
[-p PERCENTAGES [PERCENTAGES ...]] [--no-plot]
138159
{core,nlu} ..."""
139160

140161
lines = help_text.split("\n")
@@ -150,7 +171,7 @@ def test_test_nlu_help(run: Callable[..., RunResult]):
150171
[--successes] [--no-errors] [--histogram HISTOGRAM]
151172
[--confmat CONFMAT] [-c CONFIG [CONFIG ...]]
152173
[--cross-validation] [-f FOLDS] [-r RUNS]
153-
[-p PERCENTAGES [PERCENTAGES ...]]"""
174+
[-p PERCENTAGES [PERCENTAGES ...]] [--no-plot]"""
154175

155176
lines = help_text.split("\n")
156177

@@ -165,7 +186,7 @@ def test_test_core_help(run: Callable[..., RunResult]):
165186
[-s STORIES] [--max-stories MAX_STORIES] [--out OUT]
166187
[--e2e] [--endpoints ENDPOINTS]
167188
[--fail-on-prediction-errors] [--url URL]
168-
[--evaluate-model-directory]"""
189+
[--evaluate-model-directory] [--no-plot]"""
169190

170191
lines = help_text.split("\n")
171192

tests/nlu/base/test_evaluation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def test_intent_evaluation_report(tmpdir_factory):
305305
errors=False,
306306
confmat_filename=None,
307307
intent_hist_filename=None,
308+
disable_plotting=False,
308309
)
309310

310311
report = json.loads(rasa.utils.io.read_file(report_filename))
@@ -357,6 +358,7 @@ def incorrect(label: Text, _label: Text) -> IntentEvaluationResult:
357358
errors=False,
358359
confmat_filename=None,
359360
intent_hist_filename=None,
361+
disable_plotting=False,
360362
)
361363

362364
report = json.loads(rasa.utils.io.read_file(str(report_filename)))

0 commit comments

Comments
 (0)