Skip to content

Commit

Permalink
refactor: cleanup chat api make it message oriented (TabbyML#497)
Browse files Browse the repository at this point in the history
* refactor: refactor into /chat/completions api

* Revert "feat: support request level stop words (TabbyML#492)"

This reverts commit 0d6840e.

* feat: adjust interface

* switch interface in tabby-playground

* move to chat/prompt, add unit test

* update interface
  • Loading branch information
wsxiaoys authored Oct 2, 2023
1 parent dfdd037 commit f05dd3a
Show file tree
Hide file tree
Showing 25 changed files with 346 additions and 202 deletions.
24 changes: 24 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions clients/tabby-playground/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ export function Chat({ id, initialMessages, className }: ChatProps) {
}
}
})
if (messages.length > 2) {
setMessages(messages.slice(messages.length - 2, messages.length))
}
return (
<>
<div className={cn('pb-[200px] pt-4 md:pt-10', className)}>
Expand Down
15 changes: 4 additions & 11 deletions clients/tabby-playground/lib/hooks/use-patch-fetch.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { type Message } from 'ai/react'
import { CohereStream, StreamingTextResponse } from 'ai'
import { StreamingTextResponse } from 'ai'
import { TabbyStream } from '@/lib/tabby-stream'
import { useEffect } from 'react'

const serverUrl =
Expand All @@ -15,25 +16,17 @@ export function usePatchFetch() {
}

const { messages } = JSON.parse(options!.body as string)
const res = await fetch(`${serverUrl}/v1beta/generate_stream`, {
const res = await fetch(`${serverUrl}/v1beta/chat/completions`, {
...options,
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
prompt: messagesToPrompt(messages)
})
})

const stream = CohereStream(res, undefined)
const stream = TabbyStream(res, undefined)
return new StreamingTextResponse(stream)
}
}, [])
}

function messagesToPrompt(messages: Message[]) {
const instruction = messages[messages.length - 1].content
const prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n${instruction}\n\n### Response:`
return prompt
}
71 changes: 71 additions & 0 deletions clients/tabby-playground/lib/tabby-stream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import {
type AIStreamCallbacksAndOptions,
createCallbacksTransformer,
createStreamDataTransformer
} from 'ai';

const utf8Decoder = new TextDecoder('utf-8');

async function processLines(
lines: string[],
controller: ReadableStreamDefaultController<string>,
) {
for (const line of lines) {
const { content } = JSON.parse(line);
controller.enqueue(content);
}
}

async function readAndProcessLines(
reader: ReadableStreamDefaultReader<Uint8Array>,
controller: ReadableStreamDefaultController<string>,
) {
let segment = '';

while (true) {
const { value: chunk, done } = await reader.read();
if (done) {
break;
}

segment += utf8Decoder.decode(chunk, { stream: true });

const linesArray = segment.split(/\r\n|\n|\r/g);
segment = linesArray.pop() || '';

await processLines(linesArray, controller);
}

if (segment) {
const linesArray = [segment];
await processLines(linesArray, controller);
}

controller.close();
}

function createParser(res: Response) {
const reader = res.body?.getReader();

return new ReadableStream<string>({
async start(controller): Promise<void> {
if (!reader) {
controller.close();
return;
}

await readAndProcessLines(reader, controller);
},
});
}

export function TabbyStream(
reader: Response,
callbacks?: AIStreamCallbacksAndOptions,
): ReadableStream {
return createParser(reader)
.pipeThrough(createCallbacksTransformer(callbacks))
.pipeThrough(
createStreamDataTransformer(callbacks?.experimental_streamData),
);
}
2 changes: 1 addition & 1 deletion crates/ctranslate2-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl TextGeneration for CTranslate2Engine {

let decoding = self
.decoding_factory
.create(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), &options.stop_words, options.static_stop_words);
.create_incremental_decoding(self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), options.stop_words);

