Skip to content

Commit

Permalink
Update TFT
Browse files Browse the repository at this point in the history
  • Loading branch information
v-blin committed Nov 28, 2020
1 parent 30ab4a8 commit fdf0f9a
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 11 deletions.
12 changes: 2 additions & 10 deletions examples/benchmarks/TFT/libs/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,12 +721,7 @@ def _build_base_graph(self):
encoder_steps = self.num_encoder_steps

# Inputs.
all_inputs = tf.keras.layers.Input(
shape=(
time_steps,
combined_input_size,
)
)
all_inputs = tf.keras.layers.Input(shape=(time_steps, combined_input_size,))

unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs)

Expand Down Expand Up @@ -866,10 +861,7 @@ def get_lstm(return_state):
"""Returns LSTM cell initialized with default parameters."""
if self.use_cudnn:
lstm = tf.keras.layers.CuDNNLSTM(
self.hidden_layer_size,
return_sequences=True,
return_state=return_state,
stateful=False,
self.hidden_layer_size, return_sequences=True, return_state=return_state, stateful=False,
)
else:
lstm = tf.keras.layers.LSTM(
Expand Down
2 changes: 1 addition & 1 deletion examples/benchmarks/TFT/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def predict(self, dataset):

predict50 = format_score(p50_forecast, "pred", 1)
predict90 = format_score(p90_forecast, "pred", 1)
predict = (predict50 + predict90)/2 # self.label_shift
predict = (predict50 + predict90) / 2 # self.label_shift
# ===========================Predicting Process===========================
return predict

Expand Down

0 comments on commit fdf0f9a

Please sign in to comment.