Skip to content

Commit

Permalink
Add int8 quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanislas0 committed Nov 20, 2022
1 parent 601b3aa commit a2b6a15
Show file tree
Hide file tree
Showing 11 changed files with 1,902 additions and 0 deletions.
99 changes: 99 additions & 0 deletions codegeex/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pkg_resources
import torch
import ctypes

from typing import List
from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up

RESOURCE_PACKAGE_NAME = __name__


class Kernel:
def __init__(self, filename: str, function_names: List[str]):
filename = filename + ".fatbin"
if not pkg_resources.resource_exists(RESOURCE_PACKAGE_NAME, filename):
raise RuntimeError("File `%s` not found in `%s`" % (filename, RESOURCE_PACKAGE_NAME))
self.filename = filename
self.code = pkg_resources.resource_string(RESOURCE_PACKAGE_NAME, filename)
self._function_names = function_names
self._cmodule = LazyKernelCModule(self.code)

for name in self._function_names:
setattr(self, name, KernelFunction(self._cmodule, name))


kernels = Kernel(
"quantization",
[
"int4WeightCompression",
"int4WeightExtractionFloat",
"int4WeightExtractionHalf",
"int8WeightExtractionFloat",
"int8WeightExtractionHalf",
],
)


def compress_int4_weight(weight: torch.Tensor): # (n, m)
with torch.cuda.device(weight.device):
n, m = weight.size(0), weight.size(1)
assert m % 2 == 0
m = m // 2
out = torch.empty(n, m, dtype=torch.int8, device="cuda")
stream = torch.cuda.current_stream()

gridDim = (n, 1, 1)
blockDim = (min(round_up(m, 32), 1024), 1, 1)

kernels.int4WeightCompression(
gridDim,
blockDim,
0,
stream,
[ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)],
)
return out


def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int):
if source_bit_width == 8:
func = kernels.int8WeightExtractionHalf
elif source_bit_width == 4:
func = kernels.int4WeightExtractionHalf
else:
assert False, "Unsupported bit-width"

with torch.cuda.device(weight.device):
n, m = weight.size(0), weight.size(1)
out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.half, device="cuda")
stream = torch.cuda.current_stream()

gridDim = (n, 1, 1)
blockDim = (min(round_up(m, 32), 1024), 1, 1)

func(
gridDim,
blockDim,
0,
stream,
[
ctypes.c_void_p(weight.data_ptr()),
ctypes.c_void_p(scale_list.data_ptr()),
ctypes.c_void_p(out.data_ptr()),
ctypes.c_int32(n),
ctypes.c_int32(m),
],
)
return out


if __name__ == "__main__":
weight = torch.randn(4, 32).to(torch.int8).cuda()
scale = torch.ones(weight.size(0)).to(torch.half).cuda()

print(weight)
b = compress_int4_weight(weight)
print(b)

a = extract_weight_to_half(b, scale, source_bit_width=4)
print(a)
Binary file added codegeex/kernels/quantization.fatbin
Binary file not shown.
1 change: 1 addition & 0 deletions codegeex/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .quantize import quantize
139 changes: 139 additions & 0 deletions codegeex/quantization/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import torch

from torch.nn.parameter import Parameter
from codegeex.kernels import extract_weight_to_half


class W8A16Linear(torch.autograd.Function):
@staticmethod
def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width):
ctx.inp_shape = inp.size()
ctx.weight_shape = quant_w.size()
ctx.weight_bit_width = weight_bit_width
out_features = quant_w.size(0)
inp = inp.contiguous().view(-1, inp.size(-1))
weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width)
output = inp.mm(weight.t())
ctx.save_for_backward(inp, quant_w, scale_w)
return output.view(*(ctx.inp_shape[:-1] + (out_features,)))

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
inp, quant_w, scale_w = ctx.saved_tensors
weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width)
grad_output = grad_output.contiguous().view(-1, weight.size(0))
grad_input = grad_output.mm(weight)
grad_weight = grad_output.t().mm(inp)
return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None


class QuantizedLinear(torch.nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
weight_bit_width: int,
weight: torch.Tensor = None,
bias: torch.Tensor = None,
*args,
**kwargs
):
super(QuantizedLinear, self).__init__()

self.in_features = in_features
self.out_features = out_features
self.weight_bit_width = weight_bit_width

if weight is None:
self.weight = torch.empty(
shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"]
)
self.weight_scale = torch.empty(shape[0], dtype=kwargs["params_dtype"], device=kwargs["device"])
else:
self.weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half()
self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8)
if weight_bit_width == 4:
self.weight = compress_int4_weight(self.weight)

if bias is None:
self.register_parameter('bias', None)
else:
self.bias = bias

self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False)
self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False)

