Skip to content

Commit

Permalink
whisper : add loader class to allow loading from buffer and others (g…
Browse files Browse the repository at this point in the history
…gerganov#353)

* whisper : add loader to allow loading from other than file

* whisper : rename whisper_init to whisper_init_from_file

* whisper : add whisper_init_from_buffer

* android : Delete local.properties

* android : load models directly from assets

* whisper : adding <stddef.h> needed for size_t + code style

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
prsyahmi and ggerganov authored Jan 8, 2023
1 parent 52a3e0c commit 1512545
Show file tree
Hide file tree
Showing 20 changed files with 230 additions and 75 deletions.
2 changes: 1 addition & 1 deletion bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ var (
func Whisper_init(path string) *Context {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
if ctx := C.whisper_init(cPath); ctx != nil {
if ctx := C.whisper_init_from_file(cPath); ctx != nil {
return (*Context)(ctx)
} else {
return nil
Expand Down
2 changes: 1 addition & 1 deletion bindings/javascript/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct whisper_context * g_context;
EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_context == nullptr) {
g_context = whisper_init(path_model.c_str());
g_context = whisper_init_from_file(path_model.c_str());
if (g_context != nullptr) {
return true;
} else {
Expand Down
2 changes: 1 addition & 1 deletion examples/bench.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ EMSCRIPTEN_BINDINGS(bench) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
if (g_worker.joinable()) {
g_worker.join();
Expand Down
2 changes: 1 addition & 1 deletion examples/bench/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ int main(int argc, char ** argv) {

// whisper init

struct whisper_context * ctx = whisper_init(params.model.c_str());
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());

{
fprintf(stderr, "\n");
Expand Down
2 changes: 1 addition & 1 deletion examples/command.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ EMSCRIPTEN_BINDINGS(command) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
Expand Down
2 changes: 1 addition & 1 deletion examples/command/command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ int main(int argc, char ** argv) {

// whisper init

struct whisper_context * ctx = whisper_init(params.model.c_str());
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());

// print some info about the processing
{
Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ int main(int argc, char ** argv) {

// whisper init

struct whisper_context * ctx = whisper_init(params.model.c_str());
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());

if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
Expand Down
2 changes: 1 addition & 1 deletion examples/stream.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ EMSCRIPTEN_BINDINGS(stream) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
Expand Down
2 changes: 1 addition & 1 deletion examples/stream/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ int main(int argc, char ** argv) {
exit(0);
}

struct whisper_context * ctx = whisper_init(params.model.c_str());
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());

std::vector<float> pcmf32 (n_samples_30s, 0.0f);
std::vector<float> pcmf32_old(n_samples_30s, 0.0f);
Expand Down
2 changes: 1 addition & 1 deletion examples/talk.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
g_running = true;
if (g_worker.joinable()) {
Expand Down
2 changes: 1 addition & 1 deletion examples/talk/talk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ int main(int argc, char ** argv) {

// whisper init

struct whisper_context * ctx_wsp = whisper_init(params.model_wsp.c_str());
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());

// gpt init

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,22 @@ class MainScreenViewModel(private val application: Application) : ViewModel() {
private suspend fun copyAssets() = withContext(Dispatchers.IO) {
modelsPath.mkdirs()
samplesPath.mkdirs()
application.copyData("models", modelsPath, ::printMessage)
//application.copyData("models", modelsPath, ::printMessage)
application.copyData("samples", samplesPath, ::printMessage)
printMessage("All data copied to working directory.\n")
}

private suspend fun loadBaseModel() = withContext(Dispatchers.IO) {
printMessage("Loading model...\n")
val firstModel = modelsPath.listFiles()!!.first()
whisperContext = WhisperContext.createContext(firstModel.absolutePath)
printMessage("Loaded model ${firstModel.name}.\n")
val models = application.assets.list("models/")
if (models != null) {
val inputstream = application.assets.open("models/" + models[0])
whisperContext = WhisperContext.createContextFromInputStream(inputstream)
printMessage("Loaded model ${models[0]}.\n")
}

//val firstModel = modelsPath.listFiles()!!.first()
//whisperContext = WhisperContext.createContextFromFile(firstModel.absolutePath)
}

fun transcribeSample() = viewModelScope.launch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import android.os.Build
import android.util.Log
import kotlinx.coroutines.*
import java.io.File
import java.io.InputStream
import java.util.concurrent.Executors

private const val LOG_TAG = "LibWhisper"
Expand Down Expand Up @@ -39,13 +40,22 @@ class WhisperContext private constructor(private var ptr: Long) {
}

companion object {
fun createContext(filePath: String): WhisperContext {
fun createContextFromFile(filePath: String): WhisperContext {
val ptr = WhisperLib.initContext(filePath)
if (ptr == 0L) {
throw java.lang.RuntimeException("Couldn't create context with path $filePath")
}
return WhisperContext(ptr)
}

fun createContextFromInputStream(stream: InputStream): WhisperContext {
val ptr = WhisperLib.initContextFromInputStream(stream)

if (ptr == 0L) {
throw java.lang.RuntimeException("Couldn't create context from input stream")
}
return WhisperContext(ptr)
}
}
}

Expand Down Expand Up @@ -76,6 +86,7 @@ private class WhisperLib {
}

// JNI methods
external fun initContextFromInputStream(inputStream: InputStream): Long
external fun initContext(modelPath: String): Long
external fun freeContext(contextPtr: Long)
external fun fullTranscribe(contextPtr: Long, audioData: FloatArray)
Expand Down
76 changes: 75 additions & 1 deletion examples/whisper.android/app/src/main/jni/whisper/jni.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <android/log.h>
#include <stdlib.h>
#include <sys/sysinfo.h>
#include <string.h>
#include "whisper.h"

#define UNUSED(x) (void)(x)
Expand All @@ -17,13 +18,86 @@ static inline int max(int a, int b) {
return (a > b) ? a : b;
}

struct input_stream_context {
size_t offset;
JNIEnv * env;
jobject thiz;
jobject input_stream;

jmethodID mid_available;
jmethodID mid_read;
};

size_t inputStreamRead(void * ctx, void * output, size_t read_size) {
struct input_stream_context* is = (struct input_stream_context*)ctx;

jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size;

jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy);

jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy);

if (size_to_copy != read_size || size_to_copy != n_read) {
LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read);
}

jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL);
memcpy(output, byte_array_elements, size_to_copy);
(*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT);

(*is->env)->DeleteLocalRef(is->env, byte_array);

is->offset += size_to_copy;

return size_to_copy;
}
bool inputStreamEof(void * ctx) {
struct input_stream_context* is = (struct input_stream_context*)ctx;

jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available);
return result <= 0;
}
void inputStreamClose(void * ctx) {

}

JNIEXPORT jlong JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromInputStream(
JNIEnv *env, jobject thiz, jobject input_stream) {
UNUSED(thiz);

struct whisper_context *context = NULL;
struct whisper_model_loader loader = {};
struct input_stream_context inp_ctx = {};

inp_ctx.offset = 0;
inp_ctx.env = env;
inp_ctx.thiz = thiz;
inp_ctx.input_stream = input_stream;

jclass cls = (*env)->GetObjectClass(env, input_stream);
inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I");
inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I");

loader.context = &inp_ctx;
loader.read = inputStreamRead;
loader.eof = inputStreamEof;
loader.close = inputStreamClose;

loader.eof(loader.context);

context = whisper_init(&loader);
return (jlong) context;
}

JNIEXPORT jlong JNICALL
Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext(
JNIEnv *env, jobject thiz, jstring model_path_str) {
UNUSED(thiz);
struct whisper_context *context = NULL;
const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL);
context = whisper_init(model_path_chars);
context = whisper_init_from_file(model_path_chars);
(*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars);
return (jlong) context;
}
Expand Down
10 changes: 0 additions & 10 deletions examples/whisper.android/local.properties

This file was deleted.

2 changes: 1 addition & 1 deletion examples/whisper.objc/whisper.objc/ViewController.m
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ - (void)viewDidLoad {
NSLog(@"Loading model from %@", modelPath);

// create ggml context
stateInp.ctx = whisper_init([modelPath UTF8String]);
stateInp.ctx = whisper_init_from_file([modelPath UTF8String]);

// check if the model was loaded successfully
if (stateInp.ctx == NULL) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ actor WhisperContext {
}

static func createContext(path: String) throws -> WhisperContext {
let context = whisper_init(path)
let context = whisper_init_from_file(path)
if let context {
return WhisperContext(context: context)
} else {
Expand Down
2 changes: 1 addition & 1 deletion examples/whisper.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ EMSCRIPTEN_BINDINGS(whisper) {

for (size_t i = 0; i < g_contexts.size(); ++i) {
if (g_contexts[i] == nullptr) {
g_contexts[i] = whisper_init(path_model.c_str());
g_contexts[i] = whisper_init_from_file(path_model.c_str());
if (g_contexts[i] != nullptr) {
return i + 1;
} else {
Expand Down
Loading

0 comments on commit 1512545

Please sign in to comment.