Skip to content

Commit

Permalink
Merge caffe2::/at::Storage
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#11637

Reviewed By: gchanan

Differential Revision: D9806425

Pulled By: ezyang

fbshipit-source-id: e20ec93bff6dc7fb22ca9b7e7348d060b3876b67
  • Loading branch information
cpuhrsch authored and facebook-github-bot committed Sep 13, 2018
1 parent 77f6998 commit 36fc1a0
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 165 deletions.
24 changes: 0 additions & 24 deletions aten/src/ATen/core/Storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,4 @@

namespace at {

Storage::Storage(
caffe2::TypeMeta data_type,
size_t size,
Allocator* allocator,
bool resizable)
: storage_impl_(c10::make_intrusive<StorageImpl>(
data_type,
size,
allocator,
resizable)) {}

Storage::Storage(
caffe2::TypeMeta data_type,
at::DataPtr data_ptr,
size_t size,
const std::function<void(void*)>& deleter,
bool resizable)
: storage_impl_(c10::make_intrusive<StorageImpl>(
data_type,
size,
std::move(data_ptr),
/* allocator */ nullptr,
resizable)) {}

} // namespace at
158 changes: 144 additions & 14 deletions aten/src/ATen/core/Storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,179 @@ struct AT_API Storage {
Storage(const c10::intrusive_ptr<StorageImpl>& ptr) : storage_impl_(ptr) {}
Storage(c10::intrusive_ptr<StorageImpl>&& ptr) : storage_impl_(std::move(ptr)) {}
Storage(
caffe2::TypeMeta,
caffe2::TypeMeta data_type,
size_t size,
Allocator* allocator,
bool resizable = false);
bool resizable = false)
: storage_impl_(c10::make_intrusive<StorageImpl>(
data_type,
size,
allocator,
resizable)) {}

Storage(
caffe2::TypeMeta,
at::DataPtr,
caffe2::TypeMeta data_type,
at::DataPtr data_ptr,
size_t size,
const std::function<void(void*)>& deleter,
bool resizable = false);
bool resizable = false)
: storage_impl_(c10::make_intrusive<StorageImpl>(
data_type,
size,
std::move(data_ptr),
/* allocator */ nullptr,
resizable)) {}

Storage(at::DeviceType device_type)
: storage_impl_(c10::make_intrusive<StorageImpl>(device_type)) {}
Storage(at::DeviceType device_type, caffe2::TypeMeta data_type)
: storage_impl_(
c10::make_intrusive<StorageImpl>(device_type, data_type)) {}

Storage(
caffe2::TypeMeta data_type,
int64_t numel,
at::DataPtr data_ptr,
at::Allocator* allocator,
bool resizable)
: storage_impl_(c10::make_intrusive<StorageImpl>(
data_type,
numel,
std::move(data_ptr),
allocator,
resizable)) {}

void reset() {
storage_impl_->reset();
}

template <typename T>
inline bool IsType() const {
return storage_impl_->IsType<T>();
}

template <typename T>
T* data() const { return storage_impl_->data<T>(); }

template <typename T>
T* unsafe_data() const { return storage_impl_->unsafe_data<T>(); }

size_t elementSize() const { return storage_impl_->itemsize(); }
ptrdiff_t size() const { return storage_impl_->numel(); }
bool resizable() const { return storage_impl_->resizable(); }
size_t elementSize() const {
return storage_impl_->itemsize();
}

inline size_t itemsize() const {
return storage_impl_->itemsize();
}

ptrdiff_t size() const {
return storage_impl_->numel();
}

int64_t numel() const {
return storage_impl_->numel();
}

// TODO: remove later
void set_numel(int64_t numel) {
storage_impl_->set_numel(numel);
}

bool resizable() const {
return storage_impl_->resizable();
}

size_t capacity() const {
return storage_impl_->capacity();
}
// get() use here is to get const-correctness
void* data() const { return storage_impl_.get()->data(); }
const caffe2::TypeMeta dtype() const {

void* data() {
return storage_impl_->data();
}

void* data() const {
return storage_impl_.get()->data();
}

const caffe2::TypeMeta& dtype() const {
return storage_impl_->dtype();
}
const at::DataPtr& data_ptr() const { return storage_impl_->data_ptr(); }
DeviceType device_type() const { return storage_impl_->device_type(); }
at::Allocator* allocator() const { return storage_impl_.get()->allocator(); }
at::Device device() const { return storage_impl_->device(); }

at::DataPtr& data_ptr() {
return storage_impl_->data_ptr();
}

const at::DataPtr& data_ptr() const {
return storage_impl_->data_ptr();
}

// Returns the previous data_ptr
at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) {
return storage_impl_->set_data_ptr(std::move(data_ptr));
};

void set_dtype(const caffe2::TypeMeta& data_type) {
storage_impl_->set_dtype(data_type);
}

DeviceType device_type() const {
return storage_impl_->device_type();
}

at::Allocator* allocator() const {
return storage_impl_.get()->allocator();
}

at::Device device() const {
return storage_impl_->device();
}

StorageImpl* unsafeReleaseStorageImpl() {
return storage_impl_.release();
}

StorageImpl* unsafeGetStorageImpl() const noexcept {
return storage_impl_.get();
}

