Skip to content

Commit

Permalink
onehotvector/matrix behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhairya Gandhi committed Feb 2, 2019
1 parent 0469394 commit bd6158d
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Base.size(xs::OneHotVector) = (Int64(xs.of),)

Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix

Base.getindex(xs::OneHotVector, ::Colon) = xs

A::AbstractMatrix * b::OneHotVector = A[:, b.ix]

struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
Expand All @@ -22,6 +24,21 @@ Base.getindex(xs::OneHotMatrix, i::Integer, j::Integer) = xs.data[j][i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::Integer) = xs.data[i]
Base.getindex(xs::OneHotMatrix, ::Colon, i::AbstractArray) = OneHotMatrix(xs.height, xs.data[i])

Base.getindex(xs::Flux.OneHotMatrix, j::Base.UnitRange, i::Int) = xs.data[i][j]

Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = xs

# handle special case for when we want the entire column without allocating
function Base.getindex(xs::Flux.OneHotMatrix{T}, ot::Union{Base.Slice, Base.OneTo}, i::Int) where {T<:AbstractArray}
res = similar(xs, size(xs, 1), 1)
if length(ot) == size(xs, 1)
res = xs[:,i]
else
res = xs[1:length(ot),i]
end
res
end

A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]

Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
Expand Down Expand Up @@ -54,13 +71,15 @@ end
onehotbatch(ls, labels, unk...) =
OneHotMatrix(length(labels), [onehot(l, labels, unk...) for l in ls])

Base.argmax(xs::OneHotVector) = xs.ix

onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]

onecold(y::AbstractMatrix, labels...) =
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)

function argmax(xs...)
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
Base.depwarn("`argmax(...)` is deprecated, use `onecold(...)` instead.", :argmax)
return onecold(xs...)
end

Expand Down

0 comments on commit bd6158d

Please sign in to comment.