Skip to content

Commit

Permalink
talk.wasm : polishing + adding many AI personalities
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 22, 2022
1 parent 385236d commit 9aea96f
Show file tree
Hide file tree
Showing 4 changed files with 383 additions and 48 deletions.
2 changes: 1 addition & 1 deletion bindings/javascript/whisper.js

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions examples/talk.wasm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ In order to run this demo efficiently, you need to have the following:
- Speak phrases that are no longer than 10 seconds - this is the audio context of the AI
- The web-page uses about 1.4GB of RAM

Notice that this demo is using the smallest GPT-2 model, so the generated text responses are not always very good.
Also, the prompting strategy can likely be improved to achieve better results.

The demo is quite computationally heavy - it's not usual to run these transformer models in a browser. Typically, they
run on powerful GPU hardware. So for better experience, you do need to have a powerful computer.

Probably in the near future, mobile browsers will start to support the WASM SIMD capabilities and this will allow
to run the demo on your phone or tablet. But for now it seems to be not supported (at least on iPhone).

## Feedback

If you have any comments or ideas for improvement, please drop a comment in the following discussion:
Expand Down
42 changes: 29 additions & 13 deletions examples/talk.wasm/emscripten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,7 @@ std::atomic<bool> g_running(false);

bool g_force_speak = false;
std::string g_text_to_speak = "";
std::string g_status = "idle";
std::string g_status = "";
std::string g_status_forced = "";

std::string gpt2_gen_text(const std::string & prompt) {
Expand All @@ -997,7 +997,7 @@ std::string gpt2_gen_text(const std::string & prompt) {
std::vector<float> embd_w;

// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(g_gpt2.vocab, g_gpt2.prompt_base + prompt);
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(g_gpt2.vocab, prompt);

g_gpt2.n_predict = std::min(g_gpt2.n_predict, g_gpt2.model.hparams.n_ctx - (int) embd_inp.size());

Expand Down Expand Up @@ -1088,6 +1088,8 @@ void talk_main(size_t index) {
printf("gpt-2: model loaded in %d ms\n", (int) (t_load_us/1000));
}

printf("talk: using %d threads\n", N_THREAD);

std::vector<float> pcmf32;

auto & ctx = g_contexts[index];
Expand Down Expand Up @@ -1214,53 +1216,60 @@ void talk_main(size_t index) {
printf("whisper: number of tokens: %d, '%s'\n", (int) tokens.size(), text_heard.c_str());

std::string text_to_speak;
std::string prompt_base;

{
std::lock_guard<std::mutex> lock(g_mutex);
prompt_base = g_gpt2.prompt_base;
}

if (tokens.size() > 0) {
text_to_speak = gpt2_gen_text(text_heard + "\n");
text_to_speak = gpt2_gen_text(prompt_base + text_heard + "\n");
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));

std::lock_guard<std::mutex> lock(g_mutex);

// remove first 2 lines of base prompt
{
const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
prompt_base = prompt_base.substr(pos + 1);
}
}
{
const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
prompt_base = prompt_base.substr(pos + 1);
}
}
g_gpt2.prompt_base += text_heard + "\n" + text_to_speak + "\n";
prompt_base += text_heard + "\n" + text_to_speak + "\n";
} else {
text_to_speak = gpt2_gen_text("");
text_to_speak = gpt2_gen_text(prompt_base);
text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of("\n"));

std::lock_guard<std::mutex> lock(g_mutex);

const size_t pos = g_gpt2.prompt_base.find_first_of("\n");
const size_t pos = prompt_base.find_first_of("\n");
if (pos != std::string::npos) {
g_gpt2.prompt_base = g_gpt2.prompt_base.substr(pos + 1);
prompt_base = prompt_base.substr(pos + 1);
}
g_gpt2.prompt_base += text_to_speak + "\n";
prompt_base += text_to_speak + "\n";
}

printf("gpt-2: %s\n", text_to_speak.c_str());

//printf("========================\n");
//printf("gpt-2: prompt_base:\n'%s'\n", g_gpt2.prompt_base.c_str());
//printf("gpt-2: prompt_base:\n'%s'\n", prompt_base.c_str());
//printf("========================\n");

{
std::lock_guard<std::mutex> lock(g_mutex);
t_last = std::chrono::high_resolution_clock::now();
g_text_to_speak = text_to_speak;
g_pcmf32.clear();
g_gpt2.prompt_base = prompt_base;
}

talk_set_status("speaking ...");
Expand Down Expand Up @@ -1376,4 +1385,11 @@ EMSCRIPTEN_BINDINGS(talk) {
g_status_forced = status;
}
}));

emscripten::function("set_prompt", emscripten::optional_override([](const std::string & prompt) {
{
std::lock_guard<std::mutex> lock(g_mutex);
g_gpt2.prompt_base = prompt;
}
}));
}
Loading

0 comments on commit 9aea96f

Please sign in to comment.