let (sender, mut receiver) = channel::<String>(8);
let context = InferenceContext::new(sender, decoding, cancel_for_inference);
Expand Down
7 changes: 2 additions & 5 deletions crates/http-api-bindings/src/fastchat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,8 @@ impl FastChatEngine {
#[async_trait]
impl TextGeneration for FastChatEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let _stop_sequences: Vec<String> = options
.static_stop_words
.iter()
.map(|x| x.to_string())
.collect();
let _stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();

let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request {
Expand Down
2 changes: 1 addition & 1 deletion crates/http-api-bindings/src/vertex_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl VertexAIEngine {
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options
.static_stop_words
.stop_words
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.
Expand Down
11 changes: 5 additions & 6 deletions crates/llama-cpp-bindings/src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ namespace llama {
TextInferenceEngine::~TextInferenceEngine() {}

namespace {
static size_t N_BATCH = 512;
static size_t N_BATCH = 512; // # per batch inference.
static size_t N_CTX = 4096; // # max kv history.

template<class T>
using owned = std::unique_ptr<T, std::function<void(T*)>>;
Expand Down Expand Up @@ -59,7 +60,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
return std::distance(logits, std::max_element(logits, logits + n_vocab));
}

bool eval(llama_token* data, size_t size, bool reset) {
void eval(llama_token* data, size_t size, bool reset) {
if (reset) {
n_past_ = 0;
}
Expand All @@ -76,12 +77,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
auto* ctx = ctx_.get();
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
if (llama_decode(ctx, batch_)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
throw std::runtime_error("Failed to eval");
}

n_past_ += size;
return true;
}

size_t n_past_;
Expand Down Expand Up @@ -127,7 +126,7 @@ std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
}

llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 2048;
ctx_params.n_ctx = N_CTX;
ctx_params.n_batch = N_BATCH;
llama_context* ctx = llama_new_context_with_model(model, ctx_params);

Expand Down
8 changes: 5 additions & 3 deletions crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mod ffi {
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;

fn start(self: Pin<&mut TextInferenceEngine>, input_token_ids: &[u32]);
fn step(self: Pin<&mut TextInferenceEngine>) -> u32;
fn step(self: Pin<&mut TextInferenceEngine>) -> Result<u32>;
fn end(self: Pin<&mut TextInferenceEngine>);

fn eos_token(&self) -> u32;
Expand Down Expand Up @@ -75,10 +75,12 @@ impl TextGeneration for LlamaEngine {

let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.as_mut().start(input_token_ids);
let mut decoding = self.decoding_factory.create(self.tokenizer.clone(), input_token_ids, &options.stop_words, options.static_stop_words);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
let mut n_remains = options.max_decoding_length ;
while n_remains > 0 {
let next_token_id = engine.as_mut().step();
let Ok(next_token_id) = engine.as_mut().step() else {
panic!("Failed to eval");
};
if next_token_id == eos_token {
break;
}
Expand Down
38 changes: 10 additions & 28 deletions crates/tabby-inference/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,16 @@ impl Default for DecodingFactory {
}

impl DecodingFactory {
pub fn create(
pub fn create_incremental_decoding(
&self,
tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32],
stop_words: &Vec<String>,
static_stop_words: &'static Vec<&'static str>,
stop_words: &'static Vec<&'static str>,
) -> IncrementalDecoding {
IncrementalDecoding::new(
tokenizer,
vec![
self.get_static_re(static_stop_words),
self.get_re(stop_words),
]
.into_iter()
.flatten()
.collect(),
input_token_ids,
)
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
}

fn get_re(&self, stop_words: &Vec<String>) -> Option<Regex> {
if !stop_words.is_empty() {
Some(create_stop_regex(stop_words))
} else {
None
}
}

fn get_static_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
fn get_re(&self, stop_words: &'static Vec<&'static str>) -> Option<Regex> {
if stop_words.is_empty() {
None
} else {
Expand All @@ -67,8 +48,8 @@ impl DecodingFactory {
}
}

fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(x.as_ref())).collect();
fn create_stop_regex(stop_words: &[&str]) -> Regex {
let tokens: Vec<String> = stop_words.iter().map(|x| reverse(*x)).collect();

// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
Expand All @@ -78,7 +59,7 @@ fn create_stop_regex<T: AsRef<str>>(stop_words: &[T]) -> Regex {

pub struct IncrementalDecoding {
tokenizer: Arc<Tokenizer>,
stop_re: Vec<Regex>,
stop_re: Option<Regex>,

token_ids: Vec<u32>,
prefix_offset: usize,
Expand All @@ -88,7 +69,7 @@ pub struct IncrementalDecoding {
}

impl IncrementalDecoding {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Vec<Regex>, input_token_ids: &[u32]) -> Self {
pub fn new(tokenizer: Arc<Tokenizer>, stop_re: Option<Regex>, input_token_ids: &[u32]) -> Self {
let text = tokenizer
.decode(input_token_ids, /* skip_special_token = */ true)
.expect("Cannot decode token from tokenizer.");
Expand Down Expand Up @@ -129,7 +110,8 @@ impl IncrementalDecoding {

if !new_text.is_empty() {
self.reversed_text = reverse(new_text) + &self.reversed_text;
for re in &self.stop_re {

if let Some(re) = &self.stop_re {
if re.find(&self.reversed_text).is_some() {
return None;
}
Expand Down
5 changes: 1 addition & 4 deletions crates/tabby-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ pub struct TextGenerationOptions {
pub sampling_temperature: f32,

#[builder(default = "&EMPTY_STOP_WORDS")]
pub static_stop_words: &'static Vec<&'static str>,

#[builder(default = "vec![]")]
pub stop_words: Vec<String>,
pub stop_words: &'static Vec<&'static str>,
}

static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];
Expand Down
1 change: 1 addition & 0 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ http-api-bindings = { path = "../http-api-bindings" }
futures = { workspace = true }
async-stream = { workspace = true }
axum-streams = { version = "0.9.1", features = ["json"] }
minijinja = { version = "1.0.8", features = ["loader"] }

[target.'cfg(all(target_os="macos", target_arch="aarch64"))'.dependencies]
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
Expand Down
Loading

0 comments on commit f05dd3a

Please sign in to comment.