diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc index dad689d3ca782..11770d0090ce4 100644 --- a/cpp/src/arrow/array/array_base.cc +++ b/cpp/src/arrow/array/array_base.cc @@ -134,9 +134,9 @@ struct ScalarFromArraySlotImpl { Status Visit(const DictionaryArray& a) { auto ty = a.type(); - ARROW_ASSIGN_OR_RAISE(auto index, - MakeScalar(checked_cast(*ty).index_type(), - a.GetValueIndex(index_))); + ARROW_ASSIGN_OR_RAISE( + auto index, MakeScalar(checked_cast(*ty).index_type(), + a.GetValueIndex(index_))); auto scalar = DictionaryScalar(ty); scalar.is_valid = a.IsValid(index_); @@ -148,7 +148,9 @@ struct ScalarFromArraySlotImpl { } Status Visit(const ExtensionArray& a) { - return Status::NotImplemented("Non-null ExtensionScalar"); + ARROW_ASSIGN_OR_RAISE(auto storage, a.storage()->GetScalar(index_)); + out_ = std::make_shared(std::move(storage), a.type()); + return Status::OK(); } template diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 4c6f97faf9513..4ecb00a3f0852 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -776,13 +776,7 @@ class ScalarEqualsVisitor { Status Visit(const UnionScalar& left) { const auto& right = checked_cast(right_); - if (left.is_valid && right.is_valid) { - result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); - } else if (!left.is_valid && !right.is_valid) { - result_ = true; - } else { - result_ = false; - } + result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); return Status::OK(); } @@ -796,7 +790,9 @@ class ScalarEqualsVisitor { } Status Visit(const ExtensionScalar& left) { - return Status::NotImplemented("extension"); + const auto& right = checked_cast(right_); + result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); + return Status::OK(); } bool result() const { return result_; } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc index 8076c35a1321d..d9109b0f3f7b7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_internal.cc @@ -19,9 +19,11 @@ #include "arrow/compute/cast_internal.h" #include "arrow/compute/kernels/common.h" #include "arrow/extension_type.h" +#include "arrow/util/checked_cast.h" namespace arrow { +using internal::checked_cast; using internal::PrimitiveScalarBase; namespace compute { @@ -188,16 +190,24 @@ Status OutputAllNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { Status CastFromExtension(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const CastOptions& options = checked_cast(ctx->state())->options; - const DataType& in_type = *batch[0].type(); - const auto storage_type = checked_cast(in_type).storage_type(); + if (batch[0].kind() == Datum::SCALAR) { + const auto& ext_scalar = checked_cast(*batch[0].scalar()); + Datum casted_storage; - ExtensionArray extension(batch[0].array()); - - Datum casted_storage; - RETURN_NOT_OK(Cast(*extension.storage(), out->type(), options, ctx->exec_context()) - .Value(&casted_storage)); - out->value = casted_storage.array(); - return Status::OK(); + if (ext_scalar.is_valid) { + return Cast(ext_scalar.value, out->type(), options, ctx->exec_context()).Value(out); + } else { + const auto& storage_type = + checked_cast(*ext_scalar.type).storage_type(); + return Cast(MakeNullScalar(storage_type), out->type(), options, ctx->exec_context()) + .Value(out); + } + } else { + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + ExtensionArray extension(batch[0].array()); + return Cast(*extension.storage(), out->type(), options, ctx->exec_context()) + .Value(out); + } } Status CastFromNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { @@ -279,7 +289,7 @@ void AddCommonCasts(Type::type out_type_id, OutputType out_ty, CastFunction* fun } // From extension type to this type - DCHECK_OK(func->AddKernel(Type::EXTENSION, {InputType::Array(Type::EXTENSION)}, out_ty, + DCHECK_OK(func->AddKernel(Type::EXTENSION, {InputType(Type::EXTENSION)}, out_ty, CastFromExtension, NullHandling::COMPUTED_NO_PREALLOCATE, MemAllocation::NO_PREALLOCATE)); } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 1b6c862648e97..90d418945783e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -93,11 +93,6 @@ static void CheckCastFails(std::shared_ptr input, CastOptions options) { << "\n to_type: " << options.to_type->ToString() << "\n input: " << input->ToString(); - if (input->type_id() == Type::EXTENSION) { - // ExtensionScalar not implemented - return; - } - // For the scalars, check that at least one of the input fails (since many // of the tests contains a mix of passing and failing values). In some // cases we will want to check more precisely diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index e6be1b81adf02..611dc6141c114 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -121,14 +121,9 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expecte } ASSERT_TRUE(has_array) << "Must have at least 1 array input to have an array output"; - // Check all the input scalars, if scalars are implemented - if (std::none_of(inputs.begin(), inputs.end(), [](const Datum& datum) { - return datum.type()->id() == Type::EXTENSION; - })) { - // Check all the input scalars - for (int64_t i = 0; i < expected->length(); ++i) { - CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options); - } + // Check all the input scalars + for (int64_t i = 0; i < expected->length(); ++i) { + CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options); } // Since it's a scalar function, calling it on sliced inputs should diff --git a/cpp/src/arrow/extension_type.h b/cpp/src/arrow/extension_type.h index 7d91a574f4eb1..39cbc805a889f 100644 --- a/cpp/src/arrow/extension_type.h +++ b/cpp/src/arrow/extension_type.h @@ -43,7 +43,7 @@ class ARROW_EXPORT ExtensionType : public DataType { static constexpr const char* type_name() { return "extension"; } /// \brief The type of array used to represent this extension type's data - std::shared_ptr storage_type() const { return storage_type_; } + const std::shared_ptr& storage_type() const { return storage_type_; } DataTypeLayout layout() const override; @@ -114,7 +114,7 @@ class ARROW_EXPORT ExtensionArray : public Array { } /// \brief The physical storage for the extension array - std::shared_ptr storage() const { return storage_; } + const std::shared_ptr& storage() const { return storage_; } protected: void SetData(const std::shared_ptr& data); diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 712fae188f0e2..dc3f6ffc23021 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -106,8 +106,10 @@ struct ScalarHashImpl { return Status::OK(); } - // TODO(bkietz) implement less wimpy hashing when this has ValueType - Status Visit(const ExtensionScalar& s) { return Status::OK(); } + Status Visit(const ExtensionScalar& s) { + AccumulateHashFrom(*s.value); + return Status::OK(); + } template Status StdHash(const T& t) { @@ -142,14 +144,14 @@ struct ScalarHashImpl { } explicit ScalarHashImpl(const Scalar& scalar) : hash_(scalar.type->Hash()) { - if (scalar.is_valid) { - AccumulateHashFrom(scalar); - } + AccumulateHashFrom(scalar); } void AccumulateHashFrom(const Scalar& scalar) { - DCHECK_OK(StdHash(scalar.type->fingerprint())); - DCHECK_OK(VisitScalarInline(scalar, this)); + // Note we already injected the type in ScalarHashImpl::ScalarHashImpl + if (scalar.is_valid) { + DCHECK_OK(VisitScalarInline(scalar, this)); + } } size_t hash_; @@ -371,8 +373,29 @@ struct ScalarValidateImpl { return Status::OK(); } - // TODO - Status Visit(const ExtensionScalar& s) { return Status::OK(); } + Status Visit(const ExtensionScalar& s) { + if (!s.is_valid) { + if (s.value) { + return Status::Invalid("null ", s.type->ToString(), " scalar has storage value"); + } + return Status::OK(); + } + + if (!s.value) { + return Status::Invalid("non-null ", s.type->ToString(), + " scalar doesn't have storage value"); + } + if (!s.value->is_valid) { + return Status::Invalid("non-null ", s.type->ToString(), + " scalar has null storage value"); + } + const auto st = Validate(*s.value); + if (!st.ok()) { + return st.WithMessage(s.type->ToString(), + " scalar fails validation for storage value: ", st.message()); + } + return Status::OK(); + } Status ValidateStringScalar(const BaseBinaryScalar& s) { RETURN_NOT_OK(ValidateBinaryScalar(s)); @@ -616,6 +639,11 @@ struct MakeNullImpl { return Status::OK(); } + Status Visit(const ExtensionType& type) { + out_ = std::make_shared(type_); + return Status::OK(); + } + std::shared_ptr Finish() && { // Should not fail. DCHECK_OK(VisitTypeInline(*type_, this)); diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index ce4e083fe61ed..7fd48be86f45a 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -26,6 +26,7 @@ #include #include "arrow/compare.h" +#include "arrow/extension_type.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type.h" @@ -475,11 +476,19 @@ struct ARROW_EXPORT DictionaryScalar : public Scalar { Result> GetEncodedValue() const; }; +/// \brief A Scalar value for ExtensionType +/// +/// The value is the underlying storage scalar. +/// `is_valid` must only be true if `value` is non-null and `value->is_valid` is true struct ARROW_EXPORT ExtensionScalar : public Scalar { using Scalar::Scalar; using TypeClass = ExtensionType; + using ValueType = std::shared_ptr; - // TODO complete this + ExtensionScalar(std::shared_ptr storage, std::shared_ptr type) + : Scalar(std::move(type), true), value(std::move(storage)) {} + + std::shared_ptr value; }; /// @} @@ -494,34 +503,7 @@ ARROW_EXPORT Status CheckBufferLength(const FixedSizeBinaryType* t, } // namespace internal template -struct MakeScalarImpl { - template ::ScalarType, - typename ValueType = typename ScalarType::ValueType, - typename Enable = typename std::enable_if< - std::is_constructible>::value && - std::is_convertible::value>::type> - Status Visit(const T& t) { - ARROW_RETURN_NOT_OK(internal::CheckBufferLength(&t, &value_)); - out_ = std::make_shared( - static_cast(static_cast(value_)), std::move(type_)); - return Status::OK(); - } - - Status Visit(const DataType& t) { - return Status::NotImplemented("constructing scalars of type ", t, - " from unboxed values"); - } - - Result> Finish() && { - ARROW_RETURN_NOT_OK(VisitTypeInline(*type_, this)); - return std::move(out_); - } - - std::shared_ptr type_; - ValueRef value_; - std::shared_ptr out_; -}; +struct MakeScalarImpl; /// \defgroup scalar-factories Scalar factory functions /// @@ -557,4 +539,42 @@ inline std::shared_ptr MakeScalar(std::string value) { /// @} +template +struct MakeScalarImpl { + template ::ScalarType, + typename ValueType = typename ScalarType::ValueType, + typename Enable = typename std::enable_if< + std::is_constructible>::value && + std::is_convertible::value>::type> + Status Visit(const T& t) { + ARROW_RETURN_NOT_OK(internal::CheckBufferLength(&t, &value_)); + // `static_cast` makes a rvalue if ValueRef is `ValueType&&` + out_ = std::make_shared( + static_cast(static_cast(value_)), std::move(type_)); + return Status::OK(); + } + + Status Visit(const ExtensionType& t) { + ARROW_ASSIGN_OR_RAISE(auto storage, + MakeScalar(t.storage_type(), static_cast(value_))); + out_ = std::make_shared(std::move(storage), type_); + return Status::OK(); + } + + Status Visit(const DataType& t) { + return Status::NotImplemented("constructing scalars of type ", t, + " from unboxed values"); + } + + Result> Finish() && { + ARROW_RETURN_NOT_OK(VisitTypeInline(*type_, this)); + return std::move(out_); + } + + std::shared_ptr type_; + ValueRef value_; + std::shared_ptr out_; +}; + } // namespace arrow diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 5ea814582afae..247cd1d9f9d53 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -29,6 +29,7 @@ #include "arrow/memory_pool.h" #include "arrow/scalar.h" #include "arrow/status.h" +#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" #include "arrow/type_traits.h" @@ -1485,4 +1486,136 @@ TEST_F(TestDenseUnionScalar, GetScalar) { CheckGetValidUnionScalar(arr, 4, *union_three_, *three_); } +#define UUID_STRING1 "abcdefghijklmnop" +#define UUID_STRING2 "zyxwvutsrqponmlk" + +class TestExtensionScalar : public ::testing::Test { + public: + void SetUp() { + type_ = uuid(); + storage_type_ = fixed_size_binary(16); + uuid_type_ = checked_cast(type_.get()); + } + + protected: + ExtensionScalar MakeUuidScalar(util::string_view value) { + return ExtensionScalar(std::make_shared( + std::make_shared(value), storage_type_), + type_); + } + + std::shared_ptr type_, storage_type_; + const UuidType* uuid_type_{nullptr}; + + const util::string_view uuid_string1_{UUID_STRING1}; + const util::string_view uuid_string2_{UUID_STRING2}; + const util::string_view uuid_json_{"[\"" UUID_STRING1 "\", \"" UUID_STRING2 + "\", null]"}; +}; + +#undef UUID_STRING1 +#undef UUID_STRING2 + +TEST_F(TestExtensionScalar, Basics) { + const ExtensionScalar uuid_scalar = MakeUuidScalar(uuid_string1_); + ASSERT_OK(uuid_scalar.ValidateFull()); + ASSERT_TRUE(uuid_scalar.is_valid); + + const ExtensionScalar uuid_scalar2 = MakeUuidScalar(uuid_string2_); + ASSERT_OK(uuid_scalar2.ValidateFull()); + ASSERT_TRUE(uuid_scalar2.is_valid); + + const ExtensionScalar uuid_scalar3 = MakeUuidScalar(uuid_string2_); + ASSERT_OK(uuid_scalar2.ValidateFull()); + ASSERT_TRUE(uuid_scalar2.is_valid); + + const ExtensionScalar null_scalar(type_); + ASSERT_OK(null_scalar.ValidateFull()); + ASSERT_FALSE(null_scalar.is_valid); + + ASSERT_FALSE(uuid_scalar.Equals(uuid_scalar2)); + ASSERT_TRUE(uuid_scalar2.Equals(uuid_scalar3)); + ASSERT_FALSE(uuid_scalar.Equals(null_scalar)); +} + +TEST_F(TestExtensionScalar, MakeScalar) { + const ExtensionScalar null_scalar(type_); + const ExtensionScalar uuid_scalar = MakeUuidScalar(uuid_string1_); + + auto scalar = CheckMakeNullScalar(type_); + ASSERT_OK(scalar->ValidateFull()); + ASSERT_FALSE(scalar->is_valid); + + ASSERT_OK_AND_ASSIGN(auto scalar2, + MakeScalar(type_, std::make_shared(uuid_string1_))); + ASSERT_OK(scalar2->ValidateFull()); + ASSERT_TRUE(scalar2->is_valid); + + ASSERT_OK_AND_ASSIGN(auto scalar3, + MakeScalar(type_, std::make_shared(uuid_string2_))); + ASSERT_OK(scalar3->ValidateFull()); + ASSERT_TRUE(scalar3->is_valid); + + ASSERT_TRUE(scalar->Equals(null_scalar)); + ASSERT_TRUE(scalar2->Equals(uuid_scalar)); + ASSERT_FALSE(scalar3->Equals(uuid_scalar)); +} + +TEST_F(TestExtensionScalar, GetScalar) { + const ExtensionScalar null_scalar(type_); + const ExtensionScalar uuid_scalar = MakeUuidScalar(uuid_string1_); + const ExtensionScalar uuid_scalar2 = MakeUuidScalar(uuid_string2_); + + auto storage_array = ArrayFromJSON(storage_type_, uuid_json_); + auto array = ExtensionType::WrapArray(type_, storage_array); + + ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(0)); + ASSERT_OK(scalar->ValidateFull()); + AssertTypeEqual(scalar->type, type_); + ASSERT_TRUE(scalar->is_valid); + ASSERT_TRUE(scalar->Equals(uuid_scalar)); + ASSERT_FALSE(scalar->Equals(uuid_scalar2)); + + ASSERT_OK_AND_ASSIGN(scalar, array->GetScalar(1)); + ASSERT_OK(scalar->ValidateFull()); + AssertTypeEqual(scalar->type, type_); + ASSERT_TRUE(scalar->is_valid); + ASSERT_TRUE(scalar->Equals(uuid_scalar2)); + ASSERT_FALSE(scalar->Equals(uuid_scalar)); + + ASSERT_OK_AND_ASSIGN(scalar, array->GetScalar(2)); + ASSERT_OK(scalar->ValidateFull()); + AssertTypeEqual(scalar->type, type_); + ASSERT_FALSE(scalar->is_valid); + ASSERT_TRUE(scalar->Equals(null_scalar)); + ASSERT_FALSE(scalar->Equals(uuid_scalar)); +} + +TEST_F(TestExtensionScalar, ValidateErrors) { + // Mismatching is_valid and value + ExtensionScalar null_scalar(type_); + null_scalar.is_valid = true; + AssertValidationFails(null_scalar); + + ExtensionScalar uuid_scalar = MakeUuidScalar(uuid_string1_); + uuid_scalar.is_valid = false; + AssertValidationFails(uuid_scalar); + + // Null storage scalar + auto null_storage = std::make_shared(storage_type_); + ExtensionScalar scalar(null_storage, type_); + scalar.is_valid = true; + AssertValidationFails(scalar); + scalar.is_valid = false; + AssertValidationFails(scalar); + + // Invalid storage scalar (wrong length) + auto invalid_storage = std::make_shared(storage_type_); + invalid_storage->is_valid = true; + invalid_storage->value = std::make_shared("123"); + AssertValidationFails(*invalid_storage); + scalar = ExtensionScalar(invalid_storage, type_); + AssertValidationFails(scalar); +} + } // namespace arrow diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 41914f436634a..e16b26e2465b6 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -378,7 +378,7 @@ bool DataType::Equals(const std::shared_ptr& other) const { size_t DataType::Hash() const { static constexpr size_t kHashSeed = 0; size_t result = kHashSeed; - internal::hash_combine(result, this->ComputeFingerprint()); + internal::hash_combine(result, this->fingerprint()); return result; } diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 91bffeb6ad4ac..684c97315bb33 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -147,11 +147,12 @@ def show_versions(): ListScalar, LargeListScalar, FixedSizeListScalar, Date32Scalar, Date64Scalar, Time32Scalar, Time64Scalar, + TimestampScalar, DurationScalar, BinaryScalar, LargeBinaryScalar, StringScalar, LargeStringScalar, FixedSizeBinaryScalar, DictionaryScalar, - MapScalar, UnionScalar, StructScalar, - TimestampScalar, DurationScalar) + MapScalar, StructScalar, UnionScalar, + ExtensionScalar) # Buffers, allocation from pyarrow.lib import (Buffer, ResizableBuffer, foreign_buffer, py_buffer, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 7dcde652a9575..9260cd28f8591 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -919,10 +919,15 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: c_bool Equals(const CSparseCSFTensor& other) cdef cppclass CScalar" arrow::Scalar": + CScalar(shared_ptr[CDataType]) + shared_ptr[CDataType] type c_bool is_valid + c_string ToString() const c_bool Equals(const CScalar& other) const + CStatus Validate() const + CStatus ValidateFull() const CResult[shared_ptr[CScalar]] CastTo(shared_ptr[CDataType] to) const cdef cppclass CScalarHash" arrow::Scalar::Hash": @@ -1016,12 +1021,16 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: CDictionaryScalar(CDictionaryScalarIndexAndDictionary value, shared_ptr[CDataType], c_bool is_valid) CDictionaryScalarIndexAndDictionary value + CResult[shared_ptr[CScalar]] GetEncodedValue() cdef cppclass CUnionScalar" arrow::UnionScalar"(CScalar): shared_ptr[CScalar] value int8_t type_code + cdef cppclass CExtensionScalar" arrow::ExtensionScalar"(CScalar): + shared_ptr[CScalar] value + shared_ptr[CScalar] MakeScalar[Value](Value value) cdef cppclass CConcatenateTablesOptions" arrow::ConcatenateTablesOptions": diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index 7953bd936210a..01debcf9835a3 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -39,7 +39,12 @@ cdef class Scalar(_Weakrefable): if type_id == _Type_NA: return _NULL - typ = _scalar_classes[type_id] + try: + typ = _scalar_classes[type_id] + except KeyError: + raise NotImplementedError( + "Wrapping scalar of type " + + frombytes(wrapped.get().type.get().ToString())) self = typ.__new__(typ) self.init(wrapped) @@ -838,6 +843,67 @@ cdef class UnionScalar(Scalar): return sp.type_code +cdef class ExtensionScalar(Scalar): + """ + Concrete class for Extension scalars. + """ + + @property + def value(self): + """ + Return storage value as a scalar. + """ + cdef CExtensionScalar* sp = self.wrapped.get() + return Scalar.wrap(sp.value) if sp.is_valid else None + + def as_py(self): + """ + Return this scalar as a Python object. + """ + # XXX should there be a hook to wrap the result in a custom class? + value = self.value + return None if value is None else value.as_py() + + @staticmethod + def from_storage(BaseExtensionType typ, value): + """ + Construct ExtensionScalar from type and storage value. + + Parameters + ---------- + typ: DataType + The extension type for the result scalar. + value: object + The storage value for the result scalar. + + Returns + ------- + ext_scalar : ExtensionScalar + """ + cdef: + shared_ptr[CExtensionScalar] sp_scalar + CExtensionScalar* ext_scalar + + if value is None: + storage = None + elif isinstance(value, Scalar): + if value.type != typ.storage_type: + raise TypeError("Incompatible storage type {0} " + "for extension type {1}" + .format(value.type, typ)) + storage = value + else: + storage = scalar(value, typ.storage_type) + + sp_scalar = make_shared[CExtensionScalar](typ.sp_type) + ext_scalar = sp_scalar.get() + ext_scalar.is_valid = storage is not None and storage.is_valid + if ext_scalar.is_valid: + ext_scalar.value = pyarrow_unwrap_scalar(storage) + check_status(ext_scalar.Validate()) + return pyarrow_wrap_scalar( sp_scalar) + + cdef dict _scalar_classes = { _Type_BOOL: BooleanScalar, _Type_UINT8: UInt8Scalar, @@ -872,6 +938,7 @@ cdef dict _scalar_classes = { _Type_DICTIONARY: DictionaryScalar, _Type_SPARSE_UNION: UnionScalar, _Type_DENSE_UNION: UnionScalar, + _Type_EXTENSION: ExtensionScalar, } diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index ba8366a43c69b..391149772ccf9 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -166,6 +166,14 @@ def test_ext_array_lifetime(): assert ref() is None +def test_ext_array_to_pylist(): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar", None], type=pa.binary(3)) + arr = pa.ExtensionArray.from_storage(ty, storage) + + assert arr.to_pylist() == [b"foo", b"bar", None] + + def test_ext_array_errors(): ty = ParamExtType(4) storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) @@ -193,6 +201,67 @@ def test_ext_array_equality(): assert not d.equals(f) +def test_ext_scalar_from_array(): + data = [b"0123456789abcdef", b"0123456789abcdef", + b"zyxwvutsrqponmlk", None] + storage = pa.array(data, type=pa.binary(16)) + ty1 = UuidType() + ty2 = ParamExtType(16) + + a = pa.ExtensionArray.from_storage(ty1, storage) + b = pa.ExtensionArray.from_storage(ty2, storage) + + scalars_a = list(a) + assert len(scalars_a) == 4 + + for s, val in zip(scalars_a, data): + assert isinstance(s, pa.ExtensionScalar) + assert s.is_valid == (val is not None) + assert s.type == ty1 + if val is not None: + assert s.value == pa.scalar(val, storage.type) + else: + assert s.value is None + assert s.as_py() == val + + scalars_b = list(b) + assert len(scalars_b) == 4 + + for sa, sb in zip(scalars_a, scalars_b): + assert sa.is_valid == sb.is_valid + assert sa.as_py() == sb.as_py() + assert sa != sb + + +def test_ext_scalar_from_storage(): + ty = UuidType() + + s = pa.ExtensionScalar.from_storage(ty, None) + assert isinstance(s, pa.ExtensionScalar) + assert s.type == ty + assert s.is_valid is False + assert s.value is None + + s = pa.ExtensionScalar.from_storage(ty, b"0123456789abcdef") + assert isinstance(s, pa.ExtensionScalar) + assert s.type == ty + assert s.is_valid is True + assert s.value == pa.scalar(b"0123456789abcdef", ty.storage_type) + + s = pa.ExtensionScalar.from_storage(ty, pa.scalar(None, ty.storage_type)) + assert isinstance(s, pa.ExtensionScalar) + assert s.type == ty + assert s.is_valid is False + assert s.value is None + + s = pa.ExtensionScalar.from_storage( + ty, pa.scalar(b"0123456789abcdef", ty.storage_type)) + assert isinstance(s, pa.ExtensionScalar) + assert s.type == ty + assert s.is_valid is True + assert s.value == pa.scalar(b"0123456789abcdef", ty.storage_type) + + def test_ext_array_pickling(): for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): ty = ParamExtType(3)