Skip to content

Commit

Permalink
pytorch 1.0 compatibility
Browse files Browse the repository at this point in the history
Mask s in Line 45 needs to be in long format to be usable by torch 1.0
  • Loading branch information
StefOe authored Jan 14, 2019
1 parent 7caeb83 commit bd9e989
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/ode_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def forward(self, t, y):


def get_batch():
s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time), args.batch_size, replace=False))
s = torch.from_numpy(np.random.choice(np.arange(args.data_size - args.batch_time, dtype=np.int64), args.batch_size, replace=False))
batch_y0 = true_y[s] # (M, D)
batch_t = t[:args.batch_time] # (T)
batch_y = torch.stack([true_y[s + i] for i in range(args.batch_time)], dim=0) # (T, M, D)
Expand Down

0 comments on commit bd9e989

Please sign in to comment.