Skip to content

Commit

Permalink
Merge branch 'fix_item_pool_restore_bug' into 'master'
Browse files Browse the repository at this point in the history
fix item pool restore bug

See merge request data/monolith!2136

GitOrigin-RevId: 2e6c30536260831c8246cee6f0071545ddeace85
  • Loading branch information
李博 authored and monolith committed Oct 13, 2023
1 parent 24aa253 commit 9e1202c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 23 deletions.
39 changes: 27 additions & 12 deletions monolith/native_training/data/kernels/internal/cache_mgr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ namespace tensorflow {
namespace monolith_tf {
namespace internal {

std::shared_ptr<ItemFeatures> MakeItemFeaturesFromProto(
const ::monolith::io::proto::FeatureData& feature_data) {
std::shared_ptr<ItemFeatures> item_feature_ptr =
std::make_shared<ItemFeatures>();
item_feature_ptr->item_id = feature_data.gid();
for (const auto& fid : feature_data.fids()) {
item_feature_ptr->fids.push_back(fid);
}
for (const auto& fc : feature_data.feature_columns()) {
item_feature_ptr->example_features.emplace(fc.name(), fc);
}
return item_feature_ptr;
}

bool ItemFeatures::Equal(const ItemFeatures& other) const {
if (item_id != other.item_id) {
return false;
Expand Down Expand Up @@ -79,7 +93,8 @@ CacheWithGid::CacheWithGid(int max_item_num, int start_num)
: start_num_(start_num), max_item_num_(max_item_num) {}

void CacheWithGid::Push(uint64_t item_id,
std::shared_ptr<const ItemFeatures> item) {
std::shared_ptr<const ItemFeatures> item,
int64_t origin_cnt, int64_t sample_cnt) {
auto it = data_.find(item_id);
if (it == data_.end()) {
data_queue_.emplace_back(item_id);
Expand All @@ -89,14 +104,15 @@ void CacheWithGid::Push(uint64_t item_id,
auto iit = stats_.find(item_id);
if (iit == stats_.end()) {
auto stats_ptr = std::make_shared<GroupStat>();
stats_ptr->origin_cnt = 1;
stats_ptr->sample_cnt = 0;
stats_ptr->origin_cnt = origin_cnt;
stats_ptr->sample_cnt = sample_cnt;
stats_.emplace(item_id, stats_ptr);
} else {
iit->second->origin_cnt++;
iit->second->origin_cnt += origin_cnt;
iit->second->sample_cnt += sample_cnt;
}

if (data_queue_.size() > max_item_num_) {
if ((int64_t)data_queue_.size() > max_item_num_) {
uint64_t item_id = data_queue_.front();
data_.erase(item_id);
stats_.erase(item_id);
Expand All @@ -106,7 +122,7 @@ void CacheWithGid::Push(uint64_t item_id,

std::shared_ptr<const ItemFeatures> CacheWithGid::RandomSelectOne(
double* freq_factor, double* time_factor) const {
if (data_queue_.size() <= start_num_) {
if ((int64_t)data_queue_.size() <= start_num_) {
return nullptr;
}
thread_local std::mt19937 gen((std::random_device())());
Expand Down Expand Up @@ -228,18 +244,17 @@ std::shared_ptr<const ItemFeatures> CacheManager::RandomSelectOne(
}

void CacheManager::Push(uint64_t channel_id, uint64_t item_id,
const std::shared_ptr<const ItemFeatures>& item) {
const std::shared_ptr<const ItemFeatures>& item,
int64_t origin_cnt, int64_t sample_cnt) {
auto it = channel_cache_.find(channel_id);
if (it == channel_cache_.end()) {
LOG(INFO) << "Create channel(" << channel_id
<< ") in ItemPoolResource CacheManager";
auto ret = channel_cache_.emplace(
channel_id, CacheWithGid(max_item_num_per_channel_, start_num_));
it = ret.first;
}
it->second.Push(item_id, item);
}

void CacheManager::Push(uint64_t channel_id, const CacheWithGid& cwg) {
channel_cache_.emplace(channel_id, cwg);
it->second.Push(item_id, item, origin_cnt, sample_cnt);
}

absl::flat_hash_map<uint64_t, CacheWithGid>& CacheManager::GetCache() {
Expand Down
13 changes: 8 additions & 5 deletions monolith/native_training/data/kernels/internal/cache_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ struct ItemFeatures {
bool Equal(const ItemFeatures &other) const;
};

std::shared_ptr<ItemFeatures> MakeItemFeaturesFromProto(
const ::monolith::io::proto::FeatureData &feature_data);

class CacheWithGid {
public:
explicit CacheWithGid(int max_item_num, int start_num = 0);

void Push(uint64_t item_id, std::shared_ptr<const ItemFeatures> item);
void Push(uint64_t item_id, std::shared_ptr<const ItemFeatures> item,
int64_t origin_cnt = 1, int64_t sample_cnt = 0);

std::shared_ptr<const ItemFeatures> RandomSelectOne(
double *freq_factor, double *time_factor) const;
Expand All @@ -67,7 +71,7 @@ class CacheWithGid {

bool Equal(const CacheWithGid &other) const;

inline int Size() { return data_queue_.size(); }
inline int Size() const { return data_queue_.size(); }

private:
int start_num_;
Expand All @@ -86,9 +90,8 @@ class CacheManager {
uint64_t channel_id, double *freq_factor, double *time_factor) const;

void Push(uint64_t channel_id, uint64_t item_id,
const std::shared_ptr<const ItemFeatures> &item);

void Push(uint64_t channel_id, const CacheWithGid &cwg);
const std::shared_ptr<const ItemFeatures> &item,
int64_t origin_cnt = 1, int64_t sample_cnt = 0);

absl::flat_hash_map<uint64_t, CacheWithGid> &GetCache();

Expand Down
17 changes: 11 additions & 6 deletions monolith/native_training/data/kernels/item_pool_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Status ItemPoolResource::Add(
uint64_t channel_id, uint64_t item_id,
const std::shared_ptr<const internal::ItemFeatures>& item) {
absl::MutexLock l(&mu_);
cache_->Push(channel_id, item_id, item);
cache_->Push(channel_id, item_id, item, 1, 0);
return Status::OK();
}

Expand Down Expand Up @@ -139,10 +139,15 @@ Status ItemPoolResource::Restore(RandomAccessFile* istream, int64 buffer_size) {
restore_status.Update(Status::OK());
}

internal::CacheWithGid cache_with_gid(max_item_num_per_channel_,
start_num_);
cache_with_gid.FromProto(channel_cache);
cache_->Push(channel_cache.channel_id(), cache_with_gid);
for (const auto& feature_data : channel_cache.feature_datas()) {
auto item_feature_ptr = internal::MakeItemFeaturesFromProto(feature_data);
cache_->Push(channel_cache.channel_id(), feature_data.gid(),
item_feature_ptr, feature_data.origin_cnt(),
feature_data.sample_cnt());
}
LOG(INFO) << absl::StrFormat(
"ItemPoolResource: after restore, channel %lld restore %llu items",
channel_cache.channel_id(), channel_cache.feature_datas_size());
}

TF_RETURN_IF_ERROR(restore_status);
Expand Down Expand Up @@ -567,7 +572,7 @@ class ItemPoolRestoreOp : public AsyncOpKernel {
ctx->env()->GetMatchingPaths(fuzzy_matching_path, &files_fuzzy);
if (fuzzy_match.ok() && !files_fuzzy.empty()) {
int ckpt_num = FindFuzzyCkptNumber(files_fuzzy);
if (ckpt_num <= global_step_ && ckpt_num > global_step_ * 9 / 10) {
if (ckpt_num <= global_step_) {
return absl::StrCat(model_path_, "/model.ckpt-", ckpt_num, "_",
FILE_NAME_PREFIX, index % last_save_worker_num,
"_", shard_index, "_", last_save_worker_num);
Expand Down

0 comments on commit 9e1202c

Please sign in to comment.