Skip to content

Commit

Permalink
Fix rsample bug (jankrepl#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
jankrepl authored Dec 7, 2020
1 parent b56e63d commit 75d190c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion deepdow/layers/allocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def forward(self, matrix, rets=None, **kwargs):
portfolios = [] # n_portfolios elements of (n_samples, n_assets)

for _ in range(self.n_portfolios):
draws = dist.sample((n_draws,)) # (n_draws, n_samples, n_assets)
draws = dist.rsample((n_draws,)) # (n_draws, n_samples, n_assets)
rets_ = draws.mean(dim=0) if rets is not None else None # (n_samples, n_assets)
covmat_ = CovarianceMatrix(sqrt=self.uses_sqrt)(draws.permute(1, 0, 2)) # (n_samples, n_assets, ...)

Expand Down
12 changes: 11 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def test_basic(self, dtype_device, allocator_class, random_state):
single_ = torch.rand(n_assets, n_assets, dtype=dtype, device=device)
single = single_ @ single_.t()
covmat = torch.stack([single for _ in range(n_samples)], dim=0)
rets = torch.rand(n_samples, n_assets, dtype=dtype, device=device)
rets = torch.rand(n_samples, n_assets, dtype=dtype, device=device, requires_grad=True)

if allocator_class.__name__ == 'AnalyticalMarkowitz':
allocator = allocator_class()
Expand Down Expand Up @@ -440,6 +440,16 @@ def test_basic(self, dtype_device, allocator_class, random_state):
else:
assert torch.allclose(weights_1, weights_2)

# Make sure one can run backward pass (just sum the weights to get a scalar)
some_loss = weights_1.sum()

assert rets.grad is None

some_loss.backward()

assert rets.grad is not None
assert single.grad is None


class TestRNN:
@pytest.mark.parametrize('bidirectional', [True, False], ids=['bidir', 'onedir'])
Expand Down

0 comments on commit 75d190c

Please sign in to comment.