Skip to content

Commit

Permalink
Add Tensor::fromArray, Tensor ptr ctors take const*
Browse files Browse the repository at this point in the history
Summary:
- `Tensor::fromArray` that takes `std::array`.
- All input pointers for Tensor construction are now `const*`

Reviewed By: benoitsteiner

Differential Revision: D31984649

fbshipit-source-id: 5bcb53abdced89c62157f0611327f0663ab97c09
  • Loading branch information
jacobkahn authored and facebook-github-bot committed Nov 18, 2021
1 parent 45d57e2 commit c5efb81
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 23 deletions.
4 changes: 2 additions & 2 deletions flashlight/fl/tensor/TensorAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ struct TensorCreator {
virtual std::unique_ptr<TensorAdapterBase> get(
const Shape& shape = {},
fl::dtype type = fl::dtype::f32,
void* ptr = nullptr,
const void* ptr = nullptr,
MemoryLocation memoryLocation = MemoryLocation::Host) const = 0;

// Sparse tensor ctor
Expand All @@ -278,7 +278,7 @@ struct TensorCreatorImpl : public TensorCreator {
std::unique_ptr<TensorAdapterBase> get(
const Shape& shape = {},
fl::dtype type = fl::dtype::f32,
void* ptr = nullptr,
const void* ptr = nullptr,
MemoryLocation memoryLocation = MemoryLocation::Host) const override {
return std::make_unique<T>(shape, type, ptr, memoryLocation);
}
Expand Down
21 changes: 19 additions & 2 deletions flashlight/fl/tensor/TensorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Tensor::Tensor() : impl_(detail::getDefaultAdapter()) {}
Tensor::Tensor(
const Shape& shape,
fl::dtype type,
void* ptr,
const void* ptr,
MemoryLocation memoryLocation)
: impl_(detail::getDefaultAdapter(shape, type, ptr, memoryLocation)) {}

Expand Down Expand Up @@ -136,6 +136,9 @@ TensorBackend& Tensor::backend() const {
#define FL_CREATE_MEMORY_OPS(TYPE) \
template <> \
TYPE Tensor::scalar() const { \
if (isEmpty()) { \
throw std::invalid_argument("Tensor::scalar called on empty tensor"); \
} \
if (type() != dtype_traits<TYPE>::fl_type) { \
throw std::invalid_argument( \
"Tensor::scalar: requested type of " + \
Expand All @@ -149,6 +152,9 @@ TensorBackend& Tensor::backend() const {
\
template <> \
TYPE* Tensor::device() const { \
if (isEmpty()) { \
return nullptr; \
} \
TYPE* out; \
void** addr = reinterpret_cast<void**>(&out); \
impl_->device(addr); \
Expand All @@ -157,14 +163,19 @@ TensorBackend& Tensor::backend() const {
\
template <> \
TYPE* Tensor::host() const { \
if (isEmpty()) { \
return nullptr; \
} \
TYPE* out = reinterpret_cast<TYPE*>(new char[bytes()]); \
impl_->host(out); \
return out; \
} \
\
template <> \
void Tensor::host(TYPE* ptr) const { \
impl_->host(ptr); \
if (!isEmpty()) { \
impl_->host(ptr); \
} \
}
FL_CREATE_MEMORY_OPS(int);
FL_CREATE_MEMORY_OPS(unsigned);
Expand All @@ -181,13 +192,19 @@ FL_CREATE_MEMORY_OPS(unsigned short);
// void specializations
template <>
void* Tensor::device() const {
if (isEmpty()) {
return nullptr;
}
void* out;
impl_->device(&out);
return out;
}

template <>
void* Tensor::host() const {
if (isEmpty()) {
return nullptr;
}
void* out = reinterpret_cast<void*>(new char[bytes()]);
impl_->host(out);
return out;
Expand Down
33 changes: 29 additions & 4 deletions flashlight/fl/tensor/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class Tensor {
Tensor(
const Shape& shape,
fl::dtype type,
void* ptr,
const void* ptr,
MemoryLocation memoryLocation);

/**
Expand Down Expand Up @@ -156,11 +156,21 @@ class Tensor {
return Tensor(s, fl::dtype_traits<T>::fl_type, v.data(), Location::Host);
}

template <typename T, std::size_t S>
static Tensor fromArray(Shape s, std::array<T, S> a) {
return Tensor(s, fl::dtype_traits<T>::fl_type, a.data(), Location::Host);
}

template <typename T>
static Tensor fromVector(Shape s, std::vector<T> v, dtype type) {
return Tensor(s, type, v.data(), Location::Host);
}

template <typename T, std::size_t S>
static Tensor fromArray(Shape s, std::array<T, S> a, dtype type) {
return Tensor(s, type, a.data(), Location::Host);
}

template <typename T>
static Tensor fromVector(std::vector<T> v) {
return Tensor(
Expand All @@ -170,6 +180,15 @@ class Tensor {
Location::Host);
}

template <typename T, std::size_t S>
static Tensor fromArray(std::array<T, S> a) {
return Tensor(
{static_cast<long long>(a.size())},
fl::dtype_traits<T>::fl_type,
a.data(),
Location::Host);
}

/**
* Create a tensor from an existing buffer.
*
Expand All @@ -180,7 +199,7 @@ class Tensor {
* @return a tensor with values and shape as given.
*/
template <typename T>
static Tensor fromBuffer(Shape s, T* ptr, Location memoryLocation) {
static Tensor fromBuffer(Shape s, const T* ptr, Location memoryLocation) {
return Tensor(s, fl::dtype_traits<T>::fl_type, ptr, memoryLocation);
}

Expand All @@ -194,8 +213,11 @@ class Tensor {
* with which to create the tensor resides.
* @return a tensor with values and shape as given.
*/
static Tensor
fromBuffer(Shape s, fl::dtype t, uint8_t* ptr, Location memoryLocation) {
static Tensor fromBuffer(
Shape s,
fl::dtype t,
const uint8_t* ptr,
Location memoryLocation) {
return Tensor(s, t, ptr, memoryLocation);
}

Expand Down Expand Up @@ -452,6 +474,9 @@ class Tensor {
*/
template <typename T>
std::vector<T> toHostVector() const {
if (isEmpty()) {
return std::vector<T>();
}
std::vector<T> vec(this->size());
host(vec.data());
return vec;
Expand Down
2 changes: 1 addition & 1 deletion flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ ArrayFireTensor::ArrayFireTensor() : handle_(ArrayComponent()) {}
ArrayFireTensor::ArrayFireTensor(
const Shape& shape,
fl::dtype type,
void* ptr,
const void* ptr,
Location memoryLocation)
: arrayHandle_(std::make_shared<af::array>(
detail::fromFlData(shape, ptr, type, memoryLocation))),
Expand Down
2 changes: 1 addition & 1 deletion flashlight/fl/tensor/backend/af/ArrayFireTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class ArrayFireTensor : public TensorAdapterBase {
ArrayFireTensor(
const Shape& shape,
fl::dtype type,
void* ptr,
const void* ptr,
Location memoryLocation);

ArrayFireTensor(
Expand Down
23 changes: 12 additions & 11 deletions flashlight/fl/tensor/backend/af/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ af_source flToAfLocation(Location location) {

af::array fromFlData(
const Shape& shape,
void* ptr,
const void* ptr,
fl::dtype type,
fl::Location memoryLocation) {
af::dim4 dims = detail::flToAfDims(shape);
Expand All @@ -246,25 +246,26 @@ af::array fromFlData(
using af::dtype;
switch (afType) {
case f32:
return af::array(dims, reinterpret_cast<float*>(ptr), loc);
return af::array(dims, reinterpret_cast<const float*>(ptr), loc);
case f64:
return af::array(dims, reinterpret_cast<double*>(ptr), loc);
return af::array(dims, reinterpret_cast<const double*>(ptr), loc);
case s32:
return af::array(dims, reinterpret_cast<int*>(ptr), loc);
return af::array(dims, reinterpret_cast<const int*>(ptr), loc);
case u32:
return af::array(dims, reinterpret_cast<unsigned*>(ptr), loc);
return af::array(dims, reinterpret_cast<const unsigned*>(ptr), loc);
case s64:
return af::array(dims, reinterpret_cast<long long*>(ptr), loc);
return af::array(dims, reinterpret_cast<const long long*>(ptr), loc);
case u64:
return af::array(dims, reinterpret_cast<unsigned long long*>(ptr), loc);
return af::array(
dims, reinterpret_cast<const unsigned long long*>(ptr), loc);
case s16:
return af::array(dims, reinterpret_cast<short*>(ptr), loc);
return af::array(dims, reinterpret_cast<const short*>(ptr), loc);
case u16:
return af::array(dims, reinterpret_cast<unsigned short*>(ptr), loc);
return af::array(dims, reinterpret_cast<const unsigned short*>(ptr), loc);
case b8:
return af::array(dims, reinterpret_cast<char*>(ptr), loc);
return af::array(dims, reinterpret_cast<const char*>(ptr), loc);
case u8:
return af::array(dims, reinterpret_cast<unsigned char*>(ptr), loc);
return af::array(dims, reinterpret_cast<const unsigned char*>(ptr), loc);
default:
throw std::invalid_argument(
"fromFlData: can't construct ArrayFire array from given type.");
Expand Down
2 changes: 1 addition & 1 deletion flashlight/fl/tensor/backend/af/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ af_source flToAfLocation(Location location);
*/
af::array fromFlData(
const Shape& shape,
void* ptr,
const void* ptr,
fl::dtype type,
fl::Location memoryLocation);

Expand Down
24 changes: 23 additions & 1 deletion flashlight/fl/test/tensor/TensorBaseTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ TEST(TensorBaseTest, CopyOperators) {
}

TEST(TensorBaseTest, ConstructFromData) {
// Tensor::fromVector
float val = 3.;
std::vector<float> vec(100, val);
fl::Shape s = {10, 10};
Expand All @@ -143,6 +144,23 @@ TEST(TensorBaseTest, ConstructFromData) {
std::vector<float> flat = {0, 1, 2, 3, 4, 5, 6, 7};
unsigned size = flat.size();
ASSERT_EQ(fl::Tensor::fromVector(flat).shape(), Shape({size}));

// Tensor::fromArray
constexpr unsigned arrFSize = 5;
std::array<float, arrFSize> arrF = {1, 2, 3, 4, 5};
auto tArrF = Tensor::fromArray(arrF);
ASSERT_EQ(tArrF.type(), fl::dtype::f32);
ASSERT_EQ(tArrF.shape(), Shape({arrFSize}));
auto tArrD = Tensor::fromArray({arrFSize}, arrF, fl::dtype::f64);
ASSERT_EQ(tArrD.type(), fl::dtype::f64);

constexpr unsigned arrISize = 8;
std::array<unsigned, arrISize> arrI = {1, 2, 3, 4, 5, 6, 7, 8};
auto tArrI = Tensor::fromArray(arrI);
ASSERT_EQ(tArrI.type(), fl::dtype::u32);
ASSERT_EQ(tArrI.shape(), Shape({arrISize}));
auto tArrIs = Tensor::fromArray({2, 4}, arrI);
ASSERT_EQ(tArrIs.shape(), Shape({2, 4}));
}

TEST(TensorBaseTest, reshape) {
Expand Down Expand Up @@ -637,6 +655,8 @@ TEST(TensorBaseTest, host) {
for (int i = 0; i < a.size(); ++i) {
ASSERT_EQ(existingBuffer[i], a.flatten()(i).scalar<float>());
}

ASSERT_EQ(Tensor().host<void>(), nullptr);
}

TEST(TensorBaseTest, toHostVector) {
Expand All @@ -646,6 +666,8 @@ TEST(TensorBaseTest, toHostVector) {
for (int i = 0; i < a.size(); ++i) {
ASSERT_EQ(vec[i], a.flatten()(i).scalar<float>());
}

ASSERT_EQ(Tensor().toHostVector<float>().size(), 0);
}

TEST(TensorBaseTest, matmul) {
Expand Down Expand Up @@ -826,7 +848,7 @@ TEST(TensorBaseTest, sum) {

TEST(TensorBaseTest, mean) {
auto r = fl::rand({8, 7, 6});
ASSERT_NEAR(fl::mean(r).scalar<float>(), 0.5, 0.01);
ASSERT_NEAR(fl::mean(r).scalar<float>(), 0.5, 0.05);
ASSERT_EQ(
fl::mean(r, {0, 1}, /* keepDims = */ true).shape(), Shape({1, 1, 6}));

Expand Down

0 comments on commit c5efb81

Please sign in to comment.