Skip to content

Commit ddf9bbd

Browse files
committed
removed call to generator in predict
1 parent d4ebe7b commit ddf9bbd

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

rasa/utils/tensorflow/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,9 @@ def predict(self, predict_data: RasaModelData) -> Dict[Text, tf.Tensor]:
206206
logger.debug("There is no tensorflow prediction graph.")
207207
self.build_for_predict(predict_data)
208208

209-
predict_dataset = predict_data.as_tf_dataset(batch_size=1)
210-
batch_in = next(iter(predict_dataset))
209+
# predict_dataset = predict_data.as_tf_dataset(batch_size=1)
210+
batch_in = predict_data.prepare_batch(predict_data.data, 0, 1)
211+
# batch_in = next(iter(predict_dataset))
211212

212213
self._training = False # needed for eager mode
213214
return self._predict_function(batch_in)

0 commit comments

Comments
 (0)