Skip to content

Commit

Permalink
added extra endpoints for abort gen and polled streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Jun 10, 2023
1 parent 5bd9cef commit 43f7e40
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 25 deletions.
15 changes: 8 additions & 7 deletions expose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@
#include "expose.h"
#include "model_adapter.cpp"

std::string executable_path = "";
std::string lora_filename = "";


bool generation_finished;
std::vector<std::string> generated_tokens;

extern "C"
{

Expand Down Expand Up @@ -225,4 +218,12 @@ extern "C"
bool has_finished() {
return generation_finished;
}

const char* get_pending_output() {
return gpttype_get_pending_output().c_str();
}

bool abort_generate() {
return gpttype_generate_abort();
}
}
3 changes: 1 addition & 2 deletions expose.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct load_model_inputs
const int clblast_info = 0;
const int blasbatchsize = 512;
const bool debugmode;
const bool stream_sse;
const int forceversion = 0;
const int gpulayers = 0;
};
Expand All @@ -40,6 +39,7 @@ struct generation_inputs
const float mirostat_eta;
const float mirostat_tau;
const char * stop_sequence[stop_token_max];
const bool stream_sse;
};
struct generation_outputs
{
Expand All @@ -49,6 +49,5 @@ struct generation_outputs

extern std::string executable_path;
extern std::string lora_filename;

extern std::vector<std::string> generated_tokens;
extern bool generation_finished;
26 changes: 21 additions & 5 deletions gpttype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
#include "neox_v3.cpp"
#include "mpt_v3.cpp"

//shared
std::string executable_path = "";
std::string lora_filename = "";
bool generation_finished;
std::vector<std::string> generated_tokens;

//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
static FileFormat file_format = FileFormat::BADFORMAT;
Expand Down Expand Up @@ -63,7 +68,6 @@ static bool useSmartContext = false;
static bool unbanTokens = false;
static int blasbatchsize = 512;
static bool debugmode = false;
static bool stream_sse = true;
static std::string modelname;
static std::vector<gpt_vocab::id> last_n_tokens;
static std::vector<gpt_vocab::id> current_context_tokens;
Expand All @@ -72,6 +76,8 @@ static std::vector<float> logits;
static std::vector<int> smartcontext;
static std::vector<std::string> stop_sequence;
static std::vector<llama_token_data> top_picks;
static int remaining_tokens = 0;
static std::string concat_output = "";

inline bool IsNanCheck(float f)
{
Expand Down Expand Up @@ -707,6 +713,16 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in

}

bool gpttype_generate_abort()
{
remaining_tokens = 0;
return true;
}

const std::string & gpttype_get_pending_output()
{
return concat_output;
}

generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output)
{
Expand Down Expand Up @@ -735,6 +751,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
params.n_ctx = inputs.max_context_length;
params.n_batch = n_batch;
params.n_threads = n_threads;
bool stream_sse = inputs.stream_sse;

generation_finished = false; // Set current generation status
generated_tokens.clear(); // New Generation, new tokens
Expand Down Expand Up @@ -837,11 +854,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o

current_context_tokens.resize(n_past);

int remaining_tokens = params.n_predict;
remaining_tokens = params.n_predict;
int stopper_unused_tokens = 0;
int input_consumed = 0;
std::mt19937 rng(params.seed);
std::string concat_output = "";
concat_output = "";

bool startedsampling = false;

Expand Down Expand Up @@ -1153,8 +1170,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
for (auto id : embd)
{
std::string tokenizedstr = FileFormatTokenizeID(id, file_format);

if (stream_sse)
if(stream_sse)
{
generated_tokens.push_back(tokenizedstr);
}
Expand Down
41 changes: 31 additions & 10 deletions koboldcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class generation_inputs(ctypes.Structure):
("mirostat", ctypes.c_int),
("mirostat_tau", ctypes.c_float),
("mirostat_eta", ctypes.c_float),
("stop_sequence", ctypes.c_char_p * stop_token_max)]
("stop_sequence", ctypes.c_char_p * stop_token_max),
("stream_sse", ctypes.c_bool)]

class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
Expand Down Expand Up @@ -139,6 +140,8 @@ def init_library():
handle.new_token.argtypes = [ctypes.c_int]
handle.get_stream_count.restype = ctypes.c_int
handle.has_finished.restype = ctypes.c_bool
handle.abort_generate.restype = ctypes.c_bool
handle.get_pending_output.restype = ctypes.c_char_p

def load_model(model_filename):
inputs = load_model_inputs()
Expand Down Expand Up @@ -167,7 +170,7 @@ def load_model(model_filename):
ret = handle.load_model(inputs)
return ret

def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=120, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]):
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=120, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[],stream_sse=False):
inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
inputs.prompt = prompt.encode("UTF-8")
Expand All @@ -181,6 +184,7 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
inputs.tfs = tfs
inputs.rep_pen = rep_pen
inputs.rep_pen_range = rep_pen_range
inputs.stream_sse = stream_sse
if args.usemirostat and args.usemirostat[0]>0:
inputs.mirostat = int(args.usemirostat[0])
inputs.mirostat_tau = float(args.usemirostat[1])
Expand Down Expand Up @@ -215,7 +219,7 @@ def utfprint(str):
maxlen = 256
modelbusy = False
defaultport = 5001
KcppVersion = "1.29"
KcppVersion = "1.30"

class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
sys_version = ""
Expand All @@ -229,7 +233,7 @@ def __init__(self, addr, port, embedded_kailite):
def __call__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def generate_text(self, newprompt, genparams, basic_api_flag):
async def generate_text(self, newprompt, genparams, basic_api_flag, stream_flag):
loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor()

Expand All @@ -247,8 +251,9 @@ def run_blocking():
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
stop_sequence=genparams.get('stop_sequence', []),
stream_sse=stream_flag)

else:
return generate(prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx),
Expand All @@ -262,8 +267,9 @@ def run_blocking():
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
stop_sequence=genparams.get('stop_sequence', []),
stream_sse=stream_flag)


recvtxt = await loop.run_in_executor(executor, run_blocking)

Expand Down Expand Up @@ -300,7 +306,7 @@ async def handle_sse_stream(self):

current_token += 1

tokenStr = ctypes.string_at(token).decode('utf-8')
tokenStr = ctypes.string_at(token).decode("UTF-8","ignore")
event_data = {"token": tokenStr}
event_str = json.dumps(event_data)
await self.send_sse_event("message", event_str)
Expand All @@ -319,7 +325,7 @@ async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag
if stream_flag:
tasks.append(self.handle_sse_stream())

generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag))
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag, stream_flag))
tasks.append(generate_task)

try:
Expand Down Expand Up @@ -395,6 +401,21 @@ def do_POST(self):
kai_sse_stream_flag = False
self.path = self.path.rstrip('/')

if self.path.endswith('/api/extra/abort'):
ag = handle.abort_generate()
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps({"success": ("true" if ag else "false")}).encode())
print("Generation Aborted")
return

if self.path.endswith('/api/extra/generate/check'):
pendtxt = handle.get_pending_output()
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps({"results": [{"text": pendtxtStr}]}).encode())
return

if modelbusy:
self.send_response(503)
Expand Down
3 changes: 2 additions & 1 deletion model_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ enum ModelLoadResult

ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in_file_format);
generation_outputs gpttype_generate(const generation_inputs inputs, generation_outputs &output);

bool gpttype_generate_abort();
const std::string & gpttype_get_pending_output();

void timer_start();
double timer_check();
Expand Down

0 comments on commit 43f7e40

Please sign in to comment.