Skip to content

Commit

Permalink
Merge branch 'matching_leqi' into 'master'
Browse files Browse the repository at this point in the history
Fix bug when fileset contains tmp files

See merge request data/monolith!2013

GitOrigin-RevId: 7ef1b9c944123e3d4c4c833fbcb8a88d4c9dd380
  • Loading branch information
zlqiszlqbd authored and monolith committed Mar 14, 2023
1 parent 817451e commit e27ca15
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 13 deletions.
29 changes: 28 additions & 1 deletion monolith/native_training/runtime/ops/file_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ std::string GetShardedFileName(absl::string_view basename, int shard,
}

Status ValidateShardedFiles(absl::string_view basename,
absl::Span<const std::string> filenames) {
absl::Span<const std::string> filenames,
FileSpec* spec) {
std::vector<bool> show;
for (absl::string_view filename : filenames) {
if (filename.substr(0, basename.size()) != basename) {
Expand Down Expand Up @@ -69,8 +70,34 @@ Status ValidateShardedFiles(absl::string_view basename,
basename);
}
}

if (spec != nullptr) {
*spec = FileSpec::ShardedFileSpec(basename, show.size());
}
return Status::OK();
}

FileSpec FileSpec::ShardedFileSpec(absl::string_view prefix, int nshards) {
FileSpec spec;
spec.type_ = FileSpec::SHARDED_FILES;
spec.prefix_ = std::string(prefix);
spec.nshards_ = nshards;
return spec;
}

std::vector<std::string> FileSpec::GetFilenames() const {
std::vector<std::string> filenames;
switch (type_) {
case FileSpec::SHARDED_FILES:
for (int i = 0; i < nshards_; ++i) {
filenames.push_back(GetShardedFileName(prefix_, i, nshards_));
}
break;
default:
break;
}
return filenames;
}

} // namespace monolith_tf
} // namespace tensorflow
22 changes: 21 additions & 1 deletion monolith/native_training/runtime/ops/file_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,29 @@ namespace monolith_tf {
std::string GetShardedFileName(absl::string_view basename, int shard,
int nshards);

// A spec reprsents a set of files.
class FileSpec final {
public:
FileSpec() {}

static FileSpec ShardedFileSpec(absl::string_view prefix, int nshards);

std::vector<std::string> GetFilenames() const;

int nshards() const { return nshards_; }

private:
enum Type { UNKNOWN, SHARDED_FILES };

Type type_ = UNKNOWN;
std::string prefix_;
int nshards_ = 0;
};

// Validates if filenames construct a valid file spec for base name.
Status ValidateShardedFiles(absl::string_view basename,
absl::Span<const std::string> filenames);
absl::Span<const std::string> filenames,
FileSpec* spec = nullptr);

} // namespace monolith_tf
} // namespace tensorflow
Expand Down
13 changes: 12 additions & 1 deletion monolith/native_training/runtime/ops/file_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@

#include <string>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "tensorflow/core/lib/core/status_test_util.h"

