Skip to content

Commit d14a70e

Browse files
committed
fix memory leak
1 parent d8a4c4b commit d14a70e

File tree

1 file changed

+62
-20
lines changed

1 file changed

+62
-20
lines changed

train.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,6 @@ def split_ps(point_set):
6161
right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
6262
middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))
6363

64-
#if torch.numel(left_idx) > 0:
65-
# left_idx = left_idx[:,0]
66-
#if torch.numel(right_idx) > 0:
67-
# right_idx = right_idx[:,0]
68-
#if torch.numel(middle_idx) > 0:
69-
# middle_idx = middle_idx[:,0]
70-
7164
if torch.numel(left_idx) < num_points:
7265
left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
7366
if torch.numel(right_idx) < num_points:
@@ -77,11 +70,41 @@ def split_ps(point_set):
7770
right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
7871
return left_ps, right_ps, dim
7972

73+
74+
75+
def split_ps_reuse(point_set, level, pos, tree, cutdim):
76+
sz = point_set.size()
77+
num_points = np.array(sz)[0]/2
78+
max_value = point_set.max(dim=0)[0]
79+
min_value = -(-point_set).max(dim=0)[0]
80+
diff = max_value - min_value
81+
82+
dim = torch.max(diff, dim = 1)[1][0,0]
83+
84+
cut = torch.median(point_set[:,dim])[0][0]
85+
left_idx = torch.squeeze(torch.nonzero(point_set[:,dim] > cut))
86+
right_idx = torch.squeeze(torch.nonzero(point_set[:,dim] < cut))
87+
middle_idx = torch.squeeze(torch.nonzero(point_set[:,dim] == cut))
88+
89+
if torch.numel(left_idx) < num_points:
90+
left_idx = torch.cat([left_idx, middle_idx[0:1].repeat(num_points - torch.numel(left_idx))], 0)
91+
if torch.numel(right_idx) < num_points:
92+
right_idx = torch.cat([right_idx, middle_idx[0:1].repeat(num_points - torch.numel(right_idx))], 0)
93+
94+
left_ps = torch.index_select(point_set, dim = 0, index = left_idx)
95+
right_ps = torch.index_select(point_set, dim = 0, index = right_idx)
96+
97+
tree[level+1][pos * 2] = left_ps
98+
tree[level+1][pos * 2 + 1] = right_ps
99+
cutdim[level][pos * 2] = dim
100+
cutdim[level][pos * 2 + 1] = dim
101+
102+
return
103+
80104
d = PartDataset(root = '../unsupervised3d/shapenetcore_partanno_segmentation_benchmark_v0', classification = True)
81105
l = len(d)
82106
print(len(d.classes), l)
83107
levels = (np.log(2048)/np.log(2)).astype(int)
84-
cutdim = torch.zeros((levels)).long()
85108
net = KDNet().cuda()
86109
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
87110

@@ -92,22 +115,41 @@ def split_ps(point_set):
92115
for batch in range(10):
93116
j = np.random.randint(l)
94117
point_set, class_label = d[j]
118+
95119
target = Variable(class_label).cuda()
96-
tree = [[] for i in range(levels + 1)]
97-
cutdim = [[] for i in range(levels)]
98-
tree[0].append(point_set)
99-
for level in range(levels):
100-
for item in tree[level]:
101-
left_ps, right_ps, dim = split_ps(item)
102-
tree[level+1].append(left_ps)
103-
tree[level+1].append(right_ps)
104-
cutdim[level].append(dim)
105-
cutdim[level].append(dim)
106-
cutdim = [(torch.from_numpy(np.array(item).astype(np.int64))) for item in cutdim]
120+
if batch == 0 and it ==0:
121+
tree = [[] for i in range(levels + 1)]
122+
cutdim = [[] for i in range(levels)]
123+
tree[0].append(point_set)
124+
125+
for level in range(levels):
126+
for item in tree[level]:
127+
left_ps, right_ps, dim = split_ps(item)
128+
tree[level+1].append(left_ps)
129+
tree[level+1].append(right_ps)
130+
cutdim[level].append(dim)
131+
cutdim[level].append(dim)
132+
133+
else:
134+
tree[0] = [point_set]
135+
for level in range(levels):
136+
for pos, item in enumerate(tree[level]):
137+
split_ps_reuse(item, level, pos, tree, cutdim)
138+
#print level, pos
139+
140+
cutdim_v = [(torch.from_numpy(np.array(item).astype(np.int64))) for item in cutdim]
141+
142+
#import gc
143+
#import resource
144+
#gc.collect()
145+
#max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
146+
#print("{:.2f} MB".format(max_mem_used / 1024))
147+
148+
107149
points = torch.stack(tree[-1])
108150
points_v = Variable(torch.unsqueeze(torch.squeeze(points), 0)).transpose(2,1).cuda()
109151

110-
pred = net(points_v, cutdim)
152+
pred = net(points_v, cutdim_v)
111153

112154
pred_choice = pred.data.max(1)[1]
113155
correct = pred_choice.eq(target.data).cpu().sum()

0 commit comments

Comments
 (0)