Skip to content

Commit

Permalink
Merge pull request FluxML#570 from avik-pal/ap/batchnorm_fixes
Browse files Browse the repository at this point in the history
Patches for default initializers
  • Loading branch information
MikeInnes authored Jan 28, 2019
2 parents bb2210f + 2f3ad56 commit 013b421
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
stride = 1, pad = 0) where {T,N} =
DepthwiseConv(σ, w, b, expand.(sub2(Val(N)), (stride, pad))...)

DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = initn,
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
stride = 1, pad = 0) where N =
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
stride = stride, pad = pad)

DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
stride::NTuple{N,Integer} = map(_->1,k),
pad::NTuple{N,Integer} = map(_->0,k)) where N =
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
Expand Down
2 changes: 1 addition & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ mutable struct BatchNorm{F,V,W,N}
end

BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) =
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true)

Expand Down
12 changes: 12 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,15 @@ end

@test size(m(r)) == (10, 5)
end

@testset "Depthwise Conv" begin
r = zeros(Float32, 28, 28, 3, 5)

m1 = DepthwiseConv((2, 2), 3=>5)

@test size(m1(r), 3) == 15

m2 = DepthwiseConv((2, 2), 3)

@test size(m2(r), 3) == 3
end

0 comments on commit 013b421

Please sign in to comment.