Skip to content

Commit

Permalink
Hotfix(MInference): fix the import warnings, fix the apply_rotary_pos…
Browse files Browse the repository at this point in the history
…_emb_single, fix phi-3 vs kernel (microsoft#30)

Feature(MInference): remove pycuda, support multi-gpu

Co-authored-by: Yucheng Li <[email protected]>
Co-authored-by: Chengruidong Zhang <[email protected]>
  • Loading branch information
3 people authored Jul 12, 2024
1 parent 19c3c7e commit d2d8747
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ for a local gradio demo <a href='https://github.com/gradio-app/gradio'><img src=
git clone https://huggingface.co/spaces/microsoft/MInference
cd MInference
pip install -r requirments.txt
pip install flash_attn pycuda==2023.1
pip install flash_attn
python app.py
```

Expand Down
1 change: 0 additions & 1 deletion experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ python run_infinitebench.py \
Environment parameters:
- CUDA 12.3
- Triton 2.1.0
- PyCuda 2023.1

### End-to-End Benchmark

Expand Down
44 changes: 39 additions & 5 deletions minference/modules/minference_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,25 @@ def set_rope_type(self):
if ROPE_TYPE is not None:
return
if "seq_len" in inspect.signature(self.rotary_emb.forward).parameters:
ROPE_TYPE = "seq_len"
if "position_ids" in inspect.signature(self.rotary_emb.forward).parameters:
ROPE_TYPE = "seq_len,position_ids"
else:
ROPE_TYPE = "seq_len"
elif "max_seq_len" in inspect.signature(self.rotary_emb.forward).parameters:
ROPE_TYPE = "max_seq_len"
else:
ROPE_TYPE = "position_ids"

def get_cos_sin(self, value_states, kv_seq_len, position_ids):
if self.rotary_emb.inv_freq is not None and value_states.device != self.rotary_emb.inv_freq.device:
value_states = value_states.to(self.rotary_emb.inv_freq.device)
position_ids = position_ids.to(self.rotary_emb.inv_freq.device)
if value_states.device != position_ids.device:
position_ids = position_ids.to(value_states.device)
if ROPE_TYPE == "seq_len":
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
elif ROPE_TYPE == "seq_len,position_ids":
cos, sin = self.rotary_emb(value_states, position_ids=position_ids, seq_len=kv_seq_len)
elif ROPE_TYPE == "max_seq_len":
cos = self.rotary_emb(kv_seq_len)
if position_ids is not None:
Expand Down Expand Up @@ -442,10 +452,12 @@ def block_sparse_kernel(q, k, v, vertical_size=None, slash_size=None):
return fc(q, k, v, vertical_size, slash_size)

def apply_rotary_pos_emb_single(q, cos, sin, position_ids, unsqueeze_dim=1):
# cos = cos[position_ids].unsqueeze(unsqueeze_dim)
# sin = sin[position_ids].unsqueeze(unsqueeze_dim)
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
if len(cos.size()) == 2:
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
else:
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
return (q * cos) + (rotate_half(q) * sin)

def minference_forward():
Expand Down Expand Up @@ -490,9 +502,13 @@ def forward(
set_rope_type(self)
cos, sin = get_cos_sin(self, value_states, kv_seq_len, position_ids)
if ROPE_TYPE == "max_seq_len":
if cos.device != query_states.device:
cos = cos.to(query_states.device)
query_states = apply_rotary_pos_emb(query_states, cos)
key_states = apply_rotary_pos_emb(key_states, cos)
else:
if position_ids is not None and position_ids.device != cos.device:
position_ids = position_ids.to(cos.device)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
Expand Down Expand Up @@ -571,18 +587,28 @@ def forward(
part_k, part_v = None, None
for head in range(self.num_heads):
if "q_proj" in self.__dict__["_modules"]:
if hidden_states.device != self.q_proj.weight.device:
hidden_states = hidden_states.to(self.q_proj.weight.device)
attn_out = attn_out.to(self.q_proj.weight.device)
part_q = F.linear(hidden_states, self.q_proj.weight.view(self.num_heads, self.head_dim, hidden_dim)[head]).unsqueeze(2)
if self.q_proj.bias is not None:
part_q += self.q_proj.bias.view(self.num_heads, self.head_dim)[head]
else:
if hidden_states.device != self.qkv_proj.weight.device:
hidden_states = hidden_states.to(self.qkv_proj.weight.device)
attn_out = attn_out.to(self.qkv_proj.weight.device)
query_pos = self.num_heads * self.head_dim
part_q = F.linear(hidden_states, self.qkv_proj.weight[:query_pos].view(self.num_heads, self.head_dim, hidden_dim)[head]).unsqueeze(2)
if self.qkv_proj.bias is not None:
part_q += self.qkv_proj.bias[:query_pos].view(self.num_heads, self.head_dim)[head]

if ROPE_TYPE == "max_seq_len":
if cos.device != part_q.device:
cos = cos.to(part_q.device)
part_q = apply_rotary_pos_emb(part_q.transpose(1, 2), cos)
else:
if position_ids is not None and position_ids.device != cos.device:
position_ids = position_ids.to(cos.device)
part_q = apply_rotary_pos_emb_single(part_q.transpose(1, 2), cos, sin, position_ids)

if head % self.num_key_value_groups == 0:
Expand All @@ -602,8 +628,12 @@ def forward(
part_v += self.qkv_proj.bias[query_pos:].view(2, act_num_heads, self.head_dim)[1][head // self.num_key_value_groups]

if ROPE_TYPE == "max_seq_len":
if cos.device != part_k.device:
cos = cos.to(part_k.device)
part_k = apply_rotary_pos_emb(part_k.transpose(1, 2), cos)
else:
if position_ids is not None and position_ids.device != cos.device:
position_ids = position_ids.to(cos.device)
part_k = apply_rotary_pos_emb_single(part_k.transpose(1, 2), cos, sin, position_ids)
if use_cache and past_key_value is not None:
k[:,head // self.num_key_value_groups] = part_k.to(kv_cache_cpu_device)
Expand Down Expand Up @@ -669,9 +699,13 @@ def forward(
set_rope_type(self)
cos, sin = get_cos_sin(self, value_states, kv_seq_len, position_ids)
if ROPE_TYPE == "max_seq_len":
if cos.device != query_states.device:
cos = cos.to(query_states.device)
query_states = apply_rotary_pos_emb(query_states, cos)
key_states = apply_rotary_pos_emb(key_states, cos)
else:
if position_ids is not None and position_ids.device != cos.device:
position_ids = position_ids.to(cos.device)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down
5 changes: 3 additions & 2 deletions minference/ops/block_sparse_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Licensed under The MIT License [see LICENSE for details]

import numpy as np
import pycuda.autoprimaryctx
import torch
import triton
import triton.language as tl
from flash_attn import flash_attn_varlen_func
from pycuda.compiler import SourceModule

# import pycuda.autoprimaryctx
# from pycuda.compiler import SourceModule


# @triton.autotune(
Expand Down
2 changes: 2 additions & 0 deletions minference/ops/pit_sparse_flash_attention_v2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2024 Microsoft
# Licensed under The MIT License [see LICENSE for details]

import math

import torch
import triton
import triton.language as tl
Expand Down
2 changes: 2 additions & 0 deletions minference/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ def forward_llama_decoder_layer(
use_cache=use_cache,
padding_mask=padding_mask,
)
if residual.device != hidden_states.device:
residual = residual.to(hidden_states.device)
hidden_states = residual + hidden_states

# Fully Connected
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import subprocess
import sys
import urllib
import warnings

import torch
from packaging.version import Version, parse
Expand Down Expand Up @@ -38,7 +39,6 @@
"torch",
"triton",
"flash_attn",
"pycuda==2023.1",
]
QUANLITY_REQUIRES = [
"black==21.4b0",
Expand Down

0 comments on commit d2d8747

Please sign in to comment.