Skip to content

Commit

Permalink
Merge branch 'gpu_cpu_coop_rb' into 'master'
Browse files Browse the repository at this point in the history
Allow CPU/GPU hash table code to compile together under `--config=cuda`

See merge request data/monolith!1915

GitOrigin-RevId: 197d2650389d3c9c797cf1d2dea6d5ac7daea9ba
  • Loading branch information
hanzhi713 authored and zlqiszlqbd committed Jan 4, 2023
1 parent 8146308 commit ad64c2d
Show file tree
Hide file tree
Showing 13 changed files with 241 additions and 230 deletions.
4 changes: 2 additions & 2 deletions monolith/native_training/distributed_ps_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def lookup(self,
# fused_embeddings: [E], fused_splits: [N]
# id_offsets: [K*N], emb_offsets: [K*N]
with tf.device("/GPU:0"):
fused_embeddings, embedding_splits, id_offsets, emb_offsets, fused_emb_sizes = \
fused_embeddings, embedding_splits, id_offsets, emb_offsets = \
self._table.fused_lookup(id_flat_t, id_size_flat_t, self._shard_num)
if FLAGS.enable_alltoall_metrics:
with tf.device("/CPU:0"):
Expand Down Expand Up @@ -348,7 +348,7 @@ def assign_add(
# Apply_gradients uses fused update.
def apply_gradients(
self,
slot_to_grad: Dict[str, Tensor],
slot_to_grad: Dict[str, tf.Tensor],
auxiliary_bundle: Dict[str, tf.Tensor],
global_step: tf.Tensor,
req_time: tf.Tensor = None) -> DistributedMultiTypeHashTableMpi:
Expand Down
8 changes: 3 additions & 5 deletions monolith/native_training/gpu_hash_table_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,12 @@ def test_fused_lookup(self):
fused_slot_size=tf.constant([1, 2, 1, 1, 1, 1]),
num_of_shards=2)

embeddings, recv_splits, id_offsets, emb_offsets, emb_dims = sess.run(
embeddings, recv_splits, id_offsets, emb_offsets = sess.run(
embeddings)
self.assertAllEqual(embeddings, [1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1])
self.assertAllEqual(recv_splits, [8, 6])
self.assertAllEqual(id_offsets, [0, 1, 3, 4, 5, 6])
self.assertAllEqual(emb_offsets, [0, 1, 5, 8, 9, 11])
self.assertAllEqual(emb_dims, [1, 4, 3, 1, 2, 3])

@test_util.run_gpu_only
def test_fused_optimize(self):
Expand All @@ -169,7 +168,7 @@ def test_fused_optimize(self):
hash_tables.append(hash_table)
hash_table_resource = [hash_table.table for hash_table in hash_tables]
#embeddings=[1, 0, 0, 1, 0, 0]
embeddings, recv_splits, id_offsets, emb_offsets, emb_dims = ops.fused_lookup(
embeddings, recv_splits, id_offsets, emb_offsets = ops.fused_lookup(
hash_table_resource, ids, fused_slot_size, num_of_shards=2)
new_tables = ops.fused_apply_gradient(hash_table_resource,
ids,
Expand All @@ -191,15 +190,14 @@ def test_fused_optimize(self):
ids,
fused_slot_size,
num_of_shards=2)
embeddings, recv_splits, id_offsets, emb_offsets, emb_dims = sess.run(
embeddings, recv_splits, id_offsets, emb_offsets = sess.run(
lookup_op)
self.assertAllClose(
embeddings,
[1.0953462, 0.09877297, 0.09877297, 1.0953462, 0.09877297, 0.09877297])
self.assertAllEqual(recv_splits, [3, 3])
self.assertAllEqual(id_offsets, [0, 1, 2, 3])
self.assertAllEqual(emb_offsets, [0, 1, 3, 4])
self.assertAllEqual(emb_dims, [1, 2, 1, 2])


if __name__ == '__main__':
Expand Down
8 changes: 3 additions & 5 deletions monolith/native_training/hash_table_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,13 +1099,12 @@ def test_fused_lookup(self):
_get_id_tensor([0, 4, 6, 1, 3, 7]),
fused_slot_size=tf.constant([1, 1, 1, 1, 1, 1]),
num_of_shards=2)
embeddings, recv_splits, id_offsets, emb_offsets, emb_dims = sess.run(
embeddings, recv_splits, id_offsets, emb_offsets = sess.run(
embeddings)
self.assertAllEqual(embeddings, [1, 0, 1, 1, 1, 0, 1, 1])
self.assertAllEqual(recv_splits, [4, 4])
self.assertAllEqual(id_offsets, [0, 1, 2, 3, 4, 5, 6])
self.assertAllEqual(emb_offsets, [0, 1, 2, 4, 5, 6, 8])
self.assertAllEqual(emb_dims, [1, 1, 2, 1, 1, 2])

def test_fused_optimize(self):
with tf.compat.v1.Session() as sess:
Expand All @@ -1122,7 +1121,7 @@ def test_fused_optimize(self):
hash_tables.append(hash_table)
hash_table_resource = [hash_table.table for hash_table in hash_tables]
#embeddings=[1, 0, 0, 1, 0, 0]
embeddings, recv_splits, id_offsets, emb_offsets, emb_dims = ops.fused_lookup(
embeddings, recv_splits, id_offsets, emb_offsets = ops.fused_lookup(
hash_table_resource, ids, fused_slot_size, num_of_shards=2)
new_tables = ops.fused_apply_gradient(hash_table_resource,
ids,
Expand All @@ -1143,13 +1142,12 @@ def test_fused_optimize(self):
ids,
fused_slot_size,
num_of_shards=2)
embeddings, recv_splits, id_offsets, emb_offsets, emb_dims = sess.run(
embeddings, recv_splits, id_offsets, emb_offsets = sess.run(
lookup_op)
self.assertAllClose(embeddings, [1.1, 0.2, 0.2, 1.1, 0.2, 0.2])
self.assertAllEqual(recv_splits, [3, 3])
self.assertAllEqual(id_offsets, [0, 1, 2, 3, 4])
self.assertAllEqual(emb_offsets, [0, 1, 3, 4, 6])
self.assertAllEqual(emb_dims, [1, 2, 1, 2])

def test_batch_softmax_optimizer(self):
table_config = embedding_hash_table_pb2.EmbeddingHashTableConfig()
Expand Down
16 changes: 6 additions & 10 deletions monolith/native_training/multi_type_hash_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,12 @@ def test_fused_lookup(self):
"slot2": (_id([2, 3]), _value([[4, 4], [8, 8]]))
})
values_dict = hash_table.fused_lookup([0, 1, 2, 3], [1, 1, 2], 1)
embeddings, recv_splits, id_offsets, emb_offsets, emb_sizes = sess.run(
embeddings, recv_splits, id_offsets, emb_offsets = sess.run(
values_dict)
self.assertAllEqual(embeddings, [1, 2, 2, 4, 4, 8, 8])
self.assertAllEqual(recv_splits, [7])
self.assertAllEqual(id_offsets, [0, 1, 2, 4])
self.assertAllEqual(emb_offsets, [0, 1, 3, 7])
self.assertAllEqual(emb_sizes, [1, 2, 4])

def test_fused_lookup_multi_shards(self):
with self.session() as sess:
Expand All @@ -189,13 +188,12 @@ def test_fused_lookup_multi_shards(self):
"slot2": (_id([2, 3]), _value([[4, 4], [8, 8]]))
})
values_dict = hash_table.fused_lookup([0, 2, 1, 3], [1, 0, 1, 0, 1, 1], 2)
embeddings, recv_splits, id_offsets, emb_offsets, emb_sizes = sess.run(
embeddings, recv_splits, id_offsets, emb_offsets = sess.run(
values_dict)
self.assertAllEqual(embeddings, [1, 4, 4, 2, 2, 8, 8])
self.assertAllEqual(recv_splits, [3, 4])
self.assertAllEqual(id_offsets, [0, 1, 1, 2, 2, 3, 4])
self.assertAllEqual(emb_offsets, [0, 1, 1, 3, 3, 5, 7])
self.assertAllEqual(emb_sizes, [1, 0, 2, 0, 2, 2])

def test_fused_apply_gradients(self):
with self.session() as sess:
Expand All @@ -206,20 +204,19 @@ def test_fused_apply_gradients(self):
}, factory)
ids = tf.constant([0, 1, 2], dtype=tf.int64)
fused_slot_size = tf.constant([1, 2])
embeddings, _, id_offsets, emb_offsets, _ = hash_table.fused_lookup(
embeddings, _, id_offsets, emb_offsets = hash_table.fused_lookup(
ids, fused_slot_size, 1)
grads = tf.constant([2.0, 1.0, 3.0, 2.0, 4.0])
hash_table = hash_table.fused_apply_gradient(
ids, fused_slot_size, grads, id_offsets, emb_offsets,
tf.constant(0, dtype=tf.int64), tf.constant(0, dtype=tf.int64), 1)
lookup_op = hash_table.fused_lookup(ids, fused_slot_size, 1)
embeddings, recv_splits, id_offsets, emb_offsets, emb_sizes = sess.run(
embeddings, recv_splits, id_offsets, emb_offsets = sess.run(
lookup_op)
self.assertAllEqual(embeddings, [-2, -1, -3, -2, -4])
self.assertAllEqual(recv_splits, [5])
self.assertAllEqual(id_offsets, [0, 1, 3])
self.assertAllEqual(emb_offsets, [0, 1, 5])
self.assertAllEqual(emb_sizes, [1, 4])

def test_fused_apply_gradients_missing_tables(self):
with self.session() as sess:
Expand All @@ -230,20 +227,19 @@ def test_fused_apply_gradients_missing_tables(self):
}, factory)
ids = tf.constant([1, 1], dtype=tf.int64)
fused_slot_size = tf.constant([1, 0, 1, 0])
embeddings, _, id_offsets, emb_offsets, _ = hash_table.fused_lookup(
embeddings, _, id_offsets, emb_offsets = hash_table.fused_lookup(
ids, fused_slot_size, 2)
grads = tf.constant([1.0, 2.0])
hash_table = hash_table.fused_apply_gradient(
ids, fused_slot_size, grads, id_offsets, emb_offsets,
tf.constant(0, dtype=tf.int64), tf.constant(0, dtype=tf.int64), 2)
lookup_op = hash_table.fused_lookup(ids, fused_slot_size, 2)
embeddings, recv_splits, id_offsets, emb_offsets, emb_sizes = sess.run(
embeddings, recv_splits, id_offsets, emb_offsets = sess.run(
lookup_op)
self.assertAllEqual(embeddings, [-3, -3])
self.assertAllEqual(recv_splits, [1, 1])
self.assertAllEqual(id_offsets, [0, 1, 1, 2, 2])
self.assertAllEqual(emb_offsets, [0, 1, 1, 2, 2])
self.assertAllEqual(emb_sizes, [1, 0, 1, 0])


def _multi_type_factory(slot_to_config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@ namespace hash_table {

std::unique_ptr<EmbeddingHashTableInterface> NewEmbeddingHashTableFromConfig(
EmbeddingHashTableConfig config, GpuExtraArgs args) {
std::unique_ptr<EntryAccessorInterface> accessor =
NewEntryAccessor(config.entry_config());
switch (config.type_case()) {
case EmbeddingHashTableConfig::kCuckoo:
return NewCuckooEmbeddingHashTable(
config.cuckoo(), std::move(accessor), config.entry_type(),
config.initial_capacity(), config.slot_expire_time_config());
config.cuckoo(), NewEntryAccessor(config.entry_config()),
config.entry_type(), config.initial_capacity(),
config.slot_expire_time_config());
default:
throw std::invalid_argument(absl::StrFormat(
"Unknown type of hash table. %s", config.ShortDebugString()));
Expand Down
1 change: 1 addition & 0 deletions monolith/native_training/runtime/ops/hash_filter_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ REGISTER_OP("MonolithDummyHashFilter")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);

REGISTER_KERNEL_BUILDER(Name("MonolithDummyHashFilter").Device(DEVICE_CPU),
DummyFilterOp);
REGISTER_KERNEL_BUILDER(Name("MonolithProbabilisticFilter").Device(DEVICE_CPU),
Expand Down
2 changes: 2 additions & 0 deletions monolith/native_training/runtime/ops/hash_table/misc_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ REGISTER_OP("MonolithHashTableSize")
.Input("table_handle: resource")
.Output("size: int64")
.SetShapeFn(shape_inference::ScalarShape);

REGISTER_KERNEL_BUILDER(Name("MonolithHashTableSize").Device(DEVICE_CPU),
HashTableSizeOp);

class SaveAsTensorOp : public OpKernel {
public:
explicit SaveAsTensorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
Expand Down
Loading

0 comments on commit ad64c2d

Please sign in to comment.