@@ -70,38 +70,30 @@ def fit(
70
70
)
71
71
72
72
(
73
- tf_train_dataset_function ,
73
+ train_dataset_function ,
74
74
tf_train_on_batch_function ,
75
75
) = self ._get_tf_train_functions (eager , model_data , batch_strategy )
76
-
77
76
(
78
- tf_evaluation_dataset_function ,
77
+ evaluation_dataset_function ,
79
78
tf_evaluation_on_batch_function ,
80
- ) = self ._get_tf_evaluation_functions (
81
- eager , evaluate_on_num_examples , evaluation_model_data
82
- )
79
+ ) = self ._get_tf_evaluation_functions (eager , evaluation_model_data )
83
80
84
81
val_results = {} # validation is not performed every epoch
85
82
pbar = tqdm (range (epochs ), desc = "Epochs" , disable = disable )
86
83
87
84
for ep in pbar :
88
85
ep_batch_size = self .linearly_increasing_batch_size (ep , batch_size , epochs )
89
- if not eager :
90
- ep_batch_size *= tf .ones ((), tf .int32 )
91
86
92
87
self ._batch_loop (
93
- tf_train_dataset_function ,
94
- tf_train_on_batch_function ,
95
- ep_batch_size ,
96
- True ,
88
+ train_dataset_function , tf_train_on_batch_function , ep_batch_size , True
97
89
)
98
90
99
91
postfix_dict = self ._get_metric_results ()
100
92
101
93
if evaluate_on_num_examples > 0 :
102
94
if self ._should_evaluate (evaluate_every_num_epochs , epochs , ep ):
103
95
self ._batch_loop (
104
- tf_evaluation_dataset_function ,
96
+ evaluation_dataset_function ,
105
97
tf_evaluation_on_batch_function ,
106
98
ep_batch_size ,
107
99
False ,
@@ -130,14 +122,9 @@ def train_on_batch(
130
122
def build_for_predict (
131
123
self , predict_data : RasaModelData , eager : bool = False
132
124
) -> None :
133
- def predict_dataset_function ( # to reuse the same helper method
134
- _batch_size : Union [tf .Tensor , int ]
135
- ) -> tf .data .Dataset :
136
- return predict_data .as_tf_dataset (_batch_size , "sequence" , shuffle = False )
137
-
138
125
self ._training = False # needed for tf graph mode
139
- _ , self ._predict_function = self ._get_tf_functions (
140
- predict_dataset_function , self .batch_predict , eager , "prediction"
126
+ self ._predict_function = self ._get_tf_call_model_function (
127
+ predict_data . as_tf_dataset , self .batch_predict , eager , "prediction"
141
128
)
142
129
143
130
def predict (self , predict_data : RasaModelData ) -> Dict [Text , tf .Tensor ]:
@@ -194,7 +181,7 @@ def _batch_loop(
194
181
self ,
195
182
dataset_function : Callable ,
196
183
call_model_function : Callable ,
197
- batch_size : Union [ tf . Tensor , int ] ,
184
+ batch_size : int ,
198
185
training : bool ,
199
186
) -> None :
200
187
"""Run on batches"""
@@ -205,70 +192,68 @@ def _batch_loop(
205
192
call_model_function (batch_in )
206
193
207
194
@staticmethod
208
- def _get_tf_functions (
195
+ def _get_tf_call_model_function (
209
196
dataset_function : Callable ,
210
197
call_model_function : Callable ,
211
198
eager : bool ,
212
199
phase : Text ,
213
- ) -> Tuple [ Callable , Callable ] :
200
+ ) -> Callable :
214
201
"""Convert functions to tensorflow functions"""
215
202
216
203
if eager :
217
- return dataset_function , call_model_function
204
+ return call_model_function
218
205
219
206
logger .debug (f"Building tensorflow { phase } graph..." )
220
- # allows increasing batch size
221
- tf_dataset_function = tf .function (func = dataset_function )
222
207
223
- init_dataset = tf_dataset_function (tf .ones ((), tf .int32 ))
224
-
225
- tf_method_function = tf .function (
208
+ init_dataset = dataset_function (1 )
209
+ tf_call_model_function = tf .function (
226
210
call_model_function , input_signature = [init_dataset .element_spec ]
227
211
)
228
- tf_method_function (next (iter (init_dataset )))
212
+ tf_call_model_function (next (iter (init_dataset )))
229
213
230
214
logger .debug (f"Finished building tensorflow { phase } graph." )
231
215
232
- return tf_dataset_function , tf_method_function
216
+ return tf_call_model_function
233
217
234
218
def _get_tf_train_functions (
235
219
self , eager : bool , model_data : RasaModelData , batch_strategy : Text
236
220
) -> Tuple [Callable , Callable ]:
237
221
"""Create train tensorflow functions"""
238
222
239
- def train_dataset_function (
240
- _batch_size : Union [tf .Tensor , int ]
241
- ) -> tf .data .Dataset :
223
+ def train_dataset_function (_batch_size : int ) -> tf .data .Dataset :
242
224
return model_data .as_tf_dataset (_batch_size , batch_strategy , shuffle = True )
243
225
244
226
self ._training = True # needed for tf graph mode
245
- return self ._get_tf_functions (
246
- train_dataset_function , self .train_on_batch , eager , "train"
227
+ return (
228
+ train_dataset_function ,
229
+ self ._get_tf_call_model_function (
230
+ train_dataset_function , self .train_on_batch , eager , "train"
231
+ ),
247
232
)
248
233
249
234
def _get_tf_evaluation_functions (
250
- self ,
251
- eager : bool ,
252
- evaluate_on_num_examples : int ,
253
- evaluation_model_data : RasaModelData ,
235
+ self , eager : bool , evaluation_model_data : Optional [RasaModelData ],
254
236
) -> Tuple [Optional [Callable ], Optional [Callable ]]:
255
237
"""Create evaluation tensorflow functions"""
256
238
257
- if evaluate_on_num_examples > 0 :
258
-
259
- def evaluation_dataset_function (
260
- _batch_size : Union [tf .Tensor , int ]
261
- ) -> tf .data .Dataset :
262
- return evaluation_model_data .as_tf_dataset (
263
- _batch_size , "sequence" , shuffle = False
264
- )
239
+ if evaluation_model_data is None :
240
+ return None , None
265
241
266
- self . _training = False # needed for tf graph mode
267
- return self . _get_tf_functions (
268
- evaluation_dataset_function , self . _total_batch_loss , eager , "evaluation"
242
+ def evaluation_dataset_function ( _batch_size : int ) -> tf . data . Dataset :
243
+ return evaluation_model_data . as_tf_dataset (
244
+ _batch_size , "sequence" , shuffle = False
269
245
)
270
246
271
- return None , None
247
+ self ._training = False # needed for tf graph mode
248
+ return (
249
+ evaluation_dataset_function ,
250
+ self ._get_tf_call_model_function (
251
+ evaluation_dataset_function ,
252
+ self ._total_batch_loss ,
253
+ eager ,
254
+ "evaluation" ,
255
+ ),
256
+ )
272
257
273
258
def _get_metric_results (self , prefix : Optional [Text ] = None ) -> Dict [Text , Text ]:
274
259
"""Get the metrics results"""
0 commit comments