Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
neverix committed Jul 23, 2024
2 parents 211f610 + bf3eb74 commit e67e0ef
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
8 changes: 6 additions & 2 deletions saex/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SAEConfig:

expansion_factor: float = 32
topk_k: Optional[int] = None
topk_approx: bool = False

encoder_bias_init_mean: float = 0.0
use_encoder_bias: bool = False
Expand Down Expand Up @@ -268,7 +269,10 @@ def encode(self, activations: jax.Array, dot_enc=None):
if self.config.topk_k is not None:
og_shape = post_relu.shape
post_relu = post_relu.reshape(-1, post_relu.shape[-1])
values, indices = jax.lax.top_k(post_relu, self.config.topk_k)
if self.config.topk_approx:
values, indices = jax.lax.approx_max_k(post_relu, self.config.topk_k, aggregate_to_topk=True)
else:
values, indices = jax.lax.top_k(post_relu, self.config.topk_k)
post_relu = jax.vmap(lambda a, v, i: jnp.zeros_like(a).at[i].set(v))(post_relu, values, indices)
post_relu = post_relu.reshape(og_shape)
if self.config.is_gated:
Expand Down Expand Up @@ -327,7 +331,7 @@ def __call__(self, activations: jax.Array, targets: jax.typing.ArrayLike, key: j
losses = {"reconstruction": reconstruction_loss, "sparsity": sparsity_loss}
if self.is_gated:
sg_gated = (lambda x: x) if self.config.anthropic_norm else jax.lax.stop_gradient
g_out = (jax.nn.relu(pre_relu) * self.s) @ sg_gated(self.W_dec) + sg_gated(self.b_dec)
g_out = ((pre_relu * active) * self.s) @ sg_gated(self.W_dec) + sg_gated(self.b_dec)
gated_loss = self.reconstruction_loss(g_out, targets).astype(jnp.float32)
losses = {**losses, "gated": gated_loss}
loss = loss + gated_loss.mean()
Expand Down
4 changes: 3 additions & 1 deletion scripts/train_gemma_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def train(
# weights_8bit=True,
use_aqt=False,
topk_k=None,
# topk_k=256,
# topk_k=64,
# topk_approx=True,
topk_approx=False,
),
# optimizer="adafactor",
optimizer="adam",
Expand Down

0 comments on commit e67e0ef

Please sign in to comment.