-
-
Notifications
You must be signed in to change notification settings - Fork 648
/
Copy pathpact.py
34 lines (24 loc) · 815 Bytes
/
pact.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Implementation taken from https://discuss.pytorch.org/t/evaluator-returns-nan/107972/3
# Ref: https://arxiv.org/abs/1805.06085
import torch
import torch.nn as nn
class PACTClip(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
return torch.clamp(x, 0, alpha.data)
@staticmethod
def backward(ctx, dy):
x, alpha = ctx.saved_tensors
dx = dy.clone()
dx[x < 0] = 0
dx[x > alpha] = 0
dalpha = dy.clone()
dalpha[x <= alpha] = 0
return dx, torch.sum(dalpha)
class PACTReLU(nn.Module):
def __init__(self, alpha=6.0):
super().__init__()
self.alpha = nn.Parameter(torch.tensor(alpha))
def forward(self, x):
return PACTClip.apply(x, self.alpha)