Skip to content

Commit

Permalink
Merge branch 'one_table' into 'master'
Browse files Browse the repository at this point in the history
Fused GPU Hashtable Ops

See merge request data/monolith!1882

GitOrigin-RevId: 163e0466e35ceb95a67038c368d21a9942ba0412
  • Loading branch information
hanzhi713 authored and zlqiszlqbd committed Jan 4, 2023
1 parent 65a0617 commit 5260427
Show file tree
Hide file tree
Showing 21 changed files with 1,066 additions and 245 deletions.
2 changes: 0 additions & 2 deletions monolith/native_training/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,6 @@ def create_hash_table_and_filters_fn():
with device_utils.maybe_device_if_allowed('/device:GPU:0'):
hash_filters = hash_filter_ops.create_hash_filters(
self.config.num_ps,
# disable hash filter (only use dummy filter) for GPU table for now
False if self._params.train.use_gpu_emb_table else
self._enable_hash_filter,
config=slot_occurrence_threshold_config.SerializeToString(),
filter_capacity=self.config.filter_capacity,
Expand Down
3 changes: 3 additions & 0 deletions monolith/native_training/hash_filter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class FilterType(object):

SLIDING_HASH_FILTER = 'sliding_hash_filter'
PROBABILISTIC_FILTER = 'probabilistic_filter'
NO_FILTER = 'no_filter'


def create_hash_filter(capacity: int,
Expand Down Expand Up @@ -88,6 +89,8 @@ def _create_hash_filter(
elif filter_type == FilterType.PROBABILISTIC_FILTER:
return create_probabilistic_filter(filter_equal_probability, config,
name_suffix)
elif filter_type == FilterType.NO_FILTER:
return create_dummy_hash_filter(name_suffix)
else:
raise ValueError("Invalid filter type, please investigate and retry!")
else:
Expand Down
8 changes: 4 additions & 4 deletions monolith/native_training/hash_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,14 @@ def hash_table_from_config(config: entry.HashTableConfigInstance,
slot_expire_time_config = config.table_config.slot_expire_time_config.SerializeToString(
)
hash_table_name = "MonolithHashTable_" + name_suffix
if hash_filter is None or use_gpu: # We don't have gpu filter for now, get rid of or use_gpu if added one
with tf.device(d):
hash_filter = hash_filter_ops.create_dummy_hash_filter(
name_suffix=name_suffix)
if len(config.learning_rate_fns) != len(
config.table_config.entry_config.segments):
raise ValueError(
"Size of learning_rate_fns and size of segments must be equal.")
if hash_filter is None:
with tf.device(d):
hash_filter = hash_filter_ops.create_dummy_hash_filter(
name_suffix=name_suffix)
if sync_client is None or use_gpu: # We don't have gpu sync for now, get rid of or use_gpu if added one
with tf.device(d):
sync_client = distributed_serving_ops.create_dummy_sync_client()
Expand Down
8 changes: 4 additions & 4 deletions monolith/native_training/hash_table_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,8 +1103,8 @@ def test_fused_lookup(self):
embeddings)
self.assertAllEqual(embeddings, [1, 0, 1, 1, 1, 0, 1, 1])
self.assertAllEqual(recv_splits, [4, 4])
self.assertAllEqual(id_offsets, [0, 1, 2, 3, 4, 5])
self.assertAllEqual(emb_offsets, [0, 1, 2, 4, 5, 6])
self.assertAllEqual(id_offsets, [0, 1, 2, 3, 4, 5, 6])
self.assertAllEqual(emb_offsets, [0, 1, 2, 4, 5, 6, 8])
self.assertAllEqual(emb_dims, [1, 1, 2, 1, 1, 2])

def test_fused_optimize(self):
Expand Down Expand Up @@ -1147,8 +1147,8 @@ def test_fused_optimize(self):
lookup_op)
self.assertAllClose(embeddings, [1.1, 0.2, 0.2, 1.1, 0.2, 0.2])
self.assertAllEqual(recv_splits, [3, 3])
self.assertAllEqual(id_offsets, [0, 1, 2, 3])
self.assertAllEqual(emb_offsets, [0, 1, 3, 4])
self.assertAllEqual(id_offsets, [0, 1, 2, 3, 4])
self.assertAllEqual(emb_offsets, [0, 1, 3, 4, 6])
self.assertAllEqual(emb_dims, [1, 2, 1, 2])

def test_batch_softmax_optimizer(self):
Expand Down
16 changes: 8 additions & 8 deletions monolith/native_training/multi_type_hash_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def test_fused_lookup(self):
values_dict)
self.assertAllEqual(embeddings, [1, 2, 2, 4, 4, 8, 8])
self.assertAllEqual(recv_splits, [7])
self.assertAllEqual(id_offsets, [0, 1, 2])
self.assertAllEqual(emb_offsets, [0, 1, 3])
self.assertAllEqual(id_offsets, [0, 1, 2, 4])
self.assertAllEqual(emb_offsets, [0, 1, 3, 7])
self.assertAllEqual(emb_sizes, [1, 2, 4])

def test_fused_lookup_multi_shards(self):
Expand All @@ -193,8 +193,8 @@ def test_fused_lookup_multi_shards(self):
values_dict)
self.assertAllEqual(embeddings, [1, 4, 4, 2, 2, 8, 8])
self.assertAllEqual(recv_splits, [3, 4])
self.assertAllEqual(id_offsets, [0, 1, 1, 2, 2, 3])
self.assertAllEqual(emb_offsets, [0, 1, 1, 3, 3, 5])
self.assertAllEqual(id_offsets, [0, 1, 1, 2, 2, 3, 4])
self.assertAllEqual(emb_offsets, [0, 1, 1, 3, 3, 5, 7])
self.assertAllEqual(emb_sizes, [1, 0, 2, 0, 2, 2])

