forked from TabbyML/tabby
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ctranslate2-bindings / tabby rust packages (TabbyML#146)
* add ctranslate2-bindings * add fixme for linux build * turn off shared lib * add tabby-cli
- Loading branch information
Showing
16 changed files
with
3,326 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "crates/ctranslate2-bindings/CTranslate2"] | ||
path = crates/ctranslate2-bindings/CTranslate2 | ||
url = https://github.com/OpenNMT/CTranslate2.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
/target | ||
/Cargo.lock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[package] | ||
name = "ctranslate2-bindings" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
cxx = "1.0" | ||
derive_builder = "0.12.0" | ||
tokenizers = "0.13.3" | ||
|
||
[build-dependencies] | ||
bindgen = "0.53.1" | ||
cxx-build = "1.0" | ||
cmake = "0.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use cmake::Config; | ||
|
||
fn main() { | ||
let dst = Config::new("CTranslate2") | ||
// Default flags. | ||
.define("CMAKE_BUILD_TYPE", "Release") | ||
.define("BUILD_CLI", "OFF") | ||
.define("CMAKE_INSTALL_RPATH_USE_LINK_PATH", "ON") | ||
|
||
// FIXME(meng): support linux build. | ||
// OSX flags. | ||
.define("CMAKE_OSX_ARCHITECTURES", "arm64") | ||
.define("WITH_ACCELERATE", "ON") | ||
.define("WITH_MKL", "OFF") | ||
.define("OPENMP_RUNTIME", "NONE") | ||
.define("WITH_RUY", "ON") | ||
.build(); | ||
|
||
println!("cargo:rustc-link-search=native={}", dst.join("lib").display()); | ||
println!("cargo:rustc-link-lib=ctranslate2"); | ||
|
||
// Tell cargo to invalidate the built crate whenever the wrapper changes | ||
println!("cargo:rerun-if-changed=include/ctranslate2.h"); | ||
println!("cargo:rerun-if-changed=src/ctranslate2.cc"); | ||
println!("cargo:rerun-if-changed=src/lib.rs"); | ||
|
||
cxx_build::bridge("src/lib.rs") | ||
.file("src/ctranslate2.cc") | ||
.flag_if_supported("-std=c++17") | ||
.flag_if_supported(&format!("-I{}", dst.join("include").display())) | ||
.compile("cxxbridge"); | ||
} |
Submodule ctranslate2
added at
692fb6
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#pragma once | ||
|
||
#include "rust/cxx.h" | ||
|
||
namespace tabby { | ||
|
||
class TextInferenceEngine { | ||
public: | ||
virtual ~TextInferenceEngine(); | ||
virtual rust::Vec<rust::String> inference( | ||
rust::Slice<const rust::String> tokens, | ||
size_t max_decoding_length, | ||
float sampling_temperature, | ||
size_t beam_size | ||
) const = 0; | ||
}; | ||
|
||
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path); | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#include "ctranslate2-bindings/include/ctranslate2.h" | ||
|
||
#include "ctranslate2/translator.h" | ||
|
||
namespace tabby { | ||
TextInferenceEngine::~TextInferenceEngine() {} | ||
|
||
class TextInferenceEngineImpl : public TextInferenceEngine { | ||
public: | ||
TextInferenceEngineImpl(const std::string& model_path) { | ||
ctranslate2::models::ModelLoader loader(model_path); | ||
translator_ = std::make_unique<ctranslate2::Translator>(loader); | ||
} | ||
|
||
~TextInferenceEngineImpl() {} | ||
|
||
rust::Vec<rust::String> inference( | ||
rust::Slice<const rust::String> tokens, | ||
size_t max_decoding_length, | ||
float sampling_temperature, | ||
size_t beam_size | ||
) const { | ||
// Create options. | ||
ctranslate2::TranslationOptions options; | ||
options.max_decoding_length = max_decoding_length; | ||
options.sampling_temperature = sampling_temperature; | ||
options.beam_size = beam_size; | ||
|
||
// Inference. | ||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end()); | ||
ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0]; | ||
const auto& output_tokens = result.output(); | ||
|
||
// Convert to rust vec. | ||
rust::Vec<rust::String> output; | ||
output.reserve(output_tokens.size()); | ||
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output)); | ||
return output; | ||
} | ||
private: | ||
std::unique_ptr<ctranslate2::Translator> translator_; | ||
}; | ||
|
||
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) { | ||
return std::make_unique<TextInferenceEngineImpl>(std::string(model_path)); | ||
} | ||
} // namespace tabby |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
use std::sync::Mutex; | ||
use tokenizers::tokenizer::{Model, Tokenizer}; | ||
|
||
#[macro_use] | ||
extern crate derive_builder; | ||
|
||
#[cxx::bridge(namespace = "tabby")] | ||
mod ffi { | ||
unsafe extern "C++" { | ||
include!("ctranslate2-bindings/include/ctranslate2.h"); | ||
|
||
type TextInferenceEngine; | ||
|
||
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>; | ||
fn inference( | ||
&self, | ||
tokens: &[String], | ||
max_decoding_length: usize, | ||
sampling_temperature: f32, | ||
beam_size: usize, | ||
) -> Vec<String>; | ||
} | ||
} | ||
|
||
#[derive(Builder, Debug)] | ||
pub struct TextInferenceOptions { | ||
#[builder(default = "256")] | ||
max_decoding_length: usize, | ||
|
||
#[builder(default = "1.0")] | ||
sampling_temperature: f32, | ||
|
||
#[builder(default = "2")] | ||
beam_size: usize, | ||
} | ||
|
||
pub struct TextInferenceEngine { | ||
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>, | ||
tokenizer: Tokenizer, | ||
} | ||
|
||
unsafe impl Send for TextInferenceEngine {} | ||
unsafe impl Sync for TextInferenceEngine {} | ||
|
||
impl TextInferenceEngine { | ||
pub fn create(model_path: &str, tokenizer_path: &str) -> Self where { | ||
return TextInferenceEngine { | ||
engine: Mutex::new(ffi::create_engine(model_path)), | ||
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(), | ||
}; | ||
} | ||
|
||
pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String { | ||
let encoding = self.tokenizer.encode(prompt, true).unwrap(); | ||
let output_tokens = self.engine.lock().unwrap().inference( | ||
encoding.get_tokens(), | ||
options.max_decoding_length, | ||
options.sampling_temperature, | ||
options.beam_size, | ||
); | ||
|
||
let model = self.tokenizer.get_model(); | ||
let output_ids: Vec<u32> = output_tokens | ||
.iter() | ||
.map(|x| model.token_to_id(x).unwrap()) | ||
.collect(); | ||
self.tokenizer.decode(output_ids, true).unwrap() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
/target |
Oops, something went wrong.