Skip to content

Commit

Permalink
optimized positional embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Aug 1, 2023
1 parent 3850005 commit c504227
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 29 deletions.
55 changes: 27 additions & 28 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,31 @@
from typing import Optional, Tuple


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
inv_freq = 1.0 / (theta**(torch.arange(0, dim, 2).float()/dim))
t = torch.arange(end)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return torch.stack([cos, sin], dim=0)


def rotate_half(x):
x1 = x[..., :x.size(-1)//2]
x2 = x[..., x.size(-1)//2:]
return torch.cat([-x2, x1], dim=-1)


def apply_rotary_emb(
xq: torch.Tensor, # [seq_len, batch, heads, dim]
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:

freqs_cis = freqs_cis.to(xq.device)
cos, sin = freqs_cis # [seq_len, dim] both
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
xq_out = (xq*cos) + (rotate_half(xq)*sin)
xk_out = (xk*cos) + (rotate_half(xk)*sin)
def precompute_freqs_cis(dim: int, end: int,
theta: float = 10000.0) -> torch.Tensor:

freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
return torch.polar(torch.ones_like(freqs), freqs) # complex64


def reshape_for_broadcast(freqs_cis: torch.Tensor,
x: torch.Tensor) -> torch.Tensor:

ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[0], x.shape[-1])
shape = [d if i == 0 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor,
freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
6 changes: 6 additions & 0 deletions weights2megatron/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ See also `examples/weights2megatron.sh`.
Llama weights are not so easily available, but the MLO lab has access to them so we are ok.
In this case you also need to specify the directory specified as `--cache-dir` will be used to fetch the llama weights, for instance run:

**IMPORTANT**: If you are using megatron converted weights produced in the commit [264a745](https://github.com/epfLLM/old-Megatron-LM/commit/264a745b045912c2972a44aa8883a03b9ffe7c98) or earlier, you will need to update your weights.
Use:
```
python weights2megatron/permute_qkv.py --input-dir=/path/to/old/checkpoint/ --output-dir=/path/to/new/checkpoint/
```

## Correctness verification

**Warning**: The current code does not support model-parallelism, this is still work in progress.
Expand Down
79 changes: 79 additions & 0 deletions weights2megatron/permute_qkv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import re
import sys
import os
import shutil
from pathlib import Path
from argparse import ArgumentParser

import torch
from tqdm.auto import tqdm


def permute_qkv(qkv_w: torch.Tensor, dim: int, n_heads: int,
n_heads_kv: int) -> torch.Tensor:

def permute(x):
return x.view(2, head_dim//2, dim).transpose(0, 1).reshape(head_dim, dim)

head_dim = dim//n_heads
n_qs_per_kv = n_heads//n_heads_kv
n_groups = qkv_w.size(0)//head_dim//(n_qs_per_kv + 2)
groups = torch.chunk(qkv_w, n_groups, dim=0)
new = []
for group in groups:
*qs, k, v = torch.split(group, head_dim, dim=0)
assert len(qs) == n_qs_per_kv, f"{len(qs)}, {n_qs_per_kv}"
new += list(map(permute, qs)) + [permute(k), v]
return torch.cat(new, dim=0)


def update_checkpoint(input_dir: Path, output_dir: Path, overwrite_ok: bool = False):
# make sure megatron is importable
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__),
os.path.pardir)))


# prepare output dir
if output_dir.exists():
if not overwrite_ok:
raise FileExistsError(f"Output directory {output_dir} already exists")
print(f"Removing {output_dir}")
shutil.rmtree(output_dir)
output_dir.mkdir(exist_ok=True)

# determine realease
with open(input_dir/"latest_checkpointed_iteration.txt") as f:
it = f.read()
print("Updating weights of iteration", it)
with open(output_dir/"latest_checkpointed_iteration.txt", "w+") as f:
f.write(it)
(output_dir/it).mkdir()

# convert weights
for fname in tqdm(list((input_dir/it).iterdir())):
checkpoint = torch.load(fname/"model_optim_rng.pt")
args = checkpoint["args"]
args = (args.hidden_size, args.num_attention_heads,
args.num_attention_heads_kv)
if "transformer" in checkpoint["model"]["language_model"]:
key = "transformer"
attn_key = "attention"
else:
key = "encoder"
attn_key = "self_attention"
states = checkpoint["model"]["language_model"][key]
for name, weight in states.items():
if re.match(rf"^layers\.[0-9]+\.{attn_key}\.query_key_value\.weight$", name):
states[name] = permute_qkv(weight, *args)
(output_dir/it/fname.stem).mkdir()
torch.save(checkpoint, output_dir/it/fname.stem/"model_optim_rng.pt")


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-dir", type=Path)
parser.add_argument("--output-dir", type=Path)
parser.add_argument("--overwrite-ok", action="store_true")
args = parser.parse_args()
update_checkpoint(args.input_dir, args.output_dir, args.overwrite_ok)
4 changes: 3 additions & 1 deletion weights2megatron/weights2megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tqdm.auto import trange
from transformers import AutoModelForCausalLM

from permute_qkv import permute_qkv
from merge_llama import merge_llama


Expand Down Expand Up @@ -44,7 +45,7 @@ def falcon_to_megatron(weights: dict, size: int) -> dict:
weights[f"{prefix2}.mlp.dense_4h_to_h.weight"]
# qkv weights
transformer[f"{prefix1}.attention.query_key_value.weight"] = \
weights[f"{prefix2}.self_attention.query_key_value.weight"]
permute_qkv(weights[f"{prefix2}.self_attention.query_key_value.weight"])
# dense
transformer[f"{prefix1}.self_attention.dense.weight"] = \
weights[f"{prefix2}.self_attention.dense.weight"]
Expand Down Expand Up @@ -93,6 +94,7 @@ def get_wqkv(llama_config, layer_prefix, n_heads=32):
w_qkv += [wq_convert[i*n_qs_per_kv + j] for j in range(n_qs_per_kv)]
w_qkv += [wk_convert[i], wv_convert[i]]
out = torch.concat(w_qkv, dim=0)
out = permute_qkv(out, dim, n_heads, n_kv_heads)
return out

# dictionary
Expand Down

0 comments on commit c504227

Please sign in to comment.