Skip to content

Commit

Permalink
[Embedding] Fix: EV save checkpoint incorrectly which use leveldb sto…
Browse files Browse the repository at this point in the history
…rage.
  • Loading branch information
candyzone committed Apr 29, 2022
1 parent bc2ed24 commit 9ede09d
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
4 changes: 4 additions & 0 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ class EmbeddingVar : public ResourceBase {
return storage_manager_->Cache();
}

int64 GetEmbeddingIndex() {
return emb_config_.emb_index;
}

private:
std::string name_;
bool is_initialized_ = false;
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/kv_variable_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ Status DumpEmbeddingValues(EmbeddingVar<K, V>* ev, const string& tensor_key, Bun
st = SaveTensorWithFixedBuffer(tensor_key + "-keys", writer, dump_buffer,
bytes_limit, &ev_key_dump_iter,
TensorShape({partitioned_tot_key_list.size() + iterator_size}),
it, true);
it);
if (!st.ok()) {
free(dump_buffer);
return st;
Expand All @@ -328,7 +328,7 @@ Status DumpEmbeddingValues(EmbeddingVar<K, V>* ev, const string& tensor_key, Bun
st = SaveTensorWithFixedBuffer(tensor_key + "-values", writer, dump_buffer,
bytes_limit, &ev_value_dump_iter,
TensorShape({partitioned_tot_key_list.size() + iterator_size, ev->ValueLen()}),
it, false);
it, ev->storage_manager()->GetOffset(ev->GetEmbeddingIndex()));
if (!st.ok()) {
free(dump_buffer);
return st;
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/kernels/save_restore_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ Status SaveTensorWithFixedBuffer(const string& tensor_name,
DumpIterator<T>* dump_iter,
const TensorShape& dump_tensor_shape,
embedding::Iterator* it = nullptr,
bool is_key = true,
int64 value_offset = -1, // -1: save key, x_offset: save embedding(primary or slot offset)
bool use_shape = true) {
bool dump_happened = false;
size_t bytes_written = 0;
Expand Down Expand Up @@ -152,7 +152,7 @@ Status SaveTensorWithFixedBuffer(const string& tensor_name,
std::string value_str;
int64 dim = 0;
void* start = nullptr;
if (is_key) {
if (value_offset == -1) {
value_str = it->Key();

if (bytes_written + sizeof(T) > bytes_limit) {
Expand All @@ -178,7 +178,7 @@ Status SaveTensorWithFixedBuffer(const string& tensor_name,
bytes_written = 0;
buffer_idx = 0;
}
key_dump_buffer[buffer_idx] = *((T*)start + j);
key_dump_buffer[buffer_idx] = *((T*)start + j + value_offset);
buffer_idx++;
bytes_written += sizeof(T);
total_bytes_written += sizeof(T);
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/python/ops/embedding_variable_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,13 @@ def testLevelDBCheckpoint(self):
self.assertEqual(ckpt_value.tolist()[0][j], r[0][j])
self.assertEqual(ckpt_value.tolist()[1][j], r[3][j])
self.assertEqual(ckpt_value.tolist()[2][j], r[5][j])
if name == "var_1/AdagradDecay-values":
ckpt_value = checkpoint_utils.load_variable(model_path, name)
slot = [[72.1, 72.1, 72.1], [32.1, 32.1, 32.1], [8.1, 8.1, 8.1]]
for j in range(0, 3):
self.assertAlmostEqual(ckpt_value.tolist()[0][j], slot[0][j], delta=1e-5)
self.assertAlmostEqual(ckpt_value.tolist()[1][j], slot[1][j], delta=1e-5)
self.assertAlmostEqual(ckpt_value.tolist()[2][j], slot[2][j], delta=1e-5)
with self.test_session() as sess:
saver.restore(sess, model_path)
r1, _, _ = sess.run([emb, train_op,loss])
Expand Down

0 comments on commit 9ede09d

Please sign in to comment.