Skip to content

Commit 9491599

Browse files
Merge pull request janestreet#4 from AndreSlavescu/torch.nn-mappings
additional torch.nn mappings for ocaml-torch
2 parents daa45dc + 3750802 commit 9491599

File tree

2 files changed

+436
-2
lines changed

2 files changed

+436
-2
lines changed

src/torch/layer.ml

+273-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ let with_training t =
1010

1111
type activation =
1212
| Relu
13+
| Gelu
1314
| Softmax
1415
| Log_softmax
1516
| Tanh
1617
| Leaky_relu
1718
| Sigmoid
19+
| Hardsigmoid
1820

1921
let kaiming_uniform vs ~name ~shape ~a =
2022
let fan_in =
@@ -32,10 +34,12 @@ let kaiming_uniform vs ~name ~shape ~a =
3234
let apply ?activation ys =
3335
match activation with
3436
| Some Relu -> Tensor.relu ys
37+
| Some Gelu -> Tensor.gelu ys
3538
| Some Softmax -> Tensor.softmax ys ~dim:(-1) ~dtype:(T Float)
3639
| Some Log_softmax -> Tensor.log_softmax ys ~dim:(-1) ~dtype:(T Float)
3740
| Some Tanh -> Tensor.tanh ys
3841
| Some Sigmoid -> Tensor.sigmoid ys
42+
| Some Hardsigmoid -> Tensor.hardsigmoid ys
3943
| Some Leaky_relu -> Tensor.leaky_relu ys
4044
| None -> ys
4145
;;
@@ -64,6 +68,58 @@ let linear vs ?activation ?(use_bias = true) ?w_init ~input_dim output_dim =
6468
{ apply }
6569
;;
6670

71+
let conv1d
72+
vs
73+
~ksize:k1
74+
~stride
75+
?activation
76+
?(use_bias = true)
77+
?w_init
78+
?(padding = 0)
79+
?(groups = 1)
80+
~input_dim
81+
output_dim
82+
=
83+
let w =
84+
let shape = [ output_dim; input_dim / groups; k1 ] in
85+
match w_init with
86+
| None -> kaiming_uniform vs ~shape ~a:(Float.sqrt 5.) ~name:"weight"
87+
| Some init -> Var_store.new_var vs ~shape ~init ~name:"weight"
88+
in
89+
let b =
90+
if use_bias
91+
then Some (Var_store.new_var vs ~shape:[ output_dim ] ~init:Zeros ~name:"bias")
92+
else None
93+
in
94+
let apply xs = Tensor.conv1d xs w b ~padding ~stride ~groups |> apply ?activation in
95+
{ apply }
96+
;;
97+
98+
let conv1d_
99+
vs
100+
~ksize
101+
~stride
102+
?activation
103+
?use_bias
104+
?w_init
105+
?(padding = 0)
106+
?groups
107+
~input_dim
108+
output_dim
109+
=
110+
conv1d
111+
vs
112+
~ksize:(ksize)
113+
~stride:(stride)
114+
?use_bias
115+
?activation
116+
?w_init
117+
~padding:(padding)
118+
?groups
119+
~input_dim
120+
output_dim
121+
;;
122+
67123
let conv2d
68124
vs
69125
~ksize:(k1, k2)
@@ -116,6 +172,118 @@ let conv2d_
116172
output_dim
117173
;;
118174

175+
let conv3d
176+
vs
177+
~ksize:(k1, k2, k3)
178+
~stride
179+
?activation
180+
?(use_bias = true)
181+
?w_init
182+
?(padding = 0, 0, 0)
183+
?(groups = 1)
184+
~input_dim
185+
output_dim
186+
=
187+
let w =
188+
let shape = [ output_dim; input_dim / groups; k1; k2; k3 ] in
189+
match w_init with
190+
| None -> kaiming_uniform vs ~shape ~a:(Float.sqrt 5.) ~name:"weight"
191+
| Some init -> Var_store.new_var vs ~shape ~init ~name:"weight"
192+
in
193+
let b =
194+
if use_bias
195+
then Some (Var_store.new_var vs ~shape:[ output_dim ] ~init:Zeros ~name:"bias")
196+
else None
197+
in
198+
let apply xs = Tensor.conv3d xs w b ~padding ~stride ~groups |> apply ?activation in
199+
{ apply }
200+
;;
201+
202+
let conv3d_
203+
vs
204+
~ksize
205+
~stride
206+
?activation
207+
?use_bias
208+
?w_init
209+
?(padding = 0)
210+
?groups
211+
~input_dim
212+
output_dim
213+
=
214+
conv2d
215+
vs
216+
~ksize:(ksize, ksize, ksize)
217+
~stride:(stride, stride, stride)
218+
?use_bias
219+
?activation
220+
?w_init
221+
~padding:(padding, padding, padding)
222+
?groups
223+
~input_dim
224+
output_dim
225+
;;
226+
227+
let conv_transpose1d
228+
vs
229+
~ksize:(k1)
230+
~stride
231+
?activation
232+
?(use_bias = true)
233+
?(w_init = Var_store.Init.Normal { mean = 0.; stdev = 0.1 })
234+
?(padding = 0)
235+
?(output_padding = 0)
236+
?(groups = 1)
237+
~input_dim
238+
output_dim
239+
=
240+
let w =
241+
Var_store.new_var
242+
vs
243+
~shape:[ input_dim; output_dim / groups; k1 ]
244+
~init:w_init
245+
~name:"weight"
246+
in
247+
let apply =
248+
let b =
249+
if use_bias
250+
then Some (Var_store.new_var vs ~shape:[ output_dim ] ~init:Zeros ~name:"bias")
251+
else None
252+
in
253+
fun xs ->
254+
Tensor.conv_transpose1d xs w b ~output_padding ~padding ~stride ~groups
255+
|> apply ?activation
256+
in
257+
{ apply }
258+
;;
259+
260+
let conv_transpose1d_
261+
vs
262+
~ksize
263+
~stride
264+
?activation
265+
?use_bias
266+
?w_init
267+
?(padding = 0)
268+
?(output_padding = 0)
269+
?groups
270+
~input_dim
271+
output_dim
272+
=
273+
conv_transpose1d
274+
vs
275+
~ksize:(ksize)
276+
~stride:(stride)
277+
?activation
278+
?use_bias
279+
?w_init
280+
~padding:(padding)
281+
~output_padding:(output_padding)
282+
?groups
283+
~input_dim
284+
output_dim
285+
;;
286+
119287
let conv_transpose2d
120288
vs
121289
~ksize:(k1, k2)
@@ -176,6 +344,66 @@ let conv_transpose2d_
176344
output_dim
177345
;;
178346

