@@ -73,13 +73,10 @@ def fit(
73
73
train_dataset_function ,
74
74
tf_train_on_batch_function ,
75
75
) = self ._get_tf_train_functions (eager , model_data , batch_strategy )
76
-
77
76
(
78
77
evaluation_dataset_function ,
79
78
tf_evaluation_on_batch_function ,
80
- ) = self ._get_tf_evaluation_functions (
81
- eager , evaluate_on_num_examples , evaluation_model_data
82
- )
79
+ ) = self ._get_tf_evaluation_functions (eager , evaluation_model_data )
83
80
84
81
val_results = {} # validation is not performed every epoch
85
82
pbar = tqdm (range (epochs ), desc = "Epochs" , disable = disable )
@@ -235,32 +232,28 @@ def train_dataset_function(_batch_size: int) -> tf.data.Dataset:
235
232
)
236
233
237
234
def _get_tf_evaluation_functions (
238
- self ,
239
- eager : bool ,
240
- evaluate_on_num_examples : int ,
241
- evaluation_model_data : RasaModelData ,
235
+ self , eager : bool , evaluation_model_data : Optional [RasaModelData ],
242
236
) -> Tuple [Optional [Callable ], Optional [Callable ]]:
243
237
"""Create evaluation tensorflow functions"""
244
238
245
- if evaluate_on_num_examples > 0 :
246
-
247
- def evaluation_dataset_function (_batch_size : int ) -> tf .data .Dataset :
248
- return evaluation_model_data .as_tf_dataset (
249
- _batch_size , "sequence" , shuffle = False
250
- )
239
+ if evaluation_model_data is None :
240
+ return None , None
251
241
252
- self ._training = False # needed for tf graph mode
253
- return (
254
- evaluation_dataset_function ,
255
- self ._get_tf_call_model_function (
256
- evaluation_dataset_function ,
257
- self ._total_batch_loss ,
258
- eager ,
259
- "evaluation" ,
260
- ),
242
+ def evaluation_dataset_function (_batch_size : int ) -> tf .data .Dataset :
243
+ return evaluation_model_data .as_tf_dataset (
244
+ _batch_size , "sequence" , shuffle = False
261
245
)
262
246
263
- return None , None
247
+ self ._training = False # needed for tf graph mode
248
+ return (
249
+ evaluation_dataset_function ,
250
+ self ._get_tf_call_model_function (
251
+ evaluation_dataset_function ,
252
+ self ._total_batch_loss ,
253
+ eager ,
254
+ "evaluation" ,
255
+ ),
256
+ )
264
257
265
258
def _get_metric_results (self , prefix : Optional [Text ] = None ) -> Dict [Text , Text ]:
266
259
"""Get the metrics results"""
0 commit comments