Skip to content

Commit

Permalink
pytorch 0.3 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Mar 6, 2018
1 parent d3e4f39 commit 6abbb6b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 42 deletions.
23 changes: 10 additions & 13 deletions pygcn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@ class SparseMM(torch.autograd.Function):
does-pytorch-support-autograd-on-sparse-matrix/6156/7
"""

def forward(self, matrix1, matrix2):
self.save_for_backward(matrix1, matrix2)
return torch.mm(matrix1, matrix2)
def __init__(self, sparse):
super(SparseMM, self).__init__()
self.sparse = sparse

def backward(self, grad_output):
matrix1, matrix2 = self.saved_tensors
grad_matrix1 = grad_matrix2 = None
def forward(self, dense):
return torch.mm(self.sparse, dense)

def backward(self, grad_output):
grad_input = None
if self.needs_input_grad[0]:
grad_matrix1 = torch.mm(grad_output, matrix2.t())

if self.needs_input_grad[1]:
grad_matrix2 = torch.mm(matrix1.t(), grad_output)

return grad_matrix1, grad_matrix2
grad_input = torch.mm(self.sparse.t(), grad_output)
return grad_input


class GraphConvolution(Module):
Expand All @@ -56,7 +53,7 @@ def reset_parameters(self):

def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = SparseMM()(adj, support)
output = SparseMM(adj)(support)
if self.bias is not None:
return output + self.bias
else:
Expand Down
69 changes: 40 additions & 29 deletions pygcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,33 @@

# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False,
help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False,
help='Validate during training pass.')
parser.add_argument(
'--no-cuda',
action='store_true',
default=False,
help='Disables CUDA training.')
parser.add_argument(
'--fastmode',
action='store_true',
default=False,
help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=200,
help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.01,
help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4,
help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=16,
help='Number of hidden units.')
parser.add_argument('--dropout', type=float, default=0.5,
help='Dropout rate (1 - keep probability).')
parser.add_argument(
'--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument(
'--lr', type=float, default=0.01, help='Initial learning rate.')
parser.add_argument(
'--weight_decay',
type=float,
default=5e-4,
help='Weight decay (L2 loss on parameters).')
parser.add_argument(
'--hidden', type=int, default=16, help='Number of hidden units.')
parser.add_argument(
'--dropout',
type=float,
default=0.5,
help='Dropout rate (1 - keep probability).')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
Expand All @@ -43,12 +55,13 @@
adj, features, labels, idx_train, idx_val, idx_test = load_data()

# Model and optimizer
model = GCN(nfeat=features.shape[1],
nhid=args.hidden,
nclass=labels.max() + 1,
dropout=args.dropout)
optimizer = optim.Adam(model.parameters(),
lr=args.lr, weight_decay=args.weight_decay)
model = GCN(
nfeat=features.shape[1],
nhid=args.hidden,
nclass=labels.max() + 1,
dropout=args.dropout)
optimizer = optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

if args.cuda:
model.cuda()
Expand All @@ -59,7 +72,7 @@
idx_val = idx_val.cuda()
idx_test = idx_test.cuda()

features, adj, labels = Variable(features), Variable(adj), Variable(labels)
features, labels = Variable(features), Variable(labels)


def train(epoch):
Expand All @@ -80,21 +93,19 @@ def train(epoch):

loss_val = F.nll_loss(output[idx_val], labels[idx_val])
acc_val = accuracy(output[idx_val], labels[idx_val])
print('Epoch: {:04d}'.format(epoch+1),
'loss_train: {:.4f}'.format(loss_train.data[0]),
'acc_train: {:.4f}'.format(acc_train.data[0]),
'loss_val: {:.4f}'.format(loss_val.data[0]),
'acc_val: {:.4f}'.format(acc_val.data[0]),
'time: {:.4f}s'.format(time.time() - t))
print(
'Epoch: {:04d}'.format(epoch + 1), 'loss_train: {:.4f}'.format(
loss_train.data[0]), 'acc_train: {:.4f}'.format(acc_train.data[0]),
'loss_val: {:.4f}'.format(loss_val.data[0]), 'acc_val: {:.4f}'.format(
acc_val.data[0]), 'time: {:.4f}s'.format(time.time() - t))


def test():
model.eval()
output = model(features, adj)
loss_test = F.nll_loss(output[idx_test], labels[idx_test])
acc_test = accuracy(output[idx_test], labels[idx_test])
print("Test set results:",
"loss= {:.4f}".format(loss_test.data[0]),
print("Test set results:", "loss= {:.4f}".format(loss_test.data[0]),
"accuracy= {:.4f}".format(acc_test.data[0]))


Expand Down

0 comments on commit 6abbb6b

Please sign in to comment.