Skip to content

Commit

Permalink
Update nn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dsman1823 authored Apr 13, 2024
1 parent cf822b1 commit 9f8b4f4
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions deepdowmine/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,40 @@
from .layers.misc import Cov2Corr, CovarianceMatrix, KMeans


class RnnNetFullOpti3(torch.nn.Module, Benchmark):
def __init__(self, n_assets, p, shrinkage_strategy, max_weight):
super().__init__()

self.norm_layer = torch.nn.BatchNorm2d(1, affine=True)
self.dropout_layer = torch.nn.Dropout(p=p)
self.transform_layer = torch.nn.RNN(
input_size=n_assets,
hidden_size=n_assets,
batch_first=True,
)

self.linear_for_cov = torch.nn.Linear(250, 250, bias=True)
self.covariance_layer = CovarianceMatrix(sqrt=True, shrinkage_strategy=shrinkage_strategy)
self.gamma = torch.nn.Parameter(torch.ones(1), requires_grad=True)
self.alpha = torch.nn.Parameter(torch.ones(1), requires_grad=True)
self.portfolio_layer = ThesisMarkowitzFullOpti(n_assets, max_weight=max_weight)

def forward(self, x):
n_samples, _, _, _ = x.shape
x = self.norm_layer(x)
x = x.squeeze(1) # Removes the channel dimension if it's singular, adjust accordingly
output, hidden = self.transform_layer(x)
x = self.dropout_layer(output)
x = self.linear_for_cov(x.reshape(n_samples, -1)) # Reshape to (n_samples, -1) for Linear layer
x = F.relu(x)
covmat_sqrt = self.covariance_layer(x.reshape(n_samples, 50, 5))
exp_rets = hidden[0] # (n_samples, n_assets)
gamma_all = (torch.ones(len(exp_rets)).to(device=exp_rets.device, dtype=exp_rets.dtype) * self.gamma)
alpha_all = (torch.ones(len(exp_rets)).to(device=exp_rets.device, dtype=exp_rets.dtype) * self.alpha)
weights = self.portfolio_layer(exp_rets, covmat_sqrt, gamma_all, alpha_all)
return weights


class RnnNetMinVar3(torch.nn.Module, Benchmark):
def __init__(self, n_assets, p, shrinkage_strategy, max_weight):
super().__init__()
Expand Down

0 comments on commit 9f8b4f4

Please sign in to comment.