347+
let conv_transpose3d
348+
vs
349+
~ksize:(k1, k2, k3)
350+
~stride
351+
?activation
352+
?(use_bias = true)
353+
?(w_init = Var_store.Init.Normal { mean = 0.; stdev = 0.1 })
354+
?(padding = 0, 0, 0)
355+
?(output_padding = 0, 0, 0)
356+
?(groups = 1)
357+
~input_dim
358+
output_dim
359+
=
360+
let w =
361+
Var_store.new_var
362+
vs
363+
~shape:[ input_dim; output_dim / groups; k1; k2; k3 ]
364+
~init:w_init
365+
~name:"weight"
366+
in
367+
let apply =
368+
let b =
369+
if use_bias
370+
then Some (Var_store.new_var vs ~shape:[ output_dim ] ~init:Zeros ~name:"bias")
371+
else None
372+
in
373+
fun xs ->
374+
Tensor.conv_transpose3d xs w b ~output_padding ~padding ~stride ~groups
375+
|> apply ?activation
376+
in
377+
{ apply }
378+
;;
379+
380+
let conv_transpose3d_
381+
vs
382+
~ksize
383+
~stride
384+
?activation
385+
?use_bias
386+
?w_init
387+
?(padding = 0)
388+
?(output_padding = 0)
389+
?groups
390+
~input_dim
391+
output_dim
392+
=
393+
conv_transpose2d
394+
vs
395+
~ksize:(ksize, ksize, ksize)
396+
~stride:(stride, stride, stride)
397+
?activation
398+
?use_bias
399+
?w_init
400+
~padding:(padding, padding, padding)
401+
~output_padding:(output_padding, output_padding, output_padding)
402+
?groups
403+
~input_dim
404+
output_dim
405+
;;
406+
179407
let batch_norm2d
180408
vs
181409
?(w_init = Var_store.Init.Uniform (0., 1.))
@@ -221,17 +449,60 @@ let layer_norm vs ?(cudnn_enable = true) ?(eps = 1e-5) dim =
221449
let weight = Var_store.new_var vs ~name:"weight" ~shape:[ dim ] ~init:Ones in
222450
let bias = Var_store.new_var vs ~name:"bias" ~shape:[ dim ] ~init:Zeros in
223451
let apply xs =
224-
Tensor.layer_norm
452+
Tensor.layer_norm2d
225453
xs
226-
~normalized_shape:[ dim ]
454+
~num_features:[ dim ]
227455
~weight:(Some weight)
228456
~bias:(Some bias)
229457
~eps
458+
~momentum
459+
230460
~cudnn_enable
231461
in
232462
{ apply }
233463
;;
234464

465+
let instance_norm2d
466+
vs
467+
?(w_init = Var_store.Init.Uniform (0., 1.))
468+
?(cudnn_enabled = true)
469+
?(eps = 1e-5)
470+
?(momentum = 0.1)
471+
output_dim
472+
=
473+
let w = Var_store.new_var vs ~shape:[ output_dim ] ~init:w_init ~name:"weight" in
474+
let b = Var_store.new_var vs ~shape:[ output_dim ] ~init:Zeros ~name:"bias" in
475+
let running_mean =
476+
Var_store.new_var
477+
vs
478+
~trainable:false
479+
~shape:[ output_dim ]
480+
~init:Zeros
481+
~name:"running_mean"
482+
in
483+
let running_var =
484+
Var_store.new_var
485+
vs
486+
~trainable:false
487+
~shape:[ output_dim ]
488+
~init:Ones
489+
~name:"running_var"
490+
in
491+
let apply_with_training xs ~is_training =
492+
Tensor.instance_norm
493+
xs
494+
~weight:(Some w)
495+
~bias:(Some b)
496+
~running_mean:(Some running_mean)
497+
~running_var:(Some running_var)
498+
~training:is_training
499+
~momentum
500+
~eps
501+
~cudnn_enabled
502+
in
503+
{ apply_with_training }
504+
;;
505+
235506
let forward t xs = t.apply xs
236507

237508
let forward_ t_with_training xs ~is_training =

0 commit comments

Comments
 (0)