@@ -61,13 +61,6 @@ def split_ps(point_set):
61
61
right_idx = torch .squeeze (torch .nonzero (point_set [:,dim ] < cut ))
62
62
middle_idx = torch .squeeze (torch .nonzero (point_set [:,dim ] == cut ))
63
63
64
- #if torch.numel(left_idx) > 0:
65
- # left_idx = left_idx[:,0]
66
- #if torch.numel(right_idx) > 0:
67
- # right_idx = right_idx[:,0]
68
- #if torch.numel(middle_idx) > 0:
69
- # middle_idx = middle_idx[:,0]
70
-
71
64
if torch .numel (left_idx ) < num_points :
72
65
left_idx = torch .cat ([left_idx , middle_idx [0 :1 ].repeat (num_points - torch .numel (left_idx ))], 0 )
73
66
if torch .numel (right_idx ) < num_points :
@@ -77,11 +70,41 @@ def split_ps(point_set):
77
70
right_ps = torch .index_select (point_set , dim = 0 , index = right_idx )
78
71
return left_ps , right_ps , dim
79
72
73
+
74
+
75
+ def split_ps_reuse (point_set , level , pos , tree , cutdim ):
76
+ sz = point_set .size ()
77
+ num_points = np .array (sz )[0 ]/ 2
78
+ max_value = point_set .max (dim = 0 )[0 ]
79
+ min_value = - (- point_set ).max (dim = 0 )[0 ]
80
+ diff = max_value - min_value
81
+
82
+ dim = torch .max (diff , dim = 1 )[1 ][0 ,0 ]
83
+
84
+ cut = torch .median (point_set [:,dim ])[0 ][0 ]
85
+ left_idx = torch .squeeze (torch .nonzero (point_set [:,dim ] > cut ))
86
+ right_idx = torch .squeeze (torch .nonzero (point_set [:,dim ] < cut ))
87
+ middle_idx = torch .squeeze (torch .nonzero (point_set [:,dim ] == cut ))
88
+
89
+ if torch .numel (left_idx ) < num_points :
90
+ left_idx = torch .cat ([left_idx , middle_idx [0 :1 ].repeat (num_points - torch .numel (left_idx ))], 0 )
91
+ if torch .numel (right_idx ) < num_points :
92
+ right_idx = torch .cat ([right_idx , middle_idx [0 :1 ].repeat (num_points - torch .numel (right_idx ))], 0 )
93
+
94
+ left_ps = torch .index_select (point_set , dim = 0 , index = left_idx )
95
+ right_ps = torch .index_select (point_set , dim = 0 , index = right_idx )
96
+
97
+ tree [level + 1 ][pos * 2 ] = left_ps
98
+ tree [level + 1 ][pos * 2 + 1 ] = right_ps
99
+ cutdim [level ][pos * 2 ] = dim
100
+ cutdim [level ][pos * 2 + 1 ] = dim
101
+
102
+ return
103
+
80
104
d = PartDataset (root = '../unsupervised3d/shapenetcore_partanno_segmentation_benchmark_v0' , classification = True )
81
105
l = len (d )
82
106
print (len (d .classes ), l )
83
107
levels = (np .log (2048 )/ np .log (2 )).astype (int )
84
- cutdim = torch .zeros ((levels )).long ()
85
108
net = KDNet ().cuda ()
86
109
optimizer = optim .SGD (net .parameters (), lr = 0.01 , momentum = 0.9 )
87
110
@@ -92,22 +115,41 @@ def split_ps(point_set):
92
115
for batch in range (10 ):
93
116
j = np .random .randint (l )
94
117
point_set , class_label = d [j ]
118
+
95
119
target = Variable (class_label ).cuda ()
96
- tree = [[] for i in range (levels + 1 )]
97
- cutdim = [[] for i in range (levels )]
98
- tree [0 ].append (point_set )
99
- for level in range (levels ):
100
- for item in tree [level ]:
101
- left_ps , right_ps , dim = split_ps (item )
102
- tree [level + 1 ].append (left_ps )
103
- tree [level + 1 ].append (right_ps )
104
- cutdim [level ].append (dim )
105
- cutdim [level ].append (dim )
106
- cutdim = [(torch .from_numpy (np .array (item ).astype (np .int64 ))) for item in cutdim ]
120
+ if batch == 0 and it == 0 :
121
+ tree = [[] for i in range (levels + 1 )]
122
+ cutdim = [[] for i in range (levels )]
123
+ tree [0 ].append (point_set )
124
+
125
+ for level in range (levels ):
126
+ for item in tree [level ]:
127
+ left_ps , right_ps , dim = split_ps (item )
128
+ tree [level + 1 ].append (left_ps )
129
+ tree [level + 1 ].append (right_ps )
130
+ cutdim [level ].append (dim )
131
+ cutdim [level ].append (dim )
132
+
133
+ else :
134
+ tree [0 ] = [point_set ]
135
+ for level in range (levels ):
136
+ for pos , item in enumerate (tree [level ]):
137
+ split_ps_reuse (item , level , pos , tree , cutdim )
138
+ #print level, pos
139
+
140
+ cutdim_v = [(torch .from_numpy (np .array (item ).astype (np .int64 ))) for item in cutdim ]
141
+
142
+ #import gc
143
+ #import resource
144
+ #gc.collect()
145
+ #max_mem_used = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
146
+ #print("{:.2f} MB".format(max_mem_used / 1024))
147
+
148
+
107
149
points = torch .stack (tree [- 1 ])
108
150
points_v = Variable (torch .unsqueeze (torch .squeeze (points ), 0 )).transpose (2 ,1 ).cuda ()
109
151
110
- pred = net (points_v , cutdim )
152
+ pred = net (points_v , cutdim_v )
111
153
112
154
pred_choice = pred .data .max (1 )[1 ]
113
155
correct = pred_choice .eq (target .data ).cpu ().sum ()
0 commit comments