Skip to content

Commit

Permalink
add ctranslate2-bindings / tabby rust packages (TabbyML#146)
Browse files Browse the repository at this point in the history
* add ctranslate2-bindings

* add fixme for linux build

* turn off shared lib

* add tabby-cli
  • Loading branch information
wsxiaoys authored May 25, 2023
1 parent c08f5ac commit a2476af
Show file tree
Hide file tree
Showing 16 changed files with 3,326 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "crates/ctranslate2-bindings/CTranslate2"]
path = crates/ctranslate2-bindings/CTranslate2
url = https://github.com/OpenNMT/CTranslate2.git
2 changes: 2 additions & 0 deletions crates/ctranslate2-bindings/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/target
/Cargo.lock
14 changes: 14 additions & 0 deletions crates/ctranslate2-bindings/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "ctranslate2-bindings"
version = "0.1.0"
edition = "2021"

[dependencies]
cxx = "1.0"
derive_builder = "0.12.0"
tokenizers = "0.13.3"

[build-dependencies]
bindgen = "0.53.1"
cxx-build = "1.0"
cmake = "0.1"
32 changes: 32 additions & 0 deletions crates/ctranslate2-bindings/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use cmake::Config;

fn main() {
let dst = Config::new("CTranslate2")
// Default flags.
.define("CMAKE_BUILD_TYPE", "Release")
.define("BUILD_CLI", "OFF")
.define("CMAKE_INSTALL_RPATH_USE_LINK_PATH", "ON")

// FIXME(meng): support linux build.
// OSX flags.
.define("CMAKE_OSX_ARCHITECTURES", "arm64")
.define("WITH_ACCELERATE", "ON")
.define("WITH_MKL", "OFF")
.define("OPENMP_RUNTIME", "NONE")
.define("WITH_RUY", "ON")
.build();

println!("cargo:rustc-link-search=native={}", dst.join("lib").display());
println!("cargo:rustc-link-lib=ctranslate2");

// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=include/ctranslate2.h");
println!("cargo:rerun-if-changed=src/ctranslate2.cc");
println!("cargo:rerun-if-changed=src/lib.rs");

cxx_build::bridge("src/lib.rs")
.file("src/ctranslate2.cc")
.flag_if_supported("-std=c++17")
.flag_if_supported(&format!("-I{}", dst.join("include").display()))
.compile("cxxbridge");
}
1 change: 1 addition & 0 deletions crates/ctranslate2-bindings/ctranslate2
Submodule ctranslate2 added at 692fb6
19 changes: 19 additions & 0 deletions crates/ctranslate2-bindings/include/ctranslate2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include "rust/cxx.h"

namespace tabby {

class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();
virtual rust::Vec<rust::String> inference(
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature,
size_t beam_size
) const = 0;
};

std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
} // namespace
47 changes: 47 additions & 0 deletions crates/ctranslate2-bindings/src/ctranslate2.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "ctranslate2-bindings/include/ctranslate2.h"

#include "ctranslate2/translator.h"

namespace tabby {
TextInferenceEngine::~TextInferenceEngine() {}

class TextInferenceEngineImpl : public TextInferenceEngine {
public:
TextInferenceEngineImpl(const std::string& model_path) {
ctranslate2::models::ModelLoader loader(model_path);
translator_ = std::make_unique<ctranslate2::Translator>(loader);
}

~TextInferenceEngineImpl() {}

rust::Vec<rust::String> inference(
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature,
size_t beam_size
) const {
// Create options.
ctranslate2::TranslationOptions options;
options.max_decoding_length = max_decoding_length;
options.sampling_temperature = sampling_temperature;
options.beam_size = beam_size;

// Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0];
const auto& output_tokens = result.output();

// Convert to rust vec.
rust::Vec<rust::String> output;
output.reserve(output_tokens.size());
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output;
}
private:
std::unique_ptr<ctranslate2::Translator> translator_;
};

std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
return std::make_unique<TextInferenceEngineImpl>(std::string(model_path));
}
} // namespace tabby
69 changes: 69 additions & 0 deletions crates/ctranslate2-bindings/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::sync::Mutex;
use tokenizers::tokenizer::{Model, Tokenizer};

#[macro_use]
extern crate derive_builder;

#[cxx::bridge(namespace = "tabby")]
mod ffi {
unsafe extern "C++" {
include!("ctranslate2-bindings/include/ctranslate2.h");

type TextInferenceEngine;

fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
fn inference(
&self,
tokens: &[String],
max_decoding_length: usize,
sampling_temperature: f32,
beam_size: usize,
) -> Vec<String>;
}
}

#[derive(Builder, Debug)]
pub struct TextInferenceOptions {
#[builder(default = "256")]
max_decoding_length: usize,

#[builder(default = "1.0")]
sampling_temperature: f32,

#[builder(default = "2")]
beam_size: usize,
}

pub struct TextInferenceEngine {
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
tokenizer: Tokenizer,
}

unsafe impl Send for TextInferenceEngine {}
unsafe impl Sync for TextInferenceEngine {}

impl TextInferenceEngine {
pub fn create(model_path: &str, tokenizer_path: &str) -> Self where {
return TextInferenceEngine {
engine: Mutex::new(ffi::create_engine(model_path)),
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
};
}

pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let output_tokens = self.engine.lock().unwrap().inference(
encoding.get_tokens(),
options.max_decoding_length,
options.sampling_temperature,
options.beam_size,
);

let model = self.tokenizer.get_model();
let output_ids: Vec<u32> = output_tokens
.iter()
.map(|x| model.token_to_id(x).unwrap())
.collect();
self.tokenizer.decode(output_ids, true).unwrap()
}
}
1 change: 1 addition & 0 deletions crates/tabby/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/target
Loading

0 comments on commit a2476af

Please sign in to comment.