Skip to content

Commit

Permalink
Fail to fix training instabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
neverix committed Jul 8, 2024
1 parent 6513612 commit ea61a29
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 12 deletions.
21 changes: 14 additions & 7 deletions saex/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class SAEConfig:
use_model_parallel: bool = True
loss_scaling: float = 100.0
param_dtype: str = "float32"
misc_dtype: str = "float32"

class SAEOutput(NamedTuple):
losses: Dict[str, jax.Array]
Expand Down Expand Up @@ -126,11 +127,11 @@ def __init__(self, config, mesh: jshard.Mesh, key=None):
self.W_enc = jax.device_put(self.W_enc, sharding["W_enc"])

self.b_enc = jnp.full((self.d_hidden,), config.encoder_bias_init_mean,
device=sharding["b_enc"], dtype=jnp.float32)
self.s = jnp.ones((self.d_hidden,), device=sharding["s"], dtype=jnp.float32)
self.s_gate = jnp.zeros((self.d_hidden,), device=sharding["s"], dtype=jnp.float32)
self.b_gate = jnp.zeros((self.d_hidden,), device=sharding["b_enc"], dtype=jnp.float32)
self.b_dec = jnp.zeros((config.n_dimensions,), device=sharding["b_dec"], dtype=jnp.float32)
device=sharding["b_enc"], dtype=self.misc_dtype)
self.s = jnp.ones((self.d_hidden,), device=sharding["s"], dtype=self.misc_dtype)
self.s_gate = jnp.zeros((self.d_hidden,), device=sharding["s"], dtype=self.misc_dtype)
self.b_gate = jnp.zeros((self.d_hidden,), device=sharding["b_enc"], dtype=self.misc_dtype)
self.b_dec = jnp.zeros((config.n_dimensions,), device=sharding["b_dec"], dtype=self.misc_dtype)

if config.decoder_init_method == "kaiming":
self.W_dec = jax.nn.initializers.kaiming_uniform()(w_dec_subkey,
Expand Down Expand Up @@ -171,7 +172,9 @@ def __init__(self, config, mesh: jshard.Mesh, key=None):
self.avg_l0 = eqx.nn.StateIndex(jnp.zeros((self.d_hidden,),
device=state_sharding["avg_l0"]))
self.activated_buffer = eqx.nn.StateIndex(jnp.zeros((self.config.buffer_size, self.d_hidden),
device=state_sharding["activated_buffer"]))
device=state_sharding["activated_buffer"],
dtype=jnp.float32
))
self.ds_mean_norm = eqx.nn.StateIndex(jnp.array(self.config.min_norm, dtype=self.mean_norm_dtype))
self.mean_norm = jnp.array(self.config.min_norm, dtype=self.mean_norm_dtype)

Expand All @@ -186,6 +189,10 @@ def mean_norm_dtype(self):
# return jnp.float16
# return jnp.bfloat16

@property
def misc_dtype(self):
return getattr(jnp, self.config.misc_dtype)

@property
def scaler(self):
return jmp.StaticLossScale(self.config.loss_scaling)
Expand Down Expand Up @@ -596,7 +603,7 @@ def restore(self, weights_path: os.PathLike):
load_param = param
try:
self = eqx.tree_at(lambda s: getattr(s, param), self,
jax.device_put(f.get_tensor(load_param).astype(self.param_dtype if param.startswith("W") else jnp.float32),
jax.device_put(f.get_tensor(load_param).astype(self.param_dtype if param.startswith("W") else self.misc_dtype),
self.sharding[param]))
except safetensors.SafetensorError:
print("Can't load parameter", param)
Expand Down
12 changes: 7 additions & 5 deletions scripts/train_gemma_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def train(
anthropic_norm=True,
# norm_input="wes-clip",
# norm_input="wes",
# norm_input="wes-mean",
norm_input="wes-mean-fixed",
norm_input="wes-mean",
# norm_input="wes-mean-fixed",
# wes_clip=(0.25, 0.25),
death_penalty_threshold=death_penalty_threshold,
death_penalty_coefficient=0.25,
Expand All @@ -81,6 +81,7 @@ def train(
sparsity_tracking_epsilon=0.1,
is_gated=is_gated,
param_dtype="bfloat16",
misc_dtype="bfloat16",
# param_dtype="float16",
# param_dtype="float32",
),
Expand All @@ -93,8 +94,8 @@ def train(
save_buffer=False,
model_config=MicrlhfModelConfig(
tokenizer_path="alpindale/gemma-2b",
# gguf_path="weights/gemma-2b.gguf",
gguf_path="../micrlhf-progress/models/gemma-2b-it.gguf",
gguf_path="weights/gemma-2b.gguf",
# gguf_path="../micrlhf-progress/models/gemma-2b-it.gguf",
device_map=f"auto:mp={mp_devices}" if n_devices > 1 else "tpu:0",
use_flash=max_seq_len >= 128 and max_seq_len % 128 == 0,
layer=layer,
Expand Down Expand Up @@ -131,7 +132,8 @@ def main(layer: int = 12, restore: Optional[str] = None, min_sfc=2e-5, max_sfc=5
# death_penalty_threshold="auto",
death_penalty_threshold=5e-6, # <= 70 (L0) / 90k (features)
train_steps=150_000,
push_to_hub=("nev/gemma-2b-saex-test", f"it-l{layer}-{sae_type}-test-run-0"),
push_to_hub=("nev/gemma-2b-saex-test", f"l{layer}-{sae_type}-test-run-6"),
# push_to_hub=("nev/gemma-2b-saex-test", f"it-l{layer}-{sae_type}-test-run-0"),
restore=restore,
sae_type=sae_type,
)
Expand Down

0 comments on commit ea61a29

Please sign in to comment.