forked from wangzhaode/mnn-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcli_demo.cpp
64 lines (61 loc) · 1.94 KB
/
cli_demo.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
//
// cli_demo.cpp
//
// Created by MNN on 2023/03/24.
// ZhaodeWang
//
#include "llm.hpp"
#include <fstream>
#include <stdlib.h>
void benchmark(Llm* llm, std::string prompt_file) {
std::cout << "prompt file is " << prompt_file << std::endl;
std::ifstream prompt_fs(prompt_file);
std::vector<std::string> prompts;
std::string prompt;
while (std::getline(prompt_fs, prompt)) {
// prompt start with '#' will be ignored
if (prompt.substr(0, 1) == "#") {
continue;
}
prompts.push_back(prompt);
}
int prompt_len = 0;
int decode_len = 0;
int64_t prefill_time = 0;
int64_t decode_time = 0;
llm->warmup();
for (int i = 0; i < prompts.size(); i++) {
llm->response(prompts[i]);
prompt_len += llm->prompt_len_;
decode_len += llm->gen_seq_len_;
prefill_time += llm->prefill_us_;
decode_time += llm->decode_us_;
llm->reset();
}
float prefill_s = prefill_time / 1e6;
float decode_s = decode_time / 1e6;
printf("\n#################################\n");
printf("prompt tokens num = %d\n", prompt_len);
printf("decode tokens num = %d\n", decode_len);
printf("prefill time = %.2f s\n", prefill_s);
printf(" decode time = %.2f s\n", decode_s);
printf("prefill speed = %.2f tok/s\n", prompt_len / prefill_s);
printf(" decode speed = %.2f tok/s\n", decode_len / decode_s);
printf("##################################\n");
}
int main(int argc, const char* argv[]) {
if (argc < 2) {
std::cout << "Usage: " << argv[0] << " model_dir <prompt.txt>" << std::endl;
return 0;
}
std::string model_dir = argv[1];
std::cout << "model path is " << model_dir << std::endl;
std::unique_ptr<Llm> llm(Llm::createLLM(model_dir));
llm->load(model_dir);
if (argc < 3) {
llm->chat();
}
std::string prompt_file = argv[2];
benchmark(llm.get(), prompt_file);
return 0;
}