Skip to content

Commit

Permalink
基本对齐huggingface tokenizers的BPE Tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli authored and TylunasLi committed Feb 12, 2024
1 parent ace0522 commit f64d250
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 19 deletions.
37 changes: 32 additions & 5 deletions docs/llama_cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ LLaMA类模型有着基本相同的结构,但权重和prompt构造有差异。

以下配置方案根据模型的源代码整理,不保证模型推理结果与原版完全一致。

## 修改脚本并转换
## 修改方式

目前,转换脚本和两行加速方式均可用于llama类模型。但无论采用哪一种方式,都需要预留足够的内存(可以用swap空间)。

在float16模式下,转换时约需要4×参数量+1GB的空闲内存。

### 转换脚本

这里以支持推理各类Llama结构的基座模型为例,介绍如何应用本文档。

Expand Down Expand Up @@ -40,17 +46,36 @@ LLaMA类模型有着基本相同的结构,但权重和prompt构造有差异。

如需添加Token ID而非字符串(类似baichuan-chat模型),可以使用“<FLM_FIX_TOKEN_{ID}>”的格式添加。

* 执行脚本

```shell
python3 tools/alpaca2flm.py [输出文件名] [精度] [原始模型名称或路径]
```

### 两行加速

```python
conf = model.config.__dict__
conf["model_type"] = "llama"
llm.from_hf(model, tokenizer, pre_prompt = "",
user_role = "", bot_role = "", history_sep = "",
dtype = dtype)
```

## 对齐

如果想使fastllm模型和原版transformers模型基本一致,最主要的操作是对齐tokenizer。
如果模型使用了huggingface 加速版本的Tokenizers(即模型目录中包含`tokenizer.json`并优先使用),目前的转换脚本**仅在从本地文件转换时,能够对齐tokenizer**

注意检查原始tokenizer的`encode()`方法返回的结果前面是否会加空格。如果原始tokenizer没有加空格,则需要设置:

```python
conf["tokenizer_add_dummy_prefix"] = False
```

## Base Model

见上方“[修改方案](#修改方案)”。
见上方“[修改方案](#修改方式)”。

一部分模型需要制定bos_token_id,假设bos_token_id为1则可以配置如下:

Expand Down Expand Up @@ -96,10 +121,12 @@ python3 tools/internlm2flm.py internlm-7b-int4.flm float16 internlm/internlm-cha
```python
conf = model.config.__dict__
conf["model_type"] = "llama"
conf["tokenizer_add_dummy_prefix"] = False
torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "",
user_role = "Human: ", bot_role = "\n\nAssistant: ",
history_sep = "<FLM_FIX_TOKEN_3>", dtype = dtype)
```
XVERSE-13B-Chat V1 版本需要对输入做NFKC规范化,fastllm暂不支持,因此需要使用原始tokenizer.

### 其他 llama1 系列

Expand Down Expand Up @@ -174,7 +201,7 @@ python3 tools/internlm2flm.py internlm-7b-int4.flm float16 internlm/internlm-cha
```python
torch2flm.tofile(exportPath, model, tokenizer,
pre_prompt="The following is a conversation between a human and an AI assistant namely YuLan, developed by GSAI, Renmin University of China. " \
"The AI assistant gives helpful, detailed, and polite answers to the user's questions.\n"
"The AI assistant gives helpful, detailed, and polite answers to the user's questions.\n",
user_role="[|Human|]:", bot_role="\n[|AI|]:", history_sep="\n", dtype=dtype)
```

Expand All @@ -185,7 +212,7 @@ python3 tools/internlm2flm.py internlm-7b-int4.flm float16 internlm/internlm-cha

```python
torch2flm.tofile(exportPath, model, tokenizer,
pre_prompt="Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
pre_prompt="Below is an instruction that describes a task. " \
"Write a response that appropriately completes the request.\n\n",
user_role="### Instruction:\n", bot_role="\n\n### Response:", history_sep="\n", dtype=dtype)
```
2 changes: 1 addition & 1 deletion src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ namespace fastllm {
model = (basellm*)(new ChatGLMModel());
} else if (modelType == "moss") {
model = (basellm*)(new MOSSModel());
model->weight.tokenizer.type = Tokenizer::TokenizerType::NORMAL;
model->weight.tokenizer.type = Tokenizer::TokenizerType::BPE;
model->eos_token_id = 106068;
} else if (modelType == "baichuan") {
model = (basellm*)(new LlamaModel());
Expand Down
30 changes: 21 additions & 9 deletions tools/fastllm_pytools/hf_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastllm_pytools import llm;
import ctypes;
import builtins, os, json
import numpy as np
import torch
from transformers import PreTrainedTokenizerFast
Expand Down Expand Up @@ -118,20 +119,31 @@ def create(model,
else:
tokenizer = tokenizer.tokenizer
if (hasattr(tokenizer, "sp_model")):
piece_size = tokenizer.sp_model.piece_size();
piece_size = tokenizer.sp_model.piece_size()
for i in range(piece_size):
llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, tokenizer.sp_model.id_to_piece(i).encode(),
i, ctypes.c_float(tokenizer.sp_model.get_score(i)));
else:
vocab = tokenizer.get_vocab();
merges = {}
if (modelInfo["model_type"] == "moss"):
merges = {("".join(bpe_tokens), token_index) for bpe_tokens, token_index in sorted(tokenizer.bpe_ranks.items(), key=lambda kv: kv[1])}
elif isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer_file = tokenizer.name_or_path + tokenizer.vocab_files_names['tokenizer_file']
if os.path.exists(tokenizer_file):
with open(tokenizer_file, "r", encoding='utf-8') as f:
bpe_merges = json.load(f)["model"]["merges"]
bpe_merges = [pair.replace(" ", "") for pair in bpe_merges]
merges = builtins.dict(zip(bpe_merges, range(0, -len(bpe_merges), -1)))
vocab = tokenizer.get_vocab()
for v in vocab.keys():
score = merges[v] if v in merges else 1.0
if (modelInfo["model_type"] == "moss"):
vv = [(ord(c) if c not in tokenizer.byte_decoder else tokenizer.byte_decoder[c]) for c in v];
llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, vv, vocab[v], ctypes.c_float(1.0));
s = [(ord(c) if c not in tokenizer.byte_decoder else tokenizer.byte_decoder[c]) for c in v]
llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, s, vocab[v], ctypes.c_float(score));
elif (modelInfo["model_type"] == "qwen"):
llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, v, vocab[v], ctypes.c_float(1.0));
else:
llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, v.encode(), vocab[v], ctypes.c_float(1.0));
llm.fastllm_lib.add_tokenizer_word_llm_model(model_handle, v.encode(), vocab[v], ctypes.c_float(score));

