Skip to content

Commit

Permalink
Better PriorityNoise
Browse files Browse the repository at this point in the history
  • Loading branch information
semjon00 committed May 25, 2024
1 parent 213ace2 commit 9d84d3b
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions components.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,15 @@ def forward(self, x: Tensor):


class PriorityNoise(RandMachine):
def __init__(self, k_min, k_max, fun, embed_dim):
def __init__(self, k_min, k_max, fun, input_dim):
super().__init__(k_min, k_max, fun)
assert fun == 'lin', 'PriorityNoise only supports lin mode'
self.embed_dim = embed_dim
self.importance = nn.Sequential(
nn.Linear(embed_dim, 1),
nn.Linear(input_dim, 1),
nn.Sigmoid(),
nn.Flatten()
)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.input_dim = input_dim

def forward(self, x: Tensor):
val = super().get_val()
Expand All @@ -73,8 +71,11 @@ def forward(self, x: Tensor):
k = torch.max(torch.zeros_like(p), (val - p) / (1 - p))
quality = k.unsqueeze(-1) + s.unsqueeze(-1) * unadjusted_quality

noise_levels = -torch.log(quality) # eps not needed sigmoid (as we scale it above) never touches 0 or 1
noise_levels = einops.repeat(noise_levels, '... -> ... a', a=self.embed_dim)
noise = torch.randn_like(noise_levels)
# eps for log is not needed, since sigmoid (as we scale it above) never touches 0 or 1
noise_levels = -torch.log(quality) * (torch.norm(x, dim=-1) / self.input_dim)
noise_levels = einops.repeat(noise_levels, '... -> ... a', a=self.input_dim)
noise = noise_levels * torch.randn_like(noise_levels)

return self.norm2(self.norm1(x) + noise_levels * noise)
x = x + noise
x = x / torch.norm(x, dim=-1).unsqueeze(-1) * self.input_dim
return x

0 comments on commit 9d84d3b

Please sign in to comment.