namespace tensorflow {
namespace monolith_tf {
namespace {

using ::testing::ElementsAre;

TEST(ValidateShardedFilesTest, Basic) {
FileSpec spec;
TF_EXPECT_OK(ValidateShardedFiles("a/b", {"a/b-00000-of-00001"}));
TF_EXPECT_OK(ValidateShardedFiles(
"a/b", {"a/b-00000-of-00002", "a/b-00001-of-00002"}));
"a/b", {"a/b-00000-of-00002", "a/b-00001-of-00002"}, &spec));
EXPECT_THAT(spec.nshards(), 2);
TF_EXPECT_OK(ValidateShardedFiles(
"a", {"a-00000-of-00001", "a-00000-of-00001-tmp-1234"}));
TF_EXPECT_OK(ValidateShardedFiles(
Expand All @@ -43,6 +48,12 @@ TEST(ValidateShardedFilesTest, Basic) {
EXPECT_FALSE(ValidateShardedFiles("a/b", {"a/b-random-string"}).ok());
}

TEST(ValidateShardedFilesTest, FileSpecTest) {
auto spec = FileSpec::ShardedFileSpec("a/b", 2);
EXPECT_THAT(spec.GetFilenames(),
ElementsAre("a/b-00000-of-00002", "a/b-00001-of-00002"));
}

TEST(ValidateShardedFilesTest, LargeFileSet) {
std::vector<std::string> filenames;
for (int i = 0; i < 100; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ class HashFilterRestoreOp : public AsyncOpKernel {
ctx, ctx->env()->GetMatchingPaths(absl::StrCat(basename, "-*"), &files),
done);

OP_REQUIRES_OK_ASYNC(ctx, ValidateShardedFiles(basename, files), done);
OP_REQUIRES_ASYNC(ctx, !files.empty(),
FileSpec file_spec;
OP_REQUIRES_OK_ASYNC(ctx, ValidateShardedFiles(basename, files, &file_spec),
done);
OP_REQUIRES_ASYNC(ctx, file_spec.nshards() > 0,
errors::NotFound("Unable to find the dump files for: ",
name(), " in ", basename),
done);
ctx->set_output(0, ctx->input(0));
int nsplits = files.size();
int nsplits = file_spec.nshards();
auto pack = new HashFilterAsyncPack(ctx, hash_filter, basename,
std::move(done), nsplits);
for (int i = 0; i < nsplits; ++i) {
Expand Down
9 changes: 5 additions & 4 deletions monolith/native_training/runtime/ops/hash_table_restore_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ class HashTableRestoreOp : public AsyncOpKernel {
ctx, ctx->env()->GetMatchingPaths(absl::StrCat(basename, "-*"), &files),
done);

OP_REQUIRES_OK_ASYNC(ctx, ValidateShardedFiles(basename, files), done);
OP_REQUIRES_ASYNC(ctx, !files.empty(),
FileSpec file_spec;
OP_REQUIRES_OK_ASYNC(ctx, ValidateShardedFiles(basename, files, &file_spec),
done);
OP_REQUIRES_ASYNC(ctx, file_spec.nshards() > 0,
errors::NotFound("Unable to find the dump files for: ",
name(), " in ", basename),
done);
Expand Down Expand Up @@ -122,8 +124,7 @@ class HashTableRestoreOp : public AsyncOpKernel {
io::SequentialRecordReader reader(f.get(), opts);
Status restore_status;
auto get_fn = [&reader, &restore_status, &p](
EmbeddingHashTableTfBridge::EntryDump* dump,
int64_t* max_update_ts) {
EmbeddingHashTableTfBridge::EntryDump* dump, int64_t* max_update_ts) {
Status s = GetRecord(&reader, dump);
if (TF_PREDICT_FALSE(!s.ok())) {
if (!errors::IsOutOfRange(s)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,15 @@ class MultiHashTableRestoreOp : public AsyncOpKernel {
ctx, ctx->env()->GetMatchingPaths(absl::StrCat(basename, "-*"), &files),
done);

OP_REQUIRES_OK_ASYNC(ctx, ValidateShardedFiles(basename, files), done);
OP_REQUIRES_ASYNC(ctx, !files.empty(),
FileSpec file_spec;
OP_REQUIRES_OK_ASYNC(ctx, ValidateShardedFiles(basename, files, &file_spec),
done);
OP_REQUIRES_ASYNC(ctx, file_spec.nshards() > 0,
errors::NotFound("Unable to find the dump files for: ",
name(), " in ", basename),
done);

int nshards = files.size();
int nshards = file_spec.nshards();
auto pack = std::make_shared<const AsyncPack<TableType>>(
ctx, std::move(mtable), basename,
std::vector<std::unique_ptr<EmbeddingHashTableTfBridge::LockCtx>>(),
Expand Down

0 comments on commit e27ca15

Please sign in to comment.