operator bool() const {
return storage_impl_;
}

size_t use_count() const {
return storage_impl_.use_count();
}

inline bool unique() const {
return storage_impl_.unique();
}

void UniqueStorageShareExternalPointer(
void* src,
const caffe2::TypeMeta& data_type,
size_t capacity,
DeleterFnPtr d = nullptr) {
if (!storage_impl_.unique()) {
AT_ERROR(
"UniqueStorageShareExternalPointer can only be called when use_count == 1");
}
storage_impl_->UniqueStorageShareExternalPointer(
src, data_type, capacity, d);
}

void UniqueStorageShareExternalPointer(
at::DataPtr&& data_ptr,
const caffe2::TypeMeta& data_type,
size_t capacity) {
if (!storage_impl_.unique()) {
AT_ERROR(
"UniqueStorageShareExternalPointer can only be called when use_count == 1");
}
storage_impl_->UniqueStorageShareExternalPointer(
std::move(data_ptr), data_type, capacity);
}

protected:
c10::intrusive_ptr<StorageImpl> storage_impl_;
};
Expand Down
128 changes: 2 additions & 126 deletions caffe2/core/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,137 +20,13 @@
#include <ATen/core/Device.h>
#include <ATen/core/DeviceType.h>
#include <ATen/core/intrusive_ptr.h>
#include <ATen/core/Storage.h>
#include <ATen/core/StorageImpl.h>

namespace caffe2 {

using StorageImpl = at::StorageImpl;

class CAFFE2_API Storage {
public:
Storage() {}
Storage(at::DeviceType device_type)
: storage_impl_(c10::make_intrusive<StorageImpl>(device_type)) {}
Storage(at::DeviceType device_type, TypeMeta data_type)
: storage_impl_(
c10::make_intrusive<StorageImpl>(device_type, data_type)) {}

Storage(
TypeMeta data_type,
int64_t numel,
at::DataPtr data_ptr,
at::Allocator* allocator,
bool resizable)
: storage_impl_(c10::make_intrusive<StorageImpl>(
data_type,
numel,
std::move(data_ptr),
allocator,
resizable)) {}

void reset() {
storage_impl_->reset();
}

// For debugging purpose only, please don't call it
StorageImpl* unsafeGetStorageImp() const {
return storage_impl_.get();
}

template <typename T>
inline bool IsType() const {
return storage_impl_->IsType<T>();
}

void* data() const {
return storage_impl_->data();
}

void* data() {
return storage_impl_->data();
}

at::DataPtr& data_ptr() {
return storage_impl_->data_ptr();
}

const at::DataPtr& data_ptr() const {
return storage_impl_->data_ptr();
}
// Returns the previous data_ptr
at::DataPtr set_data_ptr(at::DataPtr&& data_ptr) {
return storage_impl_->set_data_ptr(std::move(data_ptr));
};

void set_dtype(const TypeMeta& data_type) {
storage_impl_->set_dtype(data_type);
}

const TypeMeta& dtype() const {
return storage_impl_->dtype();
}

size_t capacity() const {
return storage_impl_->capacity();
}

int64_t numel() const {
return storage_impl_->numel();
}

// TODO: remove later
void set_numel(int64_t numel) {
storage_impl_->set_numel(numel);
}

at::DeviceType device_type() const {
return storage_impl_->device_type();
}

const at::Allocator* allocator() const {
return storage_impl_->allocator();
}

inline size_t itemsize() const {
return storage_impl_->itemsize();
}

inline long use_count() const {
return storage_impl_.use_count();
}

inline bool unique() const {
return storage_impl_.unique();
}

void UniqueStorageShareExternalPointer(
void* src,
const TypeMeta& data_type,
size_t capacity,
MemoryDeleter d = nullptr) {
CAFFE_ENFORCE_WITH_CALLER(
storage_impl_.unique(),
"UniqueStorageShareExternalPointer can only be called when \
use_count == 1");
storage_impl_->UniqueStorageShareExternalPointer(
src, data_type, capacity, d);
}

void UniqueStorageShareExternalPointer(
at::DataPtr&& data_ptr,
const TypeMeta& data_type,
size_t capacity) {
CAFFE_ENFORCE_WITH_CALLER(
storage_impl_.unique(),
"UniqueStorageShareExternalPointer can only be called when \
use_count == 1");
storage_impl_->UniqueStorageShareExternalPointer(
std::move(data_ptr), data_type, capacity);
}

protected:
c10::intrusive_ptr<StorageImpl> storage_impl_;
};
using Storage = at::Storage;

} // namespace caffe2

Expand Down
2 changes: 1 addition & 1 deletion caffe2/core/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ vector<TIndex> GetTensorInfo(
const Tensor* tc = static_cast<const Tensor*>(c);
CHECK(tc);
CHECK(tc->unsafeGetTensorImpl());
CHECK(tc->unsafeGetTensorImpl()->storage().unsafeGetStorageImp());
CHECK(tc->unsafeGetTensorImpl()->storage().unsafeGetStorageImpl());
*capacity = tc->capacity_nbytes();
tc->ExtractDeviceOption(device);
return tc->dims();
Expand Down

0 comments on commit 36fc1a0

Please sign in to comment.