@@ -130,10 +130,30 @@ Status ValidateInputs(const Tensor *shape_t, const Tensor *reduction_axes_t) {
130
130
return Status::OK ();
131
131
}
132
132
133
- template <typename T>
134
- class SparseReduceSumOp : public OpKernel {
133
+ struct SumOp {
134
+ template <typename T>
135
+ static void Run (OpKernelContext *ctx, typename TTypes<T>::Scalar &s, const typename TTypes<T>::UnalignedVec &v) {
136
+ s.device (ctx->eigen_cpu_device ()) = v.sum ();
137
+ }
138
+ static StringPiece Name () {
139
+ return " sum" ;
140
+ }
141
+ };
142
+
143
+ struct MaxOp {
144
+ template <typename T>
145
+ static void Run (OpKernelContext *ctx, typename TTypes<T>::Scalar &s, const typename TTypes<T>::UnalignedVec &v) {
146
+ s.device (ctx->eigen_cpu_device ()) = v.maximum ();
147
+ }
148
+ static StringPiece Name () {
149
+ return " max" ;
150
+ }
151
+ };
152
+
153
+ template <typename T, typename Op>
154
+ class SparseReduceOp : public OpKernel {
135
155
public:
136
- explicit SparseReduceSumOp (OpKernelConstruction *ctx) : OpKernel(ctx) {
156
+ explicit SparseReduceOp (OpKernelConstruction *ctx) : OpKernel(ctx) {
137
157
OP_REQUIRES_OK (ctx, ctx->GetAttr (" keep_dims" , &keep_dims_));
138
158
}
139
159
@@ -163,10 +183,10 @@ class SparseReduceSumOp : public OpKernel {
163
183
auto out_flat = out_values->flat <T>();
164
184
out_flat.setZero ();
165
185
166
- Tensor tmp_group_sum ;
186
+ Tensor tmp_reduced_val ;
167
187
OP_REQUIRES_OK (ctx, ctx->allocate_temp (DataTypeToEnum<T>::value,
168
- TensorShape ({}), &tmp_group_sum ));
169
- auto group_sum = tmp_group_sum .scalar <T>();
188
+ TensorShape ({}), &tmp_reduced_val ));
189
+ auto reduced_val = tmp_reduced_val .scalar <T>();
170
190
171
191
// Compute strides, and use it to convert coords to flat index. The
172
192
// coordinates returned by .group() have the same ndims as group_by_dims.
@@ -196,11 +216,12 @@ class SparseReduceSumOp : public OpKernel {
196
216
// g.group() provides the coordinates of a particular reduced value.
197
217
sp.Reorder <T>(reduction.reorder_dims );
198
218
for (const auto &g : sp.group (reduction.group_by_dims )) {
199
- group_sum. device (ctx-> eigen_cpu_device ()) = g.template values <T>(). sum ( );
219
+ Op:: template Run<T>(ctx, reduced_val, g.template values <T>());
200
220
const int64 idx = CoordinatesToFlatIndex (g.group (), output_strides);
201
- out_flat (idx) = group_sum ();
221
+ out_flat (idx) = reduced_val ();
202
222
VLOG (2 ) << " coords: " << str_util::Join (g.group (), " ," )
203
- << " ; idx: " << idx << " ; group sum: " << group_sum ();
223
+ << " ; idx: " << idx << " ; group " << Op::Name () << " : "
224
+ << reduced_val ();
204
225
}
205
226
}
206
227
@@ -212,14 +233,21 @@ class SparseReduceSumOp : public OpKernel {
212
233
#define REGISTER_KERNELS (T ) \
213
234
REGISTER_KERNEL_BUILDER ( \
214
235
Name (" SparseReduceSum" ).Device(DEVICE_CPU).TypeConstraint<T>(" T" ), \
215
- SparseReduceSumOp<T >)
236
+ SparseReduceOp<T, SumOp >)
216
237
TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
217
238
#undef REGISTER_KERNELS
218
239
219
- template <typename T>
220
- class SparseReduceSumSparseOp : public OpKernel {
240
+ #define REGISTER_KERNELS (T ) \
241
+ REGISTER_KERNEL_BUILDER ( \
242
+ Name (" SparseReduceMax" ).Device(DEVICE_CPU).TypeConstraint<T>(" T" ), \
243
+ SparseReduceOp<T, MaxOp>)
244
+ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
245
+ #undef REGISTER_KERNELS
246
+
247
+ template <typename T, typename Op>
248
+ class SparseReduceSparseOp : public OpKernel {
221
249
public:
222
- explicit SparseReduceSumSparseOp (OpKernelConstruction *ctx) : OpKernel(ctx) {
250
+ explicit SparseReduceSparseOp (OpKernelConstruction *ctx) : OpKernel(ctx) {
223
251
OP_REQUIRES_OK (ctx, ctx->GetAttr (" keep_dims" , &keep_dims_));
224
252
}
225
253
@@ -260,13 +288,13 @@ class SparseReduceSumSparseOp : public OpKernel {
260
288
ctx->allocate_output (1 , TensorShape ({nnz}), &out_values_t ));
261
289
auto out_flat = out_values_t ->flat <T>();
262
290
263
- Tensor tmp_group_sum ;
291
+ Tensor tmp_reduced_val ;
264
292
OP_REQUIRES_OK (ctx, ctx->allocate_temp (DataTypeToEnum<T>::value,
265
- TensorShape ({}), &tmp_group_sum ));
266
- auto group_sum = tmp_group_sum .scalar <T>();
293
+ TensorShape ({}), &tmp_reduced_val ));
294
+ auto reduced_val = tmp_reduced_val .scalar <T>();
267
295
int64 i = 0 ;
268
296
for (const auto &g : sp.group (reduction.group_by_dims )) {
269
- group_sum. device (ctx-> eigen_cpu_device ()) = g.template values <T>(). sum ( );
297
+ Op:: template Run<T>(ctx, reduced_val, g.template values <T>());
270
298
std::vector<int64> group = g.group ();
271
299
for (int64 j = 0 ; j < group.size (); j++) {
272
300
if (keep_dims_) {
@@ -275,10 +303,11 @@ class SparseReduceSumSparseOp : public OpKernel {
275
303
out_indices_mat (i, j) = group[j];
276
304
}
277
305
}
278
- out_flat (i) = group_sum ();
306
+ out_flat (i) = reduced_val ();
279
307
i++;
280
308
VLOG (2 ) << " coords: " << str_util::Join (g.group (), " ," )
281
- << " ; group sum: " << group_sum ();
309
+ << " ; group " << Op::Name () << " : "
310
+ << reduced_val ();
282
311
}
283
312
284
313
Tensor *out_shape_t ;
@@ -298,8 +327,15 @@ class SparseReduceSumSparseOp : public OpKernel {
298
327
#define REGISTER_KERNELS (T ) \
299
328
REGISTER_KERNEL_BUILDER ( \
300
329
Name (" SparseReduceSumSparse" ).Device(DEVICE_CPU).TypeConstraint<T>(" T" ), \
301
- SparseReduceSumSparseOp<T >)
330
+ SparseReduceSparseOp<T, SumOp >)
302
331
TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
303
332
#undef REGISTER_KERNELS
304
333
334
+ #define REGISTER_KERNELS (T ) \
335
+ REGISTER_KERNEL_BUILDER ( \
336
+ Name (" SparseReduceMaxSparse" ).Device(DEVICE_CPU).TypeConstraint<T>(" T" ), \
337
+ SparseReduceSparseOp<T, MaxOp>)
338
+ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
339
+ #undef REGISTER_KERNELS
340
+
305
341
} // namespace tensorflow
0 commit comments