Skip to content

Commit

Permalink
Configurable approximate topK
Browse files Browse the repository at this point in the history
  • Loading branch information
neverix committed Aug 2, 2024
1 parent e67e0ef commit 56400ff
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
10 changes: 5 additions & 5 deletions saex/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SAEConfig:

expansion_factor: float = 32
topk_k: Optional[int] = None
topk_approx: bool = False
topk_approx: Optional[float] = None

encoder_bias_init_mean: float = 0.0
use_encoder_bias: bool = False
Expand Down Expand Up @@ -269,8 +269,8 @@ def encode(self, activations: jax.Array, dot_enc=None):
if self.config.topk_k is not None:
og_shape = post_relu.shape
post_relu = post_relu.reshape(-1, post_relu.shape[-1])
if self.config.topk_approx:
values, indices = jax.lax.approx_max_k(post_relu, self.config.topk_k, aggregate_to_topk=True)
if self.config.topk_approx is not None:
values, indices = jax.lax.approx_max_k(post_relu, self.config.topk_k, aggregate_to_topk=True, recall_target=self.config.topk_approx)
else:
values, indices = jax.lax.top_k(post_relu, self.config.topk_k)
post_relu = jax.vmap(lambda a, v, i: jnp.zeros_like(a).at[i].set(v))(post_relu, values, indices)
Expand Down Expand Up @@ -581,10 +581,10 @@ def requantize(x):
if is_transpose:
x = x.T
og_shape = x.shape
x = x.reshape(-1, 32).astype(jnp.float32)
x = x.reshape(-1, 16).astype(jnp.bfloat16)
zero = x.min(axis=1, keepdims=True)
x = x - zero
mx = 127
mx = 255
scale = x.max(axis=1, keepdims=True) / mx
quants = x / scale
quants = quants.clip(0, mx).round()
Expand Down
7 changes: 4 additions & 3 deletions scripts/train_gemma_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def train(
batch_size=batch_size,
expansion_factor=16,
use_encoder_bias=True,
# use_encoder_bias=False,
remove_decoder_bias=False,
encoder_init_method="orthogonal",
decoder_init_method="pseudoinverse",
Expand Down Expand Up @@ -95,8 +96,8 @@ def train(
# weights_8bit=True,
use_aqt=False,
topk_k=None,
# topk_k=64,
# topk_approx=True,
# topk_k=128,
# topk_approx=0.5,
topk_approx=False,
),
# optimizer="adafactor",
Expand Down Expand Up @@ -148,7 +149,7 @@ def main(layer: int = 12, restore: Optional[str] = None, min_sfc=2e-5, max_sfc=5
n_devices=4, use_recip=is_recip,
# death_penalty_threshold="auto",
death_penalty_threshold=5e-6, # <= 70 (L0) / 90k (features)
train_steps=50_000,
train_steps=150_000,
# push_to_hub=("nev/gemma-2b-saex-test", f"l{layer}-{sae_type}-test-run-6"),
push_to_hub=("nev/gemma-2b-saex-test", f"it-l{layer}-{sae_type}-test-run-1"),
restore=restore,
Expand Down
7 changes: 4 additions & 3 deletions train_gemmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# layers = [12, 14, 16, 11, 13, 15, 10, 6, 8, 9, 7, 5, 4, 3, 2, 1, 0]
# layers = [12, 11, 10, 6, 8, 9, 7, 5, 4, 3, 2, 1, 0][2:]
# layers = [14, 16, 13, 15, 10, 6, 8, 9, 7, 5, 4, 3, 2, 1, 0]
layers = [12]
# layers = [12]
layers = [1, 2, 3, 4, 5]
for layer_idx in range(len(layers)):
layer = layers[layer_idx]
restore = None # if layer_idx == 0 else f"weights/phi-l{layers[layer_idx-1]}-gated.safetensors"
Expand All @@ -13,8 +14,8 @@
# cf = 8
# cf = 14
cf = 1
# for s, sae_type in ((8e-6 * cf, "transcoder"),):
for s, sae_type in ((2e-5, "residual"), (2e-5, "attn_out"))[:1]:
for s, sae_type in ((8e-6 * cf, "transcoder"),):
# for s, sae_type in ((2e-5, "residual"), (2e-5, "attn_out"))[:1]:
min_sfc, max_sfc = fn(s), fn(s)
# min_sfc, max_sfc = fn(1e-5), fn(1e-5)
min_sfc, max_sfc = min_sfc, min_sfc
Expand Down

0 comments on commit 56400ff

Please sign in to comment.