def forward(self, input_):
# Matrix multiply.
output = W8A16Linear.apply(input_, self.weight, self.weight_scale, self.weight_bit_width)
if self.bias is not None:
output = output + self.bias

return output


def quantize(model, weight_bit_width):
"""Replace fp16 linear with quantized linear"""

for i in range(len(model.language_model.transformer.layers) + 1):
if i == len(model.language_model.transformer.layers):
layer = model.language_model.transformer.topQueryLayer
else:
layer = model.language_model.transformer.layers[i]

layer.attention.query = QuantizedLinear(
in_features=layer.attention.query.weight.shape[0],
out_features=layer.attention.query.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.attention.query.weight.to(torch.cuda.current_device()),
bias=layer.attention.query.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.query.weight.device,
)
layer.attention.value = QuantizedLinear(
in_features=layer.attention.value.weight.shape[0],
out_features=layer.attention.value.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.attention.value.weight.to(torch.cuda.current_device()),
bias=layer.attention.value.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.value.weight.device,
)
layer.attention.key = QuantizedLinear(
in_features=layer.attention.key.weight.shape[0],
out_features=layer.attention.key.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.attention.key.weight.to(torch.cuda.current_device()),
bias=layer.attention.key.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.key.weight.device,
)
layer.attention.dense = QuantizedLinear(
in_features=layer.attention.dense.weight.shape[0],
out_features=layer.attention.dense.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.attention.dense.weight.to(torch.cuda.current_device()),
bias=layer.attention.dense.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.attention.dense.weight.device,
)
layer.mlp.dense_h_to_4h = QuantizedLinear(
in_features=layer.mlp.dense_h_to_4h.weight.shape[0],
out_features=layer.mlp.dense_h_to_4h.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_h_to_4h.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.mlp.dense_h_to_4h.weight.device,
)
layer.mlp.dense_4h_to_h = QuantizedLinear(
in_features=layer.mlp.dense_4h_to_h.weight.shape[0],
out_features=layer.mlp.dense_4h_to_h.weight.shape[1],
weight_bit_width=weight_bit_width,
weight=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()),
bias=layer.mlp.dense_4h_to_h.bias.to(torch.cuda.current_device()),
params_dtype=torch.half,
device=layer.mlp.dense_4h_to_h.weight.device,
)

return model
1 change: 1 addition & 0 deletions codegeex/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tokenizer import CodeGeeXTokenizer
87 changes: 87 additions & 0 deletions codegeex/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
from typing import *
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast


def encode_whitespaces(text, start_extra_id: int, max_len: int):
""" Encode whitespaces to extra tokens in GPT-J.
>>> encode_whitespaces('a\\n b\\n c', 10, 10)
'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
"""

def push_acc_space(acc_len: int, text: str):
if acc_len == 0:
return text
if acc_len == 1:
return text + ' '
assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}'
extra_id = start_extra_id - 2 + acc_len
extra_token = f'<|extratoken_{extra_id}|>'
return text + extra_token

acc_len = 0
res = ''
for ch in text:
if ch == ' ':
acc_len += 1
if acc_len == max_len:
res = push_acc_space(acc_len, res)
acc_len = 0
else:
res = push_acc_space(acc_len, res)
acc_len = 0
res = res + ch

res = push_acc_space(acc_len, res)

return res


def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
""" Decode the whitespace-encoded strings produced by encode_whitespace.
>>> text = 'a\\n b\\n c'
>>> s, l = 10, 10
>>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l)
True
"""
for l in range(2, max_len + 1):
token_id = start_extra_id - 2 + l
token = f'<|extratoken_{token_id}|>'
text = text.replace(token, ' ' * l)
return text


class CodeGeeXTokenizer(object):
def __init__(
self,
tokenizer: GPT2TokenizerFast = None,
tokenizer_path: str = "EleutherAI/gpt-j-6B",
start_extra_id: int = 10,
max_len : int = 10,
mode='codegeex-13b',
dict_file: str = None,
):
self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(tokenizer_path)
if mode not in ['codegeex-13b']:
raise ValueError(f"Invalid mode {mode}, choose from ['codegeex-13b']")
self.start_extra_id = start_extra_id
self.max_len = max_len
self.mode = mode
self.eos_token_id = self.tokenizer.eos_token_id

def encode_code(self, code: str):
if self.mode == 'codegeex-13b':
code = encode_whitespaces(code, self.start_extra_id, self.max_len)
input_ids = self.tokenizer(code, is_split_into_words=False).input_ids

return input_ids

def decode_code(self, input_ids):
if self.mode == 'codegeex-13b':
text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
output_code = decode_whitespaces(text, self.start_extra_id, self.max_len)

return output_code
1 change: 1 addition & 0 deletions codegeex/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .codegeex_model import CodeGeeXModel
Loading

0 comments on commit a2b6a15

Please sign in to comment.