Skip to content

Commit

Permalink
Approximate top-K works well
Browse files Browse the repository at this point in the history
  • Loading branch information
neverix committed Jul 13, 2024
1 parent abf2e35 commit bf3eb74
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
8 changes: 6 additions & 2 deletions saex/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SAEConfig:

expansion_factor: float = 32
topk_k: Optional[int] = None
topk_approx: bool = False

encoder_bias_init_mean: float = 0.0
use_encoder_bias: bool = False
Expand Down Expand Up @@ -256,7 +257,10 @@ def encode(self, activations: jax.Array, dot_enc=None):
if self.config.topk_k is not None:
og_shape = post_relu.shape
post_relu = post_relu.reshape(-1, post_relu.shape[-1])
values, indices = jax.lax.top_k(post_relu, self.config.topk_k)
if self.config.topk_approx:
values, indices = jax.lax.approx_max_k(post_relu, self.config.topk_k, aggregate_to_topk=True)
else:
values, indices = jax.lax.top_k(post_relu, self.config.topk_k)
post_relu = jax.vmap(lambda a, v, i: jnp.zeros_like(a).at[i].set(v))(post_relu, values, indices)
post_relu = post_relu.reshape(og_shape)
if self.config.is_gated:
Expand Down Expand Up @@ -313,7 +317,7 @@ def __call__(self, activations: jax.Array, targets: jax.typing.ArrayLike, key: j
losses = {"reconstruction": reconstruction_loss, "sparsity": sparsity_loss}
if self.is_gated:
sg_gated = (lambda x: x) if self.config.anthropic_norm else jax.lax.stop_gradient
g_out = (jax.nn.relu(pre_relu) * self.s) @ sg_gated(self.W_dec) + sg_gated(self.b_dec)
g_out = ((pre_relu * active) * self.s) @ sg_gated(self.W_dec) + sg_gated(self.b_dec)
gated_loss = self.reconstruction_loss(g_out, targets).astype(jnp.float32)
losses = {**losses, "gated": gated_loss}
loss = loss + gated_loss.mean()
Expand Down
4 changes: 3 additions & 1 deletion scripts/train_gemma_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def train(
# weights_8bit=True,
use_aqt=False,
topk_k=None,
# topk_k=256,
# topk_k=64,
# topk_approx=True,
topk_approx=False,
),
# optimizer="adafactor",
optimizer="adam",
Expand Down
3 changes: 2 additions & 1 deletion train_gemmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
# cf = 8
# cf = 14
cf = 1
for s, sae_type in ((8e-6 * cf, "transcoder"),):
for s, sae_type in ((2e-5, "residual"), (2e-5, "attn_out")):
# for s, sae_type in ((8e-6 * cf, "transcoder"),):
min_sfc, max_sfc = fn(s), fn(s)
# min_sfc, max_sfc = fn(1e-5), fn(1e-5)
min_sfc, max_sfc = min_sfc, min_sfc
Expand Down

0 comments on commit bf3eb74

Please sign in to comment.