@@ -441,6 +441,7 @@ def evaluate_intents(
441
441
errors : bool ,
442
442
confmat_filename : Optional [Text ],
443
443
intent_hist_filename : Optional [Text ],
444
+ disable_plotting : bool ,
444
445
) -> Dict : # pragma: no cover
445
446
"""Creates a confusion matrix and summary statistics for intent predictions.
446
447
@@ -504,26 +505,13 @@ def evaluate_intents(
504
505
# log and save misclassified samples to file for debugging
505
506
collect_nlu_errors (intent_results , errors_filename )
506
507
507
- if confmat_filename :
508
- import matplotlib .pyplot as plt
509
-
510
- if output_directory :
511
- confmat_filename = os .path .join (output_directory , confmat_filename )
512
- intent_hist_filename = os .path .join (output_directory , intent_hist_filename )
513
-
514
- plot_confusion_matrix (
515
- cnf_matrix ,
516
- classes = labels ,
517
- title = "Intent Confusion matrix" ,
518
- out = confmat_filename ,
519
- )
520
- plt .show (block = False )
521
-
522
- plot_attribute_confidences (
523
- intent_results , intent_hist_filename , "intent_target" , "intent_prediction"
524
- )
525
-
526
- plt .show (block = False )
508
+ if not disable_plotting :
509
+ if confmat_filename :
510
+ _plot_confusion_matrix (
511
+ output_directory , confmat_filename , cnf_matrix , labels
512
+ )
513
+ if intent_hist_filename :
514
+ _plot_histogram (output_directory , intent_hist_filename , intent_results )
527
515
528
516
predictions = [
529
517
{
@@ -544,6 +532,35 @@ def evaluate_intents(
544
532
}
545
533
546
534
535
+ def _plot_confusion_matrix (
536
+ output_directory : Optional [Text ],
537
+ confmat_filename : Optional [Text ],
538
+ cnf_matrix : np .array ,
539
+ labels : Collection [Text ],
540
+ ) -> None :
541
+ if output_directory :
542
+ confmat_filename = os .path .join (output_directory , confmat_filename )
543
+
544
+ plot_confusion_matrix (
545
+ cnf_matrix ,
546
+ classes = labels ,
547
+ title = "Intent Confusion matrix" ,
548
+ out = confmat_filename ,
549
+ )
550
+
551
+
552
+ def _plot_histogram (
553
+ output_directory : Optional [Text ],
554
+ intent_hist_filename : Optional [Text ],
555
+ intent_results : List [IntentEvaluationResult ],
556
+ ) -> None :
557
+ if output_directory :
558
+ intent_hist_filename = os .path .join (output_directory , intent_hist_filename )
559
+ plot_attribute_confidences (
560
+ intent_results , intent_hist_filename , "intent_target" , "intent_prediction"
561
+ )
562
+
563
+
547
564
def merge_labels (
548
565
aligned_predictions : List [Dict ], extractor : Optional [Text ] = None
549
566
) -> np .array :
@@ -1037,6 +1054,7 @@ def run_evaluation(
1037
1054
confmat : Optional [Text ] = None ,
1038
1055
histogram : Optional [Text ] = None ,
1039
1056
component_builder : Optional [ComponentBuilder ] = None ,
1057
+ disable_plotting : bool = False ,
1040
1058
) -> Dict : # pragma: no cover
1041
1059
"""
1042
1060
Evaluate intent classification, response selection and entity extraction.
@@ -1049,6 +1067,7 @@ def run_evaluation(
1049
1067
:param confmat: path to file that will show the confusion matrix
1050
1068
:param histogram: path fo file that will show a histogram
1051
1069
:param component_builder: component builder
1070
+ :param disable_plotting: if true confusion matrix and histogram will not be rendered
1052
1071
1053
1072
:return: dictionary containing evaluation results
1054
1073
"""
@@ -1075,7 +1094,13 @@ def run_evaluation(
1075
1094
if intent_results :
1076
1095
logger .info ("Intent evaluation results:" )
1077
1096
result ["intent_evaluation" ] = evaluate_intents (
1078
- intent_results , output_directory , successes , errors , confmat , histogram
1097
+ intent_results ,
1098
+ output_directory ,
1099
+ successes ,
1100
+ errors ,
1101
+ confmat ,
1102
+ histogram ,
1103
+ disable_plotting ,
1079
1104
)
1080
1105
1081
1106
if response_selection_results :
@@ -1168,6 +1193,7 @@ def cross_validate(
1168
1193
errors : bool = False ,
1169
1194
confmat : Optional [Text ] = None ,
1170
1195
histogram : Optional [Text ] = None ,
1196
+ disable_plotting : bool = False ,
1171
1197
) -> Tuple [CVEvaluationResult , CVEvaluationResult ]:
1172
1198
"""Stratified cross validation on data.
1173
1199
@@ -1230,7 +1256,13 @@ def cross_validate(
1230
1256
if intent_classifier_present :
1231
1257
logger .info ("Accumulated test folds intent evaluation results:" )
1232
1258
evaluate_intents (
1233
- intent_test_results , output , successes , errors , confmat , histogram
1259
+ intent_test_results ,
1260
+ output ,
1261
+ successes ,
1262
+ errors ,
1263
+ confmat ,
1264
+ histogram ,
1265
+ disable_plotting ,
1234
1266
)
1235
1267
1236
1268
if extractors :
0 commit comments