@@ -77,8 +77,8 @@ def split_ps_reuse(point_set, level, pos, tree, cutdim):
77
77
num_points = np .array (sz )[0 ]/ 2
78
78
max_value = point_set .max (dim = 0 )[0 ]
79
79
min_value = - (- point_set ).max (dim = 0 )[0 ]
80
+
80
81
diff = max_value - min_value
81
-
82
82
dim = torch .max (diff , dim = 1 )[1 ][0 ,0 ]
83
83
84
84
cut = torch .median (point_set [:,dim ])[0 ][0 ]
@@ -108,7 +108,7 @@ def split_ps_reuse(point_set, level, pos, tree, cutdim):
108
108
net = KDNet ().cuda ()
109
109
optimizer = optim .SGD (net .parameters (), lr = 0.01 , momentum = 0.9 )
110
110
111
- for it in range (1000 ):
111
+ for it in range (10000 ):
112
112
optimizer .zero_grad ()
113
113
losses = []
114
114
corrects = []
@@ -144,7 +144,6 @@ def split_ps_reuse(point_set, level, pos, tree, cutdim):
144
144
#gc.collect()
145
145
#max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
146
146
#print("{:.2f} MB".format(max_mem_used / 1024))
147
-
148
147
149
148
points = torch .stack (tree [- 1 ])
150
149
points_v = Variable (torch .unsqueeze (torch .squeeze (points ), 0 )).transpose (2 ,1 ).cuda ()
@@ -160,4 +159,8 @@ def split_ps_reuse(point_set, level, pos, tree, cutdim):
160
159
losses .append (loss .data [0 ])
161
160
162
161
optimizer .step ()
163
- print ('batch: %d, loss: %f, correct %d/10' % ( it , np .mean (losses ), np .sum (corrects )))
162
+ print ('batch: %d, loss: %f, correct %d/10' % ( it , np .mean (losses ), np .sum (corrects )))
163
+
164
+ if it % 1000 == 0 :
165
+ torch .save (net .state_dict (), 'save_model_%d.pth' % (it ))
166
+
0 commit comments