Skip to content

Commit

Permalink
Fix unsharded Falcon pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Nov 30, 2023
1 parent ed3dda9 commit 5c66948
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions apps/language_models/src/pipelines/falcon_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,20 +669,17 @@ def get_tokenizer(self):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"torch_dtype": torch.float32,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model

def compile(self):
Expand Down

0 comments on commit 5c66948

Please sign in to comment.