Skip to content

Commit

Permalink
wip submitting of llava image to backend
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Mar 10, 2024
1 parent 6990d07 commit d943c73
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 33 deletions.
2 changes: 1 addition & 1 deletion class.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _load(self, save_model: bool, initial_load: bool) -> None:
unbantokens=False, bantokens=None, usemirostat=None, forceversion=0, nommap=self.kcpp_nommap,
usemlock=False, noavx2=self.kcpp_noavx2, debugmode=self.kcpp_debugmode, skiplauncher=True, hordeconfig=None, noblas=self.kcpp_noblas,
useclblast=self.kcpp_useclblast, usecublas=self.kcpp_usecublas, usevulkan=self.kcpp_usevulkan, gpulayers=self.kcpp_gpulayers, tensor_split=self.kcpp_tensor_split, config=None,
onready='', multiuser=False, foreground=False, preloadstory=None, noshift=False, remotetunnel=False, ssl=False, benchmark=False, nocertify=False, sdconfig=None)
onready='', multiuser=False, foreground=False, preloadstory=None, noshift=False, remotetunnel=False, ssl=False, benchmark=False, nocertify=False, sdconfig=None, mmproj=None)


#koboldcpp.main(kcppargs,False) #initialize library without enabling Lite http server
Expand Down
1 change: 1 addition & 0 deletions expose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ extern "C"
std::string model = inputs.model_filename;
lora_filename = inputs.lora_filename;
lora_base = inputs.lora_base;
mmproj_filename = inputs.mmproj_filename;

int forceversion = inputs.forceversion;

Expand Down
2 changes: 2 additions & 0 deletions expose.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct load_model_inputs
const char * model_filename;
const char * lora_filename;
const char * lora_base;
const char * mmproj_filename;
const bool use_mmap;
const bool use_mlock;
const bool use_smartcontext;
Expand Down Expand Up @@ -133,6 +134,7 @@ struct sd_generation_outputs
extern std::string executable_path;
extern std::string lora_filename;
extern std::string lora_base;
extern std::string mmproj_filename;
extern std::vector<std::string> generated_tokens;
extern bool generation_finished;
extern float last_eval_time;
Expand Down
47 changes: 19 additions & 28 deletions gpttype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,24 @@ static void load_grammar(const std::string & gammarstr)
}
}

static bool kcpp_eval_image(llama_context * ctx_llama, float * img_embd, int num_img_tokens, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));

for (int i = 0; i < num_img_tokens; i += n_batch) {
int n_eval = num_img_tokens - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (img_embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "\n%s : failed to eval image\n", __func__);
return false;
}
*n_past += n_eval;
}
return true;
}

//given an old GGUF context and a new context that has some middle portion removed,
//find and remove the middle portion from the old context from the KV. Does not fast forward after this destructive action
void PurgeMissingTokens(llama_context * ctx, std::vector<int> &current_context_tokens, std::vector<int> &new_context_tokens, const int genamt, const int nctx)
Expand Down Expand Up @@ -1064,6 +1082,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in

if(mmproj_filename != "")
{
printf("\nAttempting to apply Multimodal Projector: %s\n", mmproj_filename.c_str());
clp_ctx = clip_model_load(mmproj_filename.c_str(), /*verbosity=*/ 1);
if(clp_ctx == nullptr) {
fprintf(stderr, "%s: error: failed to load mmproj model!\n", __func__);
Expand Down Expand Up @@ -1672,34 +1691,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
}
}

// for (int i = 0; i < img.image_tokens; i += n_batch)
// {
// int n_eval = img.image_tokens - i;
// if (n_eval > n_batch)
// {
// n_eval = n_batch;
// }

// const int n_embd = llama_n_embd(model);
// llama_batch batch_img = {
// n_eval,
// nullptr,
// (img.image_embedding + i * n_embd),
// nullptr,
// nullptr,
// nullptr,
// nullptr,
// slot.n_past,
// 1, 0
// };
// if (llama_decode(ctx, batch_img))
// {
// LOG_TEE("%s : failed to eval image\n", __func__);
// return false;
// }
// slot.n_past += n_eval;
// }

if(addedmemory!="")
{
TokenizeString(addedmemory, embd_inp_mem, file_format);
Expand Down
32 changes: 28 additions & 4 deletions koboldcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class load_model_inputs(ctypes.Structure):
("model_filename", ctypes.c_char_p),
("lora_filename", ctypes.c_char_p),
("lora_base", ctypes.c_char_p),
("mmproj_filename", ctypes.c_char_p),
("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool),
Expand Down Expand Up @@ -352,6 +353,8 @@ def load_model(model_filename):
inputs.use_mmap = False
if len(args.lora) > 1:
inputs.lora_base = args.lora[1].encode("UTF-8")

inputs.mmproj_filename = args.mmproj.encode("UTF-8") if args.mmproj else "".encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.use_contextshift = (0 if args.noshift else 1)
inputs.blasbatchsize = args.blasbatchsize
Expand Down Expand Up @@ -590,6 +593,7 @@ def bring_terminal_to_foreground():
friendlymodelname = "inactive"
friendlysdmodelname = "inactive"
fullsdmodelpath = "" #if empty, it's not initialized
mmprojpath = "" #if empty, it's not initialized
maxctx = 2048
maxhordectx = 2048
maxhordelen = 256
Expand Down Expand Up @@ -938,7 +942,7 @@ def noscript_webui(self):
self.wfile.write(finalhtml)

def do_GET(self):
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath
global maxctx, maxhordelen, friendlymodelname, KcppVersion, totalgens, preloaded_story, exitcounter, currentusergenkey, friendlysdmodelname, fullsdmodelpath, mmprojpath
self.path = self.path.rstrip('/')
response_body = None
content_type = 'application/json'
Expand Down Expand Up @@ -976,7 +980,9 @@ def do_GET(self):
response_body = (json.dumps({"value": maxctx}).encode())

elif self.path.endswith(('/api/extra/version')):
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode())
has_txt2img = not (friendlysdmodelname=="inactive" or fullsdmodelpath=="")
has_vision = (mmprojpath!="")
response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion,"txt2img":has_txt2img,"vision":has_vision}).encode())

