diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml index d2b8c0fab..d72858504 100644 --- a/.github/workflows/Documentation.yml +++ b/.github/workflows/Documentation.yml @@ -11,7 +11,7 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@latest with: version: '1' @@ -23,6 +23,6 @@ jobs: DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key run: julia --project=docs/ --code-coverage=user docs/make.jl - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v5 with: file: lcov.info diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index 647941115..1af21937e 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - julia-version: [1,1.6] + julia-version: [1] os: [ubuntu-latest] package: - {user: JuliaDiff, repo: SparseDiffTools.jl, group: Core} @@ -32,14 +32,14 @@ jobs: - {user: SciML, repo: DelayDiffEq.jl, group: Interface} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - uses: julia-actions/julia-buildpkg@latest - name: Clone Downstream - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: ${{ matrix.package.user }}/${{ matrix.package.repo }} path: downstream @@ -61,6 +61,6 @@ jobs: exit(0) # Exit immediately, as a success end - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v5 with: file: lcov.info diff --git a/.github/workflows/Invalidations.yml b/.github/workflows/Invalidations.yml index 4d0004e83..28b9ce2fa 100644 --- a/.github/workflows/Invalidations.yml +++ b/.github/workflows/Invalidations.yml @@ -19,12 +19,12 @@ jobs: - uses: julia-actions/setup-julia@v1 with: version: '1' - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-invalidations@v1 id: invs_pr - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: ref: ${{ github.event.repository.default_branch }} - uses: julia-actions/julia-buildpkg@v1 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8ebbcfdd9..cb9fe51a0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,13 +15,12 @@ jobs: - Core version: - '1' - - '1.6' steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} - - uses: actions/cache@v3 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: @@ -36,6 +35,6 @@ jobs: env: GROUP: ${{ matrix.group }} - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v5 with: file: lcov.info \ No newline at end of file diff --git a/Project.toml b/Project.toml index d91b1cbd8..c4a079e5b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,50 +1,74 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.4.8" +version = "7.19.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" -[compat] -Adapt = "3" -Requires = "1" -julia = "1.6" +[weakdeps] +BandedMatrices = "aae01518-5342-5314-be14-df237901396f" +BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [extensions] ArrayInterfaceBandedMatricesExt = "BandedMatrices" ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" ArrayInterfaceCUDAExt = "CUDA" +ArrayInterfaceCUDSSExt = "CUDSS" +ArrayInterfaceChainRulesCoreExt = "ChainRulesCore" +ArrayInterfaceChainRulesExt = "ChainRules" ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" +ArrayInterfaceReverseDiffExt = "ReverseDiff" +ArrayInterfaceSparseArraysExt = "SparseArrays" ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" ArrayInterfaceTrackerExt = "Tracker" +[compat] +Adapt = "4" +BandedMatrices = "1" +BlockBandedMatrices = "0.13" +CUDA = "5" +CUDSS = "0.2, 0.3, 0.4" +ChainRules = "1" +ChainRulesCore = "1" +ChainRulesTestUtils = "1" +GPUArraysCore = "0.1, 0.2" +LinearAlgebra = "1.10" +ReverseDiff = "1" +SparseArrays = "1.10" +StaticArraysCore = "1" +Tracker = "0.2" +julia = "1.10" + [extras] -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker"] - -[weakdeps] -BandedMatrices = "aae01518-5342-5314-be14-df237901396f" -BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" -GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" -StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "Tracker", "ReverseDiff", "ChainRules", "FillArrays", "ComponentArrays", "ChainRulesTestUtils"] diff --git a/README.md b/README.md index 812b7753d..9d20d7bc8 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,6 @@ [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://docs.sciml.ai/ArrayInterface/stable/) [![CI](https://github.com/JuliaArrays/ArrayInterface.jl/workflows/CI/badge.svg)](https://github.com/JuliaArrays/ArrayInterface.jl/actions?query=workflow%3ACI) -[![CI (Julia nightly)](https://github.com/JuliaArrays/ArrayInterface.jl/workflows/CI%20(Julia%20nightly)/badge.svg)](https://github.com/JuliaArrays/ArrayInterface.jl/actions?query=workflow%3A%22CI+%28Julia+nightly%29%22) [![Build status](https://badge.buildkite.com/a2db252d92478e1d7196ee7454004efdfb6ab59496cbac91a2.svg?branch=master)](https://buildkite.com/julialang/arrayinterface-dot-jl) [![codecov](https://codecov.io/gh/JuliaArrays/ArrayInterface.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/JuliaArrays/ArrayInterface.jl) diff --git a/docs/src/conversions.md b/docs/src/conversions.md index 51f9b1a37..129e3746d 100644 --- a/docs/src/conversions.md +++ b/docs/src/conversions.md @@ -3,8 +3,6 @@ The following ArrayInterface functions extend Julia's Array interface for how arrays can be converted to different forms. -## - ## Conversion Functions ```@docs @@ -12,4 +10,5 @@ ArrayInterface.aos_to_soa ArrayInterface.promote_eltype ArrayInterface.restructure ArrayInterface.safevec -``` \ No newline at end of file +ArrayInterface.has_trivial_array_constructor +``` diff --git a/ext/ArrayInterfaceBandedMatricesExt.jl b/ext/ArrayInterfaceBandedMatricesExt.jl index 358434f15..5eaeb0cdf 100644 --- a/ext/ArrayInterfaceBandedMatricesExt.jl +++ b/ext/ArrayInterfaceBandedMatricesExt.jl @@ -1,16 +1,19 @@ module ArrayInterfaceBandedMatricesExt +using ArrayInterface +using ArrayInterface: BandedMatrixIndex +using BandedMatrices +using LinearAlgebra -if isdefined(Base, :get_extension) - using ArrayInterface - using ArrayInterface: BandedMatrixIndex - using BandedMatrices -else - using ..ArrayInterface - using ..ArrayInterface: BandedMatrixIndex - using ..BandedMatrices -end +const TransOrAdjBandedMatrix = Union{ + Adjoint{T, <:BandedMatrix{T}}, + Transpose{T, <:BandedMatrix{T}}, +} where {T} +const AllBandedMatrix = Union{ + BandedMatrix{T}, + TransOrAdjBandedMatrix{T}, +} where {T} Base.firstindex(i::BandedMatrixIndex) = 1 Base.lastindex(i::BandedMatrixIndex) = i.count @@ -43,14 +46,14 @@ function _bandsize(bandind, rowsize, colsize) end end -function BandedMatrixIndex(rowsize, colsize, lowerbandwidth, upperbandwidth, isrow) +function ArrayInterface.BandedMatrixIndex(rowsize, colsize, lowerbandwidth, upperbandwidth, isrow) upperbandwidth > -lowerbandwidth || throw(ErrorException("Invalid Bandwidths")) - bandinds = upperbandwidth:-1:-lowerbandwidth + bandinds = upperbandwidth:-1:(-lowerbandwidth) bandsizes = [_bandsize(band, rowsize, colsize) for band in bandinds] BandedMatrixIndex(sum(bandsizes), rowsize, colsize, bandinds, bandsizes, isrow) end -function ArrayInterface.findstructralnz(x::BandedMatrices.BandedMatrix) +function ArrayInterface.findstructralnz(x::AllBandedMatrix) l, u = BandedMatrices.bandwidths(x) rowsize, colsize = Base.size(x) rowind = BandedMatrixIndex(rowsize, colsize, l, u, true) @@ -58,11 +61,11 @@ function ArrayInterface.findstructralnz(x::BandedMatrices.BandedMatrix) return (rowind, colind) end -ArrayInterface.has_sparsestruct(::Type{<:BandedMatrices.BandedMatrix}) = true -ArrayInterface.isstructured(::Type{<:BandedMatrices.BandedMatrix}) = true -ArrayInterface.fast_matrix_colors(::Type{<:BandedMatrices.BandedMatrix}) = true +ArrayInterface.has_sparsestruct(::Type{<:AllBandedMatrix}) = true +ArrayInterface.isstructured(::Type{<:AllBandedMatrix}) = true +ArrayInterface.fast_matrix_colors(::Type{<:AllBandedMatrix}) = true -function ArrayInterface.matrix_colors(A::BandedMatrices.BandedMatrix) +function ArrayInterface.matrix_colors(A::AllBandedMatrix) l, u = BandedMatrices.bandwidths(A) width = u + l + 1 return ArrayInterface._cycle(1:width, Base.size(A, 2)) diff --git a/ext/ArrayInterfaceBlockBandedMatricesExt.jl b/ext/ArrayInterfaceBlockBandedMatricesExt.jl index f66362130..e4e7bbae1 100644 --- a/ext/ArrayInterfaceBlockBandedMatricesExt.jl +++ b/ext/ArrayInterfaceBlockBandedMatricesExt.jl @@ -1,18 +1,9 @@ module ArrayInterfaceBlockBandedMatricesExt - - -if isdefined(Base, :get_extension) - using ArrayInterface - using ArrayInterface: BandedMatrixIndex - using BlockBandedMatrices - using BlockBandedMatrices.BlockArrays -else - using ..ArrayInterface - using ..ArrayInterface: BandedMatrixIndex - using ..BlockBandedMatrices - using ..BlockBandedMatrices.BlockArrays -end +using ArrayInterface +using ArrayInterface: BandedMatrixIndex +using BlockBandedMatrices +using BlockBandedMatrices.BlockArrays struct BlockBandedMatrixIndex <: ArrayInterface.MatrixIndex count::Int diff --git a/ext/ArrayInterfaceCUDAExt.jl b/ext/ArrayInterfaceCUDAExt.jl index f5d2a9507..ae4477d3e 100644 --- a/ext/ArrayInterfaceCUDAExt.jl +++ b/ext/ArrayInterfaceCUDAExt.jl @@ -1,16 +1,9 @@ module ArrayInterfaceCUDAExt using ArrayInterface - -if isdefined(Base, :get_extension) - using CUDA - using CUDA.CUSOLVER - using LinearAlgebra -else - using ..CUDA - using ..CUDA.CUSOLVER - using ..LinearAlgebra -end +using CUDA +using CUDA.CUSOLVER +using LinearAlgebra function ArrayInterface.lu_instance(A::CuMatrix{T}) where {T} if VERSION >= v"1.8-" diff --git a/ext/ArrayInterfaceCUDSSExt.jl b/ext/ArrayInterfaceCUDSSExt.jl new file mode 100644 index 000000000..01fb23953 --- /dev/null +++ b/ext/ArrayInterfaceCUDSSExt.jl @@ -0,0 +1,17 @@ +module ArrayInterfaceCUDSSExt + +using ArrayInterface +using CUDSS + +function ArrayInterface.lu_instance(A::CUDSS.CuSparseMatrixCSR) + ArrayInterface.LinearAlgebra.checksquare(A) + fact = CudssSolver(A, "G", 'F') + T = eltype(A) + n = size(A,1) + x = CudssMatrix(T, n) + b = CudssMatrix(T, n) + cudss("analysis", fact, x, b) + fact +end + +end diff --git a/ext/ArrayInterfaceChainRulesCoreExt.jl b/ext/ArrayInterfaceChainRulesCoreExt.jl new file mode 100644 index 000000000..6cf4c406f --- /dev/null +++ b/ext/ArrayInterfaceChainRulesCoreExt.jl @@ -0,0 +1,22 @@ +module ArrayInterfaceChainRulesCoreExt + +import ArrayInterface +import ChainRulesCore +import ChainRulesCore: unthunk, NoTangent, ZeroTangent, ProjectTo, @thunk + +function ChainRulesCore.rrule(::typeof(ArrayInterface.restructure), target, src) + projectT = ProjectTo(target) + function restructure_pullback(dt) + dt = unthunk(dt) + + f̄ = NoTangent() + t̄ = ZeroTangent() + s̄ = @thunk(projectT(ArrayInterface.restructure(src, dt))) + + f̄, t̄, s̄ + end + + return ArrayInterface.restructure(target, src), restructure_pullback +end + +end diff --git a/ext/ArrayInterfaceChainRulesExt.jl b/ext/ArrayInterfaceChainRulesExt.jl new file mode 100644 index 000000000..0a91bbb37 --- /dev/null +++ b/ext/ArrayInterfaceChainRulesExt.jl @@ -0,0 +1,8 @@ +module ArrayInterfaceChainRulesExt + +using ArrayInterface +using ChainRules: OneElement + +ArrayInterface.can_setindex(::Type{<:OneElement}) = false + +end \ No newline at end of file diff --git a/ext/ArrayInterfaceGPUArraysCoreExt.jl b/ext/ArrayInterfaceGPUArraysCoreExt.jl index 40d0fc1ac..79bbf6063 100644 --- a/ext/ArrayInterfaceGPUArraysCoreExt.jl +++ b/ext/ArrayInterfaceGPUArraysCoreExt.jl @@ -1,17 +1,9 @@ module ArrayInterfaceGPUArraysCoreExt - -if isdefined(Base, :get_extension) - using Adapt - using ArrayInterface - using LinearAlgebra: lu - import GPUArraysCore -else - using Adapt # Will cause problems for relocatability. - using ..ArrayInterface - using ..LinearAlgebra: lu - import ..GPUArraysCore -end +using Adapt +using ArrayInterface +using LinearAlgebra: lu +import GPUArraysCore ArrayInterface.fast_scalar_indexing(::Type{<:GPUArraysCore.AbstractGPUArray}) = false @inline ArrayInterface.allowed_getindex(x::GPUArraysCore.AbstractGPUArray, i...) = GPUArraysCore.@allowscalar(x[i...]) diff --git a/ext/ArrayInterfaceReverseDiffExt.jl b/ext/ArrayInterfaceReverseDiffExt.jl new file mode 100644 index 000000000..3a000def3 --- /dev/null +++ b/ext/ArrayInterfaceReverseDiffExt.jl @@ -0,0 +1,19 @@ +module ArrayInterfaceReverseDiffExt + +using ArrayInterface +import ReverseDiff + +ArrayInterface.ismutable(::Type{<:ReverseDiff.TrackedArray}) = false +ArrayInterface.ismutable(T::Type{<:ReverseDiff.TrackedReal}) = false +ArrayInterface.can_setindex(::Type{<:ReverseDiff.TrackedArray}) = false +ArrayInterface.fast_scalar_indexing(::Type{<:ReverseDiff.TrackedArray}) = false +function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N}) where {N} + y = length(x) > 1 ? reduce(vcat, x) : reduce(vcat, [x[1], x[1]])[1:1] + return reshape(y, size(x)) +end + +function ArrayInterface.restructure(x::Array, y::ReverseDiff.TrackedArray) + reshape(y, Base.size(x)...) +end + +end # module diff --git a/ext/ArrayInterfaceSparseArraysExt.jl b/ext/ArrayInterfaceSparseArraysExt.jl new file mode 100644 index 000000000..9f0e06b43 --- /dev/null +++ b/ext/ArrayInterfaceSparseArraysExt.jl @@ -0,0 +1,38 @@ +module ArrayInterfaceSparseArraysExt + +import ArrayInterface: buffer, has_sparsestruct, issingular, findstructralnz, bunchkaufman_instance, DEFAULT_CHOLESKY_PIVOT, cholesky_instance, ldlt_instance, lu_instance, qr_instance +using ArrayInterface.LinearAlgebra +using SparseArrays + +buffer(x::SparseMatrixCSC) = getfield(x, :nzval) +buffer(x::SparseVector) = getfield(x, :nzval) +has_sparsestruct(::Type{<:SparseMatrixCSC}) = true +issingular(A::AbstractSparseMatrix) = !issuccess(lu(A, check = false)) + +function findstructralnz(x::SparseMatrixCSC) + rowind, colind, _ = findnz(x) + (rowind, colind) +end + +function bunchkaufman_instance(A::SparseMatrixCSC{Tv, Ti}) where {Tv, Ti} + bunchkaufman(SparseMatrixCSC{Tv, Ti}(similar(A, 1, 1)), check = false) +end + +function cholesky_instance(A::Union{SparseMatrixCSC{Tv, Ti},Symmetric{<:Number,<:SparseMatrixCSC{Tv, Ti}}}, pivot = DEFAULT_CHOLESKY_PIVOT) where {Tv, Ti} + cholesky(SparseMatrixCSC{Tv, Ti}(similar(A, 1, 1)), check = false) +end + +function ldlt_instance(A::SparseMatrixCSC{Tv, Ti}) where {Tv, Ti} + ldlt(SparseMatrixCSC{Tv, Ti}(similar(A, 1, 1)), check=false) +end + +# Could be optimized but this should work for any real case. +function lu_instance(jac_prototype::SparseMatrixCSC{Tv, Ti}, pivot = DEFAULT_CHOLESKY_PIVOT) where {Tv, Ti} + lu(SparseMatrixCSC{Tv, Ti}(rand(1,1))) +end + +function qr_instance(jac_prototype::SparseMatrixCSC{Tv, Ti}, pivot = DEFAULT_CHOLESKY_PIVOT) where {Tv, Ti} + qr(SparseMatrixCSC{Tv, Ti}(rand(1,1))) +end + +end diff --git a/ext/ArrayInterfaceStaticArraysCoreExt.jl b/ext/ArrayInterfaceStaticArraysCoreExt.jl index 5c555f638..7b33dc802 100644 --- a/ext/ArrayInterfaceStaticArraysCoreExt.jl +++ b/ext/ArrayInterfaceStaticArraysCoreExt.jl @@ -1,35 +1,30 @@ module ArrayInterfaceStaticArraysCoreExt -if isdefined(Base, :get_extension) - import ArrayInterface - using LinearAlgebra - import StaticArraysCore -else - import ..ArrayInterface - using ..LinearAlgebra - import ..StaticArraysCore -end +import ArrayInterface +using LinearAlgebra +import StaticArraysCore: SArray, SMatrix, SVector, StaticMatrix, StaticArray, SizedArray, MArray, MMatrix -function ArrayInterface.undefmatrix(::StaticArraysCore.MArray{S, T, N, L}) where {S, T, N, L} - return StaticArraysCore.MMatrix{L, L, T, L*L}(undef) +function ArrayInterface.undefmatrix(::MArray{S, T, N, L}) where {S, T, N, L} + return MMatrix{L, L, T, L*L}(undef) end # SArray doesn't have an undef constructor and is going to be small enough that this is fine. -function ArrayInterface.undefmatrix(s::StaticArraysCore.SArray) +function ArrayInterface.undefmatrix(s::SArray) v = vec(s) return v.*v' end -ArrayInterface.ismutable(::Type{<:StaticArraysCore.StaticArray}) = false -ArrayInterface.ismutable(::Type{<:StaticArraysCore.MArray}) = true -ArrayInterface.ismutable(::Type{<:StaticArraysCore.SizedArray}) = true +ArrayInterface.ismutable(::Type{<:StaticArray}) = false +ArrayInterface.ismutable(::Type{<:MArray}) = true +ArrayInterface.ismutable(::Type{<:SizedArray}) = true -ArrayInterface.can_setindex(::Type{<:StaticArraysCore.StaticArray}) = false -ArrayInterface.buffer(A::Union{StaticArraysCore.SArray,StaticArraysCore.MArray}) = getfield(A, :data) +ArrayInterface.can_setindex(::Type{<:StaticArray}) = false +ArrayInterface.can_setindex(::Type{<:MArray}) = true +ArrayInterface.buffer(A::Union{SArray, MArray}) = getfield(A, :data) -function ArrayInterface.lu_instance(_A::StaticArraysCore.StaticMatrix{N,N}) where {N} - lu(one(_A)) +function ArrayInterface.lu_instance(A::StaticMatrix{N,N}) where {N} + lu(one(A)) end -ArrayInterface.restructure(x::StaticArraysCore.SArray{S}, y) where {S} = StaticArraysCore.SArray{S}(y) +ArrayInterface.restructure(x::SArray{S}, y) where {S} = SArray{S}(y) end diff --git a/ext/ArrayInterfaceTrackerExt.jl b/ext/ArrayInterfaceTrackerExt.jl index 4bb10c39c..5723d9f1f 100644 --- a/ext/ArrayInterfaceTrackerExt.jl +++ b/ext/ArrayInterfaceTrackerExt.jl @@ -1,12 +1,7 @@ module ArrayInterfaceTrackerExt -if isdefined(Base, :get_extension) - using ArrayInterface - import Tracker -else - using ..ArrayInterface - import ..Tracker -end +using ArrayInterface +import Tracker ArrayInterface.ismutable(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.ismutable(T::Type{<:Tracker.TrackedReal}) = false @@ -14,4 +9,11 @@ ArrayInterface.can_setindex(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x) +function ArrayInterface.restructure(x::Array, y::Tracker.TrackedArray) + reshape(y, Base.size(x)...) +end +function ArrayInterface.restructure(x::Array, y::Array{<:Tracker.TrackedReal}) + reshape(y, Base.size(x)...) +end + end # module diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index ed616a87f..b9e56d01b 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -1,8 +1,6 @@ module ArrayInterface using LinearAlgebra -using SparseArrays -using SuiteSparse @static if isdefined(Base, Symbol("@assume_effects")) using Base: @assume_effects @@ -40,22 +38,8 @@ julia> ArrayInterface.map_tuple_type(sqrt, Tuple{1,4,16}) ``` """ function map_tuple_type end -if VERSION >= v"1.8" - @inline function map_tuple_type(f, @nospecialize(T::Type)) - ntuple(i -> f(fieldtype(T, i)), Val{fieldcount(T)}()) - end -else - function map_tuple_type(f::F, ::Type{T}) where {F, T <: Tuple} - if @generated - t = Expr(:tuple) - for i in 1:fieldcount(T) - push!(t.args, :(f($(fieldtype(T, i))))) - end - Expr(:block, Expr(:meta, :inline), t) - else - Tuple(f(fieldtype(T, i)) for i in 1:fieldcount(T)) - end - end +@inline function map_tuple_type(f, @nospecialize(T::Type)) + ntuple(i -> f(fieldtype(T, i)), Val{fieldcount(T)}()) end """ @@ -78,50 +62,22 @@ julia> ArrayInterface.flatten_tuples((1, (2, (3,)))) ``` """ function flatten_tuples end -if VERSION >= v"1.8" - function flatten_tuples(t::Tuple) - fields = _new_field_positions(t) - ntuple(Val{nfields(fields)}()) do k - i, j = getfield(fields, k) - i = length(t) - i - @inbounds j === 0 ? getfield(t, i) : getfield(getfield(t, i), j) - end - end - _new_field_positions(::Tuple{}) = () - @nospecialize - function _new_field_positions(x::Tuple) - (_fl1(x, x[1])..., _new_field_positions(Base.tail(x))...) - end - _fl1(x::Tuple, x1::Tuple) = ntuple(Base.Fix1(tuple, length(x) - 1), Val(length(x1))) - _fl1(x::Tuple, x1) = ((length(x) - 1, 0),) - @specialize -else - @inline function flatten_tuples(t::Tuple) - if @generated - texpr = Expr(:tuple) - for i in 1:fieldcount(t) - p = fieldtype(t, i) - if p <: Tuple - for j in 1:fieldcount(p) - push!(texpr.args, :(@inbounds(getfield(getfield(t, $i), $j)))) - end - else - push!(texpr.args, :(@inbounds(getfield(t, $i)))) - end - end - Expr(:block, Expr(:meta, :inline), texpr) - else - _flatten(t) - end - end - _flatten(::Tuple{}) = () - @inline function _flatten(t::Tuple{Any, Vararg{Any}}) - (getfield(t, 1), _flatten(Base.tail(t))...) - end - @inline function _flatten(t::Tuple{Tuple, Vararg{Any}}) - (getfield(t, 1)..., _flatten(Base.tail(t))...) +function flatten_tuples(t::Tuple) + fields = _new_field_positions(t) + ntuple(Val{nfields(fields)}()) do k + i, j = getfield(fields, k) + i = length(t) - i + @inbounds j === 0 ? getfield(t, i) : getfield(getfield(t, i), j) end end +_new_field_positions(::Tuple{}) = () +@nospecialize +function _new_field_positions(x::Tuple) + (_fl1(x, x[1])..., _new_field_positions(Base.tail(x))...) +end +_fl1(x::Tuple, x1::Tuple) = ntuple(Base.Fix1(tuple, length(x) - 1), Val(length(x1))) +_fl1(x::Tuple, x1) = ((length(x) - 1, 0),) +@specialize """ parent_type(::Type{T}) -> Type @@ -163,8 +119,6 @@ Return the buffer data that `x` points to. Unlike `parent(x::AbstractArray)`, `b may not return another array type. """ buffer(x) = parent(x) -buffer(x::SparseMatrixCSC) = getfield(x, :nzval) -buffer(x::SparseVector) = getfield(x, :nzval) buffer(@nospecialize x::Union{Base.Slice, Base.IdentityUnitRange}) = getfield(x, :indices) """ @@ -299,11 +253,7 @@ ismutable(::Type{BigFloat}) = false ismutable(::Type{BigInt}) = false function ismutable(::Type{T}) where {T} if parent_type(T) <: T - @static if VERSION ≥ v"1.7.0-DEV.1208" - return Base.ismutabletype(T) - else - return T.mutable - end + return Base.ismutabletype(T) else return ismutable(parent_type(T)) end @@ -354,7 +304,6 @@ Determine whether `findstructralnz` accepts the parameter `x`. has_sparsestruct(x) = has_sparsestruct(typeof(x)) has_sparsestruct(::Type) = false has_sparsestruct(::Type{<:AbstractArray}) = false -has_sparsestruct(::Type{<:SparseMatrixCSC}) = true has_sparsestruct(::Type{<:Diagonal}) = true has_sparsestruct(::Type{<:Bidiagonal}) = true has_sparsestruct(::Type{<:Tridiagonal}) = true @@ -366,7 +315,6 @@ has_sparsestruct(::Type{<:SymTridiagonal}) = true Determine whether a given abstract matrix is singular. """ issingular(A::AbstractMatrix) = issingular(Matrix(A)) -issingular(A::AbstractSparseMatrix) = !issuccess(lu(A, check = false)) issingular(A::Matrix) = !issuccess(lu(A, check = false)) issingular(A::UniformScaling) = A.λ == 0 issingular(A::Diagonal) = any(iszero, A.diag) @@ -405,11 +353,6 @@ function findstructralnz(x::Union{Tridiagonal, SymTridiagonal}) (rowind, colind) end -function findstructralnz(x::SparseMatrixCSC) - rowind, colind, _ = findnz(x) - (rowind, colind) -end - abstract type ColoringAlgorithm end """ @@ -449,9 +392,6 @@ cheaply. function bunchkaufman_instance(A::Matrix{T}) where T return bunchkaufman(similar(A, 0, 0), check = false) end -function bunchkaufman_instance(A::SparseMatrixCSC) - bunchkaufman(sparse(similar(A, 1, 1)), check = false) -end """ bunchkaufman_instance(a::Number) -> a @@ -467,11 +407,7 @@ Returns the number. """ bunchkaufman_instance(a::Any) = bunchkaufman(a, check = false) -@static if VERSION < v"1.7beta" - const DEFAULT_CHOLESKY_PIVOT = Val(false) -else - const DEFAULT_CHOLESKY_PIVOT = LinearAlgebra.NoPivot() -end +const DEFAULT_CHOLESKY_PIVOT = LinearAlgebra.NoPivot() """ cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorization_instance @@ -479,14 +415,10 @@ cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorizati Returns an instance of the Cholesky factorization object with the correct type cheaply. """ -function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T} +function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T} return cholesky(similar(A, 0, 0), pivot, check = false) end -function cholesky_instance(A::Union{SparseMatrixCSC,Symmetric{<:Number,<:SparseMatrixCSC}}, pivot = DEFAULT_CHOLESKY_PIVOT) - cholesky(sparse(similar(A, 1, 1)), check = false) -end - """ cholesky_instance(a::Number, pivot = LinearAlgebra.RowMaximum()) -> a @@ -508,11 +440,12 @@ ldlt_instance(A) -> ldlt_factorization_instance Returns an instance of the LDLT factorization object with the correct type cheaply. """ -function ldlt_instance(A::Matrix{T}) where {T} - return ldlt(SymTridiagonal(similar(A, 0, 0))) +function ldlt_instance(A::Matrix{T}) where {T} + return ldlt_instance(SymTridiagonal(similar(A, 0, 0))) end -function ldlt_instance(A::SparseMatrixCSC) - ldlt(sparse(similar(A, 1, 1)), check=false) + +function ldlt_instance(A::SymTridiagonal{T,V}) where {T,V} + return LinearAlgebra.LDLt{T,SymTridiagonal{T,V}}(A) end """ @@ -543,19 +476,27 @@ function lu_instance(A::Matrix{T}) where {T} info = zero(LinearAlgebra.BlasInt) return LU{luT}(similar(A, 0, 0), ipiv, info) end -function lu_instance(jac_prototype::SparseMatrixCSC) - @static if VERSION < v"1.9.0-DEV.1622" - SuiteSparse.UMFPACK.UmfpackLU(Ptr{Cvoid}(), - Ptr{Cvoid}(), - 1, - 1, - jac_prototype.colptr[1:1], - jac_prototype.rowval[1:1], - jac_prototype.nzval[1:1], - 0) - else - SuiteSparse.UMFPACK.UmfpackLU(similar(jac_prototype, 1, 1)) - end + +function lu_instance(A::Symmetric{T}) where {T} + noUnitT = typeof(zero(T)) + luT = LinearAlgebra.lutype(noUnitT) + ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0) + info = zero(LinearAlgebra.BlasInt) + return LU{luT}(similar(A, 0, 0), ipiv, info) +end + +noalloc_diag(A::Diagonal) = A.diag +noalloc_diag(A::Tridiagonal) = A.d +noalloc_diag(A::SymTridiagonal) = A.dv + +function lu_instance(A::Union{Tridiagonal{T},Diagonal{T},SymTridiagonal{T}}) where {T} + noUnitT = typeof(zero(T)) + luT = LinearAlgebra.lutype(noUnitT) + ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0) + info = zero(LinearAlgebra.BlasInt) + vectype = similar(noalloc_diag(A), 0) + newA = Tridiagonal(vectype, vectype, vectype) + return LU{luT}(newA, ipiv, info) end """ @@ -574,30 +515,29 @@ specialized for new matrix types. lu_instance(a::Any) = lu(a, check = false) """ - qr_instance(A) -> qr_factorization_instance + qr_instance(A, pivot = NoPivot()) -> qr_factorization_instance Returns an instance of the QR factorization object with the correct type cheaply. """ -function qr_instance(A::Matrix{T}) where {T} - LinearAlgebra.QRCompactWY(zeros(T,0,0),zeros(T,0,0)) +function qr_instance(A::Matrix{T},pivot = DEFAULT_CHOLESKY_PIVOT) where {T} + if pivot === DEFAULT_CHOLESKY_PIVOT + LinearAlgebra.QRCompactWY(zeros(T,0,0),zeros(T,0,0)) + else + LinearAlgebra.QRPivoted(zeros(T,0,0),zeros(T,0),zeros(Int,0)) + end end -function qr_instance(A::Matrix{BigFloat}) +function qr_instance(A::Matrix{BigFloat},pivot = DEFAULT_CHOLESKY_PIVOT) LinearAlgebra.QR(zeros(BigFloat,0,0),zeros(BigFloat,0)) end -# Could be optimized but this should work for any real case. -function qr_instance(jac_prototype::SparseMatrixCSC) - qr(sparse(rand(1,1))) -end - """ qr_instance(a::Number) -> a Returns the number. """ -qr_instance(a::Number) = a +qr_instance(a::Number, pivot = DEFAULT_CHOLESKY_PIVOT) = a """ qr_instance(a::Any) -> qr(a) @@ -605,7 +545,7 @@ qr_instance(a::Number) = a Slow fallback which gets the instance via factorization. Should get specialized for new matrix types. """ -qr_instance(a::Any) = qr(a)# check = false) +qr_instance(a::Any, pivot = DEFAULT_CHOLESKY_PIVOT) = qr(a)# check = false) """ svd_instance(A) -> qr_factorization_instance @@ -614,7 +554,7 @@ Returns an instance of the SVD factorization object with the correct type cheaply. """ function svd_instance(A::Matrix{T}) where {T} - LinearAlgebra.SVD(zeros(T,0,0),zeros(T,0),zeros(T,0,0)) + LinearAlgebra.SVD(zeros(T,0,0),zeros(real(T),0),zeros(T,0,0)) end """ @@ -1030,18 +970,28 @@ ensures_sorted(@nospecialize( T::Type{<:AbstractRange})) = true ensures_sorted(T::Type) = is_forwarding_wrapper(T) ? ensures_sorted(parent_type(T)) : false ensures_sorted(@nospecialize(x)) = ensures_sorted(typeof(x)) -## Extensions - -import Requires -@static if !isdefined(Base, :get_extension) - function __init__() - Requires.@require BandedMatrices = "aae01518-5342-5314-be14-df237901396f" begin include("../ext/ArrayInterfaceBandedMatricesExt.jl") end - Requires.@require BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" begin include("../ext/ArrayInterfaceBlockBandedMatricesExt.jl") end - Requires.@require GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" begin include("../ext/ArrayInterfaceGPUArraysCoreExt.jl") end - Requires.@require StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin include("../ext/ArrayInterfaceStaticArraysCoreExt.jl") end - Requires.@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin include("../ext/ArrayInterfaceCUDAExt.jl") end - Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/ArrayInterfaceTrackerExt.jl") end - end +""" + has_trivial_array_constructor(T::Type, args...) -> Bool + +Returns `true` if an object of type `T` can be constructed using the collection of `args` + +Note: This checks if a compatible `convert` methood exists between `T` and `args` + +# Examples: + +```julia +julia> ca = ComponentVector((x = rand(3), y = rand(4),)) +ComponentVector{Float64}(x = [0.6549137106381634, 0.37555505280294565, 0.8521039568665254], y = [0.40314196291239024, 0.35484725607638834, 0.6580528978034597, 0.10055508457632167]) + +julia> ArrayInterface.has_trivial_array_constructor(typeof(ca), ones(6)) +true + +julia> ArrayInterface.has_trivial_array_constructor(typeof(cv), (x = rand(6),)) +false +``` +""" +function has_trivial_array_constructor(::Type{T}, args...) where T + applicable(convert, T, args...) end end # module diff --git a/test/ad.jl b/test/ad.jl new file mode 100644 index 000000000..3c61873e6 --- /dev/null +++ b/test/ad.jl @@ -0,0 +1,41 @@ +using ArrayInterface, ReverseDiff, Tracker, Test +x = ReverseDiff.track([4.0]) +@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray +x = reshape([ReverseDiff.track(rand(1, 1, 1))[1]], 1, 1, 1) +@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray +@test ndims(ArrayInterface.aos_to_soa(x)) == 3 +x = reduce(vcat, ReverseDiff.track([4.0,4.0])) +@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray +x = [ReverseDiff.track([4.0])[1]] +@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray +x = reduce(vcat, ReverseDiff.track([4.0,4.0])) +x = [x[1],x[2]] +@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray + +x = Tracker.TrackedArray([4.0]) +@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray +x = [Tracker.TrackedArray([4.0])[1]] +@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray +x = Tracker.TrackedArray([4.0,4.0]) +@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray +x = reduce(vcat, Tracker.TrackedArray([4.0,4.0])) +x = [x[1],x[2]] +@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray + +x = rand(4) +y = Tracker.TrackedReal.(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Array +@test eltype(ArrayInterface.restructure(x, y)) <: Tracker.TrackedReal +@test size(ArrayInterface.restructure(x, y)) == (4,) +y = Tracker.TrackedArray(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray +@test size(ArrayInterface.restructure(x, y)) == (4,) + +x = rand(4) +y = ReverseDiff.track(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa ReverseDiff.TrackedArray +@test size(ArrayInterface.restructure(x, y)) == (4,) +y = ReverseDiff.track.(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Array +@test eltype(ArrayInterface.restructure(x, y)) <: ReverseDiff.TrackedReal +@test size(ArrayInterface.restructure(x, y)) == (4,) diff --git a/test/bandedmatrices.jl b/test/bandedmatrices.jl index a6142efd0..94626aac3 100644 --- a/test/bandedmatrices.jl +++ b/test/bandedmatrices.jl @@ -1,19 +1,50 @@ - using ArrayInterface using BandedMatrices using Test -B=BandedMatrix(Ones(5,5), (-1,2)) -B[band(1)].=[1,2,3,4] -B[band(2)].=[5,6,7] +function checkequal(idx1::ArrayInterface.BandedMatrixIndex, + idx2::ArrayInterface.BandedMatrixIndex) + return idx1.rowsize == idx2.rowsize && idx1.colsize == idx2.colsize && + idx1.bandinds == idx2.bandinds && idx1.bandsizes == idx2.bandsizes && + idx1.isrow == idx2.isrow && idx1.count == idx2.count +end + +B = BandedMatrix(Ones(5, 5), (-1, 2)) +B[band(1)] .= [1, 2, 3, 4] +B[band(2)] .= [5, 6, 7] @test ArrayInterface.has_sparsestruct(B) -rowind,colind=ArrayInterface.findstructralnz(B) -@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,1,2,3,4] -B=BandedMatrix(Ones(4,6), (-1,2)) -B[band(1)].=[1,2,3,4] -B[band(2)].=[5,6,7,8] -rowind,colind=ArrayInterface.findstructralnz(B) -@test [B[rowind[i],colind[i]] for i in 1:length(rowind)]==[5,6,7,8,1,2,3,4] +rowind, colind = ArrayInterface.findstructralnz(B) +@test [B[rowind[i], colind[i]] for i in 1:length(rowind)] == [5, 6, 7, 1, 2, 3, 4] +B = BandedMatrix(Ones(4, 6), (-1, 2)) +B[band(1)] .= [1, 2, 3, 4] +B[band(2)] .= [5, 6, 7, 8] +rowind, colind = ArrayInterface.findstructralnz(B) +@test [B[rowind[i], colind[i]] for i in 1:length(rowind)] == [5, 6, 7, 8, 1, 2, 3, 4] @test ArrayInterface.isstructured(typeof(B)) @test ArrayInterface.fast_matrix_colors(typeof(B)) +for op in (adjoint, transpose) + B = BandedMatrix(Ones(5, 5), (-1, 2)) + B[band(1)] .= [1, 2, 3, 4] + B[band(2)] .= [5, 6, 7] + B′ = op(B) + @test ArrayInterface.has_sparsestruct(B′) + rowind′, colind′ = ArrayInterface.findstructralnz(B′) + rowind′′, colind′′ = ArrayInterface.findstructralnz(BandedMatrix(B′)) + @test checkequal(rowind′, rowind′′) + @test checkequal(colind′, colind′′) + + B = BandedMatrix(Ones(4, 6), (-1, 2)) + B[band(1)] .= [1, 2, 3, 4] + B[band(2)] .= [5, 6, 7, 8] + B′ = op(B) + rowind′, colind′ = ArrayInterface.findstructralnz(B′) + rowind′′, colind′′ = ArrayInterface.findstructralnz(BandedMatrix(B′)) + @test checkequal(rowind′, rowind′′) + @test checkequal(colind′, colind′′) + + @test ArrayInterface.isstructured(typeof(B′)) + @test ArrayInterface.fast_matrix_colors(typeof(B′)) + + @test ArrayInterface.matrix_colors(B′) == ArrayInterface.matrix_colors(BandedMatrix(B′)) +end diff --git a/test/blockbandedmatrices.jl b/test/blockbandedmatrices.jl index 9f29716d0..5f1588687 100644 --- a/test/blockbandedmatrices.jl +++ b/test/blockbandedmatrices.jl @@ -1,11 +1,12 @@ using ArrayInterface using BlockBandedMatrices +using FillArrays using Test BB=BlockBandedMatrix(Ones(10,10),[1,2,3,4],[4,3,2,1],(1,0)) -BB[Block(1,1)].=[1 2 3 4] -BB[Block(2,1)].=[5 6 7 8;9 10 11 12] +BB[BlockBandedMatrices.Block(1,1)].=[1 2 3 4] +BB[BlockBandedMatrices.Block(2,1)].=[5 6 7 8;9 10 11 12] rowind,colind=ArrayInterface.findstructralnz(BB) @test [BB[rowind[i],colind[i]] for i in 1:length(rowind)]== [1,5,9,2,6,10,3,7,11,4,8,12, diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 000000000..759a55bee --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,13 @@ +using ArrayInterface, ChainRules, Test +using ComponentArrays, ChainRulesTestUtils, StaticArrays + +x = ChainRules.OneElement(3.0, (3, 3), (1:4, 1:4)) + +@test !ArrayInterface.can_setindex(x) +@test !ArrayInterface.can_setindex(typeof(x)) + +arr = ComponentArray(a = 1.0, b = [2.0, 3.0], c = (; a = 4.0, b = 5.0), d = SVector{2}(6.0, 7.0)) +b = zeros(length(arr)) + +ChainRulesTestUtils.test_rrule(ArrayInterface.restructure, arr, b) +ChainRulesTestUtils.test_rrule(ArrayInterface.restructure, b, arr) diff --git a/test/core.jl b/test/core.jl index bd0cd6cf3..003f7b5db 100644 --- a/test/core.jl +++ b/test/core.jl @@ -2,6 +2,7 @@ using ArrayInterface using ArrayInterface: zeromatrix, undefmatrix import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, parent_type, zeromatrix +using ComponentArrays using LinearAlgebra using Random using SparseArrays @@ -71,6 +72,11 @@ end A = sprand(50, 50, 0.5) @test lu_instance(A) isa typeof(lu(A)) @test lu_instance(1) === 1 + + @test lu_instance(Symmetric(rand(3,3))) isa typeof(lu(Symmetric(rand(4,4)))) + @test lu_instance(Tridiagonal(rand(3),rand(4),rand(3))) isa typeof(lu(Tridiagonal(rand(3),rand(4),rand(3)))) + @test lu_instance(SymTridiagonal(rand(4),rand(3))) isa typeof(lu(SymTridiagonal(rand(4),rand(3)))) + @test lu_instance(Diagonal(rand(4))) isa typeof(lu(Diagonal(rand(4)))) end @testset "ismutable" begin @@ -261,16 +267,18 @@ end end @testset "linearalgebra instances" begin - for A in [rand(2,2), rand(Float32,2,2), rand(BigFloat,2,2)] + for A in [rand(2,2), rand(Float32,2,2), rand(BigFloat,2,2), rand(ComplexF32,2,2), rand(ComplexF64,2,2)] @test ArrayInterface.lu_instance(A) isa typeof(lu(A)) @test ArrayInterface.qr_instance(A) isa typeof(qr(A)) if !(eltype(A) <: BigFloat) - @test ArrayInterface.bunchkaufman_instance(A' * A) isa typeof(bunchkaufman(A' * A)) @test ArrayInterface.cholesky_instance(A' * A) isa typeof(cholesky(A' * A)) - @test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A))) @test ArrayInterface.svd_instance(A) isa typeof(svd(A)) + if !(eltype(A) <: Union{ComplexF16,ComplexF32,ComplexF64}) + @test ArrayInterface.bunchkaufman_instance(A' * A) isa typeof(bunchkaufman(A' * A)) + @test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A))) + end end end @@ -282,4 +290,9 @@ end end @test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A))) end -end \ No newline at end of file +end + +@testset "Array conversion" begin + cv = ComponentVector((x = rand(3), y = rand(3))) + @test ArrayInterface.has_trivial_array_constructor(typeof(cv), rand(6)) +end diff --git a/test/runtests.jl b/test/runtests.jl index ec3493fd8..8a5d7b363 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,11 +13,13 @@ end @time @safetestset "BandedMatrices" begin include("bandedmatrices.jl") end @time @safetestset "BlockBandedMatrices" begin include("blockbandedmatrices.jl") end @time @safetestset "Core" begin include("core.jl") end - @time @safetestset "StaticArraysCore" begin include("staticarrayscore.jl") end + @time @safetestset "AD Integration" begin include("ad.jl") end + @time @safetestset "StaticArrays" begin include("staticarrays.jl") end + @time @safetestset "ChainRules" begin include("chainrules.jl") end end if GROUP == "GPU" activate_gpu_env() @time @safetestset "CUDA" begin include("gpu/cuda.jl") end end -end \ No newline at end of file +end diff --git a/test/staticarrayscore.jl b/test/staticarrays.jl similarity index 96% rename from test/staticarrayscore.jl rename to test/staticarrays.jl index 420a05c74..cbbba184b 100644 --- a/test/staticarrayscore.jl +++ b/test/staticarrays.jl @@ -11,6 +11,7 @@ x = @SVector [1,2,3] x = @MVector [1,2,3] @test ArrayInterface.ismutable(x) == true @test ArrayInterface.ismutable(view(x, 1:2)) == true +@test ArrayInterface.can_setindex(typeof(x)) == true A = @SMatrix(randn(5, 5)) @test ArrayInterface.lu_instance(A) isa typeof(lu(A))