weight_type_dict = {}
module_dict = {}
Expand All @@ -157,13 +169,13 @@ def create(model,
to_data_type = 0

if (cur_weight_type == 1):
to_data_type = fastllm_data_type_dict[dtype];
to_data_type = fastllm_data_type_dict[dtype]
if (to_data_type == 7):
ori_data_type = 7;
ori_np_data_type = np.float16;
ori_data_type = 7
ori_np_data_type = np.float16
elif (cur_weight_type == 2):
# TODO bfloat
to_data_type = 0;
to_data_type = 0

weight_name = key
if hasattr(model, "peft_config"):
Expand Down
14 changes: 13 additions & 1 deletion tools/fastllm_pytools/torch2flm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import struct
import builtins, os, json
import numpy as np
import torch
from transformers import PreTrainedTokenizerFast
Expand Down Expand Up @@ -174,9 +175,20 @@ def tofile(exportPath,
fo.write(struct.pack('i', i))
fo.write(struct.pack('f', float(tokenizer.sp_model.get_score(i))))
else:
merges = {}
if (modelInfo["model_type"] == "moss"):
merges = {("".join(bpe_tokens), token_index) for bpe_tokens, token_index in sorted(tokenizer.bpe_ranks.items(), key=lambda kv: kv[1])}
elif isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer_file = tokenizer.name_or_path + tokenizer.vocab_files_names['tokenizer_file']
if os.path.exists(tokenizer_file):
with open(tokenizer_file, "r", encoding='utf-8') as f:
bpe_merges = json.load(f)["model"]["merges"]
bpe_merges = [pair.replace(" ", "") for pair in bpe_merges]
merges = builtins.dict(zip(bpe_merges, range(0, -len(bpe_merges), -1)))
vocab = tokenizer.get_vocab()
fo.write(struct.pack('i', len(vocab)))
for v in vocab.keys():
score = merges[v] if v in merges else 1.0
if (modelInfo["model_type"] == "moss"):
s = [(ord(c) if c not in tokenizer.byte_decoder else tokenizer.byte_decoder[c]) for c in v]
elif (modelInfo["model_type"] == "qwen"):
Expand All @@ -187,7 +199,7 @@ def tofile(exportPath,
for c in s:
fo.write(struct.pack('i', c))
fo.write(struct.pack('i', vocab[v]))
fo.write(struct.pack('f', 1.0))
fo.write(struct.pack('f', score))
else:
fo.write(struct.pack('i', 0))

Expand Down
8 changes: 5 additions & 3 deletions tools/scripts/alpaca2flm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import sys
from transformers import LlamaTokenizer, LlamaForCausalLM
import torch
from transformers import AutoTokenizer, LlamaForCausalLM
from fastllm_pytools import torch2flm

if __name__ == "__main__":
model_name = sys.argv[3] if len(sys.argv) >= 4 else 'minlik/chinese-alpaca-33b-merged'
tokenizer = LlamaTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(model_name).float()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# `torch_dtype=torch.float16` is set by default, if it will not cause an OOM Error, you can load model in float32.
model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
conf = model.config.__dict__
conf["model_type"] = "llama"
dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
Expand Down

0 comments on commit f64d250

Please sign in to comment.