Skip to content

Commit

Permalink
Move MatMulBCast class to core/util.
Browse files Browse the repository at this point in the history
Export it under the core/framework target (same as core/util/bcast.h) instead of core/kernels:batch_matmul_op.

As an aside, this allows TFLite use this class without adding a dependency on core/kernels when it just needs the util.

PiperOrigin-RevId: 244945168
  • Loading branch information
bloops authored and tensorflower-gardener committed Apr 24, 2019
1 parent 408cea8 commit b199a97
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 23 deletions.
1 change: 0 additions & 1 deletion tensorflow/contrib/makefile/tf_op_files.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ tensorflow/contrib/boosted_trees/ops/training_ops.cc
tensorflow/core/kernels/aggregate_ops.cc
tensorflow/core/kernels/argmax_op.cc
tensorflow/core/kernels/avgpooling_op.cc
tensorflow/core/kernels/batch_matmul_op_common.cc
tensorflow/core/kernels/batch_matmul_op_real.cc
tensorflow/core/kernels/batch_norm_op.cc
tensorflow/core/kernels/batchtospace_op.cc
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,7 @@ tf_cuda_library(
"util/activation_mode.h",
"util/batch_util.h",
"util/bcast.h",
"util/matmul_bcast.h",
"util/cuda_kernel_helper.h",
"util/device_name_utils.h",
"util/dump_graph.h",
Expand Down Expand Up @@ -3977,6 +3978,7 @@ tf_cc_tests(
"util/events_writer_test.cc",
"util/example_proto_fast_parsing_test.cc",
"util/example_proto_helper_test.cc",
"util/matmul_bcast_test.cc",
"util/memmapped_file_system_test.cc",
"util/presized_cuckoo_map_test.cc",
"util/reffed_status_callback_test.cc",
Expand Down
16 changes: 0 additions & 16 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3682,19 +3682,6 @@ tf_cuda_cc_test(
],
)

tf_cc_test(
name = "batch_matmul_op_common_test",
size = "small",
srcs = ["batch_matmul_op_common_test.cc"],
deps = [
":batch_matmul_op",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

tf_cuda_cc_test(
name = "batch_matmul_op_test",
size = "small",
Expand Down Expand Up @@ -5606,8 +5593,6 @@ filegroup(
name = "mobile_srcs",
srcs = [
"avgpooling_op.h",
"batch_matmul_op_common.cc",
"batch_matmul_op_common.h",
"batch_util.h",
"cwise_ops.h",
"cwise_ops_common.h",
Expand Down Expand Up @@ -6108,7 +6093,6 @@ filegroup(
"*_3d*",
"*.cu.*",
# Ops already in android_srcs
"batch_matmul_op_common.cc",
"pooling_ops_common.cc",
# Ops which we are currently excluding because they are likely
# not used on Android. Those ops also do not compile if included,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/batch_matmul_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/batch_matmul_op_common.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/matmul_bcast.h"
#include "tensorflow/core/util/work_sharder.h"

#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/kernels/batch_matmul_op_common.h"
#include "tensorflow/core/util/matmul_bcast.h"

namespace tensorflow {
namespace {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
#ifndef TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_
#define TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_

#include <vector>

Expand Down Expand Up @@ -67,4 +67,4 @@ class MatMulBCast {

} // namespace tensorflow

#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
#endif // TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/kernels/batch_matmul_op_common.h"
#include "tensorflow/core/util/matmul_bcast.h"

#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
Expand Down

0 comments on commit b199a97

Please sign in to comment.