Skip to content

Commit

Permalink
Port interface of store base class from Caffe2 (pytorch#7439)
Browse files Browse the repository at this point in the history
The file store implementation is new and based on the file
initialization method (which uses a single file and file locking) and
the interface of the Caffe2 store handler.

See pytorch#7434.
  • Loading branch information
pietern authored May 10, 2018
1 parent 6547245 commit d5e77fb
Show file tree
Hide file tree
Showing 9 changed files with 536 additions and 0 deletions.
8 changes: 8 additions & 0 deletions torch/lib/c10d/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
cmake_minimum_required(VERSION 3.2 FATAL_ERROR)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake ${CMAKE_MODULE_PATH})

add_library(store Store.cpp FileStore.cpp)
target_compile_options(store PUBLIC "-std=c++11")

enable_testing()
add_subdirectory(test)
280 changes: 280 additions & 0 deletions torch/lib/c10d/FileStore.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
#include "FileStore.hpp"

#include <assert.h>
#include <stdint.h>
#include <sys/file.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/stat.h>

#include <chrono>
#include <functional>
#include <iostream>
#include <limits>
#include <sstream>
#include <system_error>
#include <thread>

#define SYSASSERT(rv, ...) \
if ((rv) < 0) { \
throw std::system_error( \
errno, \
std::system_category(), \
##__VA_ARGS__); \
}

namespace c10d {

namespace {

template<typename F>
typename std::result_of<F()>::type syscall(F fn) {
while (true) {
auto rv = fn();
if (rv == -1) {
if (errno == EINTR) {
continue;
}
}
return rv;
}
}

// For a comprehensive overview of file locking methods,
// see: https://gavv.github.io/blog/file-locks/.
// We stick to flock(2) here because we don't care about
// locking byte ranges and don't want locks to be process-wide.

// RAII wrapper around flock(2)
class Lock {
public:
explicit Lock(int fd, int operation) : fd_(fd) {
flock(operation);
}

~Lock() {
unlock();
}

Lock(const Lock& that) = delete;

Lock(Lock&& other) noexcept {
fd_ = other.fd_;
other.fd_ = -1;
}

void unlock() {
if (fd_ >= 0) {
flock(LOCK_UN);
fd_ = -1;
}
}

protected:
int fd_;

void flock(int operation) {
auto rv = syscall(std::bind(::flock, fd_, operation));
SYSASSERT(rv, "flock");
}
};

class File {
public:
explicit File(const std::string& path, int flags) {
fd_ = syscall(std::bind(::open, path.c_str(), flags, 0644));
SYSASSERT(fd_, "open(" + path + ")");
}

~File() {
::close(fd_);
}

Lock lockShared() {
return Lock(fd_, LOCK_SH);
}

Lock lockExclusive() {
return Lock(fd_, LOCK_EX);
}

off_t seek(off_t offset, int whence) {
auto rv = syscall(std::bind(lseek, fd_, offset, whence));
SYSASSERT(rv, "lseek");
return rv;
}

off_t tell() {
auto rv = syscall(std::bind(lseek, fd_, 0, SEEK_CUR));
SYSASSERT(rv, "lseek");
return rv;
}

off_t size() {
auto pos = tell();
auto size = seek(0, SEEK_END);
seek(pos, SEEK_SET);
return size;
}

void write(const void* buf, size_t count) {
while (count > 0) {
auto rv = syscall(std::bind(::write, fd_, buf, count));
SYSASSERT(rv, "write");
buf = (uint8_t*) buf + count;
count -= rv;
}
}

void read(void* buf, size_t count) {
while (count > 0) {
auto rv = syscall(std::bind(::read, fd_, buf, count));
SYSASSERT(rv, "read");
buf = (uint8_t*) buf + count;
count -= rv;
}
}

void write(const std::string& str) {
uint32_t len = str.size();
assert(str.size() <= std::numeric_limits<decltype(len)>::max());
write(&len, sizeof(len));
write(str.c_str(), len);
}

void write(const std::vector<uint8_t>& data) {
uint32_t len = data.size();
assert(data.size() <= std::numeric_limits<decltype(len)>::max());
write(&len, sizeof(len));
write(data.data(), len);
}

void read(std::string& str) {
uint32_t len;
read(&len, sizeof(len));
std::vector<uint8_t> buf(len);
read(buf.data(), len);
str.assign(buf.begin(), buf.end());
}

void read(std::vector<uint8_t>& data) {
uint32_t len;
read(&len, sizeof(len));
data.resize(len);
read(data.data(), len);
}

protected:
int fd_;
};

off_t refresh(
File& file,
off_t pos,
std::unordered_map<std::string, std::vector<uint8_t>>& cache) {
auto size = file.size();
if (size != pos) {
std::string tmpKey;
std::vector<uint8_t> tmpValue;
file.seek(pos, SEEK_SET);
while (size > pos) {
file.read(tmpKey);
file.read(tmpValue);
cache[tmpKey] = std::move(tmpValue);
pos = file.tell();
}
}
return pos;
}

} // namespace

FileStore::FileStore(const std::string& path)
: Store(),
path_(path),
pos_(0) {
}

FileStore::~FileStore() {
}

void FileStore::set(const std::string& key, const std::vector<uint8_t>& value) {
File file(path_, O_RDWR | O_CREAT);
auto lock = file.lockExclusive();
file.seek(0, SEEK_END);
file.write(key);
file.write(value);
}

std::vector<uint8_t> FileStore::get(const std::string& key) {
while (cache_.count(key) == 0) {
File file(path_, O_RDONLY);
auto lock = file.lockShared();
auto size = file.size();
if (size == pos_) {
// No new entries; release the shared lock and sleep for a bit
lock.unlock();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
continue;
}

pos_ = refresh(file, pos_, cache_);
}

return cache_[key];
}

int64_t FileStore::add(const std::string& key, int64_t i) {
File file(path_, O_RDWR | O_CREAT);
auto lock = file.lockExclusive();
pos_ = refresh(file, pos_, cache_);

const auto& value = cache_[key];
int64_t ti = i;
if (!value.empty()) {
auto buf = reinterpret_cast<const char*>(value.data());
auto len = value.size();
ti += std::stoll(std::string(buf, len));
}

// File cursor is at the end of the file now, and we have an
// exclusive lock, so we can write the new value.
file.write(key);
file.write(std::to_string(ti));

return ti;
}

bool FileStore::check(const std::vector<std::string>& keys) {
File file(path_, O_RDONLY);
auto lock = file.lockShared();
pos_ = refresh(file, pos_, cache_);

for (const auto& key : keys) {
if (cache_.count(key) == 0) {
return false;
}
}

return true;
}

void FileStore::wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
// Not using inotify because it doesn't work on many
// shared filesystems (such as NFS).
const auto start = std::chrono::steady_clock::now();
while (!check(keys)) {
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - start);
if (timeout != kNoTimeout && elapsed > timeout) {
throw std::runtime_error("Wait timeout");
}

/* sleep override */
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}

} // namespace c10d
38 changes: 38 additions & 0 deletions torch/lib/c10d/FileStore.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include <sys/types.h>

