Skip to content

Commit

Permalink
增加基本的GLM基座模型(例如glm-large-chinese)支持。
Browse files Browse the repository at this point in the history
  • Loading branch information
fluxlinkage committed Sep 20, 2023
1 parent 5cb58c0 commit ede16bc
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 8 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ endif()
message(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS})
set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/model.cpp src/executor.cpp
src/devices/cpu/cpudevice.cpp src/devices/cpu/cpudevicebatch.cpp
src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp)
src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp src/models/glm.cpp)

include_directories(include)
include_directories(include/utils)
Expand Down Expand Up @@ -113,4 +113,4 @@ else()
)
endif()

endif()
endif()
3 changes: 2 additions & 1 deletion include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ namespace fastllm {
enum TokenizerType {
BPE = 0,
NORMAL = 1,
QWEN = 2
QWEN = 2,
GLM = 3
};

struct TrieNode {
Expand Down
122 changes: 121 additions & 1 deletion src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,126 @@ namespace fastllm {
TryMergePairs(symbols, top.l, symbols[top.l].next, workQueue);
}

std::vector<float> v;
for (int i = 0; i < symbols.size(); i++) {
if (symbols[i].len > 0) {
v.push_back(symbols[i].node->tokenId);
} else if (symbols[i].node == nullptr) {
if (symbols[i].fixId != -999999) {
v.push_back(symbols[i].fixId);
} else {
// 未识别的字符
uint8_t c = (uint8_t) (symbols[i].s[symbols[i].pos]);
std::string now = "<0x00>";
now[3] = (c / 16 > 9 ? ('A' + c / 16 - 10) : ('0' + c / 16));
now[4] = (c % 16 > 9 ? ('A' + c % 16 - 10) : ('0' + c % 16));
if (stringToTokenDict.find(now) != stringToTokenDict.end()) {
v.push_back(stringToTokenDict[now]);
}
}
}
}
return Data (DataType::FLOAT32, {1, (int)v.size()}, v);
} else if (this->type == TokenizerType::GLM) {
const std::map<std::string, int> specialTokens = {{"[MASK]", 50003}, {"[sMASK]", 50008}, {"[gMASK]", 50009}};
std::string blank = "";
blank += 226, blank += 150, blank += 129;
std::string s = blank;
if (15 < ori.size() && ori.substr(0, 15) == "<FLM_FIX_TOKEN_") {
s = "";
}
for (int i = 0; i < ori.size(); i++) {
if (ori[i] == ' ') {
if (i != 0 && ori[i - 1] != ' ') {
s += blank;
}
} else {
s += ori[i];
}
}

std::vector<Symbol> symbols;
for (int i = 0; i < s.size(); i++) {
if (i + 3 < s.size() && s[i] == '<' && s[i + 1] == 'F' && s[i + 2] == 'L' && s[i + 3] == 'M') {
if (i + 15 < s.size() && s.substr(i, 15) == "<FLM_FIX_TOKEN_") {
i += 15;
int now = 0;
while (s[i] >= '0' && s[i] <= '9') {
now = now * 10 + s[i] - '0';
i++;
}
symbols.push_back(Symbol(nullptr, (char *) s.data(), i, 0, (int) symbols.size() - 1,
(int) symbols.size() + 1, now));
continue;
}
}

int tokenId = -999999, pos = i - 1;
TrieNode *now = this->root;
for(const auto &special:specialTokens){
const std::string &tokenTxt=special.first;
int sz=tokenTxt.size();
if(i+sz<=s.size()&&s.substr(i,sz)==tokenTxt){
tokenId=special.second;
pos=i+sz-1;
for(int k=0;k<sz;k++){
now=now->next[s[i+k]];
}
break;
}
}
if(tokenId<0){
for (int j = i; j < s.size(); j++) {
if (now->next.find(s[j]) != now->next.end()) {
now = now->next[s[j]];
if (now->tokenId != -999999) {
tokenId = now->tokenId;
pos = j;
break;
}
} else {
break;
}
}
}
if (pos >= i) {
symbols.push_back(Symbol(now, (char *) s.data(), i, pos - i + 1, (int) symbols.size() - 1,
(int) symbols.size() + 1, -999999));
i = pos;
} else {
symbols.push_back(Symbol(nullptr, (char *) s.data(), i, 0, (int) symbols.size() - 1,
(int) symbols.size() + 1, -999999));
}
}
symbols.back().next = -1;

std::priority_queue<SymbolPairs> workQueue;
for (int i = 1; i < symbols.size(); i++) {
TryMergePairs(symbols, i - 1, i, workQueue);
}

while (!workQueue.empty()) {
auto top = workQueue.top();
workQueue.pop();
if (symbols[top.l].len == 0 || symbols[top.r].len == 0 ||
symbols[top.l].len + symbols[top.r].len != top.size) {
continue;
}

for (int i = symbols[top.r].pos; i < symbols[top.r].pos + symbols[top.r].len; i++) {
symbols[top.l].node = symbols[top.l].node->next[symbols[top.r].s[i]];
}
symbols[top.l].len += symbols[top.r].len;
symbols[top.r].len = 0;
symbols[top.l].next = symbols[top.r].next;
if (symbols[top.r].next >= 0) {
symbols[symbols[top.r].next].prev = top.l;
}

TryMergePairs(symbols, symbols[top.l].prev, top.l, workQueue);
TryMergePairs(symbols, top.l, symbols[top.l].next, workQueue);
}

std::vector<float> v;
for (int i = 0; i < symbols.size(); i++) {
if (symbols[i].len > 0) {
Expand Down Expand Up @@ -1951,4 +2071,4 @@ namespace fastllm {
std::map <std::string, int> GetDeviceMap() {
return defaultDeviceMap;
}
}
}
15 changes: 12 additions & 3 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "moss.h"
#include "llama.h"
#include "qwen.h"
#include "glm.h"

namespace fastllm {
void basellm::LoadFromFile(const std::string &fileName) {
Expand All @@ -16,15 +17,21 @@ namespace fastllm {

void basellm::InitParams() {
if (this->weight.dicts.find("bos_token_id") != this->weight.dicts.end()) {
this->bos_token_id = atoi(this->weight.dicts["bos_token_id"].c_str());
this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str());
if(this->weight.dicts["bos_token_id"]!="None"){
this->bos_token_id = atoi(this->weight.dicts["bos_token_id"].c_str());
}
if(this->weight.dicts["eos_token_id"]!="None"){
this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str());
}
}
if (this->weight.dicts.find("im_start_id") != this->weight.dicts.end()) {
this->bos_token_id = atoi(this->weight.dicts["im_start_id"].c_str());
this->eos_token_id = atoi(this->weight.dicts["im_end_id"].c_str());
}
if (this->weight.dicts.find("num_hidden_layers") != this->weight.dicts.end()) {
block_cnt = atoi(this->weight.dicts["num_hidden_layers"].c_str());
}else if (this->weight.dicts.find("num_layers") != this->weight.dicts.end()) {
block_cnt = atoi(this->weight.dicts["num_layers"].c_str());
}
if (this->weight.dicts.find("hidden_size") != this->weight.dicts.end()) {
embed_dim = atoi(this->weight.dicts["hidden_size"].c_str());
Expand Down Expand Up @@ -77,6 +84,8 @@ namespace fastllm {
} else if (modelType == "qwen") {
model = (basellm *) (new QWenModel());
model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
} else if (modelType == "glm") {
model = (basellm*)(new GLMModel());
} else {
ErrorInFastLLM("Unkown model type: " + modelType);
}
Expand All @@ -95,4 +104,4 @@ namespace fastllm {
basellm *model = CreateModelWithType(modelType);
return std::unique_ptr<fastllm::basellm> (model);
}
}
}
2 changes: 1 addition & 1 deletion src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -777,4 +777,4 @@ printf("tot = %d\n", tot);
void basellm::DisableAdapter() {
adapterName = "";
}
}
}

0 comments on commit ede16bc

Please sign in to comment.