Skip to content

Commit

Permalink
Set default adapter source (predibase#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Feb 5, 2024
1 parent 9f51118 commit a4f0e75
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
15 changes: 9 additions & 6 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ struct Args {
/// or it can be a local directory containing the necessary files
/// as saved by `save_pretrained(...)` methods of transformers.
/// Should be compatible with the model specified in `model_id`.
#[clap(default_value = "", long, env)]
adapter_id: String,
#[clap(long, env)]
adapter_id: Option<String>,

/// The source of the model to load.
/// Can be `hub` or `s3`.
Expand All @@ -115,7 +115,7 @@ struct Args {
#[clap(default_value = "hub", long, env)]
source: String,

/// The source of the model to load.
/// The source of the static adapter to load.
/// Can be `hub` or `s3` or `pbase`
/// `hub` will load the model from the huggingface hub.
/// `s3` will load the model from the predibase S3 bucket.
Expand Down Expand Up @@ -764,9 +764,10 @@ fn download_convert_model(
download_args.push(revision.to_string())
}

if !args.adapter_id.is_empty() {
// check if option has a value
if let Some(adapter_id) = &args.adapter_id {
download_args.push("--adapter-id".to_string());
download_args.push(args.adapter_id.clone());
download_args.push(adapter_id.to_string());
}

// Copy current process env
Expand Down Expand Up @@ -877,7 +878,7 @@ fn spawn_shards(
// Start shard processes
for rank in 0..num_shard {
let model_id = args.model_id.clone();
let adapter_id = args.adapter_id.clone();
let adapter_id = args.adapter_id.clone().unwrap_or_default();
let revision = args.revision.clone();
let source: String = args.source.clone();
let adapter_source: String = args.adapter_source.clone();
Expand Down Expand Up @@ -996,6 +997,8 @@ fn spawn_webserver(
format!("{}-0", args.shard_uds_path),
"--tokenizer-name".to_string(),
args.model_id,
"--adapter-source".to_string(),
args.adapter_source,
];

// Model optional max batch total tokens
Expand Down
8 changes: 3 additions & 5 deletions router/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ use std::hash;

use crate::AdapterParameters;

use crate::server::DEFAULT_ADAPTER_SOURCE;

/// "adapter ID" for the base model. The base model does not have an adapter ID,
/// but we reason about it in the same way. This must match the base model ID
/// used in the Python server.
pub const BASE_MODEL_ADAPTER_ID: &str = "__base_model__";

/// default adapter source. One TODO is to figure out how to do this
/// from within the proto definition, or lib.rs
pub const DEFAULT_ADAPTER_SOURCE: &str = "hub";

#[derive(Debug, Clone)]
pub(crate) struct Adapter {
/// adapter parameters
Expand Down Expand Up @@ -85,7 +83,7 @@ pub(crate) fn extract_adapter_params(
}
let mut adapter_source = adapter_source.clone();
if adapter_source.is_none() {
adapter_source = Some(DEFAULT_ADAPTER_SOURCE.to_string());
adapter_source = Some(DEFAULT_ADAPTER_SOURCE.get().unwrap().to_string());
}

let adapter_parameters = adapter_parameters.clone().unwrap_or(AdapterParameters {
Expand Down
4 changes: 4 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct Args {
ngrok_authtoken: Option<String>,
#[clap(long, env)]
ngrok_edge: Option<String>,
#[clap(default_value = "hub", long, env)]
adapter_source: String,
}

fn main() -> Result<(), RouterError> {
Expand Down Expand Up @@ -108,6 +110,7 @@ fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
adapter_source,
} = args;

// Validate args
Expand Down Expand Up @@ -323,6 +326,7 @@ fn main() -> Result<(), RouterError> {
ngrok,
ngrok_authtoken,
ngrok_edge,
adapter_source,
)
.await?;
Ok(())
Expand Down
7 changes: 7 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;

static MODEL_ID: OnceCell<String> = OnceCell::new();
pub static DEFAULT_ADAPTER_SOURCE: OnceCell<String> = OnceCell::new();

/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
Expand Down Expand Up @@ -712,6 +713,7 @@ pub async fn run(
ngrok: bool,
ngrok_authtoken: Option<String>,
ngrok_edge: Option<String>,
adapter_source: String,
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
Expand Down Expand Up @@ -874,6 +876,11 @@ pub async fn run(
MODEL_ID.set(model_id.clone()).unwrap_or_else(|_| {
panic!("MODEL_ID was already set!");
});
DEFAULT_ADAPTER_SOURCE
.set(adapter_source.clone())
.unwrap_or_else(|_| {
panic!("DEFAULT_ADAPTER_SOURCE was already set!");
});

// Create router
let app = Router::new()
Expand Down

0 comments on commit a4f0e75

Please sign in to comment.