From ede16bc927b47b99cfc98600f9fce6762484d795 Mon Sep 17 00:00:00 2001 From: test Date: Wed, 20 Sep 2023 17:33:06 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=9F=BA=E6=9C=AC=E7=9A=84GL?= =?UTF-8?q?M=E5=9F=BA=E5=BA=A7=E6=A8=A1=E5=9E=8B=EF=BC=88=E4=BE=8B?= =?UTF-8?q?=E5=A6=82glm-large-chinese=EF=BC=89=E6=94=AF=E6=8C=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CMakeLists.txt | 4 +- include/fastllm.h | 3 +- src/fastllm.cpp | 122 ++++++++++++++++++++++++++++++++++++++++- src/model.cpp | 15 ++++- src/models/basellm.cpp | 2 +- 5 files changed, 138 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b4708eae..1f10d361 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,7 @@ endif() message(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS}) set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/model.cpp src/executor.cpp src/devices/cpu/cpudevice.cpp src/devices/cpu/cpudevicebatch.cpp - src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp) + src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp src/models/glm.cpp) include_directories(include) include_directories(include/utils) @@ -113,4 +113,4 @@ else() ) endif() -endif() \ No newline at end of file +endif() diff --git a/include/fastllm.h b/include/fastllm.h index 4e7d6e4b..abb04bf7 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -306,7 +306,8 @@ namespace fastllm { enum TokenizerType { BPE = 0, NORMAL = 1, - QWEN = 2 + QWEN = 2, + GLM = 3 }; struct TrieNode { diff --git a/src/fastllm.cpp b/src/fastllm.cpp index a5eade2b..829a3c39 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -894,6 +894,126 @@ namespace fastllm { TryMergePairs(symbols, top.l, symbols[top.l].next, workQueue); } + std::vector v; + for (int i = 0; i < symbols.size(); i++) { + if (symbols[i].len > 0) { + v.push_back(symbols[i].node->tokenId); + } else if (symbols[i].node == nullptr) { + if (symbols[i].fixId != -999999) { + v.push_back(symbols[i].fixId); + } else { + // 未识别的字符 + uint8_t c = (uint8_t) (symbols[i].s[symbols[i].pos]); + std::string now = "<0x00>"; + now[3] = (c / 16 > 9 ? ('A' + c / 16 - 10) : ('0' + c / 16)); + now[4] = (c % 16 > 9 ? ('A' + c % 16 - 10) : ('0' + c % 16)); + if (stringToTokenDict.find(now) != stringToTokenDict.end()) { + v.push_back(stringToTokenDict[now]); + } + } + } + } + return Data (DataType::FLOAT32, {1, (int)v.size()}, v); + } else if (this->type == TokenizerType::GLM) { + const std::map specialTokens = {{"[MASK]", 50003}, {"[sMASK]", 50008}, {"[gMASK]", 50009}}; + std::string blank = ""; + blank += 226, blank += 150, blank += 129; + std::string s = blank; + if (15 < ori.size() && ori.substr(0, 15) == " symbols; + for (int i = 0; i < s.size(); i++) { + if (i + 3 < s.size() && s[i] == '<' && s[i + 1] == 'F' && s[i + 2] == 'L' && s[i + 3] == 'M') { + if (i + 15 < s.size() && s.substr(i, 15) == "= '0' && s[i] <= '9') { + now = now * 10 + s[i] - '0'; + i++; + } + symbols.push_back(Symbol(nullptr, (char *) s.data(), i, 0, (int) symbols.size() - 1, + (int) symbols.size() + 1, now)); + continue; + } + } + + int tokenId = -999999, pos = i - 1; + TrieNode *now = this->root; + for(const auto &special:specialTokens){ + const std::string &tokenTxt=special.first; + int sz=tokenTxt.size(); + if(i+sz<=s.size()&&s.substr(i,sz)==tokenTxt){ + tokenId=special.second; + pos=i+sz-1; + for(int k=0;knext[s[i+k]]; + } + break; + } + } + if(tokenId<0){ + for (int j = i; j < s.size(); j++) { + if (now->next.find(s[j]) != now->next.end()) { + now = now->next[s[j]]; + if (now->tokenId != -999999) { + tokenId = now->tokenId; + pos = j; + break; + } + } else { + break; + } + } + } + if (pos >= i) { + symbols.push_back(Symbol(now, (char *) s.data(), i, pos - i + 1, (int) symbols.size() - 1, + (int) symbols.size() + 1, -999999)); + i = pos; + } else { + symbols.push_back(Symbol(nullptr, (char *) s.data(), i, 0, (int) symbols.size() - 1, + (int) symbols.size() + 1, -999999)); + } + } + symbols.back().next = -1; + + std::priority_queue workQueue; + for (int i = 1; i < symbols.size(); i++) { + TryMergePairs(symbols, i - 1, i, workQueue); + } + + while (!workQueue.empty()) { + auto top = workQueue.top(); + workQueue.pop(); + if (symbols[top.l].len == 0 || symbols[top.r].len == 0 || + symbols[top.l].len + symbols[top.r].len != top.size) { + continue; + } + + for (int i = symbols[top.r].pos; i < symbols[top.r].pos + symbols[top.r].len; i++) { + symbols[top.l].node = symbols[top.l].node->next[symbols[top.r].s[i]]; + } + symbols[top.l].len += symbols[top.r].len; + symbols[top.r].len = 0; + symbols[top.l].next = symbols[top.r].next; + if (symbols[top.r].next >= 0) { + symbols[symbols[top.r].next].prev = top.l; + } + + TryMergePairs(symbols, symbols[top.l].prev, top.l, workQueue); + TryMergePairs(symbols, top.l, symbols[top.l].next, workQueue); + } + std::vector v; for (int i = 0; i < symbols.size(); i++) { if (symbols[i].len > 0) { @@ -1951,4 +2071,4 @@ namespace fastllm { std::map GetDeviceMap() { return defaultDeviceMap; } -} \ No newline at end of file +} diff --git a/src/model.cpp b/src/model.cpp index 61e990bc..401e5ab9 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -7,6 +7,7 @@ #include "moss.h" #include "llama.h" #include "qwen.h" +#include "glm.h" namespace fastllm { void basellm::LoadFromFile(const std::string &fileName) { @@ -16,8 +17,12 @@ namespace fastllm { void basellm::InitParams() { if (this->weight.dicts.find("bos_token_id") != this->weight.dicts.end()) { - this->bos_token_id = atoi(this->weight.dicts["bos_token_id"].c_str()); - this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str()); + if(this->weight.dicts["bos_token_id"]!="None"){ + this->bos_token_id = atoi(this->weight.dicts["bos_token_id"].c_str()); + } + if(this->weight.dicts["eos_token_id"]!="None"){ + this->eos_token_id = atoi(this->weight.dicts["eos_token_id"].c_str()); + } } if (this->weight.dicts.find("im_start_id") != this->weight.dicts.end()) { this->bos_token_id = atoi(this->weight.dicts["im_start_id"].c_str()); @@ -25,6 +30,8 @@ namespace fastllm { } if (this->weight.dicts.find("num_hidden_layers") != this->weight.dicts.end()) { block_cnt = atoi(this->weight.dicts["num_hidden_layers"].c_str()); + }else if (this->weight.dicts.find("num_layers") != this->weight.dicts.end()) { + block_cnt = atoi(this->weight.dicts["num_layers"].c_str()); } if (this->weight.dicts.find("hidden_size") != this->weight.dicts.end()) { embed_dim = atoi(this->weight.dicts["hidden_size"].c_str()); @@ -77,6 +84,8 @@ namespace fastllm { } else if (modelType == "qwen") { model = (basellm *) (new QWenModel()); model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN; + } else if (modelType == "glm") { + model = (basellm*)(new GLMModel()); } else { ErrorInFastLLM("Unkown model type: " + modelType); } @@ -95,4 +104,4 @@ namespace fastllm { basellm *model = CreateModelWithType(modelType); return std::unique_ptr (model); } -} \ No newline at end of file +} diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index caf04948..2f0dfb82 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -777,4 +777,4 @@ printf("tot = %d\n", tot); void basellm::DisableAdapter() { adapterName = ""; } -} \ No newline at end of file +}