Skip to content

Commit

Permalink
Start training new L12 SAE
Browse files Browse the repository at this point in the history
  • Loading branch information
neverix committed Jul 23, 2024
1 parent ffe3482 commit 2b3c8af
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions train_gemmas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# layers = [3, 4, 7, 8, 9]
# layers = [12, 14, 16, 11, 13, 15, 10, 6, 8, 9, 7, 5, 4, 3, 2, 1, 0]
# layers = [12, 11, 10, 6, 8, 9, 7, 5, 4, 3, 2, 1, 0][2:]
layers = [14, 16, 13, 15, 10, 6, 8, 9, 7, 5, 4, 3, 2, 1, 0]
# layers = [14, 16, 13, 15, 10, 6, 8, 9, 7, 5, 4, 3, 2, 1, 0]
layers = [12]
for layer_idx in range(len(layers)):
layer = layers[layer_idx]
restore = None # if layer_idx == 0 else f"weights/phi-l{layers[layer_idx-1]}-gated.safetensors"
Expand All @@ -12,7 +13,8 @@
# cf = 8
# cf = 14
cf = 1
for s, sae_type in ((8e-6 * cf, "transcoder"),):
# for s, sae_type in ((8e-6 * cf, "transcoder"),):
for s, sae_type in ((2e-5, "residual"), (2e-5, "attn_out"))[:1]:
min_sfc, max_sfc = fn(s), fn(s)
# min_sfc, max_sfc = fn(1e-5), fn(1e-5)
min_sfc, max_sfc = min_sfc, min_sfc
Expand Down

0 comments on commit 2b3c8af

Please sign in to comment.