diff --git a/candle-examples/examples/qwen/main.rs b/candle-examples/examples/qwen/main.rs index 008bada20e..53f2f70dd1 100644 --- a/candle-examples/examples/qwen/main.rs +++ b/candle-examples/examples/qwen/main.rs @@ -144,6 +144,14 @@ enum WhichModel { W72b, #[value(name = "moe-a2.7b")] MoeA27b, + #[value(name = "2-0.5b")] + W2_0_5b, + #[value(name = "2-1.5b")] + W2_1_5b, + #[value(name = "2-7b")] + W2_7b, + #[value(name = "2-72b")] + W2_72b, } #[derive(Parser, Debug)] @@ -234,16 +242,20 @@ fn main() -> Result<()> { let model_id = match args.model_id { Some(model_id) => model_id, None => { - let size = match args.model { - WhichModel::W0_5b => "0.5B", - WhichModel::W1_8b => "1.8B", - WhichModel::W4b => "4B", - WhichModel::W7b => "7B", - WhichModel::W14b => "14B", - WhichModel::W72b => "72B", - WhichModel::MoeA27b => "MoE-A2.7B", + let (version, size) = match args.model { + WhichModel::W2_0_5b => ("2", "0.5B"), + WhichModel::W2_1_5b => ("2", "1.5B"), + WhichModel::W2_7b => ("2", "7B"), + WhichModel::W2_72b => ("2", "72B"), + WhichModel::W0_5b => ("1.5", "0.5B"), + WhichModel::W1_8b => ("1.5", "1.8B"), + WhichModel::W4b => ("1.5", "4B"), + WhichModel::W7b => ("1.5", "7B"), + WhichModel::W14b => ("1.5", "14B"), + WhichModel::W72b => ("1.5", "72B"), + WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"), }; - format!("Qwen/Qwen1.5-{size}") + format!("Qwen/Qwen{version}-{size}") } }; let repo = api.repo(Repo::with_revision( @@ -261,11 +273,15 @@ fn main() -> Result<()> { .map(std::path::PathBuf::from) .collect::>(), None => match args.model { - WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?], + WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => { + vec![repo.get("model.safetensors")?] + } WhichModel::W4b | WhichModel::W7b + | WhichModel::W2_7b | WhichModel::W14b | WhichModel::W72b + | WhichModel::W2_72b | WhichModel::MoeA27b => { candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")? } diff --git a/candle-transformers/src/models/qwen2.rs b/candle-transformers/src/models/qwen2.rs index 16ee8b01bd..3dce5c6a6a 100644 --- a/candle-transformers/src/models/qwen2.rs +++ b/candle-transformers/src/models/qwen2.rs @@ -360,8 +360,12 @@ pub struct ModelForCausalLM { impl ModelForCausalLM { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; - let base_model = Model::new(cfg, vb)?; + let base_model = Model::new(cfg, vb.clone())?; + let lm_head = if vb.contains_tensor("lm_head") { + linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))? + } else { + Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None) + }; Ok(Self { base_model, lm_head,