Skip to content

Commit

Permalink
fixes; isort; 8-bit weights work, 8-bit Adam doesn't work
Browse files Browse the repository at this point in the history
  • Loading branch information
neverix committed Jul 9, 2024
1 parent 72a3f1a commit 953cf8d
Show file tree
Hide file tree
Showing 13 changed files with 754 additions and 669 deletions.
1,248 changes: 654 additions & 594 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions saex/buffer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from functools import partial
from typing import List

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.sharding as jshard
import numpy as np
from jax.sharding import PartitionSpec as P

import equinox as eqx
from safetensors import safe_open
from safetensors.flax import save_file

Expand Down
3 changes: 1 addition & 2 deletions saex/haver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Optional

import equinox as eqx
import jax
import jax.sharding as jshard
import numpy as np

import equinox as eqx

from . import utils
from .buffer import ActivationBuffer
from .iterable_dataset import IterableDatasetConfig, create_iterable_dataset
Expand Down
5 changes: 2 additions & 3 deletions saex/models/micrlhf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from dataclasses import dataclass, field, replace
from typing import List, Literal

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.sharding as jshard
import jax.tree_util

import equinox as eqx
from micrlhf.llama import LlamaBlock, LlamaTransformer
from micrlhf.flash import flashify
from micrlhf.llama import LlamaBlock, LlamaTransformer
from micrlhf.scan import sequential_to_scan
from penzai import pz
from penzai.toolshed import jit_wrapper
Expand Down
3 changes: 1 addition & 2 deletions saex/models/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
from functools import partial
from typing import Any, Dict, List

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.sharding as jshard
import numpy as np

import equinox as eqx
import transformers
from oryx.core import plant, sow

Expand Down
50 changes: 35 additions & 15 deletions saex/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
from tempfile import NamedTemporaryFile
from typing import Dict, Literal, NamedTuple, Optional, Tuple, Union

import equinox as eqx
import jax
import jmp
import jax.numpy as jnp
import jax.sharding as jshard
import jmp
import numpy as np
import safetensors
from jax.experimental.checkify import checkify
from jax.sharding import PartitionSpec as P

import equinox as eqx
import safetensors
from jaxtyping import Array, Float, PyTree
from safetensors.flax import save_file

Expand Down Expand Up @@ -67,6 +66,7 @@ class SAEConfig:
param_dtype: str = "float32"
bias_dtype: str = "float32"
misc_dtype: str = "float32"
weights_8bit: bool = False

class SAEOutput(NamedTuple):
losses: Dict[str, jax.Array]
Expand Down Expand Up @@ -453,8 +453,11 @@ def adjust_mean(b_dec, opt_state):
b_dec = b_dec + jnp.mean(last_target - last_output.output, axis=0)
elif self.config.decoder_bias_init_method == "geom_median":
b_dec = b_dec + geometric_median(last_target - last_output.output)
opt_state = eqx.tree_at(lambda s: get_adam(s).mu.b_dec, opt_state, jnp.zeros_like(get_adam(opt_state).mu.b_dec))
opt_state = eqx.tree_at(lambda s: get_adam(s).nu.b_dec, opt_state, jnp.zeros_like(get_adam(opt_state).nu.b_dec))
try:
opt_state = eqx.tree_at(lambda s: get_adam(s).mu.b_dec, opt_state, jnp.zeros_like(get_adam(opt_state).mu.b_dec))
opt_state = eqx.tree_at(lambda s: get_adam(s).nu.b_dec, opt_state, jnp.zeros_like(get_adam(opt_state).nu.b_dec))
except AttributeError:
pass
return b_dec, opt_state
updated_b_dec, opt_state = jax.lax.switch(jnp.astype(step == 1, jnp.int32), (lambda *a: a, adjust_mean), updated.b_dec, opt_state)
updated = eqx.tree_at(lambda s: s.b_dec, updated, updated_b_dec)
Expand Down Expand Up @@ -490,16 +493,19 @@ def resample(updated, state, opt_state):
updated = eqx.tree_at(lambda s: s.W_dec, updated, W_dec)
updated = eqx.tree_at(lambda s: s.s, updated, jnp.where(dead, 1, updated.s))

# reset momentum and variance
adam = get_adam(opt_state)

