@@ -10,11 +10,13 @@ let with_training t =
10
10
11
11
type activation =
12
12
| Relu
13
+ | Gelu
13
14
| Softmax
14
15
| Log_softmax
15
16
| Tanh
16
17
| Leaky_relu
17
18
| Sigmoid
19
+ | Hardsigmoid
18
20
19
21
let kaiming_uniform vs ~name ~shape ~a =
20
22
let fan_in =
@@ -32,10 +34,12 @@ let kaiming_uniform vs ~name ~shape ~a =
32
34
let apply ?activation ys =
33
35
match activation with
34
36
| Some Relu -> Tensor. relu ys
37
+ | Some Gelu -> Tensor. gelu ys
35
38
| Some Softmax -> Tensor. softmax ys ~dim: (- 1 ) ~dtype: (T Float )
36
39
| Some Log_softmax -> Tensor. log_softmax ys ~dim: (- 1 ) ~dtype: (T Float )
37
40
| Some Tanh -> Tensor. tanh ys
38
41
| Some Sigmoid -> Tensor. sigmoid ys
42
+ | Some Hardsigmoid -> Tensor. hardsigmoid ys
39
43
| Some Leaky_relu -> Tensor. leaky_relu ys
40
44
| None -> ys
41
45
;;
@@ -64,6 +68,58 @@ let linear vs ?activation ?(use_bias = true) ?w_init ~input_dim output_dim =
64
68
{ apply }
65
69
;;
66
70
71
+ let conv1d
72
+ vs
73
+ ~ksize :k1
74
+ ~stride
75
+ ?activation
76
+ ?(use_bias = true )
77
+ ?w_init
78
+ ?(padding = 0 )
79
+ ?(groups = 1 )
80
+ ~input_dim
81
+ output_dim
82
+ =
83
+ let w =
84
+ let shape = [ output_dim; input_dim / groups; k1 ] in
85
+ match w_init with
86
+ | None -> kaiming_uniform vs ~shape ~a: (Float. sqrt 5. ) ~name: " weight"
87
+ | Some init -> Var_store. new_var vs ~shape ~init ~name: " weight"
88
+ in
89
+ let b =
90
+ if use_bias
91
+ then Some (Var_store. new_var vs ~shape: [ output_dim ] ~init: Zeros ~name: " bias" )
92
+ else None
93
+ in
94
+ let apply xs = Tensor. conv1d xs w b ~padding ~stride ~groups |> apply ?activation in
95
+ { apply }
96
+ ;;
97
+
98
+ let conv1d_
99
+ vs
100
+ ~ksize
101
+ ~stride
102
+ ?activation
103
+ ?use_bias
104
+ ?w_init
105
+ ?(padding = 0 )
106
+ ?groups
107
+ ~input_dim
108
+ output_dim
109
+ =
110
+ conv1d
111
+ vs
112
+ ~ksize: (ksize)
113
+ ~stride: (stride)
114
+ ?use_bias
115
+ ?activation
116
+ ?w_init
117
+ ~padding: (padding)
118
+ ?groups
119
+ ~input_dim
120
+ output_dim
121
+ ;;
122
+
67
123
let conv2d
68
124
vs
69
125
~ksize :(k1 , k2 )
@@ -116,6 +172,118 @@ let conv2d_
116
172
output_dim
117
173
;;
118
174
175
+ let conv3d
176
+ vs
177
+ ~ksize :(k1 , k2 , k3 )
178
+ ~stride
179
+ ?activation
180
+ ?(use_bias = true )
181
+ ?w_init
182
+ ?(padding = 0 , 0 , 0 )
183
+ ?(groups = 1 )
184
+ ~input_dim
185
+ output_dim
186
+ =
187
+ let w =
188
+ let shape = [ output_dim; input_dim / groups; k1; k2; k3 ] in
189
+ match w_init with
190
+ | None -> kaiming_uniform vs ~shape ~a: (Float. sqrt 5. ) ~name: " weight"
191
+ | Some init -> Var_store. new_var vs ~shape ~init ~name: " weight"
192
+ in
193
+ let b =
194
+ if use_bias
195
+ then Some (Var_store. new_var vs ~shape: [ output_dim ] ~init: Zeros ~name: " bias" )
196
+ else None
197
+ in
198
+ let apply xs = Tensor. conv3d xs w b ~padding ~stride ~groups |> apply ?activation in
199
+ { apply }
200
+ ;;
201
+
202
+ let conv3d_
203
+ vs
204
+ ~ksize
205
+ ~stride
206
+ ?activation
207
+ ?use_bias
208
+ ?w_init
209
+ ?(padding = 0 )
210
+ ?groups
211
+ ~input_dim
212
+ output_dim
213
+ =
214
+ conv2d
215
+ vs
216
+ ~ksize: (ksize, ksize, ksize)
217
+ ~stride: (stride, stride, stride)
218
+ ?use_bias
219
+ ?activation
220
+ ?w_init
221
+ ~padding: (padding, padding, padding)
222
+ ?groups
223
+ ~input_dim
224
+ output_dim
225
+ ;;
226
+
227
+ let conv_transpose1d
228
+ vs
229
+ ~ksize :(k1 )
230
+ ~stride
231
+ ?activation
232
+ ?(use_bias = true )
233
+ ?(w_init = Var_store.Init. Normal { mean = 0. ; stdev = 0.1 })
234
+ ?(padding = 0 )
235
+ ?(output_padding = 0 )
236
+ ?(groups = 1 )
237
+ ~input_dim
238
+ output_dim
239
+ =
240
+ let w =
241
+ Var_store. new_var
242
+ vs
243
+ ~shape: [ input_dim; output_dim / groups; k1 ]
244
+ ~init: w_init
245
+ ~name: " weight"
246
+ in
247
+ let apply =
248
+ let b =
249
+ if use_bias
250
+ then Some (Var_store. new_var vs ~shape: [ output_dim ] ~init: Zeros ~name: " bias" )
251
+ else None
252
+ in
253
+ fun xs ->
254
+ Tensor. conv_transpose1d xs w b ~output_padding ~padding ~stride ~groups
255
+ |> apply ?activation
256
+ in
257
+ { apply }
258
+ ;;
259
+
260
+ let conv_transpose1d_
261
+ vs
262
+ ~ksize
263
+ ~stride
264
+ ?activation
265
+ ?use_bias
266
+ ?w_init
267
+ ?(padding = 0 )
268
+ ?(output_padding = 0 )
269
+ ?groups
270
+ ~input_dim
271
+ output_dim
272
+ =
273
+ conv_transpose1d
274
+ vs
275
+ ~ksize: (ksize)
276
+ ~stride: (stride)
277
+ ?activation
278
+ ?use_bias
279
+ ?w_init
280
+ ~padding: (padding)
281
+ ~output_padding: (output_padding)
282
+ ?groups
283
+ ~input_dim
284
+ output_dim
285
+ ;;
286
+
119
287
let conv_transpose2d
120
288
vs
121
289
~ksize :(k1 , k2 )
@@ -176,6 +344,66 @@ let conv_transpose2d_
176
344
output_dim
177
345
;;
178
346
347
+ let conv_transpose3d
348
+ vs
349
+ ~ksize :(k1 , k2 , k3 )
350
+ ~stride
351
+ ?activation
352
+ ?(use_bias = true )
353
+ ?(w_init = Var_store.Init. Normal { mean = 0. ; stdev = 0.1 })
354
+ ?(padding = 0 , 0 , 0 )
355
+ ?(output_padding = 0 , 0 , 0 )
356
+ ?(groups = 1 )
357
+ ~input_dim
358
+ output_dim
359
+ =
360
+ let w =
361
+ Var_store. new_var
362
+ vs
363
+ ~shape: [ input_dim; output_dim / groups; k1; k2; k3 ]
364
+ ~init: w_init
365
+ ~name: " weight"
366
+ in
367
+ let apply =
368
+ let b =
369
+ if use_bias
370
+ then Some (Var_store. new_var vs ~shape: [ output_dim ] ~init: Zeros ~name: " bias" )
371
+ else None
372
+ in
373
+ fun xs ->
374
+ Tensor. conv_transpose3d xs w b ~output_padding ~padding ~stride ~groups
375
+ |> apply ?activation
376
+ in
377
+ { apply }
378
+ ;;
379
+
380
+ let conv_transpose3d_
381
+ vs
382
+ ~ksize
383
+ ~stride
384
+ ?activation
385
+ ?use_bias
386
+ ?w_init
387
+ ?(padding = 0 )
388
+ ?(output_padding = 0 )
389
+ ?groups
390
+ ~input_dim
391
+ output_dim
392
+ =
393
+ conv_transpose2d
394
+ vs
395
+ ~ksize: (ksize, ksize, ksize)
396
+ ~stride: (stride, stride, stride)
397
+ ?activation
398
+ ?use_bias
399
+ ?w_init
400
+ ~padding: (padding, padding, padding)
401
+ ~output_padding: (output_padding, output_padding, output_padding)
402
+ ?groups
403
+ ~input_dim
404
+ output_dim
405
+ ;;
406
+
179
407
let batch_norm2d
180
408
vs
181
409
?(w_init = Var_store.Init. Uniform (0. , 1. ))
@@ -221,17 +449,60 @@ let layer_norm vs ?(cudnn_enable = true) ?(eps = 1e-5) dim =
221
449
let weight = Var_store. new_var vs ~name: " weight" ~shape: [ dim ] ~init: Ones in
222
450
let bias = Var_store. new_var vs ~name: " bias" ~shape: [ dim ] ~init: Zeros in
223
451
let apply xs =
224
- Tensor. layer_norm
452
+ Tensor. layer_norm2d
225
453
xs
226
- ~normalized_shape : [ dim ]
454
+ ~num_features : [ dim ]
227
455
~weight: (Some weight)
228
456
~bias: (Some bias)
229
457
~eps
458
+ ~momentum
459
+
230
460
~cudnn_enable
231
461
in
232
462
{ apply }
233
463
;;
234
464
465
+ let instance_norm2d
466
+ vs
467
+ ?(w_init = Var_store.Init. Uniform (0. , 1. ))
468
+ ?(cudnn_enabled = true )
469
+ ?(eps = 1e-5 )
470
+ ?(momentum = 0.1 )
471
+ output_dim
472
+ =
473
+ let w = Var_store. new_var vs ~shape: [ output_dim ] ~init: w_init ~name: " weight" in
474
+ let b = Var_store. new_var vs ~shape: [ output_dim ] ~init: Zeros ~name: " bias" in
475
+ let running_mean =
476
+ Var_store. new_var
477
+ vs
478
+ ~trainable: false
479
+ ~shape: [ output_dim ]
480
+ ~init: Zeros
481
+ ~name: " running_mean"
482
+ in
483
+ let running_var =
484
+ Var_store. new_var
485
+ vs
486
+ ~trainable: false
487
+ ~shape: [ output_dim ]
488
+ ~init: Ones
489
+ ~name: " running_var"
490
+ in
491
+ let apply_with_training xs ~is_training =
492
+ Tensor. instance_norm
493
+ xs
494
+ ~weight: (Some w)
495
+ ~bias: (Some b)
496
+ ~running_mean: (Some running_mean)
497
+ ~running_var: (Some running_var)
498
+ ~training: is_training
499
+ ~momentum
500
+ ~eps
501
+ ~cudnn_enabled
502
+ in
503
+ { apply_with_training }
504
+ ;;
505
+
235
506
let forward t xs = t.apply xs
236
507
237
508
let forward_ t_with_training xs ~is_training =
0 commit comments