Skip to content

Commit

Permalink
初步实现baichuan模型(输入输出已对齐,前处理(分词)和后处理还有点问题)
Browse files Browse the repository at this point in the history
  • Loading branch information
huangyuyang committed Jun 15, 2023
1 parent 4d99a41 commit 3849424
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 7 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ endif()
include_directories(include)

message(STATUS "CMAKE_CXX_FLAGS" ${CMAKE_CXX_FLAGS})
set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/devices/cpu/cpudevice.cpp src/executor.cpp src/chatglm.cpp src/moss.cpp src/vicuna.cpp)
set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/devices/cpu/cpudevice.cpp src/executor.cpp
src/chatglm.cpp src/moss.cpp src/vicuna.cpp src/baichuan.cpp)

if (USE_CUDA)
enable_language(CUDA)
Expand Down
5 changes: 5 additions & 0 deletions include/factoryllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
#include "moss.h"
#include "basellm.h"
#include "vicuna.h"
#include "baichuan.h"

enum LLM_TYPE {
LLM_TYPE_CHATGLM = 0,
LLM_TYPE_MOSS = 1,
LLM_TYPE_VICUNA = 2,
LLM_TYPE_BAICHUAN = 3,
};

class factoryllm {
Expand All @@ -26,6 +28,9 @@ class factoryllm {
case LLM_TYPE_VICUNA:
pLLM = new fastllm::VicunaModel;
break;
case LLM_TYPE_BAICHUAN:
pLLM = new fastllm::BaichuanModel;
break;
default:
break;
}
Expand Down
38 changes: 35 additions & 3 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ static char* modelpath = NULL;
static fastllm::basellm* chatGlm = fllm.createllm(LLM_TYPE_CHATGLM);
static fastllm::basellm* moss = fllm.createllm(LLM_TYPE_MOSS);
static fastllm::basellm* vicuna = fllm.createllm(LLM_TYPE_VICUNA);
static fastllm::basellm* baichuan = fllm.createllm(LLM_TYPE_BAICHUAN);
static int sRound = 0;
static std::string history;

std::map <std::string, int> modelDict = {
{"chatglm", 0}, {"moss", 1}, {"vicuna", 2}
{"chatglm", 0}, {"moss", 1}, {"vicuna", 2}, {"baichuan", 3}
};

struct RunConfig {
Expand Down Expand Up @@ -75,6 +76,9 @@ int initLLMConf(int model,bool isLowMem, const char* modelPath, int threads) {
}
if (modeltype == 2) {
vicuna->LoadFromFile(modelPath);
}
if (modeltype == 3) {
baichuan->LoadFromFile(modelPath);
}
return 0;
}
Expand Down Expand Up @@ -141,6 +145,29 @@ int chat(const char* prompt) {
history += (ret + "</s>");
}

if (modeltype == LLM_TYPE_BAICHUAN) {
if (history == "") {
history = "ASSISTANT: A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. ";
}

auto prompt = history + "USER: " + input + " ASSISTANT: ";

prompt = "登鹳雀楼->王之涣\n夜雨寄北->\n";
printf("prompt: %s\n", prompt.c_str());
ret = baichuan->Response(prompt, [](int index, const char* content) {
if (index == 0) {
printf("BAICHUAN: %s", content);
}
if (index > 0) {
printf("%s", content);
}
if (index == -1) {
printf("\n");
}
});
history += (ret + "</s>");
}

long len = ret.length();
return len;
}
Expand All @@ -161,6 +188,10 @@ void uninitLLM()
delete vicuna;
vicuna = NULL;
}
if (baichuan) {
delete baichuan;
baichuan = NULL;
}
}

int main(int argc, char **argv) {
Expand Down Expand Up @@ -188,11 +219,12 @@ int main(int argc, char **argv) {
}
chat(input.c_str());
}
} else if (config.model == LLM_TYPE_VICUNA) {
} else if (config.model == LLM_TYPE_VICUNA || config.model == LLM_TYPE_BAICHUAN) {
while (true) {
printf("用户: ");
std::string input;
std::getline(std::cin, input);
//std::getline(std::cin, input);
input = "登鹳雀楼->王之涣\n夜雨寄北->";
if (input == "stop") {
break;
}
Expand Down
6 changes: 3 additions & 3 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ namespace fastllm {
printf("\n");
*/
int n = Count(0) / dims.back(), m = dims.back();
for (int i = 0; i < n && i < 10; i++) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < 10 && j < m; j++) {
printf("%f ", ((float*)cpuData)[i * m + j]);
}
Expand Down Expand Up @@ -884,8 +884,8 @@ namespace fastllm {
}

void RMSNorm(const Data &input, const Data &weight, float eps, Data &output) {
curExecutor->Run("LayerNorm", {
{"input", (Data*)&input}, {"wegiht", (Data*)&weight}, {"output", &output}
curExecutor->Run("RMSNorm", {
{"input", (Data*)&input}, {"weight", (Data*)&weight}, {"output", &output}
}, {{"eps", eps}}, {});
}

Expand Down
5 changes: 5 additions & 0 deletions tools/quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "moss.h"
#include "chatglm.h"
#include "vicuna.h"
#include "baichuan.h"

struct QuantConfig {
std::string model = "chatglm"; // 模型类型, chatglm或moss
Expand Down Expand Up @@ -66,6 +67,10 @@ int main(int argc, char **argv) {
fastllm::VicunaModel vicuna;
vicuna.LoadFromFile(config.path);
vicuna.SaveLowBitModel(config.output, config.bits);
} else if (config.model == "baichuan") {
fastllm::BaichuanModel baichuan;
baichuan.LoadFromFile(config.path);
baichuan.SaveLowBitModel(config.output, config.bits);
} else {
Usage();
exit(-1);
Expand Down

0 comments on commit 3849424

Please sign in to comment.