Skip to content

Commit

Permalink
ARROW-13541: [C++][Python] Implement ExtensionScalar
Browse files Browse the repository at this point in the history
Closes apache#10904 from pitrou/ARROW-13541-ext-scalar

Authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
pitrou committed Aug 11, 2021
1 parent 7365806 commit 4fa9832
Show file tree
Hide file tree
Showing 14 changed files with 404 additions and 79 deletions.
10 changes: 6 additions & 4 deletions cpp/src/arrow/array/array_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ struct ScalarFromArraySlotImpl {
Status Visit(const DictionaryArray& a) {
auto ty = a.type();

ARROW_ASSIGN_OR_RAISE(auto index,
MakeScalar(checked_cast<DictionaryType&>(*ty).index_type(),
a.GetValueIndex(index_)));
ARROW_ASSIGN_OR_RAISE(
auto index, MakeScalar(checked_cast<const DictionaryType&>(*ty).index_type(),
a.GetValueIndex(index_)));

auto scalar = DictionaryScalar(ty);
scalar.is_valid = a.IsValid(index_);
Expand All @@ -148,7 +148,9 @@ struct ScalarFromArraySlotImpl {
}

Status Visit(const ExtensionArray& a) {
return Status::NotImplemented("Non-null ExtensionScalar");
ARROW_ASSIGN_OR_RAISE(auto storage, a.storage()->GetScalar(index_));
out_ = std::make_shared<ExtensionScalar>(std::move(storage), a.type());
return Status::OK();
}

template <typename Arg>
Expand Down
12 changes: 4 additions & 8 deletions cpp/src/arrow/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -776,13 +776,7 @@ class ScalarEqualsVisitor {

Status Visit(const UnionScalar& left) {
const auto& right = checked_cast<const UnionScalar&>(right_);
if (left.is_valid && right.is_valid) {
result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_);
} else if (!left.is_valid && !right.is_valid) {
result_ = true;
} else {
result_ = false;
}
result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_);
return Status::OK();
}

Expand All @@ -796,7 +790,9 @@ class ScalarEqualsVisitor {
}

Status Visit(const ExtensionScalar& left) {
return Status::NotImplemented("extension");
const auto& right = checked_cast<const ExtensionScalar&>(right_);
result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_);
return Status::OK();
}

bool result() const { return result_; }
Expand Down
30 changes: 20 additions & 10 deletions cpp/src/arrow/compute/kernels/scalar_cast_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
#include "arrow/compute/cast_internal.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/extension_type.h"
#include "arrow/util/checked_cast.h"

