Skip to content

Commit

Permalink
Merge branch 'fix_item_pool_ckpt_restore' into 'master'
Browse files Browse the repository at this point in the history
fix item_pool ckpt restore

See merge request data/monolith!2106

GitOrigin-RevId: d7a67a84f7bff000634dceb88a66acb8a632125b
  • Loading branch information
李博 authored and monolith committed Sep 14, 2023
1 parent 26cba68 commit 88095e5
Showing 1 changed file with 60 additions and 23 deletions.
83 changes: 60 additions & 23 deletions monolith/native_training/data/kernels/item_pool_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "absl/random/random.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
Expand Down Expand Up @@ -507,58 +508,94 @@ class ItemPoolRestoreOp : public AsyncOpKernel {
int FindLastNumber(std::vector<std::string> const &files, OpKernelContext* ctx) {
// 支持 restore 时的 worker_num 可以和 save 时不同
int last_worker_num = 1;
int64 mtime_nsec = 0;
for (const auto& file : files) {
if (absl::EndsWith(file, "tmp")) {
LOG(INFO) << "Files vector contains file with tmp suffix.";
continue;
}
FileStatistics stat;
ctx->env()->Stat(file, &stat);
if (mtime_nsec < stat.mtime_nsec) {
std::vector<absl::string_view> items = absl::StrSplit(file, "_");
absl::SimpleAtoi(items.back(), &last_worker_num);
mtime_nsec = stat.mtime_nsec;
std::vector<absl::string_view> items = absl::StrSplit(file, "_");
if (!items.empty() && absl::SimpleAtoi(items.back(), &last_worker_num)) {
break;
}
}
return last_worker_num;
}

int FindFuzzyCkptNumber(const std::vector<std::string>& files) {
int max_match_step = -1;
for (const std::string& file : files) {
LOG(INFO) << "match fuzzy ckpt:" << file;
if (absl::EndsWith(file, "tmp")) {
LOG(INFO) << "Files vector contains file with tmp suffix.";
continue;
}
// file like "xxx/model.ckpt-25541095_item_pool_28_0_60"
std::vector<absl::string_view> items =
absl::StrSplit(file, absl::ByAnyChar("-_"));
CHECK_GT(items.size(), 6) << absl::StrFormat(
"item_pool ckpt's filepath is not correct: %s", file);
int global_step = -1;
CHECK(absl::SimpleAtoi(items.at(items.size() - 6), &global_step));
if (global_step > max_match_step) {
max_match_step = global_step;
}
}
return max_match_step;
}

std::string GetRestoreFileName(OpKernelContext* ctx, int shard_index) {
int index, worker_num;
get_index_and_worker_num(&index, &worker_num);
std::vector<std::string> files_new;
std::vector<std::string> files_old;

// the global step of chief's item_pool ckpt is correct
Status s_new = ctx->env()->GetMatchingPaths(
absl::StrCat(model_path_, "/model.ckpt-", global_step_, "_",
FILE_NAME_PREFIX, "*"),
&files_new);
Status s_old = ctx->env()->GetMatchingPaths(
absl::StrCat(model_path_, "/", FILE_NAME_PREFIX, "*"), &files_old);
if (!s_new.ok() && !s_old.ok()) {
LOG(INFO) << "GetMatchingPaths Error: [new] " << s_new << " and [old] " << s_old;
return "";

if (s_new.ok() && !files_new.empty()) {
int last_save_worker_num = FindLastNumber(files_new, ctx);
LOG(INFO) << "last worker num is: " << last_save_worker_num;
std::vector<std::string> files_fuzzy;
std::string fuzzy_matching_path =
absl::StrCat(model_path_, "/model.ckpt-", "*", "_", FILE_NAME_PREFIX,
index % last_save_worker_num, "_", shard_index, "_",
last_save_worker_num);
Status fuzzy_match =
ctx->env()->GetMatchingPaths(fuzzy_matching_path, &files_fuzzy);
if (fuzzy_match.ok() && !files_fuzzy.empty()) {
int ckpt_num = FindFuzzyCkptNumber(files_fuzzy);
if (ckpt_num <= global_step_ && ckpt_num > global_step_ * 9 / 10) {
return absl::StrCat(model_path_, "/model.ckpt-", ckpt_num, "_",
FILE_NAME_PREFIX, index % last_save_worker_num,
"_", shard_index, "_", last_save_worker_num);
} else {
LOG(INFO) << absl::StrFormat(
"step not match: fuzzy match step is %d, target global step is "
"%d",
ckpt_num, global_step_);
}
} else {
LOG(INFO) << absl::StrFormat("path not match: %s", fuzzy_matching_path);
}
}

if (files_new.size() > 0) {
// {model_path}/item_pool_{index}_{worker_num}
LOG(INFO) << "new version files > 0";
int last_worker_num = FindLastNumber(files_new, ctx);
LOG(INFO) << "last worker num is: " << last_worker_num;
return absl::StrCat(model_path_, "/model.ckpt-", global_step_, "_",
FILE_NAME_PREFIX, index % last_worker_num, "_",
shard_index, "_", last_worker_num);
} else if (files_old.size() > 0) {
Status s_old = ctx->env()->GetMatchingPaths(
absl::StrCat(model_path_, "/", FILE_NAME_PREFIX, "*"), &files_old);
if (s_old.ok() && !files_old.empty()) {
LOG(INFO) << "old version files > 0";
int last_worker_num = FindLastNumber(files_old, ctx);
LOG(INFO) << "last worker num is: " << last_worker_num;
return absl::StrCat(model_path_, "/", FILE_NAME_PREFIX,
index % last_worker_num, "_", shard_index, "_",
last_worker_num);
} else {
return "";
}

LOG(INFO) << "GetMatchingPaths Error: [new] " << s_new << " and [old] "
<< s_old;
return "";
}
};

Expand Down

0 comments on commit 88095e5

Please sign in to comment.