Skip to content

Commit

Permalink
添加转化接口&fix python api
Browse files Browse the repository at this point in the history
  • Loading branch information
wildkid1024 committed Jul 4, 2023
1 parent b703414 commit 6848068
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 52 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.log
*.pyc
token
/cmake-build-debug/
/build-tfacc/
Expand Down
91 changes: 84 additions & 7 deletions pyfastllm/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
# pyfastllm

by [wildkid1024](https://github.com/wildkid1024)

pyfastllm是基于fastllm的python api接口实现,通过pyfastllm可以更加灵活地编码实现pythonic场景,满足更复杂更个性化的业务需求。

- 对接fastapi、flask等web框架,向外提供数据接口
- 利用python yield生成器语言特性,流式问答响应
- 对接Lora、Ptuning等微调方法,下游任务可微调(开发中...)
- 无缝对接加速HugingFace模型库,无痛加速迁移原有业务代码(开发中...)
- 其他更多...

## 编译安装

本地编译安装fastllm的python接口,以两种方式编译运行:
1. cpp方式:编译为动态库,需放在python运行加载目录下
2. python方式:编译为wheel包,但暂不支持cuda
1. 动态库方式:编译为动态库,需放在python运行加载目录下
2. wheel包方式:编译为wheel包,安装在python的site-packages下,但暂不支持cuda

### cpp方式
### 动态库方式

手动编译
Cpp手动编译
```
mkdir build-py
cd build-py
Expand All @@ -15,19 +27,84 @@ make -j4
python cli.py -p chatglm-6b-int8.bin -t 8 # 与cpp编译的运行结果保持一致
```

脚本编译
Python脚本编译

```
cd pyfastllm
python build_libs --cuda
python cli.py -p chatglm-6b-int8.bin -t 8
```

### python方式
### wheel包方式

```
cd pyfastllm
python setup.py build
python setup.py install
python cli.py -p chatglm-6b-int8.bin -t 8
```
```

## API编程接口


### fastllm数据结构

> fattllm.Tensor数据类型
- fastllm.float32
- fastllm.bfloat16
- fastllm.int16
- fastllm.int8
- fastllm.int4
- fastllm.int2
- fastllm.float16

> fastllm.Tensor: fastllm基础张量结构
- fastllm.Tensor()
- fastllm.Tensor(Datatype)
- fastllm.Tensor(Datatype, Dims:list[int])
- fastllm.Tensor(Datatype, Dims:list[int], Data:list[float])
- fastllm.Tensor(Data:fastllm.Tensor)
- fastllm.Tensor.to_list() # 将Tensor转化list并返回
- fastllm.Tensor.to() # 将Tensor转移到对应设备上
- fastllm.Tensor.zeros(Dims:list[int]) # 按照Dims生成全零矩阵
- fastllm.Tensor.cat(Data:list[fastllm.Tensor], axis:int) # 将Tensor按照axis(默认为0)方向上拼接

### fastllm函数

> fastllm.get_llm_type(model_path:str)->str # 获取当前model的类型
> fastllm.set_threads(thread:int) -> None # 设置当前运行线程数,默认为4
> fastllm.get_threads()->int # 获取当前运行线程数
> fastllm.create_llm(model_path: str)-> fastllm.model # 从本地权重文件生成对应的模型实例,基于规则匹配
> fastllm.set_low_memory() # 低内存模式下运行,默认为False
> fastllm.get_low_memory() # 查看当前是否为低内存运行模式
### fastllm模块

> fastllm.Tokenizer: 分词及编解码工具
> Tips: 该类不可直接实例化,只可通过model.weight.tokenizer访问具体实例
- fastllm.Tokenizer.encode(prompt:str) # 将prompt分词并进行编码
- fastllm.Tokenizer.decode(output_ids:fastllm.Tensor) # 将fastllm.Tensor解码为对应字符串
- fastllm.Tokenizer.decode(output_ids: list[int]) # 将list[int]解码为对应的字符串
- fastllm.Tokenizer.decode_byte(output_ids: fastllm.Tensor) # 将Tensor解码对应字节流

### fastllm模型

> fastllm.ChatGLMModel: 具体模型实例,其中chatglm可以更换为llama、alpaca、Moss等模型
- fastllm.ChatGLMModel() # 初始化模型实例
- __call__(input_ids:fastllm.Tensor, attention_mask:fastllm.Tensor, position_ids:fastllm.Tensor, penalty_factor:fastllm.Tensor, pastKeyValues:memory_view) # 以类call function的方式调用模型进行推理
- fastllm.ChatGLMModel.load_weights(model_path:str) # 从文件路径中加载模型权重
- fastllm.ChatGLMModel.response(inputs:str, callback:function) # 发送字符串到模型中并使用callback函数接受处理返回的答案
- fastllm.ChatGLMModel.response_batch(inputs:list[str], callback:function) -> outputs:list[str] # 发送列表字符串到模型中并使用callback函数接受处理返回的答案
- fastllm.ChatGLMModel.warmup() # GPU热身,填充GPU,防止冷启动
- fastllm.ChatGLMModel.launch_response(inputs:str)->handle_id:int # 多线程下使用,填充第一个token,并返回多线程的线程id
- fastllm.ChatGLMModel.fetch_response(handle_id:int) # 根据线程ID从消息队列中取出对应的消息并返回
- fastllm.ChatGLMModel.save_lowbit_model(model_path:str, q_bit:int) # 量化保持低bit的权重并保存模型


## 开发计划

- [ ] 模型运行参数对象类,封装模型运行时参数,包含模型路径、运行线程数、是否为低内存模型、惩罚因子、温度等
- [ ] Tensor的深复制和浅复制,以及基础运算符重载
- [ ] 编解码部分优化,合并不同的返回类型
- [ ] 暴露更多的底层api接口,按照module的方式定义模型的点,拼接model实现自定义model
- [ ] 修改response_batch的output_str函数,以返回值的形式返回响应
3 changes: 2 additions & 1 deletion pyfastllm/fastllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from pyfastllm import *
from pyfastllm import *
from . import convert
41 changes: 28 additions & 13 deletions pyfastllm/fastllm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,17 @@
import argparse
from .utils import torch2flm

HF_INSTALLED = False
try:
import torch
from transformers import AutoTokenizer, AutoModel # chatglm
from transformers import LlamaTokenizer, LlamaForCausalLM # alpaca
from transformers import AutoModelForCausalLM, AutoTokenizer # baichuan, moss
from peft import PeftModel
HF_INSTALLED = True
except Exception as e:
logging.error("Make sure that you installed transformers and peft!!!")
sys.exit(1)
logging.info("You have not installed transformers and peft before convert!!!")


MODEL_DICT = {
"alpaca":{
Expand All @@ -37,21 +39,23 @@
}

def parse_args():
# -p 模型路径或hf路径
# -o --out_path 导出路径
# -q 量化位数
parser = argparse.ArgumentParser(description='build fastllm libs')
parser.add_argument('-o', dest='export_path', default=None,
help='output export path')
parser.add_argument('-m', dest='model', default='', required=True,
help='model name with(alpaca, baichuan7B, chatglm6B, moss)')
parser.add_argument('-p', dest='model_path', type=list,
help='a list with "tokenizer", "model", "peft model", such as: -p THUDM/chatglm-6b THUDM/chatglm-6b')
parser.add_argument('-m', dest='model', default='',
help='model name with(alpaca, baichuan7B, chatglm6B, moss)')
parser.add_argument('-q', dest='qbit', type=int, choices=[16, 8, 4, 2],
help='a list with "tokenizer", "model", "peft model", such as: -p THUDM/chatglm-6b THUDM/chatglm-6b')
parser.add_argument('-o', dest='export_path', default=None,
help='output export path')


args = parser.parse_args()
return args


def convert(model, tokenizer, export_path):
assert HF_INSTALLED, "Make sure that you installed transformers and peft before convert!!!"
torch2flm.tofile(model, tokenizer, export_path)


Expand Down Expand Up @@ -82,11 +86,11 @@ def moss(model_path, ):
model = model.eval()
return model, tokenizer

def main(args):
def export_model(args):
assert HF_INSTALLED, "Make sure that you installed transformers and peft before convert!!!"
if args.model not in MODEL_DICT:
assert f"Not Support {args.model} Yet!!!"


model_args = {}
model_args["model_path"] = MODEL_DICT[args.model].get("model")
if MODEL_DICT[args.model].has_key("peft"):
Expand All @@ -101,6 +105,17 @@ def main(args):
export_path = args.export_path or f"{args.model}-fp32.bin"
torch2flm.tofile(export_path, model.model, tokenizer)

if __name__ == "__main__":
if args.qbit:
import pyfastllm as fastllm
export_name, export_ext = export_path.split('.')
q_export_path = export_name + f"-q{args.qbit}." + export_ext
flm_model = fastllm.create_llm(export_path)
flm_model.save_lowbit_model(q_export_path, args.qbit)

def convert_main():
assert HF_INSTALLED, "Make sure that you installed transformers and peft before convert!!!"
args = parse_args()
main(args)
export_model(args)

if __name__ == "__main__":
convert_main()
41 changes: 15 additions & 26 deletions pyfastllm/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,29 @@

import sys
import argparse

parser = argparse.ArgumentParser(description='build pyfastllm wheel')
parser.add_argument('--cuda', dest='cuda', action='store_true', default=False,
help='build with cuda support')
args, unknown = parser.parse_known_args()
sys.argv = [sys.argv[0]] + unknown

__VERSION__ = "'0.1.0'"

def get_version():
root_dir = os.getenv('PROJECT_ROOT', os.path.dirname(os.path.dirname(os.getcwd())))
version_header = os.path.join(root_dir, 'include/MNN/MNNDefine.h')
version_major = version_minor = version_patch = 'x'
for line in open(version_header, 'rt').readlines():
if '#define FASTLLM_VERSION_MAJOR' in line:
version_major = int(line.strip().split(' ')[-1])
if '#define FASTLLM_VERSION_MINOR' in line:
version_minor = int(line.strip().split(' ')[-1])
if '#define FASTLLM_VERSION_PATCH' in line:
version_patch = int(line.strip().split(' ')[-1])
return '{}.{}.{}'.format(version_major, version_minor, version_patch)


__VERSION__ = "'0.1.2'"
BASE_DIR = os.path.dirname(os.path.dirname(__file__))

ext_modules = []
try:
from pybind11.setup_helpers import Pybind11Extension, ParallelCompile, naive_recompile

# `N` is to set the bumer of threads
# `naive_recompile` makes it recompile only if the source file changes. It does not check header files!
ParallelCompile("NPY_NUM_BUILD_JOBS", needs_recompile=naive_recompile, default=4).install()

from pybind11.setup_helpers import Pybind11Extension
# could only be relative paths, otherwise the `build` command would fail if you use a MANIFEST.in to distribute your package
# only source files (.cpp, .c, .cc) are needed
#
source_files = glob.glob(os.path.join(BASE_DIR, "src/**/*.cpp"), recursive=True)
# source_files.append(os.path.join(BASE_DIR, "src/devices/cpu/cpudevice.cpp"))
source_files.remove('/public/Code/Cpp/fastllm/src/devices/cuda/cudadevice.cpp')
print(source_files)

# remove cuda source
for file in source_files:
if file.endswith('cudadevice.cpp'):
source_files.remove(file)

extra_compile_args = ["-w", "-DPY_API"]
# If any libraries are used, e.g. libabc.so
Expand Down Expand Up @@ -80,16 +64,21 @@ def get_version():
print(e)
sys.exit(1)

cmdclass = {}

cmdclass = {}
setup(
name='fastllm', # used by `pip install`
version='0.0.1',
version=eval(__VERSION__),
description='python api for fastllm',
long_description='',
ext_modules=ext_modules,
packages = find_packages(), # the directory would be installed to site-packages
cmdclass=cmdclass,
entry_points = {
'console_scripts': [
'fastllm-convert = fastllm.convert:convert_main',
]
},
setup_requires=["pybind11"],
install_requires=[""],
python_requires='>=3.6',
Expand Down
32 changes: 27 additions & 5 deletions src/pybinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ using namespace pybind11::literals;


#ifdef PY_API
// template <typename... Args>
// using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;


// PYBIND11_MAKE_OPAQUE(std::vector<std::pair<fastllm::Data,fastllm::Data>>);
PYBIND11_MAKE_OPAQUE(fastllm::Data);
Expand Down Expand Up @@ -107,7 +110,9 @@ PYBIND11_MODULE(pyfastllm, m) {

py::class_<fastllm::Tokenizer>(m, "Tokenizer")
.def("encode", &fastllm::Tokenizer::Encode)
.def("decode", &fastllm::Tokenizer::Decode)
// .def("decode", &fastllm::Tokenizer::Decode)
.def("decode", py::overload_cast<const fastllm::Data &>(&fastllm::Tokenizer::Decode), "Decode from Tensor")
.def("decode", py::overload_cast<const std::vector<int> &>(&fastllm::Tokenizer::Decode), "Decode from Vector")
.def("decode_byte", [](fastllm::Tokenizer &tokenizer, const fastllm::Data &data){
std::string ret = tokenizer.Decode(data);
return py::bytes(ret);
Expand All @@ -128,7 +133,13 @@ PYBIND11_MODULE(pyfastllm, m) {
.def_readonly("eos_token_id", &fastllm::ChatGLMModel::eos_token_id)
.def("load_weights", &fastllm::ChatGLMModel::LoadFromFile)
.def("response", &fastllm::ChatGLMModel::Response)
.def("batch_response", &fastllm::ChatGLMModel::ResponseBatch)
.def("batch_response", [](fastllm::ChatGLMModel model,
const std::vector <std::string> &inputs,
RuntimeResultBatch retCb)->std::vector<std::string>outputs {
std::vector <std::string> &outputs,
model.ResponseBatch(inputs, outputs, retCb);
return outputs;
})
.def("warmup", &fastllm::ChatGLMModel::WarmUp)
.def("__call__",
[](fastllm::ChatGLMModel &model,
Expand All @@ -153,7 +164,13 @@ PYBIND11_MODULE(pyfastllm, m) {
.def_readonly("eos_token_id", &fastllm::MOSSModel::eos_token_id)
.def("load_weights", &fastllm::MOSSModel::LoadFromFile)
.def("response", &fastllm::MOSSModel::Response)
.def("batch_response", &fastllm::MOSSModel::ResponseBatch)
.def("batch_response", [](fastllm::MOSSModel model,
const std::vector <std::string> &inputs,
RuntimeResultBatch retCb)->std::vector<std::string>outputs {
std::vector <std::string> &outputs,
model.ResponseBatch(inputs, outputs, retCb);
return outputs;
})
.def("__call__",
[](fastllm::MOSSModel &model,
const fastllm::Data &inputIds,
Expand All @@ -177,7 +194,13 @@ PYBIND11_MODULE(pyfastllm, m) {
.def_readonly("eos_token_id", &fastllm::LlamaModel::eos_token_id)
.def("load_weights", &fastllm::LlamaModel::LoadFromFile)
.def("response", &fastllm::LlamaModel::Response)
.def("batch_response", &fastllm::LlamaModel::ResponseBatch)
.def("batch_response", [](fastllm::LlamaModel model,
const std::vector <std::string> &inputs,
RuntimeResultBatch retCb)->std::vector<std::string>outputs {
std::vector <std::string> &outputs,
model.ResponseBatch(inputs, outputs, retCb);
return outputs;
})
.def("warmup", &fastllm::LlamaModel::WarmUp)
.def("__call__",
[](fastllm::LlamaModel &model,
Expand All @@ -192,7 +215,6 @@ PYBIND11_MODULE(pyfastllm, m) {
.def("launch_response", &fastllm::LlamaModel::LaunchResponseTokens)
.def("fetch_response", &fastllm::LlamaModel::FetchResponseTokens)
.def("save_lowbit_model", &fastllm::LlamaModel::SaveLowBitModel);

#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;
#else
Expand Down

0 comments on commit 6848068

Please sign in to comment.