forked from THUDM/CodeGeeX
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
601b3aa
commit a2b6a15
Showing
11 changed files
with
1,902 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import pkg_resources | ||
import torch | ||
import ctypes | ||
|
||
from typing import List | ||
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up | ||
|
||
RESOURCE_PACKAGE_NAME = __name__ | ||
|
||
|
||
class Kernel: | ||
def __init__(self, filename: str, function_names: List[str]): | ||
filename = filename + ".fatbin" | ||
if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename): | ||
raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME)) | ||
self.filename = filename | ||
self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename) | ||
self._function_names = function_names | ||
self._cmodule = LazyKernelCModule(self.code) | ||
|
||
for name in self._function_names: | ||
setattr(self, name, KernelFunction(self._cmodule, name)) | ||
|
||
|
||
kernels = Kernel( | ||
"quantization", | ||
[ | ||
"int4WeightCompression", | ||
"int4WeightExtractionFloat", | ||
"int4WeightExtractionHalf", | ||
"int8WeightExtractionFloat", | ||
"int8WeightExtractionHalf", | ||
], | ||
) | ||
|
||
|
||
def compress_int4_weight(weight: torch.Tensor): # (n, m) | ||
with torch.cuda.device(weight.device): | ||
n, m = weight.size(0), weight.size(1) | ||
assert m % 2 == 0 | ||
m = m // 2 | ||
out = torch.empty(n, m, dtype=torch.int8, device="cuda") | ||
stream = torch.cuda.current_stream() | ||
|
||
gridDim = (n, 1, 1) | ||
blockDim = (min(round_up(m, 32), 1024), 1, 1) | ||
|
||
kernels.int4WeightCompression( | ||
gridDim, | ||
blockDim, | ||
0, | ||
stream, | ||
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], | ||
) | ||
return out | ||
|
||
|
||
def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): | ||
if source_bit_width == 8: | ||
func = kernels.int8WeightExtractionHalf | ||
elif source_bit_width == 4: | ||
func = kernels.int4WeightExtractionHalf | ||
else: | ||
assert False, "Unsupported bit-width" | ||
|
||
with torch.cuda.device(weight.device): | ||
n, m = weight.size(0), weight.size(1) | ||
out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda") | ||
stream = torch.cuda.current_stream() | ||
|
||
gridDim = (n, 1, 1) | ||
blockDim = (min(round_up(m, 32), 1024), 1, 1) | ||
|
||
func( | ||
gridDim, | ||
blockDim, | ||
0, | ||
stream, | ||
[ | ||
ctypes.c_void_p(weight.data_ptr()), | ||
ctypes.c_void_p(scale_list.data_ptr()), | ||
ctypes.c_void_p(out.data_ptr()), | ||
ctypes.c_int32(n), | ||
ctypes.c_int32(m), | ||
], | ||
) | ||
return out | ||
|
||
|
||
if __name__ == "__main__": | ||
weight = torch.randn(4, 32).to(torch.int8).cuda() | ||
scale = torch.ones(weight.size(0)).to(torch.half).cuda() | ||
|
||
print(weight) | ||
b = compress_int4_weight(weight) | ||
print(b) | ||
|
||
a = extract_weight_to_half(b, scale, source_bit_width=4) | ||
print(a) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .quantize import quantize |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import torch | ||
|
||
from torch.nn.parameter import Parameter | ||
from codegeex.kernels import extract_weight_to_half | ||
|
||
|
||
class W8A16Linear(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): | ||
ctx.inp_shape = inp.size() | ||
ctx.weight_shape = quant_w.size() | ||
ctx.weight_bit_width = weight_bit_width | ||
out_features = quant_w.size(0) | ||
inp = inp.contiguous().view(-1, inp.size(-1)) | ||
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) | ||
output = inp.mm(weight.t()) | ||
ctx.save_for_backward(inp, quant_w, scale_w) | ||
return output.view(*(ctx.inp_shape[:-1] + (out_features,))) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output: torch.Tensor): | ||
inp, quant_w, scale_w = ctx.saved_tensors | ||
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) | ||
grad_output = grad_output.contiguous().view(-1, weight.size(0)) | ||
grad_input = grad_output.mm(weight) | ||
grad_weight = grad_output.t().mm(inp) | ||
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None | ||
|
||
|
||
class QuantizedLinear(torch.nn.Module): | ||
def __init__( | ||
self, | ||
in_features: int, | ||
out_features: int, | ||
weight_bit_width: int, | ||
weight: torch.Tensor = None, | ||
bias: torch.Tensor = None, | ||
*args, | ||
**kwargs | ||
): | ||
super(QuantizedLinear, self).__init__() | ||
|
||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.weight_bit_width = weight_bit_width | ||
|
||
if weight is None: | ||
self.weight = torch.empty( | ||
shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] | ||
) | ||
self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"]) | ||
else: | ||
self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() | ||
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) | ||
if weight_bit_width == 4: | ||
self.weight = compress_int4_weight(self.weight) | ||
|
||
if bias is None: | ||
self.register_parameter('bias', None) | ||
else: | ||
self.bias = bias | ||
|
||
self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) | ||
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) | ||
|
||
def forward(self, input_): | ||
# Matrix multiply. | ||
output = W8A16Linear.apply(input_, self.weight, self.weight_scale, self.weight_bit_width) | ||
if self.bias is not None: | ||
output = output + self.bias | ||
|
||
return output | ||
|
||
|
||
def quantize(model, weight_bit_width): | ||
"""Replace fp16 linear with quantized linear""" | ||
|
||
for i in range(len(model.language_model.transformer.layers) + 1): | ||
if i == len(model.language_model.transformer.layers): | ||
layer = model.language_model.transformer.topQueryLayer | ||
else: | ||
layer = model.language_model.transformer.layers[i] | ||
|
||
layer.attention.query = QuantizedLinear( | ||
in_features=layer.attention.query.weight.shape[0], | ||
out_features=layer.attention.query.weight.shape[1], | ||
weight_bit_width=weight_bit_width, | ||
weight=layer.attention.query.weight.to(torch.cuda.current_device()), | ||
bias=layer.attention.query.bias.to(torch.cuda.current_device()), | ||
params_dtype=torch.half, | ||
device=layer.attention.query.weight.device, | ||
) | ||
layer.attention.value = QuantizedLinear( | ||
in_features=layer.attention.value.weight.shape[0], | ||
out_features=layer.attention.value.weight.shape[1], | ||
weight_bit_width=weight_bit_width, | ||
weight=layer.attention.value.weight.to(torch.cuda.current_device()), | ||
bias=layer.attention.value.bias.to(torch.cuda.current_device()), | ||
params_dtype=torch.half, | ||
device=layer.attention.value.weight.device, | ||
) | ||
layer.attention.key = QuantizedLinear( | ||
in_features=layer.attention.key.weight.shape[0], | ||
out_features=layer.attention.key.weight.shape[1], | ||
weight_bit_width=weight_bit_width, | ||
weight=layer.attention.key.weight.to(torch.cuda.current_device()), | ||
bias=layer.attention.key.bias.to(torch.cuda.current_device()), | ||
params_dtype=torch.half, | ||
device=layer.attention.key.weight.device, | ||
) | ||
layer.attention.dense = QuantizedLinear( | ||
in_features=layer.attention.dense.weight.shape[0], | ||
out_features=layer.attention.dense.weight.shape[1], | ||
weight_bit_width=weight_bit_width, | ||
weight=layer.attention.dense.weight.to(torch.cuda.current_device()), | ||
bias=layer.attention.dense.bias.to(torch.cuda.current_device()), | ||
params_dtype=torch.half, | ||
device=layer.attention.dense.weight.device, | ||
) | ||
layer.mlp.dense_h_to_4h = QuantizedLinear( | ||
in_features=layer.mlp.dense_h_to_4h.weight.shape[0], | ||
out_features=layer.mlp.dense_h_to_4h.weight.shape[1], | ||
weight_bit_width=weight_bit_width, | ||
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), | ||
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()), | ||
params_dtype=torch.half, | ||
device=layer.mlp.dense_h_to_4h.weight.device, | ||
) | ||
layer.mlp.dense_4h_to_h = QuantizedLinear( | ||
in_features=layer.mlp.dense_4h_to_h.weight.shape[0], | ||
out_features=layer.mlp.dense_4h_to_h.weight.shape[1], | ||
weight_bit_width=weight_bit_width, | ||
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), | ||
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()), | ||
params_dtype=torch.half, | ||
device=layer.mlp.dense_4h_to_h.weight.device, | ||
) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .tokenizer import CodeGeeXTokenizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import torch | ||
from typing import * | ||
from transformers import AutoTokenizer | ||
from transformers.models.gpt2 import GPT2TokenizerFast | ||
|
||
|
||
def encode_whitespaces(text, start_extra_id: int, max_len: int): | ||
""" Encode whitespaces to extra tokens in GPT-J. | ||
>>> encode_whitespaces('a\\n b\\n c', 10, 10) | ||
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c' | ||
""" | ||
|
||
def push_acc_space(acc_len: int, text: str): | ||
if acc_len == 0: | ||
return text | ||
if acc_len == 1: | ||
return text + ' ' | ||
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}' | ||
extra_id = start_extra_id - 2 + acc_len | ||
extra_token = f'<|extratoken_{extra_id}|>' | ||
return text + extra_token | ||
|
||
acc_len = 0 | ||
res = '' | ||
for ch in text: | ||
if ch == ' ': | ||
acc_len += 1 | ||
if acc_len == max_len: | ||
res = push_acc_space(acc_len, res) | ||
acc_len = 0 | ||
else: | ||
res = push_acc_space(acc_len, res) | ||
acc_len = 0 | ||
res = res + ch | ||
|
||
res = push_acc_space(acc_len, res) | ||
|
||
return res | ||
|
||
|
||
def decode_whitespaces(text: str, start_extra_id: int, max_len: int): | ||
""" Decode the whitespace-encoded strings produced by encode_whitespace. | ||
>>> text = 'a\\n b\\n c' | ||
>>> s, l = 10, 10 | ||
>>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l) | ||
True | ||
""" | ||
for l in range(2, max_len + 1): | ||
token_id = start_extra_id - 2 + l | ||
token = f'<|extratoken_{token_id}|>' | ||
text = text.replace(token, ' ' * l) | ||
return text | ||
|
||
|
||
class CodeGeeXTokenizer(object): | ||
def __init__( | ||
self, | ||
tokenizer: GPT2TokenizerFast = None, | ||
tokenizer_path: str = "EleutherAI/gpt-j-6B", | ||
start_extra_id: int = 10, | ||
max_len : int = 10, | ||
mode='codegeex-13b', | ||
dict_file: str = None, | ||
): | ||
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(tokenizer_path) | ||
if mode not in ['codegeex-13b']: | ||
raise ValueError(f"Invalid mode {mode}, choose from ['codegeex-13b']") | ||
self.start_extra_id = start_extra_id | ||
self.max_len = max_len | ||
self.mode = mode | ||
self.eos_token_id = self.tokenizer.eos_token_id | ||
|
||
def encode_code(self, code: str): | ||
if self.mode == 'codegeex-13b': | ||
code = encode_whitespaces(code, self.start_extra_id, self.max_len) | ||
input_ids = self.tokenizer(code, is_split_into_words=False).input_ids | ||
|
||
return input_ids | ||
|
||
def decode_code(self, input_ids): | ||
if self.mode == 'codegeex-13b': | ||
text = self.tokenizer.decode(input_ids, skip_special_tokens=False) | ||
output_code = decode_whitespaces(text, self.start_extra_id, self.max_len) | ||
|
||
return output_code |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .codegeex_model import CodeGeeXModel |
Oops, something went wrong.