|
1 |
| -from typing import Optional |
| 1 | +from typing import Callable, Optional |
2 | 2 | from torchtyping import TensorType
|
3 | 3 |
|
4 | 4 | import math
|
5 | 5 | import torch
|
| 6 | +import torch.nn as nn |
6 | 7 | import torch.nn.functional as F
|
7 | 8 | from torch.distributions import constraints
|
8 | 9 |
|
@@ -119,3 +120,77 @@ def log_diag_jacobian(
|
119 | 120 | left = torch.zeros_like(x)
|
120 | 121 | right = torch.ones_like(x) * math.log(self.negative_slope)
|
121 | 122 | return torch.where(x >= 0., left, right)
|
| 123 | + |
| 124 | + |
| 125 | +class ContinuousActivation(ElementwiseTransform): |
| 126 | + """ |
| 127 | + Continuous activation function. |
| 128 | + At t=0 is identity, as t grows acts more like the original activation. |
| 129 | +
|
| 130 | + Args: |
| 131 | + activation (callable): An activation function (e.g, `torch.tanh`) |
| 132 | + temperature (float): How fast the activation becomes like the original |
| 133 | + as t increases. Higher temperature -> faster |
| 134 | + learnable (bool): Whether temperature is a learnable parameter |
| 135 | + """ |
| 136 | + def __init__( |
| 137 | + self, |
| 138 | + activation: Callable, |
| 139 | + temperature: int = 1., |
| 140 | + learnable: bool = False, |
| 141 | + ): |
| 142 | + super().__init__() |
| 143 | + self.activation = activation |
| 144 | + self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable) |
| 145 | + |
| 146 | + def forward( |
| 147 | + self, x: TensorType[..., 'dim'], t: TensorType[..., 1], **kwargs, |
| 148 | + ) -> TensorType[..., 'dim']: |
| 149 | + w = torch.tanh(self.temperature * t) |
| 150 | + return x * (1 - w) + self.activation(x) * w |
| 151 | + |
| 152 | + def inverse(self, y: TensorType, **kwargs) -> None: |
| 153 | + raise NotImplementedError |
| 154 | + |
| 155 | + def log_diag_jacobian(self, x: TensorType, y: TensorType, **kwargs) -> None: |
| 156 | + raise NotImplementedError |
| 157 | + |
| 158 | + def log_det_jacobian(self, x: TensorType, y: TensorType, **kwargs) -> None: |
| 159 | + raise NotImplementedError |
| 160 | + |
| 161 | + |
| 162 | +class ContinuousTanh(ElementwiseTransform): |
| 163 | + """ |
| 164 | + Continuous activation that is the solution to `dx(t)/dt = tanh(x(t))`. |
| 165 | +
|
| 166 | + Args: |
| 167 | + log_time (bool): Whether to use time in log domain. |
| 168 | + """ |
| 169 | + def __init__(self, log_time: bool = False): |
| 170 | + super().__init__() |
| 171 | + self.log_time = log_time |
| 172 | + |
| 173 | + def forward( |
| 174 | + self, x: TensorType[..., 'dim'], t: TensorType[..., 1], reverse: bool = False, **kwargs, |
| 175 | + ) -> TensorType[..., 'dim']: |
| 176 | + if self.log_time: |
| 177 | + t = torch.log1p(t) |
| 178 | + if reverse: |
| 179 | + t = -t |
| 180 | + return torch.asinh(t.exp() * torch.sinh(x)) |
| 181 | + |
| 182 | + def inverse( |
| 183 | + self, y: TensorType[..., 'dim'], t: TensorType[..., 1], **kwargs, |
| 184 | + ) -> TensorType[..., 'dim']: |
| 185 | + return self.forward(y, t, True) |
| 186 | + |
| 187 | + def log_diag_jacobian(self, x: TensorType, y: TensorType, t: TensorType, **kwargs) -> None: |
| 188 | + if self.log_time: |
| 189 | + t = torch.log1p(t) |
| 190 | + t = t.exp() |
| 191 | + diag_jac = t * torch.cosh(x) / torch.sqrt(torch.square(t * torch.sinh(x)) + 1) |
| 192 | + return diag_jac.log() |
| 193 | + |
| 194 | + def log_det_jacobian(self, x: TensorType, y: TensorType, t: TensorType, **kwargs) -> None: |
| 195 | + log_diag_jac = self.log_diag_jacobian(x, y, t) |
| 196 | + return log_diag_jac.sum(-1, keepdim=True) |
0 commit comments