Skip to content

Commit 947b944

Browse files
committed
whisper : improve handling of prompts (whisper/1981)
* whisper : improve handling of prompts * whisper : add whisper_token_count helper
1 parent 2f55f41 commit 947b944

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

examples/whisper/main.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
207207
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
208208
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
209209
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
210-
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
210+
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
211211
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
212212
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
213213
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());

examples/whisper/whisper.cpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -3721,7 +3721,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
37213721

37223722
if (n_max_tokens < (int) res.size()) {
37233723
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
3724-
return -1;
3724+
return -(int) res.size();
37253725
}
37263726

37273727
for (int i = 0; i < (int) res.size(); i++) {
@@ -3731,6 +3731,10 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to
37313731
return res.size();
37323732
}
37333733

3734+
int whisper_token_count(struct whisper_context * ctx, const char * text) {
3735+
return -whisper_tokenize(ctx, text, NULL, 0);
3736+
}
3737+
37343738
int whisper_lang_max_id() {
37353739
auto max_id = 0;
37363740
for (const auto & kv : g_lang) {
@@ -5313,7 +5317,12 @@ int whisper_full_with_state(
53135317
// initial prompt
53145318
if (!params.prompt_tokens && params.initial_prompt) {
53155319
prompt_tokens.resize(1024);
5316-
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
5320+
int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
5321+
if (n_needed < 0) {
5322+
prompt_tokens.resize(-n_needed);
5323+
n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
5324+
}
5325+
prompt_tokens.resize(n_needed);
53175326
params.prompt_tokens = prompt_tokens.data();
53185327
params.prompt_n_tokens = prompt_tokens.size();
53195328
}

examples/whisper/whisper.h

+7-1
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,18 @@ extern "C" {
337337
// Convert the provided text into tokens.
338338
// The tokens pointer must be large enough to hold the resulting tokens.
339339
// Returns the number of tokens on success, no more than n_max_tokens
340-
// Returns -1 on failure
340+
// Returns a negative number on failure - the number of tokens that would have been returned
341341
// TODO: not sure if correct
342342
WHISPER_API int whisper_tokenize(
343343
struct whisper_context * ctx,
344344
const char * text,
345345
whisper_token * tokens,
346346
int n_max_tokens);
347347

348+
// Return the number of tokens in the provided text
349+
// Equivalent to: -whisper_tokenize(ctx, text, NULL, 0)
350+
int whisper_token_count(struct whisper_context * ctx, const char * text);
351+
348352
// Largest language id (i.e. number of available languages - 1)
349353
WHISPER_API int whisper_lang_max_id();
350354

@@ -503,6 +507,8 @@ extern "C" {
503507

504508
// tokens to provide to the whisper decoder as initial prompt
505509
// these are prepended to any existing text context from a previous call
510+
// use whisper_tokenize() to convert text to tokens
511+
// maximum of whisper_n_text_ctx()/2 tokens are used (typically 224)
506512
const char * initial_prompt;
507513
const whisper_token * prompt_tokens;
508514
int prompt_n_tokens;

0 commit comments

Comments
 (0)