Skip to content

Commit 6b89d13

Browse files
committed
Use output_dict
1 parent 00e4824 commit 6b89d13

File tree

4 files changed

+17
-48
lines changed

4 files changed

+17
-48
lines changed

alt_requirements/requirements_spacy_sklearn.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
-r requirements_bare.txt
33

44
spacy==2.0.18
5-
scikit-learn==0.19.1
5+
scikit-learn==0.20.2
66
scipy==1.1.0
77
sklearn-crfsuite==0.3.6
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Minimum Install Requirements
22
-r requirements_bare.txt
33

4-
scikit-learn==0.19.1
4+
scikit-learn==0.20.2
55
tensorflow==1.12.0
66
scipy==1.1.0
77
sklearn-crfsuite==0.3.6

rasa_nlu/classifiers/sklearn_intent_classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,10 @@ def process(self, message, **kwargs):
178178
else:
179179
X = message.get("text_features").reshape(1, -1)
180180
intent_ids, probabilities = self.predict(X)
181-
intents = self.transform_labels_num2str(intent_ids)
181+
intents = self.transform_labels_num2str(np.ravel(intent_ids))
182182
# `predict` returns a matrix as it is supposed
183183
# to work for multiple examples as well, hence we need to flatten
184-
intents, probabilities = intents.flatten(), probabilities.flatten()
184+
probabilities = probabilities.flatten()
185185

186186
if intents.size > 0 and probabilities.size > 0:
187187
ranking = list(zip(list(intents),

rasa_nlu/evaluate.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,15 @@ def log_evaluation_table(report, # type: Text
171171
logger.info("Classification report: \n{}".format(report))
172172

173173

174-
def get_evaluation_metrics(targets, predictions): # pragma: no cover
174+
def get_evaluation_metrics(targets, predictions, output_dict=False): # pragma: no cover
175175
"""Compute the f1, precision, accuracy and summary report from sklearn."""
176176
from sklearn import metrics
177177

178178
targets = clean_intent_labels(targets)
179179
predictions = clean_intent_labels(predictions)
180180

181-
report = metrics.classification_report(targets, predictions)
181+
report = metrics.classification_report(targets, predictions,
182+
output_dict=output_dict)
182183
precision = metrics.precision_score(targets, predictions,
183184
average='weighted')
184185
f1 = metrics.f1_score(targets, predictions, average='weighted')
@@ -187,43 +188,6 @@ def get_evaluation_metrics(targets, predictions): # pragma: no cover
187188
return report, precision, f1, accuracy
188189

189190

190-
def report_to_dict(report, f1, precision, accuracy):
191-
"""Convert sklearn metrics report into dict"""
192-
193-
report_dict = {
194-
'f1': f1,
195-
'precision': precision,
196-
'accuracy': accuracy,
197-
'intents': []
198-
}
199-
200-
lines = list(filter(None, report.split('\n')))
201-
labels = lines[0].split()
202-
203-
report_dict['intents'] = report_row_to_dict(labels, lines[1:-1])
204-
205-
return report_dict
206-
207-
208-
def report_row_to_dict(labels, lines):
209-
"""Convert sklearn metrics report row to dict"""
210-
import re
211-
212-
array = []
213-
for line in lines:
214-
row_data = re.split('\s{2,}', line.strip())
215-
name = row_data[0]
216-
values = row_data[1:]
217-
r = {
218-
'name': name
219-
}
220-
for i in range(len(values)):
221-
r[labels[i]] = values[i]
222-
array.append(r)
223-
224-
return array
225-
226-
227191
def remove_empty_intent_examples(intent_results):
228192
"""Remove those examples without an intent."""
229193

@@ -343,16 +307,21 @@ def evaluate_intents(intent_results,
343307
"of {} examples".format(len(intent_results), num_examples))
344308

345309
targets, predictions = _targets_predictions_from(intent_results)
346-
report, precision, f1, accuracy = get_evaluation_metrics(targets,
347-
predictions)
348-
349-
log_evaluation_table(report, precision, f1, accuracy)
350310

351311
if report_filename:
352-
save_json(report_to_dict(report, f1, precision, accuracy), report_filename)
312+
report, precision, f1, accuracy = get_evaluation_metrics(targets,
313+
predictions,
314+
output_dict=True)
315+
316+
save_json(report, report_filename)
353317
logger.info("Classification report saved to {}."
354318
.format(report_filename))
355319

320+
else:
321+
report, precision, f1, accuracy = get_evaluation_metrics(targets,
322+
predictions)
323+
log_evaluation_table(report, precision, f1, accuracy)
324+
356325
if successes_filename:
357326
# save classified samples to file for debugging
358327
collect_nlu_successes(intent_results, successes_filename)

0 commit comments

Comments
 (0)