Skip to content

Commit

Permalink
Add cached generation
Browse files Browse the repository at this point in the history
  • Loading branch information
neverix committed Aug 17, 2024
1 parent 214144d commit 85f68ed
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
11 changes: 5 additions & 6 deletions src/jif/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def update_fn(*args, **kwargs):


def main(
batch_size=512,
batch_size=256,
seq_len = 128,
diffusion_eps = 1e-3,
ema_decay=0.99,
Expand All @@ -41,10 +41,10 @@ def main(
n_classes = 258,
bos_token=256,
pad_token=257,
schedule_free=True,
schedule_free=False,
b1=0.9,
b2=0.98,
warmup_steps=200,
warmup_steps=100,
n_mp=1,
seed=0,
grad_clip_norm=10.0,
Expand Down Expand Up @@ -137,17 +137,16 @@ def get_loss(model, rng, state, sample):
def get_samples(trainer, batch_size, seq_len, key, num_steps=None):
if num_steps is None:
num_steps = seq_len * 16
trainer_state = trainer.state.value
model = trainer.model
if ema_decay is not None:
trainer_state = trainer.state.value
ema_params = trainer_state.loss_fn_state["ema"]
ema_model = jax.tree.map(lambda x: x.unfreeze_as_copy() if isinstance(x, pz.ParameterValue) else x, model, is_leaf=lambda x: isinstance(x, pz.ParameterValue))
ema_treedef, param_types = pz.unbind_params(ema_model)
ema_params = [pz.ParameterValue(value=pz.nx.wrap(ep, *pt.value.named_shape.keys()), label=pt.label,) for ep, pt in zip(ema_params, param_types)]
ema_model = pz.bind_variables(ema_treedef, ema_params)
else:
trainer_state = trainer.state.value
optim_state = trainer_state.opt_state[-1] # remove gradient processors
model = trainer.model
ema_treedef, params = pz.unbind_params(model, freeze=True)
ema_model_params = optax.contrib.schedule_free_eval_params(optim_state, params)
ema_model = pz.bind_variables(ema_treedef, ema_model_params)
Expand Down
2 changes: 2 additions & 0 deletions src/jif/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def detokenize(x):
def collate(generator, batch_size, seq_len, pad_token_id=1, epochs=None):
for _ in (range(epochs) if epochs is not None else iter(int, 1)):
for batch in chunked(generator, batch_size):
if len(batch) < batch_size:
break
batch = [text[:seq_len] for text in batch]
lengths = [len(text) for text in batch]
mask = [[1] * len(text) + [0] * (seq_len - len(text)) for text in batch]
Expand Down
24 changes: 14 additions & 10 deletions src/jif/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,23 +189,27 @@ def sample(self, score_fn, key, n_steps, batch_shape, denoise=True, projector=la
timesteps = jnp.linspace(1, 0, n_steps + (1 if denoise else 0))
alphas = self.alpha(timesteps)
full_projector = lambda x: self.replace_bos(projector(x))
x = full_projector(x)

def update(i, carry):
key, x = carry
key, x, last_probs, was_updated = carry
key, subkey = jax.random.split(key)
a_prev = jnp.full(x.shape, alphas[i])
a_post = jnp.full(x.shape, alphas[i + 1])
x = full_projector(x)
logits = self.process_logits(score_fn(x, a_prev))
probs = jax.nn.softmax(logits, axis=-1)
probs = jnp.concatenate((probs * (a_post - a_prev)[..., None], (1 - a_post)[..., None]), axis=-1) / (1 - a_prev)[..., None]
x = jax.random.categorical(subkey, jnp.log(1e-10 + probs))
return key, x
key, x = jax.lax.fori_loop(0, n_steps, update, (key, x))
def compute_probs():
logits = self.process_logits(score_fn(x, a_prev))
probs = jax.nn.softmax(logits, axis=-1)
return probs
# TODO any way to bucket at large batch sizes?
probs = jax.lax.switch(was_updated.any().astype(jnp.uint8), (lambda: last_probs, compute_probs))
probs_full = jnp.concatenate((probs * (a_post - a_prev)[..., None], (1 - a_post)[..., None]), axis=-1) / (1 - a_prev)[..., None]
new_x = jnp.where(x == self.n_classes, jax.random.categorical(subkey, jnp.log(1e-10 + probs_full)), x)
new_x = full_projector(new_x)
return key, new_x, probs, new_x != x
key, x, _, _ = jax.lax.fori_loop(0, n_steps, update, (key, x, jnp.zeros(x.shape + (self.n_classes,), dtype=jnp.float32), jnp.ones(batch_shape, dtype=jnp.bool_)))

if denoise:
# denoising step
x = full_projector(x)
t = jnp.full(x.shape, alphas[-1])
x = score_fn(x, t).argmax(-1)
x = jnp.where(x == self.n_classes, score_fn(x, t).argmax(-1), x)
return full_projector(x)

0 comments on commit 85f68ed

Please sign in to comment.