Skip to content

Commit 0018798

Browse files
committed
add save
1 parent d14a70e commit 0018798

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

train.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def split_ps_reuse(point_set, level, pos, tree, cutdim):
7777
num_points = np.array(sz)[0]/2
7878
max_value = point_set.max(dim=0)[0]
7979
min_value = -(-point_set).max(dim=0)[0]
80+
8081
diff = max_value - min_value
81-
8282
dim = torch.max(diff, dim = 1)[1][0,0]
8383

8484
cut = torch.median(point_set[:,dim])[0][0]
@@ -108,7 +108,7 @@ def split_ps_reuse(point_set, level, pos, tree, cutdim):
108108
net = KDNet().cuda()
109109
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
110110

111-
for it in range(1000):
111+
for it in range(10000):
112112
optimizer.zero_grad()
113113
losses = []
114114
corrects = []
@@ -144,7 +144,6 @@ def split_ps_reuse(point_set, level, pos, tree, cutdim):
144144
#gc.collect()
145145
#max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
146146
#print("{:.2f} MB".format(max_mem_used / 1024))
147-
148147

149148
points = torch.stack(tree[-1])
150149
points_v = Variable(torch.unsqueeze(torch.squeeze(points), 0)).transpose(2,1).cuda()
@@ -160,4 +159,8 @@ def split_ps_reuse(point_set, level, pos, tree, cutdim):
160159
losses.append(loss.data[0])
161160

162161
optimizer.step()
163-
print('batch: %d, loss: %f, correct %d/10' %( it, np.mean(losses), np.sum(corrects)))
162+
print('batch: %d, loss: %f, correct %d/10' %( it, np.mean(losses), np.sum(corrects)))
163+
164+
if it % 1000 == 0:
165+
torch.save(net.state_dict(), 'save_model_%d.pth' % (it))
166+

0 commit comments

Comments
 (0)