Skip to content

Commit 65ad5d9

Browse files
authored
Merge pull request RasaHQ#5240 from RasaHQ/tf2-fix
don't create tf function for dataset
2 parents 4694f60 + 7bf37cf commit 65ad5d9

File tree

2 files changed

+38
-56
lines changed

2 files changed

+38
-56
lines changed

rasa/utils/tensorflow/model_data.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,7 @@ def batch_tuple_sizes(self) -> Dict[Text, int]:
304304
return tuple_sizes
305305

306306
def as_tf_dataset(
307-
self,
308-
batch_size: Union[tf.Tensor, int],
309-
batch_strategy: Text = "sequence",
310-
shuffle: bool = False,
307+
self, batch_size: int, batch_strategy: Text = "sequence", shuffle: bool = False
311308
) -> tf.data.Dataset:
312309
"""Create tf dataset."""
313310

rasa/utils/tensorflow/models.py

Lines changed: 37 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -70,38 +70,30 @@ def fit(
7070
)
7171

7272
(
73-
tf_train_dataset_function,
73+
train_dataset_function,
7474
tf_train_on_batch_function,
7575
) = self._get_tf_train_functions(eager, model_data, batch_strategy)
76-
7776
(
78-
tf_evaluation_dataset_function,
77+
evaluation_dataset_function,
7978
tf_evaluation_on_batch_function,
80-
) = self._get_tf_evaluation_functions(
81-
eager, evaluate_on_num_examples, evaluation_model_data
82-
)
79+
) = self._get_tf_evaluation_functions(eager, evaluation_model_data)
8380

8481
val_results = {} # validation is not performed every epoch
8582
pbar = tqdm(range(epochs), desc="Epochs", disable=disable)
8683

8784
for ep in pbar:
8885
ep_batch_size = self.linearly_increasing_batch_size(ep, batch_size, epochs)
89-
if not eager:
90-
ep_batch_size *= tf.ones((), tf.int32)
9186

9287
self._batch_loop(
93-
tf_train_dataset_function,
94-
tf_train_on_batch_function,
95-
ep_batch_size,
96-
True,
88+
train_dataset_function, tf_train_on_batch_function, ep_batch_size, True
9789
)
9890

9991
postfix_dict = self._get_metric_results()
10092

10193
if evaluate_on_num_examples > 0:
10294
if self._should_evaluate(evaluate_every_num_epochs, epochs, ep):
10395
self._batch_loop(
104-
tf_evaluation_dataset_function,
96+
evaluation_dataset_function,
10597
tf_evaluation_on_batch_function,
10698
ep_batch_size,
10799
False,
@@ -130,14 +122,9 @@ def train_on_batch(
130122
def build_for_predict(
131123
self, predict_data: RasaModelData, eager: bool = False
132124
) -> None:
133-
def predict_dataset_function( # to reuse the same helper method
134-
_batch_size: Union[tf.Tensor, int]
135-
) -> tf.data.Dataset:
136-
return predict_data.as_tf_dataset(_batch_size, "sequence", shuffle=False)
137-
138125
self._training = False # needed for tf graph mode
139-
_, self._predict_function = self._get_tf_functions(
140-
predict_dataset_function, self.batch_predict, eager, "prediction"
126+
self._predict_function = self._get_tf_call_model_function(
127+
predict_data.as_tf_dataset, self.batch_predict, eager, "prediction"
141128
)
142129

143130
def predict(self, predict_data: RasaModelData) -> Dict[Text, tf.Tensor]:
@@ -194,7 +181,7 @@ def _batch_loop(
194181
self,
195182
dataset_function: Callable,
196183
call_model_function: Callable,
197-
batch_size: Union[tf.Tensor, int],
184+
batch_size: int,
198185
training: bool,
199186
) -> None:
200187
"""Run on batches"""
@@ -205,70 +192,68 @@ def _batch_loop(
205192
call_model_function(batch_in)
206193

207194
@staticmethod
208-
def _get_tf_functions(
195+
def _get_tf_call_model_function(
209196
dataset_function: Callable,
210197
call_model_function: Callable,
211198
eager: bool,
212199
phase: Text,
213-
) -> Tuple[Callable, Callable]:
200+
) -> Callable:
214201
"""Convert functions to tensorflow functions"""
215202

216203
if eager:
217-
return dataset_function, call_model_function
204+
return call_model_function
218205

219206
logger.debug(f"Building tensorflow {phase} graph...")
220-
# allows increasing batch size
221-
tf_dataset_function = tf.function(func=dataset_function)
222207

223-
init_dataset = tf_dataset_function(tf.ones((), tf.int32))
224-
225-
tf_method_function = tf.function(
208+
init_dataset = dataset_function(1)
209+
tf_call_model_function = tf.function(
226210
call_model_function, input_signature=[init_dataset.element_spec]
227211
)
228-
tf_method_function(next(iter(init_dataset)))
212+
tf_call_model_function(next(iter(init_dataset)))
229213

230214
logger.debug(f"Finished building tensorflow {phase} graph.")
231215

232-
return tf_dataset_function, tf_method_function
216+
return tf_call_model_function
233217

234218
def _get_tf_train_functions(
235219
self, eager: bool, model_data: RasaModelData, batch_strategy: Text
236220
) -> Tuple[Callable, Callable]:
237221
"""Create train tensorflow functions"""
238222

239-
def train_dataset_function(
240-
_batch_size: Union[tf.Tensor, int]
241-
) -> tf.data.Dataset:
223+
def train_dataset_function(_batch_size: int) -> tf.data.Dataset:
242224
return model_data.as_tf_dataset(_batch_size, batch_strategy, shuffle=True)
243225

244226
self._training = True # needed for tf graph mode
245-
return self._get_tf_functions(
246-
train_dataset_function, self.train_on_batch, eager, "train"
227+
return (
228+
train_dataset_function,
229+
self._get_tf_call_model_function(
230+
train_dataset_function, self.train_on_batch, eager, "train"
231+
),
247232
)
248233

249234
def _get_tf_evaluation_functions(
250-
self,
251-
eager: bool,
252-
evaluate_on_num_examples: int,
253-
evaluation_model_data: RasaModelData,
235+
self, eager: bool, evaluation_model_data: Optional[RasaModelData],
254236
) -> Tuple[Optional[Callable], Optional[Callable]]:
255237
"""Create evaluation tensorflow functions"""
256238

257-
if evaluate_on_num_examples > 0:
258-
259-
def evaluation_dataset_function(
260-
_batch_size: Union[tf.Tensor, int]
261-
) -> tf.data.Dataset:
262-
return evaluation_model_data.as_tf_dataset(
263-
_batch_size, "sequence", shuffle=False
264-
)
239+
if evaluation_model_data is None:
240+
return None, None
265241

266-
self._training = False # needed for tf graph mode
267-
return self._get_tf_functions(
268-
evaluation_dataset_function, self._total_batch_loss, eager, "evaluation"
242+
def evaluation_dataset_function(_batch_size: int) -> tf.data.Dataset:
243+
return evaluation_model_data.as_tf_dataset(
244+
_batch_size, "sequence", shuffle=False
269245
)
270246

271-
return None, None
247+
self._training = False # needed for tf graph mode
248+
return (
249+
evaluation_dataset_function,
250+
self._get_tf_call_model_function(
251+
evaluation_dataset_function,
252+
self._total_batch_loss,
253+
eager,
254+
"evaluation",
255+
),
256+
)
272257

273258
def _get_metric_results(self, prefix: Optional[Text] = None) -> Dict[Text, Text]:
274259
"""Get the metrics results"""

0 commit comments

Comments
 (0)