forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port interface of store base class from Caffe2 (pytorch#7439)
The file store implementation is new and based on the file initialization method (which uses a single file and file locking) and the interface of the Caffe2 store handler. See pytorch#7434.
- Loading branch information
Showing
9 changed files
with
536 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
cmake_minimum_required(VERSION 3.2 FATAL_ERROR) | ||
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake ${CMAKE_MODULE_PATH}) | ||
|
||
add_library(store Store.cpp FileStore.cpp) | ||
target_compile_options(store PUBLIC "-std=c++11") | ||
|
||
enable_testing() | ||
add_subdirectory(test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
#include "FileStore.hpp" | ||
|
||
#include <assert.h> | ||
#include <stdint.h> | ||
#include <sys/file.h> | ||
#include <fcntl.h> | ||
#include <unistd.h> | ||
#include <sys/stat.h> | ||
|
||
#include <chrono> | ||
#include <functional> | ||
#include <iostream> | ||
#include <limits> | ||
#include <sstream> | ||
#include <system_error> | ||
#include <thread> | ||
|
||
#define SYSASSERT(rv, ...) \ | ||
if ((rv) < 0) { \ | ||
throw std::system_error( \ | ||
errno, \ | ||
std::system_category(), \ | ||
##__VA_ARGS__); \ | ||
} | ||
|
||
namespace c10d { | ||
|
||
namespace { | ||
|
||
template<typename F> | ||
typename std::result_of<F()>::type syscall(F fn) { | ||
while (true) { | ||
auto rv = fn(); | ||
if (rv == -1) { | ||
if (errno == EINTR) { | ||
continue; | ||
} | ||
} | ||
return rv; | ||
} | ||
} | ||
|
||
// For a comprehensive overview of file locking methods, | ||
// see: https://gavv.github.io/blog/file-locks/. | ||
// We stick to flock(2) here because we don't care about | ||
// locking byte ranges and don't want locks to be process-wide. | ||
|
||
// RAII wrapper around flock(2) | ||
class Lock { | ||
public: | ||
explicit Lock(int fd, int operation) : fd_(fd) { | ||
flock(operation); | ||
} | ||
|
||
~Lock() { | ||
unlock(); | ||
} | ||
|
||
Lock(const Lock& that) = delete; | ||
|
||
Lock(Lock&& other) noexcept { | ||
fd_ = other.fd_; | ||
other.fd_ = -1; | ||
} | ||
|
||
void unlock() { | ||
if (fd_ >= 0) { | ||
flock(LOCK_UN); | ||
fd_ = -1; | ||
} | ||
} | ||
|
||
protected: | ||
int fd_; | ||
|
||
void flock(int operation) { | ||
auto rv = syscall(std::bind(::flock, fd_, operation)); | ||
SYSASSERT(rv, "flock"); | ||
} | ||
}; | ||
|
||
class File { | ||
public: | ||
explicit File(const std::string& path, int flags) { | ||
fd_ = syscall(std::bind(::open, path.c_str(), flags, 0644)); | ||
SYSASSERT(fd_, "open(" + path + ")"); | ||
} | ||
|
||
~File() { | ||
::close(fd_); | ||
} | ||
|
||
Lock lockShared() { | ||
return Lock(fd_, LOCK_SH); | ||
} | ||
|
||
Lock lockExclusive() { | ||
return Lock(fd_, LOCK_EX); | ||
} | ||
|
||
off_t seek(off_t offset, int whence) { | ||
auto rv = syscall(std::bind(lseek, fd_, offset, whence)); | ||
SYSASSERT(rv, "lseek"); | ||
return rv; | ||
} | ||
|
||
off_t tell() { | ||
auto rv = syscall(std::bind(lseek, fd_, 0, SEEK_CUR)); | ||
SYSASSERT(rv, "lseek"); | ||
return rv; | ||
} | ||
|
||
off_t size() { | ||
auto pos = tell(); | ||
auto size = seek(0, SEEK_END); | ||
seek(pos, SEEK_SET); | ||
return size; | ||
} | ||
|
||
void write(const void* buf, size_t count) { | ||
while (count > 0) { | ||
auto rv = syscall(std::bind(::write, fd_, buf, count)); | ||
SYSASSERT(rv, "write"); | ||
buf = (uint8_t*) buf + count; | ||
count -= rv; | ||
} | ||
} | ||
|
||
void read(void* buf, size_t count) { | ||
while (count > 0) { | ||
auto rv = syscall(std::bind(::read, fd_, buf, count)); | ||
SYSASSERT(rv, "read"); | ||
buf = (uint8_t*) buf + count; | ||
count -= rv; | ||
} | ||
} | ||
|
||
void write(const std::string& str) { | ||
uint32_t len = str.size(); | ||
assert(str.size() <= std::numeric_limits<decltype(len)>::max()); | ||
write(&len, sizeof(len)); | ||
write(str.c_str(), len); | ||
} | ||
|
||
void write(const std::vector<uint8_t>& data) { | ||
uint32_t len = data.size(); | ||
assert(data.size() <= std::numeric_limits<decltype(len)>::max()); | ||
write(&len, sizeof(len)); | ||
write(data.data(), len); | ||
} | ||
|
||
void read(std::string& str) { | ||
uint32_t len; | ||
read(&len, sizeof(len)); | ||
std::vector<uint8_t> buf(len); | ||
read(buf.data(), len); | ||
str.assign(buf.begin(), buf.end()); | ||
} | ||
|
||
void read(std::vector<uint8_t>& data) { | ||
uint32_t len; | ||
read(&len, sizeof(len)); | ||
data.resize(len); | ||
read(data.data(), len); | ||
} | ||
|
||
protected: | ||
int fd_; | ||
}; | ||
|
||
off_t refresh( | ||
File& file, | ||
off_t pos, | ||
std::unordered_map<std::string, std::vector<uint8_t>>& cache) { | ||
auto size = file.size(); | ||
if (size != pos) { | ||
std::string tmpKey; | ||
std::vector<uint8_t> tmpValue; | ||
file.seek(pos, SEEK_SET); | ||
while (size > pos) { | ||
file.read(tmpKey); | ||
file.read(tmpValue); | ||
cache[tmpKey] = std::move(tmpValue); | ||
pos = file.tell(); | ||
} | ||
} | ||
return pos; | ||
} | ||
|
||
} // namespace | ||
|
||
FileStore::FileStore(const std::string& path) | ||
: Store(), | ||
path_(path), | ||
pos_(0) { | ||
} | ||
|
||
FileStore::~FileStore() { | ||
} | ||
|
||
void FileStore::set(const std::string& key, const std::vector<uint8_t>& value) { | ||
File file(path_, O_RDWR | O_CREAT); | ||
auto lock = file.lockExclusive(); | ||
file.seek(0, SEEK_END); | ||
file.write(key); | ||
file.write(value); | ||
} | ||
|
||
std::vector<uint8_t> FileStore::get(const std::string& key) { | ||
while (cache_.count(key) == 0) { | ||
File file(path_, O_RDONLY); | ||
auto lock = file.lockShared(); | ||
auto size = file.size(); | ||
if (size == pos_) { | ||
// No new entries; release the shared lock and sleep for a bit | ||
lock.unlock(); | ||
std::this_thread::sleep_for(std::chrono::milliseconds(10)); | ||
continue; | ||
} | ||
|
||
pos_ = refresh(file, pos_, cache_); | ||
} | ||
|
||
return cache_[key]; | ||
} | ||
|
||
int64_t FileStore::add(const std::string& key, int64_t i) { | ||
File file(path_, O_RDWR | O_CREAT); | ||
auto lock = file.lockExclusive(); | ||
pos_ = refresh(file, pos_, cache_); | ||
|
||
const auto& value = cache_[key]; | ||
int64_t ti = i; | ||
if (!value.empty()) { | ||
auto buf = reinterpret_cast<const char*>(value.data()); | ||
auto len = value.size(); | ||
ti += std::stoll(std::string(buf, len)); | ||
} | ||
|
||
// File cursor is at the end of the file now, and we have an | ||
// exclusive lock, so we can write the new value. | ||
file.write(key); | ||
file.write(std::to_string(ti)); | ||
|
||
return ti; | ||
} | ||
|
||
bool FileStore::check(const std::vector<std::string>& keys) { | ||
File file(path_, O_RDONLY); | ||
auto lock = file.lockShared(); | ||
pos_ = refresh(file, pos_, cache_); | ||
|
||
for (const auto& key : keys) { | ||
if (cache_.count(key) == 0) { | ||
return false; | ||
} | ||
} | ||
|
||
return true; | ||
} | ||
|
||
void FileStore::wait( | ||
const std::vector<std::string>& keys, | ||
const std::chrono::milliseconds& timeout) { | ||
// Not using inotify because it doesn't work on many | ||
// shared filesystems (such as NFS). | ||
const auto start = std::chrono::steady_clock::now(); | ||
while (!check(keys)) { | ||
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>( | ||
std::chrono::steady_clock::now() - start); | ||
if (timeout != kNoTimeout && elapsed > timeout) { | ||
throw std::runtime_error("Wait timeout"); | ||
} | ||
|
||
/* sleep override */ | ||
std::this_thread::sleep_for(std::chrono::milliseconds(10)); | ||
} | ||
} | ||
|
||
} // namespace c10d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#pragma once | ||
|
||
#include <sys/types.h> | ||
|
||
#include <unordered_map> | ||
|
||
#include "Store.hpp" | ||
|
||
namespace c10d { | ||
|
||
class FileStore : public Store { | ||
public: | ||
explicit FileStore(const std::string& path); | ||
|
||
virtual ~FileStore(); | ||
|
||
void set( | ||
const std::string& key, | ||
const std::vector<uint8_t>& value) override; | ||
|
||
std::vector<uint8_t> get(const std::string& key) override; | ||
|
||
int64_t add(const std::string& key, int64_t value) override; | ||
|
||
bool check(const std::vector<std::string>& keys) override; | ||
|
||
void wait( | ||
const std::vector<std::string>& keys, | ||
const std::chrono::milliseconds& timeout = kDefaultTimeout) override; | ||
|
||
protected: | ||
std::string path_; | ||
off_t pos_; | ||
|
||
std::unordered_map<std::string, std::vector<uint8_t>> cache_; | ||
}; | ||
|
||
} // namespace c10d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# THD refactor | ||
|
||
This is a work in progress. It is separate from the main THD directory | ||
to avoid disrupting THD users or have to deal with backwards compat | ||
early on. Once this gets to a usable state, we'll add Python bindings | ||
and a compat layer. | ||
|
||
See https://github.com/pytorch/pytorch/issues/7434 for the main issue. | ||
|
||
This tree is intentionally not part of the main build and will be | ||
buildable/testable in isolation, as long as ATen is available in | ||
`<repository root>/torch/lib/tmp_install`. | ||
|
||
To build and install ATen here, navigate to the root of this | ||
repository and run: | ||
|
||
``` shell | ||
tools/build_pytorch_libs.sh --with-cuda ATen | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#include "Store.hpp" | ||
|
||
namespace c10d { | ||
|
||
constexpr std::chrono::milliseconds Store::kDefaultTimeout; | ||
constexpr std::chrono::milliseconds Store::kNoTimeout; | ||
|
||
// Define destructor symbol for abstract base class. | ||
Store::~Store() { | ||
} | ||
|
||
} // namespace c10d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#pragma once | ||
|
||
#include <chrono> | ||
#include <cstdint> | ||
#include <stdexcept> | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace c10d { | ||
|
||
class Store { | ||
public: | ||
static constexpr std::chrono::milliseconds kDefaultTimeout = | ||
std::chrono::seconds(30); | ||
static constexpr std::chrono::milliseconds kNoTimeout = | ||
std::chrono::milliseconds::zero(); | ||
|
||
virtual ~Store(); | ||
|
||
virtual void set( | ||
const std::string& key, | ||
const std::vector<uint8_t>& value) = 0; | ||
|
||
virtual std::vector<uint8_t> get(const std::string& key) = 0; | ||
|
||
virtual int64_t add(const std::string& key, int64_t value) = 0; | ||
|
||
virtual bool check(const std::vector<std::string>& keys) = 0; | ||
|
||
virtual void wait( | ||
const std::vector<std::string>& keys, | ||
const std::chrono::milliseconds& timeout = kDefaultTimeout) = 0; | ||
}; | ||
|
||
} // namespace c10d |
Oops, something went wrong.