forked from FluxML/Flux.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.jl
155 lines (125 loc) · 3.45 KB
/
utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Arrays
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0/sum(dims))
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0/sum(dims))
ones(T::Type, dims...) = Base.ones(T, dims...)
zeros(T::Type, dims...) = Base.zeros(T, dims...)
ones(dims...) = Base.ones(Float32, dims...)
zeros(dims...) = Base.zeros(Float32, dims...)
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
"""
chunk(xs, n)
Split `xs` into `n` parts.
```julia
julia> chunk(1:10, 3)
3-element Array{Array{Int64,1},1}:
[1, 2, 3, 4]
[5, 6, 7, 8]
[9, 10]
```
"""
chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n)))
batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
"""
frequencies(xs)
Count the number of times that each element of `xs` appears.
```julia
julia> frequencies(['a','b','b'])
Dict{Char,Int64} with 2 entries:
'b' => 2
'a' => 1
```
"""
function frequencies(xs)
fs = Dict{eltype(xs),Int}()
for x in xs
fs[x] = get(fs, x, 0) + 1
end
return fs
end
head(x::Tuple) = reverse(Base.tail(reverse(x)))
squeezebatch(x) = reshape(x, head(size(x)))
"""
batch(xs)
Batch the arrays in `xs` into a single array.
```julia
julia> batch([[1,2,3],[4,5,6]])
3×2 Array{Int64,2}:
1 4
2 5
3 6
```
"""
function batch(xs)
data = first(xs) isa AbstractArray ?
similar(first(xs), size(first(xs))..., length(xs)) :
Vector{eltype(xs)}(undef, length(xs))
for (i, x) in enumerate(xs)
data[batchindex(data, i)...] = x
end
return data
end
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
"""
batchseq(seqs, pad)
Take a list of `N` sequences, and turn them into a single sequence where each
item is a batch of `N`. Short sequences will be padded by `pad`.
```julia
julia> batchseq([[1, 2, 3], [4, 5]], 0)
3-element Array{Array{Int64,1},1}:
[1, 4]
[2, 5]
[3, 0]
```
"""
function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
xs_ = [rpad(x, n, pad) for x in xs]
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
end
# Other
"""
Returns a function that when invoked, will only be triggered at most once
during `timeout` seconds. Normally, the throttled function will run
as much as it can, without ever going more than once per `wait` duration;
but if you'd like to disable the execution on the leading edge, pass
`leading=false`. To enable execution on the trailing edge, ditto.
"""
function throttle(f, timeout; leading=true, trailing=false)
cooldown = true
later = nothing
result = nothing
function throttled(args...; kwargs...)
yield()
if cooldown
if leading
result = f(args...; kwargs...)
else
later = () -> f(args...; kwargs...)
end
cooldown = false
@async try
while (sleep(timeout); later != nothing)
later()
later = nothing
end
finally
cooldown = true
end
elseif trailing
later = () -> (result = f(args...; kwargs...))
end
return result
end
end
"""
@jit ...
The `@jit` annotation can be applied to any code, and the code will be compiled
for performance.
@jit f(x) = @jit(x) + @jit(x)
Note that compilation happens regardless of the `@jit` macro, so it should only
be used for aesthetic purposes, or by recovering Python users.
"""
macro jit(ex)
esc(ex)
end