From 43f7e40470607220468c05fea4d4dc31d7b6ffd2 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sat, 10 Jun 2023 18:13:26 +0800 Subject: [PATCH] added extra endpoints for abort gen and polled streaming --- expose.cpp | 15 ++++++++------- expose.h | 3 +-- gpttype_adapter.cpp | 26 +++++++++++++++++++++----- koboldcpp.py | 41 +++++++++++++++++++++++++++++++---------- model_adapter.h | 3 ++- 5 files changed, 63 insertions(+), 25 deletions(-) diff --git a/expose.cpp b/expose.cpp index 6e4f656e146be..aa0daaaead318 100644 --- a/expose.cpp +++ b/expose.cpp @@ -20,13 +20,6 @@ #include "expose.h" #include "model_adapter.cpp" -std::string executable_path = ""; -std::string lora_filename = ""; - - -bool generation_finished; -std::vector generated_tokens; - extern "C" { @@ -225,4 +218,12 @@ extern "C" bool has_finished() { return generation_finished; } + + const char* get_pending_output() { + return gpttype_get_pending_output().c_str(); + } + + bool abort_generate() { + return gpttype_generate_abort(); + } } diff --git a/expose.h b/expose.h index bb9c5920b078d..8ec9fa42abc67 100644 --- a/expose.h +++ b/expose.h @@ -18,7 +18,6 @@ struct load_model_inputs const int clblast_info = 0; const int blasbatchsize = 512; const bool debugmode; - const bool stream_sse; const int forceversion = 0; const int gpulayers = 0; }; @@ -40,6 +39,7 @@ struct generation_inputs const float mirostat_eta; const float mirostat_tau; const char * stop_sequence[stop_token_max]; + const bool stream_sse; }; struct generation_outputs { @@ -49,6 +49,5 @@ struct generation_outputs extern std::string executable_path; extern std::string lora_filename; - extern std::vector generated_tokens; extern bool generation_finished; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 3c9862980efe2..adb30672a3167 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -28,6 +28,11 @@ #include "neox_v3.cpp" #include "mpt_v3.cpp" +//shared +std::string executable_path = ""; +std::string lora_filename = ""; +bool generation_finished; +std::vector generated_tokens; //return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) static FileFormat file_format = FileFormat::BADFORMAT; @@ -63,7 +68,6 @@ static bool useSmartContext = false; static bool unbanTokens = false; static int blasbatchsize = 512; static bool debugmode = false; -static bool stream_sse = true; static std::string modelname; static std::vector last_n_tokens; static std::vector current_context_tokens; @@ -72,6 +76,8 @@ static std::vector logits; static std::vector smartcontext; static std::vector stop_sequence; static std::vector top_picks; +static int remaining_tokens = 0; +static std::string concat_output = ""; inline bool IsNanCheck(float f) { @@ -707,6 +713,16 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } +bool gpttype_generate_abort() +{ + remaining_tokens = 0; + return true; +} + +const std::string & gpttype_get_pending_output() +{ + return concat_output; +} generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output) { @@ -735,6 +751,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o params.n_ctx = inputs.max_context_length; params.n_batch = n_batch; params.n_threads = n_threads; + bool stream_sse = inputs.stream_sse; generation_finished = false; // Set current generation status generated_tokens.clear(); // New Generation, new tokens @@ -837,11 +854,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o current_context_tokens.resize(n_past); - int remaining_tokens = params.n_predict; + remaining_tokens = params.n_predict; int stopper_unused_tokens = 0; int input_consumed = 0; std::mt19937 rng(params.seed); - std::string concat_output = ""; + concat_output = ""; bool startedsampling = false; @@ -1153,8 +1170,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o for (auto id : embd) { std::string tokenizedstr = FileFormatTokenizeID(id, file_format); - - if (stream_sse) + if(stream_sse) { generated_tokens.push_back(tokenizedstr); } diff --git a/koboldcpp.py b/koboldcpp.py index 84aaa76999f71..d0ab580ba44e1 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -45,7 +45,8 @@ class generation_inputs(ctypes.Structure): ("mirostat", ctypes.c_int), ("mirostat_tau", ctypes.c_float), ("mirostat_eta", ctypes.c_float), - ("stop_sequence", ctypes.c_char_p * stop_token_max)] + ("stop_sequence", ctypes.c_char_p * stop_token_max), + ("stream_sse", ctypes.c_bool)] class generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), @@ -139,6 +140,8 @@ def init_library(): handle.new_token.argtypes = [ctypes.c_int] handle.get_stream_count.restype = ctypes.c_int handle.has_finished.restype = ctypes.c_bool + handle.abort_generate.restype = ctypes.c_bool + handle.get_pending_output.restype = ctypes.c_char_p def load_model(model_filename): inputs = load_model_inputs() @@ -167,7 +170,7 @@ def load_model(model_filename): ret = handle.load_model(inputs) return ret -def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=120, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]): +def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=120, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[],stream_sse=False): inputs = generation_inputs() outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) inputs.prompt = prompt.encode("UTF-8") @@ -181,6 +184,7 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k= inputs.tfs = tfs inputs.rep_pen = rep_pen inputs.rep_pen_range = rep_pen_range + inputs.stream_sse = stream_sse if args.usemirostat and args.usemirostat[0]>0: inputs.mirostat = int(args.usemirostat[0]) inputs.mirostat_tau = float(args.usemirostat[1]) @@ -215,7 +219,7 @@ def utfprint(str): maxlen = 256 modelbusy = False defaultport = 5001 -KcppVersion = "1.29" +KcppVersion = "1.30" class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): sys_version = "" @@ -229,7 +233,7 @@ def __init__(self, addr, port, embedded_kailite): def __call__(self, *args, **kwargs): super().__init__(*args, **kwargs) - async def generate_text(self, newprompt, genparams, basic_api_flag): + async def generate_text(self, newprompt, genparams, basic_api_flag, stream_flag): loop = asyncio.get_event_loop() executor = ThreadPoolExecutor() @@ -247,8 +251,9 @@ def run_blocking(): rep_pen=genparams.get('rep_pen', 1.1), rep_pen_range=genparams.get('rep_pen_range', 128), seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) + stop_sequence=genparams.get('stop_sequence', []), + stream_sse=stream_flag) + else: return generate(prompt=newprompt, max_context_length=genparams.get('max_context_length', maxctx), @@ -262,8 +267,9 @@ def run_blocking(): rep_pen=genparams.get('rep_pen', 1.1), rep_pen_range=genparams.get('rep_pen_range', 128), seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) + stop_sequence=genparams.get('stop_sequence', []), + stream_sse=stream_flag) + recvtxt = await loop.run_in_executor(executor, run_blocking) @@ -300,7 +306,7 @@ async def handle_sse_stream(self): current_token += 1 - tokenStr = ctypes.string_at(token).decode('utf-8') + tokenStr = ctypes.string_at(token).decode("UTF-8","ignore") event_data = {"token": tokenStr} event_str = json.dumps(event_data) await self.send_sse_event("message", event_str) @@ -319,7 +325,7 @@ async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag if stream_flag: tasks.append(self.handle_sse_stream()) - generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag)) + generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag, stream_flag)) tasks.append(generate_task) try: @@ -395,6 +401,21 @@ def do_POST(self): kai_sse_stream_flag = False self.path = self.path.rstrip('/') + if self.path.endswith('/api/extra/abort'): + ag = handle.abort_generate() + self.send_response(200) + self.end_headers() + self.wfile.write(json.dumps({"success": ("true" if ag else "false")}).encode()) + print("Generation Aborted") + return + + if self.path.endswith('/api/extra/generate/check'): + pendtxt = handle.get_pending_output() + pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore") + self.send_response(200) + self.end_headers() + self.wfile.write(json.dumps({"results": [{"text": pendtxtStr}]}).encode()) + return if modelbusy: self.send_response(503) diff --git a/model_adapter.h b/model_adapter.h index 65dd7e282ea25..2b43f7e6fa3cd 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -56,7 +56,8 @@ enum ModelLoadResult ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format); generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output); - +bool gpttype_generate_abort(); +const std::string & gpttype_get_pending_output(); void timer_start(); double timer_check();