Skip to content

Commit

Permalink
[Embedding] Fix BatchCache coredump in background thread. (DeepRec-AI…
Browse files Browse the repository at this point in the history
  • Loading branch information
candyzone authored Jul 5, 2022
1 parent 4520669 commit d806b3f
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 10 deletions.
4 changes: 4 additions & 0 deletions tensorflow/core/framework/embedding/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <unordered_map>
#include <set>
#include <list>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/lib/core/status.h"
Expand All @@ -17,6 +18,9 @@ template <class K>
class BatchCache {
public:
BatchCache() {}
void add_to_rank(const Tensor& t) {
add_to_rank((K*)t.data(), t.NumElements());
}
virtual size_t get_evic_ids(K* evic_ids, size_t k_size) = 0;
virtual void add_to_rank(const K* batch_ids, size_t batch_size) = 0;
virtual size_t size() = 0;
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/framework/embedding/embedding_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ class EmbeddingVar : public ResourceBase {
return storage_manager_->Size();
}

int64 CacheSize() const {
return storage_manager_->CacheSize();
}

int64 MinFreq() {
return emb_config_.filter_freq;
}
Expand Down
14 changes: 7 additions & 7 deletions tensorflow/core/framework/embedding/multilevel_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ class StorageManager {
return total_size;
}

int64 CacheSize() const {
return cache_capacity_;
}

Status GetSnapshot(std::vector<K>* key_list,
std::vector<ValuePtr<V>* >* value_ptr_list) {
for (auto kv : kvs_) {
Expand Down Expand Up @@ -375,7 +379,6 @@ class StorageManager {
Status Destroy() {
if (eviction_thread_) {
mutex_lock l(mu_);
shutdown_cv_.notify_all();
shutdown_ = true;
}
delete eviction_thread_;
Expand Down Expand Up @@ -432,9 +435,7 @@ class StorageManager {
if (shutdown_) {
break;
}
const int kTimeoutMilliseconds = 1;
WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);

// add WaitForMilliseconds() for sleep if necessary
for (int i = 0; i < value_ptr_out_of_date_.size(); i++) {
value_ptr_out_of_date_[i]->Destroy(kvs_[0].second);
delete value_ptr_out_of_date_[i];
Expand Down Expand Up @@ -478,10 +479,9 @@ class StorageManager {
BatchCache<K>* cache_;
int64 cache_capacity_;
mutex mu_;
condition_variable shutdown_cv_;
bool shutdown_ GUARDED_BY(mu_) = false;
volatile bool shutdown_ GUARDED_BY(mu_) = false;

bool done_ = false;
volatile bool done_ = false;
std::atomic_flag flag_ = ATOMIC_FLAG_INIT;

};
Expand Down
8 changes: 6 additions & 2 deletions tensorflow/core/kernels/kv_variable_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ class KvResourceGatherOp : public OpKernel {
errors::InvalidArgument(
"ev's value_len should same with output's dimension(1)",
std::to_string(slice_elems), std::to_string(ev->ValueLen())));
OP_REQUIRES(c, !ev->IsMultiLevel() || (ev->IsMultiLevel() && ev->CacheSize() >= N),
errors::InvalidArgument(
"MultiLevel EV's Cache size ", ev->CacheSize(),
" should large than IDs in batch ", N));
const size_t slice_bytes = slice_elems * sizeof(TValue);
auto do_work = [this, indices_flat,
out_base, slice_elems, c, default_v, ev, counts] (
Expand All @@ -436,10 +440,10 @@ class KvResourceGatherOp : public OpKernel {
worker_threads->workers, indices_size,
slice_bytes, do_work);

ev->storage_manager()->Schedule([ev, indices_flat, indices_size]() {
ev->storage_manager()->Schedule([ev, indices]() {
embedding::BatchCache<TKey>* cache = ev->Cache();
if (cache) {
cache->add_to_rank(indices_flat.data(), indices_size);
cache->add_to_rank(indices);
}
});
}
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/ops/embedding_variable_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,7 @@ def runTestAdagrad(self, var, g):
sess.run([init])
for i in xrange(60):
r, _, _ = sess.run([emb, train_op, loss])
r = sess.run(emb)
return r

with ops.Graph().as_default() as g:
Expand All @@ -1976,7 +1977,7 @@ def runTestAdagrad(self, var, g):
steps_to_live=5,
ev_option = variables.EmbeddingVariableOption(storage_option=variables.StorageOption(storage_type=config_pb2.StorageType.DRAM_SSDHASH,
storage_path="/tmp/ssd_utpy",
storage_size=[512])))
storage_size=[5120])))
emb1 = runTestAdagrad(self, emb_var, g)

with ops.Graph().as_default() as g:
Expand Down

0 comments on commit d806b3f

Please sign in to comment.