namespace arrow {

using internal::checked_cast;
using internal::PrimitiveScalarBase;

namespace compute {
Expand Down Expand Up @@ -188,16 +190,24 @@ Status OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
Status CastFromExtension(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const CastOptions& options = checked_cast<const CastState*>(ctx->state())->options;

const DataType& in_type = *batch[0].type();
const auto storage_type = checked_cast<const ExtensionType&>(in_type).storage_type();
if (batch[0].kind() == Datum::SCALAR) {
const auto& ext_scalar = checked_cast<const ExtensionScalar&>(*batch[0].scalar());
Datum casted_storage;

ExtensionArray extension(batch[0].array());

Datum casted_storage;
RETURN_NOT_OK(Cast(*extension.storage(), out->type(), options, ctx->exec_context())
.Value(&casted_storage));
out->value = casted_storage.array();
return Status::OK();
if (ext_scalar.is_valid) {
return Cast(ext_scalar.value, out->type(), options, ctx->exec_context()).Value(out);
} else {
const auto& storage_type =
checked_cast<const ExtensionType&>(*ext_scalar.type).storage_type();
return Cast(MakeNullScalar(storage_type), out->type(), options, ctx->exec_context())
.Value(out);
}
} else {
DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
ExtensionArray extension(batch[0].array());
return Cast(*extension.storage(), out->type(), options, ctx->exec_context())
.Value(out);
}
}

Status CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
Expand Down Expand Up @@ -279,7 +289,7 @@ void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* fun
}

// From extension type to this type
DCHECK_OK(func->AddKernel(Type::EXTENSION, {InputType::Array(Type::EXTENSION)}, out_ty,
DCHECK_OK(func->AddKernel(Type::EXTENSION, {InputType(Type::EXTENSION)}, out_ty,
CastFromExtension, NullHandling::COMPUTED_NO_PREALLOCATE,
MemAllocation::NO_PREALLOCATE));
}
Expand Down
5 changes: 0 additions & 5 deletions cpp/src/arrow/compute/kernels/scalar_cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ static void CheckCastFails(std::shared_ptr<Array> input, CastOptions options) {
<< "\n to_type: " << options.to_type->ToString()
<< "\n input: " << input->ToString();

if (input->type_id() == Type::EXTENSION) {
// ExtensionScalar not implemented
return;
}

// For the scalars, check that at least one of the input fails (since many
// of the tests contains a mix of passing and failing values). In some
// cases we will want to check more precisely
Expand Down
11 changes: 3 additions & 8 deletions cpp/src/arrow/compute/kernels/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,9 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expecte
}
ASSERT_TRUE(has_array) << "Must have at least 1 array input to have an array output";

// Check all the input scalars, if scalars are implemented
if (std::none_of(inputs.begin(), inputs.end(), [](const Datum& datum) {
return datum.type()->id() == Type::EXTENSION;
})) {
// Check all the input scalars
for (int64_t i = 0; i < expected->length(); ++i) {
CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options);
}
// Check all the input scalars
for (int64_t i = 0; i < expected->length(); ++i) {
CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options);
}

// Since it's a scalar function, calling it on sliced inputs should
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/extension_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ARROW_EXPORT ExtensionType : public DataType {
static constexpr const char* type_name() { return "extension"; }

/// \brief The type of array used to represent this extension type's data
std::shared_ptr<DataType> storage_type() const { return storage_type_; }
const std::shared_ptr<DataType>& storage_type() const { return storage_type_; }

DataTypeLayout layout() const override;

Expand Down Expand Up @@ -114,7 +114,7 @@ class ARROW_EXPORT ExtensionArray : public Array {
}

/// \brief The physical storage for the extension array
std::shared_ptr<Array> storage() const { return storage_; }
const std::shared_ptr<Array>& storage() const { return storage_; }

protected:
void SetData(const std::shared_ptr<ArrayData>& data);
Expand Down
46 changes: 37 additions & 9 deletions cpp/src/arrow/scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ struct ScalarHashImpl {
return Status::OK();
}

// TODO(bkietz) implement less wimpy hashing when this has ValueType
Status Visit(const ExtensionScalar& s) { return Status::OK(); }
Status Visit(const ExtensionScalar& s) {
AccumulateHashFrom(*s.value);
return Status::OK();
}

template <typename T>
Status StdHash(const T& t) {
Expand Down Expand Up @@ -142,14 +144,14 @@ struct ScalarHashImpl {
}

explicit ScalarHashImpl(const Scalar& scalar) : hash_(scalar.type->Hash()) {
if (scalar.is_valid) {
AccumulateHashFrom(scalar);
}
AccumulateHashFrom(scalar);
}

void AccumulateHashFrom(const Scalar& scalar) {
DCHECK_OK(StdHash(scalar.type->fingerprint()));
DCHECK_OK(VisitScalarInline(scalar, this));
// Note we already injected the type in ScalarHashImpl::ScalarHashImpl
if (scalar.is_valid) {
DCHECK_OK(VisitScalarInline(scalar, this));
}
}

size_t hash_;
Expand Down Expand Up @@ -371,8 +373,29 @@ struct ScalarValidateImpl {
return Status::OK();
}

// TODO
Status Visit(const ExtensionScalar& s) { return Status::OK(); }
Status Visit(const ExtensionScalar& s) {
if (!s.is_valid) {
if (s.value) {
return Status::Invalid("null ", s.type->ToString(), " scalar has storage value");
}
return Status::OK();
}

if (!s.value) {
return Status::Invalid("non-null ", s.type->ToString(),
" scalar doesn't have storage value");
}
if (!s.value->is_valid) {
return Status::Invalid("non-null ", s.type->ToString(),
" scalar has null storage value");
}
const auto st = Validate(*s.value);
if (!st.ok()) {
return st.WithMessage(s.type->ToString(),
" scalar fails validation for storage value: ", st.message());
}
return Status::OK();
}

Status ValidateStringScalar(const BaseBinaryScalar& s) {
RETURN_NOT_OK(ValidateBinaryScalar(s));
Expand Down Expand Up @@ -616,6 +639,11 @@ struct MakeNullImpl {
return Status::OK();
}

Status Visit(const ExtensionType& type) {
out_ = std::make_shared<ExtensionScalar>(type_);
return Status::OK();
}

std::shared_ptr<Scalar> Finish() && {
// Should not fail.
DCHECK_OK(VisitTypeInline(*type_, this));
Expand Down
78 changes: 49 additions & 29 deletions cpp/src/arrow/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <vector>

#include "arrow/compare.h"
#include "arrow/extension_type.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/type.h"
Expand Down Expand Up @@ -475,11 +476,19 @@ struct ARROW_EXPORT DictionaryScalar : public Scalar {
Result<std::shared_ptr<Scalar>> GetEncodedValue() const;
};

/// \brief A Scalar value for ExtensionType
///
/// The value is the underlying storage scalar.
/// `is_valid` must only be true if `value` is non-null and `value->is_valid` is true
struct ARROW_EXPORT ExtensionScalar : public Scalar {
using Scalar::Scalar;
using TypeClass = ExtensionType;
using ValueType = std::shared_ptr<Scalar>;

// TODO complete this
ExtensionScalar(std::shared_ptr<Scalar> storage, std::shared_ptr<DataType> type)
: Scalar(std::move(type), true), value(std::move(storage)) {}

std::shared_ptr<Scalar> value;
};

/// @}
Expand All @@ -494,34 +503,7 @@ ARROW_EXPORT Status CheckBufferLength(const FixedSizeBinaryType* t,
} // namespace internal

template <typename ValueRef>
struct MakeScalarImpl {
template <typename T, typename ScalarType = typename TypeTraits<T>::ScalarType,
typename ValueType = typename ScalarType::ValueType,
typename Enable = typename std::enable_if<
std::is_constructible<ScalarType, ValueType,
std::shared_ptr<DataType>>::value &&
std::is_convertible<ValueRef, ValueType>::value>::type>
Status Visit(const T& t) {
ARROW_RETURN_NOT_OK(internal::CheckBufferLength(&t, &value_));
out_ = std::make_shared<ScalarType>(
static_cast<ValueType>(static_cast<ValueRef>(value_)), std::move(type_));
return Status::OK();
}

Status Visit(const DataType& t) {
return Status::NotImplemented("constructing scalars of type ", t,
" from unboxed values");
}

Result<std::shared_ptr<Scalar>> Finish() && {
ARROW_RETURN_NOT_OK(VisitTypeInline(*type_, this));
return std::move(out_);
}

std::shared_ptr<DataType> type_;
ValueRef value_;
std::shared_ptr<Scalar> out_;
};
struct MakeScalarImpl;

/// \defgroup scalar-factories Scalar factory functions
///
Expand Down Expand Up @@ -557,4 +539,42 @@ inline std::shared_ptr<Scalar> MakeScalar(std::string value) {

/// @}

template <typename ValueRef>
struct MakeScalarImpl {
template <typename T, typename ScalarType = typename TypeTraits<T>::ScalarType,
typename ValueType = typename ScalarType::ValueType,
typename Enable = typename std::enable_if<
std::is_constructible<ScalarType, ValueType,
std::shared_ptr<DataType>>::value &&
std::is_convertible<ValueRef, ValueType>::value>::type>
Status Visit(const T& t) {
ARROW_RETURN_NOT_OK(internal::CheckBufferLength(&t, &value_));
// `static_cast<ValueRef>` makes a rvalue if ValueRef is `ValueType&&`
out_ = std::make_shared<ScalarType>(
static_cast<ValueType>(static_cast<ValueRef>(value_)), std::move(type_));
return Status::OK();
}

Status Visit(const ExtensionType& t) {
ARROW_ASSIGN_OR_RAISE(auto storage,
MakeScalar(t.storage_type(), static_cast<ValueRef>(value_)));
out_ = std::make_shared<ExtensionScalar>(std::move(storage), type_);
return Status::OK();
}

Status Visit(const DataType& t) {
return Status::NotImplemented("constructing scalars of type ", t,
" from unboxed values");
}

Result<std::shared_ptr<Scalar>> Finish() && {
ARROW_RETURN_NOT_OK(VisitTypeInline(*type_, this));
return std::move(out_);
}

std::shared_ptr<DataType> type_;
ValueRef value_;
std::shared_ptr<Scalar> out_;
};

} // namespace arrow
Loading

0 comments on commit 4fa9832

Please sign in to comment.