Skip to content

Commit 7bf37cf

Browse files
committed
refactor evaluation function creation
1 parent 935537c commit 7bf37cf

File tree

1 file changed

+17
-24
lines changed

1 file changed

+17
-24
lines changed

rasa/utils/tensorflow/models.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,10 @@ def fit(
7373
train_dataset_function,
7474
tf_train_on_batch_function,
7575
) = self._get_tf_train_functions(eager, model_data, batch_strategy)
76-
7776
(
7877
evaluation_dataset_function,
7978
tf_evaluation_on_batch_function,
80-
) = self._get_tf_evaluation_functions(
81-
eager, evaluate_on_num_examples, evaluation_model_data
82-
)
79+
) = self._get_tf_evaluation_functions(eager, evaluation_model_data)
8380

8481
val_results = {} # validation is not performed every epoch
8582
pbar = tqdm(range(epochs), desc="Epochs", disable=disable)
@@ -235,32 +232,28 @@ def train_dataset_function(_batch_size: int) -> tf.data.Dataset:
235232
)
236233

237234
def _get_tf_evaluation_functions(
238-
self,
239-
eager: bool,
240-
evaluate_on_num_examples: int,
241-
evaluation_model_data: RasaModelData,
235+
self, eager: bool, evaluation_model_data: Optional[RasaModelData],
242236
) -> Tuple[Optional[Callable], Optional[Callable]]:
243237
"""Create evaluation tensorflow functions"""
244238

245-
if evaluate_on_num_examples > 0:
246-
247-
def evaluation_dataset_function(_batch_size: int) -> tf.data.Dataset:
248-
return evaluation_model_data.as_tf_dataset(
249-
_batch_size, "sequence", shuffle=False
250-
)
239+
if evaluation_model_data is None:
240+
return None, None
251241

252-
self._training = False # needed for tf graph mode
253-
return (
254-
evaluation_dataset_function,
255-
self._get_tf_call_model_function(
256-
evaluation_dataset_function,
257-
self._total_batch_loss,
258-
eager,
259-
"evaluation",
260-
),
242+
def evaluation_dataset_function(_batch_size: int) -> tf.data.Dataset:
243+
return evaluation_model_data.as_tf_dataset(
244+
_batch_size, "sequence", shuffle=False
261245
)
262246

263-
return None, None
247+
self._training = False # needed for tf graph mode
248+
return (
249+
evaluation_dataset_function,
250+
self._get_tf_call_model_function(
251+
evaluation_dataset_function,
252+
self._total_batch_loss,
253+
eager,
254+
"evaluation",
255+
),
256+
)
264257

265258
def _get_metric_results(self, prefix: Optional[Text] = None) -> Dict[Text, Text]:
266259
"""Get the metrics results"""

0 commit comments

Comments
 (0)