opt_state = eqx.tree_at(lambda s: get_adam(s).mu.W_enc, opt_state, jnp.where(dead[None, :], 0, adam.mu.W_enc))
opt_state = eqx.tree_at(lambda s: get_adam(s).mu.b_enc, opt_state, jnp.where(dead, 0, adam.mu.b_enc))
# opt_state = eqx.tree_at(lambda s: s[adam_idx].mu.W_dec, opt_state, jnp.where(dead[:, None], 0, opt_state[adam_idx].mu.W_dec))
try:
# reset momentum and variance
adam = get_adam(opt_state)

opt_state = eqx.tree_at(lambda s: get_adam(s).mu.W_enc, opt_state, jnp.where(dead[None, :], 0, adam.mu.W_enc))
opt_state = eqx.tree_at(lambda s: get_adam(s).mu.b_enc, opt_state, jnp.where(dead, 0, adam.mu.b_enc))
# opt_state = eqx.tree_at(lambda s: s[adam_idx].mu.W_dec, opt_state, jnp.where(dead[:, None], 0, opt_state[adam_idx].mu.W_dec))

opt_state = eqx.tree_at(lambda s: get_adam(s).nu.W_enc, opt_state, jnp.where(dead[None, :], 0, adam.nu.W_enc))
opt_state = eqx.tree_at(lambda s: get_adam(s).nu.b_enc, opt_state, jnp.where(dead, 0, adam.nu.b_enc))
# opt_state = eqx.tree_at(lambda s: s[adam_idx].nu.W_dec, opt_state, jnp.where(dead[:, None], 0, opt_state[adam_idx].nu.W_dec))
opt_state = eqx.tree_at(lambda s: get_adam(s).nu.W_enc, opt_state, jnp.where(dead[None, :], 0, adam.nu.W_enc))
opt_state = eqx.tree_at(lambda s: get_adam(s).nu.b_enc, opt_state, jnp.where(dead, 0, adam.nu.b_enc))
# opt_state = eqx.tree_at(lambda s: s[adam_idx].nu.W_dec, opt_state, jnp.where(dead[:, None], 0, opt_state[adam_idx].nu.W_dec))
except AttributeError:
pass

state = state.set(self.time_since_fired, jnp.where(dead, 0, state.get(self.time_since_fired)))
return updated, state, opt_state
Expand All @@ -509,6 +515,20 @@ def resample(updated, state, opt_state):
(lambda *a: a, resample),
updated_params, state, opt_state)
updated = eqx.combine(updated_params, updated_static)

def requantize(x):
# simulating 9-bit quantization
og_shape = x.shape
x = x.reshape(-1, 32)
scale = jnp.abs(x).max(-1, keepdims=True) / 127
quants = x / scale
quants = quants.clip(-127, 127).round()

return (quants * scale).reshape(og_shape)

if self.config.weights_8bit:
for selector in (lambda s: s.W_enc, lambda s: s.W_dec):
updated = eqx.tree_at(selector, updated, replace_fn=requantize)

return updated, state, opt_state

Expand Down
7 changes: 4 additions & 3 deletions saex/train_script.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import gc
from itertools import chain
from typing import List, Union
import wandb
import jax
import gc

import jax
import jax_smi

import wandb

from .trainer_cache import BufferCacher, BufferTrainer, BufferTrainerConfig


Expand Down
35 changes: 22 additions & 13 deletions saex/trainer_cache.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import json
from dataclasses import dataclass, is_dataclass
from functools import partial
from typing import Optional, Tuple
from typing import Literal, Optional, Tuple

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.sharding as jshard
import jax_smi
import numpy as np
import optax
from jax.sharding import PartitionSpec as P
from tqdm.auto import tqdm, trange
from micrlhf.adam_8bit import scale_by_adam_8bit

import equinox as eqx
import jax_smi
import optax
import wandb
from tqdm.auto import tqdm, trange

from . import utils
from .buffer import ActivationBuffer
Expand Down Expand Up @@ -58,6 +59,8 @@ class BufferTrainerConfig:
cache_batch_size: int = 16
cache_every_steps: int = 1
cache_acc: int = 1

optimizer: Literal["adam", "adafactor", "adam8"] = "adam"

buffer_max_samples: int = 0
buffer_dtype: str = "float32"
Expand Down Expand Up @@ -210,7 +213,11 @@ def get_final_params():
n_cycles = int(self.config.train_iterations / scheduler_cycle)
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(self.config.lr, b1=self.config.beta1, b2=self.config.beta2),

