Skip to content

Commit

Permalink
Implement onnx MeanVarianceNormalization (tinygrad#943)
Browse files Browse the repository at this point in the history
  • Loading branch information
M4tthewDE authored Jun 6, 2023
1 parent 3bb38c3 commit 664d6cc
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion extra/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,9 @@ def CastLike(input, target_type):
assert isinstance(target_type, Tensor), "can only CastLike Tensor"
return input

def Binarizer(input, threshold=0.0): return input > threshold
def Binarizer(input, threshold=0.0): return input > threshold

def MeanVarianceNormalization(input, axis=(0, 2, 3)):
data_mean = input.mean(axis=axis, keepdim=True)
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
return (input - data_mean) / (std + 1e-9)

0 comments on commit 664d6cc

Please sign in to comment.