Skip to content

Commit 48c309f

Browse files
committed
Add continuous activations
1 parent 24b8e88 commit 48c309f

File tree

2 files changed

+105
-1
lines changed

2 files changed

+105
-1
lines changed

stribor/flows/activations.py

+76-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Optional
1+
from typing import Callable, Optional
22
from torchtyping import TensorType
33

44
import math
55
import torch
6+
import torch.nn as nn
67
import torch.nn.functional as F
78
from torch.distributions import constraints
89

@@ -119,3 +120,77 @@ def log_diag_jacobian(
119120
left = torch.zeros_like(x)
120121
right = torch.ones_like(x) * math.log(self.negative_slope)
121122
return torch.where(x >= 0., left, right)
123+
124+
125+
class ContinuousActivation(ElementwiseTransform):
126+
"""
127+
Continuous activation function.
128+
At t=0 is identity, as t grows acts more like the original activation.
129+
130+
Args:
131+
activation (callable): An activation function (e.g, `torch.tanh`)
132+
temperature (float): How fast the activation becomes like the original
133+
as t increases. Higher temperature -> faster
134+
learnable (bool): Whether temperature is a learnable parameter
135+
"""
136+
def __init__(
137+
self,
138+
activation: Callable,
139+
temperature: int = 1.,
140+
learnable: bool = False,
141+
):
142+
super().__init__()
143+
self.activation = activation
144+
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable)
145+
146+
def forward(
147+
self, x: TensorType[..., 'dim'], t: TensorType[..., 1], **kwargs,
148+
) -> TensorType[..., 'dim']:
149+
w = torch.tanh(self.temperature * t)
150+
return x * (1 - w) + self.activation(x) * w
151+
152+
def inverse(self, y: TensorType, **kwargs) -> None:
153+
raise NotImplementedError
154+
155+
def log_diag_jacobian(self, x: TensorType, y: TensorType, **kwargs) -> None:
156+
raise NotImplementedError
157+
158+
def log_det_jacobian(self, x: TensorType, y: TensorType, **kwargs) -> None:
159+
raise NotImplementedError
160+
161+
162+
class ContinuousTanh(ElementwiseTransform):
163+
"""
164+
Continuous activation that is the solution to `dx(t)/dt = tanh(x(t))`.
165+
166+
Args:
167+
log_time (bool): Whether to use time in log domain.
168+
"""
169+
def __init__(self, log_time: bool = False):
170+
super().__init__()
171+
self.log_time = log_time
172+
173+
def forward(
174+
self, x: TensorType[..., 'dim'], t: TensorType[..., 1], reverse: bool = False, **kwargs,
175+
) -> TensorType[..., 'dim']:
176+
if self.log_time:
177+
t = torch.log1p(t)
178+
if reverse:
179+
t = -t
180+
return torch.asinh(t.exp() * torch.sinh(x))
181+
182+
def inverse(
183+
self, y: TensorType[..., 'dim'], t: TensorType[..., 1], **kwargs,
184+
) -> TensorType[..., 'dim']:
185+
return self.forward(y, t, True)
186+
187+
def log_diag_jacobian(self, x: TensorType, y: TensorType, t: TensorType, **kwargs) -> None:
188+
if self.log_time:
189+
t = torch.log1p(t)
190+
t = t.exp()
191+
diag_jac = t * torch.cosh(x) / torch.sqrt(torch.square(t * torch.sinh(x)) + 1)
192+
return diag_jac.log()
193+
194+
def log_det_jacobian(self, x: TensorType, y: TensorType, t: TensorType, **kwargs) -> None:
195+
log_diag_jac = self.log_diag_jacobian(x, y, t)
196+
return log_diag_jac.sum(-1, keepdim=True)

stribor/test/test_activations.py

+29
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,32 @@ def test_activations(name):
1111

1212
check_inverse_transform(f, x)
1313
check_log_jacobian_determinant(f, x)
14+
15+
@pytest.mark.parametrize('func', [torch.tanh, torch.sigmoid, torch.relu])
16+
def test_continuous_activation(func):
17+
torch.manual_seed(123)
18+
19+
f = st.ContinuousActivation(func)
20+
x = torch.randn(3, 4, 5)
21+
22+
# Initial condition, identity at 0
23+
t = torch.zeros(3, 4, 1)
24+
assert torch.allclose(f(x, t), x)
25+
26+
# For high values of t, same behavior as before
27+
t = torch.ones(3, 4, 1) * 100
28+
assert torch.allclose(f(x, t), func(x))
29+
30+
f = st.ContinuousActivation(func, learnable=True)
31+
check_gradients_not_nan(f, x, t=t)
32+
33+
@pytest.mark.parametrize('log_time', [True, False])
34+
def test_continuous_tanh(log_time):
35+
torch.manual_seed(123)
36+
37+
f = st.ContinuousTanh(log_time)
38+
x = torch.randn(100, 5, 10)
39+
t = torch.rand_like(x[...,:1]) * 5
40+
41+
check_inverse_transform(f, x, t=t)
42+
check_log_jacobian_determinant(f, x, t=t)

0 commit comments

Comments
 (0)