Skip to content

Commit

Permalink
Add weights to control different types of losses in debug tools
Browse files Browse the repository at this point in the history
Summary: Add the weights to enable feature importance analysis on specific combination of next state, reward and terminal prediction

Reviewed By: kittipatv

Differential Revision: D14517213

fbshipit-source-id: b4d7306f29093ef89923de6889f91482eafd4339
  • Loading branch information
czxttkl authored and facebook-github-bot committed Mar 27, 2019
1 parent 4376efb commit ccddfe4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
9 changes: 8 additions & 1 deletion ml/rl/thrift/core.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,12 @@ struct MDNRNNParameters {
2: i32 num_hidden_layers = 2,
3: i32 minibatch_size = 16,
4: double learning_rate = 0.001,
5: i32 num_gaussians = 5
5: i32 num_gaussians = 5,
6: double train_data_percentage = 60.0,
7: double validation_data_percentage = 20.0,
8: double test_data_percentage = 20.0,
# weight in calculating world-model loss
9: double reward_loss_weight = 1.0,
10: double next_state_loss_weight = 1.0,
11: double not_terminal_loss_weight = 1.0,
}
12 changes: 9 additions & 3 deletions ml/rl/training/world_model/mdnrnn_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,15 @@ def get_loss(
mdnrnn_output.not_terminal,
)

gmm = gmm_loss(learning_input.next_state, mus, sigmas, logpi)
bce = F.binary_cross_entropy_with_logits(ds, learning_input.not_terminal)
mse = F.mse_loss(rs, learning_input.reward)
gmm = (
gmm_loss(learning_input.next_state, mus, sigmas, logpi)
* self.params.next_state_loss_weight
)
bce = (
F.binary_cross_entropy_with_logits(ds, learning_input.not_terminal)
* self.params.not_terminal_loss_weight
)
mse = F.mse_loss(rs, learning_input.reward) * self.params.reward_loss_weight
if state_dim is not None:
loss = (gmm + bce + mse) / (state_dim + 2)
else:
Expand Down

0 comments on commit ccddfe4

Please sign in to comment.