elif self.path.endswith(('/api/extra/perf')):
lastp = handle.get_last_process_time()
Expand Down Expand Up @@ -1434,6 +1440,7 @@ def hide_tooltip(event):
lora_var = ctk.StringVar()
lora_base_var = ctk.StringVar()
preloadstory_var = ctk.StringVar()
mmproj_var = ctk.StringVar()

port_var = ctk.StringVar(value=defaultport)
host_var = ctk.StringVar(value="")
Expand Down Expand Up @@ -1882,7 +1889,8 @@ def togglerope(a,b,c):
makefileentry(model_tab, "Model:", "Select GGML Model File", model_var, 1, onchoosefile=on_picked_model_file,tooltiptxt="Select a GGUF or GGML model file on disk to be loaded.")
makefileentry(model_tab, "Lora:", "Select Lora File",lora_var, 3,tooltiptxt="Select an optional GGML LoRA adapter to use.\nLeave blank to skip.")
makefileentry(model_tab, "Lora Base:", "Select Lora Base File", lora_base_var, 5,tooltiptxt="Select an optional F16 GGML LoRA base file to use.\nLeave blank to skip.")
makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 7,tooltiptxt="Select an optional KoboldAI JSON savefile \nto be served on launch to any client.")
makefileentry(model_tab, "LLaVA mmproj:", "Select LLaVA mmproj File", mmproj_var, 7,tooltiptxt="Select a mmproj file to use for LLaVA.\nLeave blank to skip.")
makefileentry(model_tab, "Preloaded Story:", "Select Preloaded Story File", preloadstory_var, 9,tooltiptxt="Select an optional KoboldAI JSON savefile \nto be served on launch to any client.")

# Network Tab
network_tab = tabcontent["Network"]
Expand Down Expand Up @@ -2006,6 +2014,7 @@ def export_vars():
args.model_param = None if model_var.get() == "" else model_var.get()
args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()])
args.preloadstory = None if preloadstory_var.get() == "" else preloadstory_var.get()
args.mmproj = None if mmproj_var.get() == "" else mmproj_var.get()

args.ssl = None if (ssl_cert_var.get() == "" or ssl_key_var.get() == "") else ([ssl_cert_var.get(), ssl_key_var.get()])

Expand Down Expand Up @@ -2121,6 +2130,9 @@ def import_vars(dict):
else:
lora_var.set(dict["lora"][0])

if "mmproj" in dict and dict["mmproj"]:
mmproj_var.set(dict["mmproj"])

if "ssl" in dict and dict["ssl"]:
if len(dict["ssl"]) == 2:
ssl_cert_var.set(dict["ssl"][0])
Expand Down Expand Up @@ -2572,7 +2584,7 @@ def sanitize_string(input_string):
return sanitized_string

def main(launch_args,start_server=True):
global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath
global args, friendlymodelname, friendlysdmodelname, fullsdmodelpath, mmprojpath
args = launch_args
embedded_kailite = None
embedded_kcpp_docs = None
Expand Down Expand Up @@ -2696,6 +2708,17 @@ def main(launch_args,start_server=True):
else:
args.lora[1] = os.path.abspath(args.lora[1])

if args.mmproj and args.mmproj!="":
if not os.path.exists(args.mmproj):
exitcounter = 999
print(f"Cannot find mmproj file: {args.mmproj}")
time.sleep(3)
sys.exit(2)
else:
global mmprojpath
args.mmproj = os.path.abspath(args.mmproj)
mmprojpath = args.mmproj

if not args.blasthreads or args.blasthreads <= 0:
args.blasthreads = args.threads

Expand Down Expand Up @@ -2943,5 +2966,6 @@ def range_checker(arg: str):
parser.add_argument("--ssl", help="Allows all content to be served over SSL instead. A valid UNENCRYPTED SSL cert and key .pem files must be provided", metavar=('[cert_pem]', '[key_pem]'), nargs='+')
parser.add_argument("--nocertify", help="Allows insecure SSL connections. Use this if you have cert errors and need to bypass certificate restrictions.", action='store_true')
parser.add_argument("--sdconfig", help="Specify a stable diffusion safetensors model to enable image generation. If quick is specified, force optimal generation settings for speed.",metavar=('[sd_filename]', '[normal|quick|clamped] [threads] [quant|noquant]'), nargs='+')
parser.add_argument("--mmproj", help="Select a multimodal projector file for LLaVA.", default="")

main(parser.parse_args(),start_server=True)

0 comments on commit d943c73

Please sign in to comment.