Skip to content

Commit

Permalink
cpu push
Browse files Browse the repository at this point in the history
  • Loading branch information
b0kch01 committed Mar 2, 2023
1 parent 3ab9757 commit 40d0eb2
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 24 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Byte-compiled / optimized / DLL files
.DS_Store
download.sh
weights/
.vscode
llama.egg-info
__pycache__/
*.py[cod]
*$py.class
Expand Down
20 changes: 13 additions & 7 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def setup_model_parallel() -> Tuple[int, int]:
local_rank = int(os.environ.get("LOCAL_RANK", -1))
world_size = int(os.environ.get("WORLD_SIZE", -1))

torch.distributed.init_process_group("nccl")
torch.distributed.init_process_group("gloo")
initialize_model_parallel(world_size)
torch.cuda.set_device(local_rank)
# torch.cuda.set_device(local_rank)

# seed must be the same in all processes
torch.manual_seed(1)
Expand All @@ -41,12 +41,13 @@ def load(ckpt_dir: str, tokenizer_path: str, local_rank: int, world_size: int) -
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())

model_args: ModelArgs = ModelArgs(max_seq_len=1024, max_batch_size=32, **params)
model_args: ModelArgs = ModelArgs(
max_seq_len=1024, max_batch_size=32, **params)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
torch.set_default_tensor_type(torch.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
torch.set_default_tensor_type(torch.HalfTensor)
model.load_state_dict(checkpoint, strict=False)

generator = LLaMA(model, tokenizer)
Expand All @@ -60,8 +61,13 @@ def main(ckpt_dir: str, tokenizer_path: str, temperature: float = 0.8, top_p: fl
sys.stdout = open(os.devnull, 'w')

generator = load(ckpt_dir, tokenizer_path, local_rank, world_size)
prompts = ["The capital of Germany is the city of", "Here is my sonnet in the style of Shakespeare about an artificial intelligence:"]
results = generator.generate(prompts, max_gen_len=256, temperature=temperature, top_p=top_p)
prompts = [input("Enter prompt: ")]

start_time = time.time()
results = generator.generate(
prompts, max_gen_len=256, temperature=temperature, top_p=top_p)
print(f"responded in {time.time() - start_time:.2f} seconds")
return generator

for result in results:
print(result)
Expand Down
6 changes: 4 additions & 2 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ def generate(
params = self.model.params
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
prompt_tokens = [self.tokenizer.encode(
x, bos=True, eos=False) for x in prompts]

min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])

total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)

tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long()
tokens = torch.full(
(bsz, total_len), self.tokenizer.pad_id).long()
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != self.tokenizer.pad_id
Expand Down
41 changes: 26 additions & 15 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def forward(self, x):


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
Expand All @@ -56,7 +57,8 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
shape = [d if i == 1 or i == ndim -
1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


Expand Down Expand Up @@ -110,11 +112,13 @@ def __init__(self, args: ModelArgs):
)

self.cache_k = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()
(args.max_batch_size, args.max_seq_len,
self.n_local_heads, self.head_dim)
)
self.cache_v = torch.zeros(
(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
).cuda()
(args.max_batch_size, args.max_seq_len,
self.n_local_heads, self.head_dim)
)

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.shape
Expand All @@ -129,20 +133,23 @@ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)

self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv

keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]

xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / \
math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
# (bs, n_local_heads, slen, cache_len + slen)
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
# (bs, n_local_heads, slen, head_dim)
output = torch.matmul(scores, values)
output = output.transpose(
1, 2
).contiguous().view(bsz, seqlen, -1)
Expand All @@ -159,7 +166,8 @@ def __init__(
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
hidden_dim = multiple_of * \
((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
Expand Down Expand Up @@ -190,7 +198,9 @@ def __init__(self, layer_id: int, args: ModelArgs):
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
h = x + \
self.attention.forward(self.attention_norm(
x), start_pos, freqs_cis, mask)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out

Expand Down Expand Up @@ -224,11 +234,12 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]

mask = None
if seqlen > 1:
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.full((1, 1, seqlen, seqlen),
float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

for layer in self.layers:
Expand Down

0 comments on commit 40d0eb2

Please sign in to comment.