#include <unordered_map>

#include "Store.hpp"

namespace c10d {

class FileStore : public Store {
public:
explicit FileStore(const std::string& path);

virtual ~FileStore();

void set(
const std::string& key,
const std::vector<uint8_t>& value) override;

std::vector<uint8_t> get(const std::string& key) override;

int64_t add(const std::string& key, int64_t value) override;

bool check(const std::vector<std::string>& keys) override;

void wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout = kDefaultTimeout) override;

protected:
std::string path_;
off_t pos_;

std::unordered_map<std::string, std::vector<uint8_t>> cache_;
};

} // namespace c10d
19 changes: 19 additions & 0 deletions torch/lib/c10d/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# THD refactor

This is a work in progress. It is separate from the main THD directory
to avoid disrupting THD users or have to deal with backwards compat
early on. Once this gets to a usable state, we'll add Python bindings
and a compat layer.

See https://github.com/pytorch/pytorch/issues/7434 for the main issue.

This tree is intentionally not part of the main build and will be
buildable/testable in isolation, as long as ATen is available in
`<repository root>/torch/lib/tmp_install`.

To build and install ATen here, navigate to the root of this
repository and run:

``` shell
tools/build_pytorch_libs.sh --with-cuda ATen
```
12 changes: 12 additions & 0 deletions torch/lib/c10d/Store.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "Store.hpp"

namespace c10d {

constexpr std::chrono::milliseconds Store::kDefaultTimeout;
constexpr std::chrono::milliseconds Store::kNoTimeout;

// Define destructor symbol for abstract base class.
Store::~Store() {
}

} // namespace c10d
35 changes: 35 additions & 0 deletions torch/lib/c10d/Store.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include <chrono>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <vector>

namespace c10d {

class Store {
public:
static constexpr std::chrono::milliseconds kDefaultTimeout =
std::chrono::seconds(30);
static constexpr std::chrono::milliseconds kNoTimeout =
std::chrono::milliseconds::zero();

virtual ~Store();

virtual void set(
const std::string& key,
const std::vector<uint8_t>& value) = 0;

virtual std::vector<uint8_t> get(const std::string& key) = 0;

virtual int64_t add(const std::string& key, int64_t value) = 0;

virtual bool check(const std::vector<std::string>& keys) = 0;

virtual void wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout = kDefaultTimeout) = 0;
};

} // namespace c10d
Loading

0 comments on commit d5e77fb

Please sign in to comment.