Skip to content

Commit

Permalink
Merge pull request tkipf#10 from daniilidis-group/master
Browse files Browse the repository at this point in the history
Small changes to make it compatible with PyTorch 0.4.0
  • Loading branch information
tkipf authored Jun 26, 2018
2 parents 7474897 + 88c6676 commit bf1410e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 32 deletions.
29 changes: 3 additions & 26 deletions pygcn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,6 @@
from torch.nn.modules.module import Module


class SparseMM(torch.autograd.Function):
"""
Sparse x dense matrix multiplication with autograd support.
Implementation by Soumith Chintala:
https://discuss.pytorch.org/t/
does-pytorch-support-autograd-on-sparse-matrix/6156/7
"""

def __init__(self, sparse):
super(SparseMM, self).__init__()
self.sparse = sparse

def forward(self, dense):
return torch.mm(self.sparse, dense)

def backward(self, grad_output):
grad_input = None
if self.needs_input_grad[0]:
grad_input = torch.mm(self.sparse.t(), grad_output)
return grad_input


class GraphConvolution(Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
Expand All @@ -38,9 +15,9 @@ def __init__(self, in_features, out_features, bias=True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(in_features, out_features))
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
Expand All @@ -53,7 +30,7 @@ def reset_parameters(self):

def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = SparseMM(adj)(support)
output = torch.spmm(adj, support)
if self.bias is not None:
return output + self.bias
else:
Expand Down
9 changes: 3 additions & 6 deletions pygcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from pygcn.utils import load_data, accuracy
from pygcn.models import GCN
Expand Down Expand Up @@ -45,7 +44,7 @@
# Model and optimizer
model = GCN(nfeat=features.shape[1],
nhid=args.hidden,
nclass=labels.max() + 1,
nclass=labels.max().item() + 1,
dropout=args.dropout)
optimizer = optim.Adam(model.parameters(),
lr=args.lr, weight_decay=args.weight_decay)
Expand All @@ -59,8 +58,6 @@
idx_val = idx_val.cuda()
idx_test = idx_test.cuda()

features, labels = Variable(features), Variable(labels)


def train(epoch):
t = time.time()
Expand Down Expand Up @@ -94,8 +91,8 @@ def test():
loss_test = F.nll_loss(output[idx_test], labels[idx_test])
acc_test = accuracy(output[idx_test], labels[idx_test])
print("Test set results:",
"loss= {:.4f}".format(loss_test.data[0]),
"accuracy= {:.4f}".format(acc_test.data[0]))
"loss= {:.4f}".format(loss_test.item()),
"accuracy= {:.4f}".format(acc_test.item()))


# Train model
Expand Down

0 comments on commit bf1410e

Please sign in to comment.