optax.adam(self.config.lr, b1=self.config.beta1, b2=self.config.beta2) if self.config.optimizer == "adam" else
optax.adafactor(self.config.lr) if self.config.optimizer == "adafactor" else
scale_by_adam_8bit(b1=self.config.beta1, b2=self.config.beta2) if self.config.optimizer == "adam8" else 1/0,

optax.scale_by_schedule(
optax.join_schedules(
[optax.linear_schedule(0, 1, self.config.scheduler_warmup)]
Expand All @@ -237,11 +244,12 @@ def train_step(
batch = jnp.nan_to_num(batch)
targets = jnp.nan_to_num(targets)
sae_params = eqx.filter_shard(sae_params, self.sharding_sae)
# SAE state is pretty small and there's no quadratic scaling, so we don't need to shard it as hard
# (I don't think equinox state can be sharded... StateIndex is not ordered, so tree_flatten won't work)
# sae_state = eqx.filter_shard(sae_state, self.sharding_sae_state)
opt_state = eqx.tree_at(lambda o: o[1][0].mu, opt_state, replace_fn=lambda x: eqx.filter_shard(x, self.sharding_sae))
opt_state = eqx.tree_at(lambda o: o[1][0].nu, opt_state, replace_fn=lambda x: eqx.filter_shard(x, self.sharding_sae))
if self.config.optimizer == "adam":
# SAE state is pretty small and there's no quadratic scaling, so we don't need to shard it as hard
# (I don't think equinox state can be sharded... StateIndex is not ordered, so tree_flatten won't work)
# sae_state = eqx.filter_shard(sae_state, self.sharding_sae_state)
opt_state = eqx.tree_at(lambda o: o[1][0].mu, opt_state, replace_fn=lambda x: eqx.filter_shard(x, self.sharding_sae))
opt_state = eqx.tree_at(lambda o: o[1][0].nu, opt_state, replace_fn=lambda x: eqx.filter_shard(x, self.sharding_sae))

batch = jax.lax.with_sharding_constraint(batch, jshard.NamedSharding(self.mesh, P("dp", None)))
targets = jax.lax.with_sharding_constraint(targets, jshard.NamedSharding(self.mesh, P("dp", None)))
Expand All @@ -258,8 +266,9 @@ def train_step(
sae_params, _ = eqx.partition(sae, is_trainable)
sae_params = eqx.filter_shard(sae_params, self.sharding_sae)
# sae_state = eqx.filter_shard(sae_state, self.sharding_sae_state)
opt_state = eqx.tree_at(lambda o: o[1][0].mu, opt_state, replace_fn=lambda x: eqx.filter_shard(x, self.sharding_sae))
opt_state = eqx.tree_at(lambda o: o[1][0].nu, opt_state, replace_fn=lambda x: eqx.filter_shard(x, self.sharding_sae))
if self.config.optimizer == "adam":
opt_state = eqx.tree_at(lambda o: o[1][0].mu, opt_state, replace_fn=lambda x: eqx.filter_shard(x, self.sharding_sae))
opt_state = eqx.tree_at(lambda o: o[1][0].nu, opt_state, replace_fn=lambda x: eqx.filter_shard(x, self.sharding_sae))
if self.config.ema:
def ema_update(ema_params, sae_params):
return jax.tree.map(lambda ema, sae: ema * self.config.ema + sae * (1 - self.config.ema),
Expand Down
33 changes: 12 additions & 21 deletions scripts/cache_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,26 @@
except NameError:
pass

from saex.iterable_dataset import IterableDatasetConfig
from saex.models.micrlhf_model import MicrlhfModelConfig
from saex.haver import ModelHaver, SAEHaver
from saex.sae import SAEConfig
from more_itertools import chunked
import dataclasses

from micrlhf.utils.load_sae import get_sae
import os

from collections import defaultdict, Counter
import random
from collections import Counter, defaultdict
from functools import partial
from tqdm.auto import trange
import jax.numpy as jnp
import numpy as np

import equinox as eqx
import random
import jax

import pyarrow.parquet as pq
import pyarrow as pa

import jax.numpy as jnp
import numpy as np
import pyarrow as pa
from tqdm.auto import trange
import pyarrow.parquet as pq
from micrlhf.utils.load_sae import get_sae
from more_itertools import chunked
from tqdm.auto import tqdm, trange

from tqdm.auto import tqdm
import pyarrow as pa
import numpy as np
from saex.haver import ModelHaver, SAEHaver
from saex.iterable_dataset import IterableDatasetConfig
from saex.models.micrlhf_model import MicrlhfModelConfig
from saex.sae import SAEConfig

stride = 0.25
n_strides = 128
Expand Down
8 changes: 4 additions & 4 deletions scripts/feature_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os

import gradio as gr
import pyarrow.parquet as pq
import pyarrow.compute as pc
from transformers import AutoTokenizer
import pyarrow.parquet as pq
from matplotlib import pyplot as plt
import os

from transformers import AutoTokenizer

token_table = pq.read_table("weights/tokens.parquet")
cache_path = "weights/caches"
Expand Down
15 changes: 10 additions & 5 deletions scripts/train_gemma_sae.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import jax
from typing import Optional

import fire
import numpy as np
import jax
import jax.numpy as jnp
from typing import Optional
import numpy as np

from saex.models.micrlhf_model import MicrlhfModelConfig
from saex.train_script import train_main
Expand Down Expand Up @@ -90,7 +91,11 @@ def train(
restrict_dec_norm=None,
project_grads_from_dec=False,
project_updates_from_dec=False,
weights_8bit=False,
# weights_8bit=True,
),
optimizer="adam",
# optimizer="adam8",
sae_restore=restore,
cache_every_steps=int(cache_size / batch_size * cache_ratio),
cache_batch_size=cache_batch_size,
Expand All @@ -103,7 +108,7 @@ def train(
# gguf_path="weights/gemma-2b.gguf",
gguf_path="../micrlhf-progress/models/gemma-2b-it.gguf",
device_map=f"auto:mp={mp_devices}" if n_devices > 1 else "tpu:0",
use_flash=max_seq_len >= 128 and max_seq_len % 128 == 0,
use_flash=False,
layer=layer,
max_seq_len=max_seq_len,
from_type="gemma",
Expand Down Expand Up @@ -139,7 +144,7 @@ def main(layer: int = 12, restore: Optional[str] = None, min_sfc=2e-5, max_sfc=5
death_penalty_threshold=5e-6, # <= 70 (L0) / 90k (features)
train_steps=150_000,
# push_to_hub=("nev/gemma-2b-saex-test", f"l{layer}-{sae_type}-test-run-6"),
push_to_hub=("nev/gemma-2b-saex-test", f"it-l{layer}-{sae_type}-test-run-0"),
push_to_hub=("nev/gemma-2b-saex-test", f"it-l{layer}-{sae_type}-test-run-1"),
restore=restore,
sae_type=sae_type,
)
Expand Down
4 changes: 2 additions & 2 deletions scripts/train_phi_sae.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import re
from typing import Optional

import fire
import numpy as np
import jax.numpy as jnp
from typing import Optional
import numpy as np

from saex.models.micrlhf_model import MicrlhfModelConfig
from saex.train_script import train_main
Expand Down
9 changes: 6 additions & 3 deletions train_gemmas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
# layers = [3, 4, 7, 8, 9][3:]
layers = [12, 14, 16, 11, 13, 15, 10, 6, 8, 9, 7, 5]
# layers = [3, 4, 7, 8, 9]
layers = [12, 14, 16, 11, 13, 15, 10, 6, 8, 9, 7, 5, 4, 3, 2, 1]
for layer_idx in range(len(layers)):
layer = layers[layer_idx]
restore = None # if layer_idx == 0 else f"weights/phi-l{layers[layer_idx-1]}-gated.safetensors"
# fn = lambda x: x * ((layer / 12) ** 2)
fn = lambda x: x
for s, sae_type in ((2e-5, "residual"), (2e-5, "attn_out"))[:1]:
# for s, sae_type in ((2e-5, "residual"), (2e-5, "attn_out"))[:1]:
# cf = 8
cf = 1
for s, sae_type in ((2e-5 * cf, "residual"), (2e-5 * cf, "attn_out")):
min_sfc, max_sfc = fn(s), fn(s)
# min_sfc, max_sfc = fn(1e-5), fn(1e-5)
min_sfc, max_sfc = min_sfc, min_sfc
Expand Down

0 comments on commit 953cf8d

Please sign in to comment.