Skip to content

Commit

Permalink
fix JuliaLang#6672 (better reducedim type inference)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevengj committed May 27, 2014
1 parent 946e8cc commit 74ac755
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
20 changes: 17 additions & 3 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ function gen_reduction_body(N, f::Function)
end
end

reduction_init{T}(A::AbstractArray, region, initial::T) = fill!(similar(A,T,reduced_dims(A,region)), initial)
reduction_init{T}(A::AbstractArray, region, initial::T, Tr=T) = fill!(similar(A,Tr,reduced_dims(A,region)), initial)


### Pre-generated cases
Expand Down Expand Up @@ -143,11 +143,25 @@ minimum{T}(A::AbstractArray{T}, region) =

eval(ngenerate(:N, :(typeof(R)), :(_sum!{T,N}(R::AbstractArray, A::AbstractArray{T,N})), N->gen_reduction_body(N, +)))
sum!{R}(r::AbstractArray{R}, A::AbstractArray; init::Bool=true) = _sum!(initarray!(r, zero(R), init), A)
sum{T}(A::AbstractArray{T}, region) = _sum!(reduction_init(A, region, zero(T)+zero(T)), A)

eval(ngenerate(:N, :(typeof(R)), :(_prod!{T,N}(R::AbstractArray, A::AbstractArray{T,N})), N->gen_reduction_body(N, *)))
prod!{R}(r::AbstractArray{R}, A::AbstractArray; init::Bool=true) = _prod!(initarray!(r, one(R), init), A)
prod{T}(A::AbstractArray{T}, region) = _prod!(reduction_init(A, region, one(T)*one(T)), A)

for (f,init,op) in ((:sum,:zero,:+), (:prod,:one,:*))
_f = symbol(string("_",f,"!")) # _sum!, _prod!
@eval function $f{T}(A::AbstractArray{T}, region)
if method_exists($init, (Type{T},))
z = $op($init(T), $init(T))
Tr = typeof(z) == typeof($init(T)) ? T : typeof(z)
else
# TODO: handle more heterogeneous sums. e.g. sum(A, 1) where
# A is a Matrix{Any} with one column of numbers and one of vectors
z = $init($f(A))
Tr = typeof(z)
end
$_f(reduction_init(A, region, z, Tr), A)
end
end

prod(A::AbstractArray{Bool}, region) = error("use all() instead of prod() for boolean arrays")

Expand Down
6 changes: 6 additions & 0 deletions test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,9 @@ A = [1.0 3.0 6.0;
@test findmax(A, (1,)) == ([5.0 3.0 6.0], [2 3 5])
@test findmax(A, (2,)) == (reshape([6.0,5.0], 2, 1), reshape([5,2], 2, 1))
@test findmax(A, (1,2)) == (fill(6.0,1,1),fill(5,1,1))

# issue #6672
@test sum(Real[1 2 3; 4 5.3 7.1], 2) == reshape([6, 16.4], 2, 1)
@test std(FloatingPoint[1,2,3], 1) == [1.0]
@test sum({1 2;3 4},1) == [4 6]
@test sum(Vector{Int}[[1,2],[4,3]], 1)[1] == [5,5]

0 comments on commit 74ac755

Please sign in to comment.