-
Notifications
You must be signed in to change notification settings - Fork 0
/
mutag_net_train.py
75 lines (64 loc) · 2.32 KB
/
mutag_net_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from torch_geometric.loader import DataLoader
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from nci import NCI
from models import GcnEncoderGraph
from mutag import Mutagenicity
# training setting
batch_size = 128
lr = 1e-2
epochs = 10000
num_workers = 32
train_set = Mutagenicity('data/MUTAG', mode='training')
test_set = Mutagenicity('data/MUTAG', mode="testing")
val_set = Mutagenicity('data/MUTAG', mode='evaluation')
test_loader = DataLoader(test_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
val_loader = DataLoader(val_set,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
train_loader = DataLoader(train_set,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
criterion = CrossEntropyLoss()
device = torch.device('cuda:2')
model = GcnEncoderGraph(input_dim=14,
hidden_dim=50,
embedding_dim=10,
num_layers=3,
pred_hidden_dims=[10, 10],
label_dim=2)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
best_acc = 0
for epoch in range(1, epochs + 1):
model.train()
loss_all = 0
optimizer.zero_grad()
for data in train_loader:
data = data.to(device)
y_pred = model(data.x, data.edge_index, data.batch)
loss = criterion(y_pred, data.y)
loss.backward()
loss_all += loss.item() * data.num_graphs
optimizer.step()
if epoch % 10 == 0:
model.eval()
correct = 0
with torch.no_grad():
for data in test_loader:
data = data.to(device)
y_pred = model(
data.x,
data.edge_index,
data.batch,
)
correct += float(y_pred.argmax(dim=1).eq(data.y).sum().item())
if correct / len(test_set) > best_acc:
best_acc = correct / len(test_set)
torch.save(model.state_dict(), 'params/mutag_net.pt')
print('test acc: ', correct / len(test_set), ' best acc', best_acc)