Skip to content

Commit

Permalink
for cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
tvayer committed Oct 28, 2019
1 parent 8b7330c commit bdc4fcb
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions lib/risgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,19 @@ def risgw_gpu(xs,xt,device,nproj=200,P=None,lr=0.001,
from sgw_pytorch import sgw
n_samples=300
Xs=np.random.rand(n_samples,2)
Xt=np.random.rand(n_samples,1)
Xs=np.random.rand(n_samples,1)
Xt=np.random.rand(n_samples,2)
xs=torch.from_numpy(Xs).to(torch.float32)
xt=torch.from_numpy(Xt).to(torch.float32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
P=np.random.randn(2,500)
risgw_gpu(xs,xt,device,P=torch.from_numpy(P).to(torch.float32))
"""
affine_map = StiefelLinear(in_features=xs.size(1),
out_features=xt.size(1))
out_features=xt.size(1),device=device)

optimizer = geoopt.optim.RiemannianAdam(affine_map.parameters(), lr=lr)


running_loss = 0.0
for i in range(max_iter):
Expand All @@ -70,21 +71,21 @@ def risgw_gpu(xs,xt,device,nproj=200,P=None,lr=0.001,
optimizer.step()

# print statistics
running_loss = loss.item()
running_loss = loss
if verbose and (i + 1) % step_verbose == 0:
print('Iteration {}: sgw loss: {:.3f}'.format(i + 1,
running_loss))
running_loss.item()))

return running_loss


def stiefel_uniform_(tensor): # TODO: make things better
with torch.no_grad():
tensor.data = torch.eye(2)
tensor.data = torch.eye(tensor.size(0),tensor.size(1))
return tensor

class StiefelLinear(torch.nn.Module):
def __init__(self, in_features, out_features, bias=False):
def __init__(self, in_features, out_features,device, bias=True):
super(StiefelLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
Expand All @@ -93,10 +94,12 @@ def __init__(self, in_features, out_features, bias=False):
manifold=geoopt.Stiefel()
)
if bias:
self.bias = torch.nn.Parameter(torch.Tensor(out_features))
self.bias = torch.nn.Parameter(torch.Tensor(out_features).to(device))
else:
self.register_parameter('bias', None)
self.reset_parameters()

self.weight.data=self.weight.data.to(device)

def reset_parameters(self):
stiefel_uniform_(self.weight)
Expand Down

0 comments on commit bdc4fcb

Please sign in to comment.