Skip to content

Commit

Permalink
fix input_embdings.dtype=None.dtype (ModelTC#390)
Browse files Browse the repository at this point in the history
  • Loading branch information
WANDY666 authored Apr 11, 2024
1 parent 15a050a commit aa2b655
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion lightllm/models/bloom/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
else:
last_input[:, :] = input_embdings[-batch_size:, :]

input_embdings_dtype = input_embdings.dtype
input_embdings = None
last_input = self._norm(last_input, infer_state, layer_weight)
last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, batch_size)
Expand All @@ -44,7 +45,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
if self.world_size_ == 1:
gather_data = logic_batch
else:
gather_data = torch.empty((self.vocab_size_, batch_size), device=logic_batch.device, dtype=input_embdings.dtype)
gather_data = torch.empty((self.vocab_size_, batch_size), device=logic_batch.device, dtype=input_embdings_dtype)
split_size = self.vocab_size_ // self.world_size_
dist.all_gather([gather_data[i * split_size: (i + 1) * split_size, :]
for i in range(self.world_size_)], logic_batch, group=None, async_op=False)
Expand Down
3 changes: 2 additions & 1 deletion lightllm/models/llama/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def token_forward(
return_logics=False,
):
last_input, token_num = self._slice_get_last_input(input_embdings, infer_state)
input_embdings_dtype = input_embdings.dtype
input_embdings = None
last_input = self._norm(last_input, infer_state, layer_weight)
last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num)
Expand All @@ -81,7 +82,7 @@ def token_forward(
if self.world_size_ == 1:
gather_data = logic_batch
else:
gather_data = torch.empty((self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings.dtype)
gather_data = torch.empty((self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype)
split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64)
dist.all_gather(
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)],
Expand Down

0 comments on commit aa2b655

Please sign in to comment.