def test_fused_apply_gradients(self):
Expand All @@ -217,8 +217,8 @@ def test_fused_apply_gradients(self):
lookup_op)
self.assertAllEqual(embeddings, [-2, -1, -3, -2, -4])
self.assertAllEqual(recv_splits, [5])
self.assertAllEqual(id_offsets, [0, 1])
self.assertAllEqual(emb_offsets, [0, 1])
self.assertAllEqual(id_offsets, [0, 1, 3])
self.assertAllEqual(emb_offsets, [0, 1, 5])
self.assertAllEqual(emb_sizes, [1, 4])

def test_fused_apply_gradients_missing_tables(self):
Expand All @@ -241,8 +241,8 @@ def test_fused_apply_gradients_missing_tables(self):
lookup_op)
self.assertAllEqual(embeddings, [-3, -3])
self.assertAllEqual(recv_splits, [1, 1])
self.assertAllEqual(id_offsets, [0, 1, 1, 2])
self.assertAllEqual(emb_offsets, [0, 1, 1, 2])
self.assertAllEqual(id_offsets, [0, 1, 1, 2, 2])
self.assertAllEqual(emb_offsets, [0, 1, 1, 2, 2])
self.assertAllEqual(emb_sizes, [1, 0, 1, 0])


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace monolith {
namespace hash_table {

std::unique_ptr<EmbeddingHashTableInterface> NewEmbeddingHashTableFromConfig(
EmbeddingHashTableConfig config, cudaStream_t stream) {
EmbeddingHashTableConfig config, GpuExtraArgs args) {
std::unique_ptr<EntryAccessorInterface> accessor =
NewEntryAccessor(config.entry_config());
switch (config.type_case()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace monolith {
namespace hash_table {

std::unique_ptr<EmbeddingHashTableInterface> NewEmbeddingHashTableFromConfig(
EmbeddingHashTableConfig config, cudaStream_t stream = 0);
EmbeddingHashTableConfig config, GpuExtraArgs args = GpuExtraArgs{});

} // namespace hash_table
} // namespace monolith
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ using cudaStream_t = void*;
namespace monolith {
namespace hash_table {

class CucoMultiHashTableOp;
struct GpuExtraArgs {
CucoMultiHashTableOp* shared_state;
cudaStream_t stream;
};

// Hash table maps int64 to a float array with fixed length.
// Implemention of this interface should guarantee thread safety.
class EmbeddingHashTableInterface {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#include "absl/algorithm/container.h"
#include "absl/types/span.h"
#include "glog/logging.h"
#include "monolith/native_training/runtime/hash_table/embedding_hash_table.pb.h"
#include "monolith/native_training/runtime/hash_table/initializer/initializer_interface.h"
#include "monolith/native_training/runtime/hash_table/optimizer/optimizer_interface.h"
Expand Down
1 change: 1 addition & 0 deletions monolith/native_training/runtime/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ tf_gpu_kernel_library_allow_except(
"hash_table_restore_op.cc",
"hash_table_save_op.cc",
"hash_table_update_op.cc",
"gpu_multi_hash_table.h",
"multi_hash_table.h",
"multi_hash_table_lookup_op.cc",
"multi_hash_table_op.cc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Status ValidateDim(const Tensor& t, int64 expected_dim) {
Status EmbeddingHashTableTfBridge::New(
monolith::hash_table::EmbeddingHashTableConfig config,
HashFilterTfBridge* hash_filter, EmbeddingHashTableTfBridge** new_bridge,
const std::string& name, cudaStream_t stream) {
const std::string& name, hash_table::GpuExtraArgs args) {
auto bridge = core::RefCountPtr<EmbeddingHashTableTfBridge>(
new EmbeddingHashTableTfBridge(hash_filter));
bridge->config_ = config;
Expand All @@ -60,7 +60,7 @@ Status EmbeddingHashTableTfBridge::New(
}
try {
bridge->table_ =
hash_table::NewEmbeddingHashTableFromConfig(config, stream);
hash_table::NewEmbeddingHashTableFromConfig(config, std::move(args));
} catch (const std::exception& e) {
return errors::InvalidArgument(e.what());
}
Expand Down Expand Up @@ -134,7 +134,6 @@ Status EmbeddingHashTableTfBridge::BatchLookup(OpKernelContext* ctx,
}
}


Status EmbeddingHashTableTfBridge::BatchLookupEntry(
OpKernelContext* ctx, const int num_ids, int64_t* ids,
EntryDump* out_entries) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class EmbeddingHashTableTfBridge : public ResourceBase {
static Status New(monolith::hash_table::EmbeddingHashTableConfig config,
HashFilterTfBridge* hash_filter,
EmbeddingHashTableTfBridge** new_bridge,
const std::string& name, cudaStream_t stream = 0);
const std::string& name,
monolith::hash_table::GpuExtraArgs args = {});

~EmbeddingHashTableTfBridge();

Expand Down Expand Up @@ -106,6 +107,9 @@ class EmbeddingHashTableTfBridge : public ResourceBase {
std::vector<std::pair<int64_t, const void*>> TouchedKeySet() const;

const monolith::hash_table::EmbeddingHashTableConfig& GetConfig() const;
monolith::hash_table::EmbeddingHashTableInterface* GetTable() const {
return table_.get();
}

private:
explicit EmbeddingHashTableTfBridge(HashFilterTfBridge* hash_filter)
Expand Down
42 changes: 42 additions & 0 deletions monolith/native_training/runtime/ops/gpu_multi_hash_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2022 ByteDance and/or its affiliates.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_GPU_MULTI_HASH_TABLE
#define MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_GPU_MULTI_HASH_TABLE
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#include <vector>

#include "monolith/native_training/runtime/hash_table/GPUcucohash/cuco_multi_table_ops.cuh.h"
#include "tensorflow/core/framework/resource_mgr.h"

namespace tensorflow {
namespace monolith_tf {

class GpuMultiHashTable : public ResourceBase {
public:
::monolith::hash_table::CucoMultiHashTableOp op;
explicit GpuMultiHashTable(
std::vector<int> slot_occ = {},
::monolith::hash_table::GpucucoEmbeddingHashTableConfig config_ = {})
: op(std::move(slot_occ), std::move(config_)) {}
std::string DebugString() const override {
return "This is a GPU multi hash table";
}
};

} // namespace monolith_tf
} // namespace tensorflow
#endif
#endif // MONOLITH_NATIVE_TRAINING_RUNTIME_OPS_GPU_MULTI_HASH_TABLE
9 changes: 4 additions & 5 deletions monolith/native_training/runtime/ops/hash_filter_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace tensorflow {
namespace monolith_tf {

using ::monolith::hash_filter::DummyHashFilter;
using ::monolith::hash_filter::SlidingHashFilter;
using ::monolith::hash_filter::ProbabilisticFilter;
using ::monolith::hash_filter::SlidingHashFilter;

class DummyFilterOp : public ResourceOpKernel<HashFilterTfBridge> {
public:
Expand Down Expand Up @@ -99,7 +99,7 @@ class ProbabilisticFilterOp : public ResourceOpKernel<HashFilterTfBridge> {
Status CreateResource(HashFilterTfBridge** filter_bridge)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) override {
auto filter = std::make_unique<ProbabilisticFilter>(equal_probability_);
*filter_bridge = new HashFilterTfBridge(std::move(filter), config_);
*filter_bridge = new HashFilterTfBridge(std::move(filter), config_, true);
return Status::OK();
};

Expand Down Expand Up @@ -132,9 +132,6 @@ REGISTER_OP("MonolithProbabilisticFilter")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);

REGISTER_KERNEL_BUILDER(Name("MonolithProbabilisticFilter").Device(DEVICE_CPU),
ProbabilisticFilterOp);

REGISTER_OP("MonolithDummyHashFilter")
.Output("handle: resource")
.Attr("container: string = ''")
Expand All @@ -143,6 +140,8 @@ REGISTER_OP("MonolithDummyHashFilter")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_KERNEL_BUILDER(Name("MonolithDummyHashFilter").Device(DEVICE_CPU),
DummyFilterOp);
REGISTER_KERNEL_BUILDER(Name("MonolithProbabilisticFilter").Device(DEVICE_CPU),
ProbabilisticFilterOp);

} // namespace monolith_tf
} // namespace tensorflow
8 changes: 5 additions & 3 deletions monolith/native_training/runtime/ops/hash_filter_tf_bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@
// limitations under the License.

#include "monolith/native_training/runtime/ops/hash_filter_tf_bridge.h"

#include "monolith/native_training/data/training_instance/cc/reader_util.h"

namespace tensorflow {
namespace monolith_tf {

using ::monolith::hash_filter::Filter;
using ::monolith::hash_table::HashFilterSplitMetaDump;
using ::monolith::hash_table::HashFilterSplitDataDump;
using ::monolith::hash_table::HashFilterSplitMetaDump;
using ::monolith::hash_table::SlotOccurrenceThresholdConfig;

HashFilterTfBridge::HashFilterTfBridge(
std::unique_ptr<Filter> filter, const SlotOccurrenceThresholdConfig& config)
: filter_(std::move(filter)) {
std::unique_ptr<Filter> filter, const SlotOccurrenceThresholdConfig& config,
bool is_probabilistic)
: filter_(std::move(filter)), is_probabilistic_(is_probabilistic) {
slot_to_occurrence_threshold_.resize(get_max_slot_number(),
config.default_occurrence_threshold());
for (const auto& slot_occurrence_threshold :
Expand Down
8 changes: 7 additions & 1 deletion monolith/native_training/runtime/ops/hash_filter_tf_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class HashFilterTfBridge : public ResourceBase {
public:
explicit HashFilterTfBridge(
std::unique_ptr<monolith::hash_filter::Filter> filter,
const monolith::hash_table::SlotOccurrenceThresholdConfig& config);
const monolith::hash_table::SlotOccurrenceThresholdConfig& config,
bool is_probabilistic = false);

bool ShouldBeFiltered(
int64_t id, int64_t count,
Expand Down Expand Up @@ -65,12 +66,17 @@ class HashFilterTfBridge : public ResourceBase {
get_meta_fn,
std::function<bool(::monolith::hash_table::HashFilterSplitDataDump*)>
get_data_fn) const;
const std::vector<int>& GetOccuranceThresholdArray() const {
return slot_to_occurrence_threshold_;
}
bool IsProbabilistic() const { return is_probabilistic_; }

private:
int GetSlotOccurrenceThreshold(int64_t fid) const;

std::unique_ptr<monolith::hash_filter::Filter> filter_;
std::vector<int> slot_to_occurrence_threshold_;
const bool is_probabilistic_;
};

// Carries the data through async process.
Expand Down
Loading

0 comments on commit 5260427

Please sign in to comment.