Skip to content

Commit

Permalink
Fix whipser conversion for safetensors models (ml-explore#935)
Browse files Browse the repository at this point in the history
* fix whipser conversion for safetensor only. error in mlx lm for existing paths

* fix tests
  • Loading branch information
awni authored Aug 14, 2024
1 parent 3390544 commit 95840f3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
13 changes: 10 additions & 3 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,16 @@ def convert(
revision: Optional[str] = None,
dequantize: bool = False,
):
# Check the save path is empty
if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)

if mlx_path.exists():
raise ValueError(
f"Cannot save to the path {mlx_path} as it already exists."
" Please delete the file/directory or specify a new path to save to."
)

print("[INFO] Loading")
model_path = get_model_path(hf_path, revision=revision)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
Expand All @@ -681,9 +691,6 @@ def convert(
model = dequantize_model(model)
weights = dict(tree_flatten(model.parameters()))

if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)

del model
save_weights(mlx_path, weights, donate_weights=True)

Expand Down
1 change: 1 addition & 0 deletions llms/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_convert(self):
self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear))

# Check model weights have right type
mlx_path = os.path.join(self.test_dir, "mlx_model_bf16")
utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16")
model, _ = utils.load(mlx_path)

Expand Down
34 changes: 23 additions & 11 deletions whisper/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ def load_torch_weights_and_config(

name_or_path = snapshot_download(
repo_id=name_or_path,
allow_patterns=["*.json", "pytorch_model.bin", "*.txt"],
allow_patterns=[
"*.json",
"pytorch_model.bin",
"model.safetensors",
"*.txt",
],
)
else:
raise RuntimeError(
Expand All @@ -176,10 +181,11 @@ def load_torch_weights_and_config(
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
else:
name_or_path = Path(name_or_path)
weights = torch.load(
name_or_path / "pytorch_model.bin",
map_location="cpu",
)
pt_path = name_or_path / "pytorch_model.bin"
if pt_path.is_file():
weights = torch.load(pt_path, map_location="cpu")
else:
weights = mx.load(str(name_or_path / "model.safetensors"))
with open(name_or_path / "config.json", "r") as fp:
config = json.load(fp)
weights, config = hf_to_pt(weights, config)
Expand Down Expand Up @@ -230,7 +236,9 @@ def remap(key, value):
key = key.replace("mlp.2", "mlp2")
if "conv" in key and value.ndim == 3:
value = value.swapaxes(1, 2)
return key, mx.array(value.detach()).astype(dtype)
if isinstance(value, torch.Tensor):
value = mx.array(value.detach())
return key, value.astype(dtype)

weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
weights.pop("encoder.positional_embedding", None)
Expand Down Expand Up @@ -262,12 +270,16 @@ def upload_to_hub(path: str, name: str, torch_name_or_path: str):
## Use with mlx
```bash
git clone https://github.com/ml-explore/mlx-examples.git
cd mlx-examples/whisper/
pip install -r requirements.txt
pip install mlx-whisper
```
```python
import mlx_whisper
>> import whisper
>> whisper.transcribe("FILE_NAME")
result = mlx_whisper.transcribe(
"FILE_NAME",
path_or_hf_repo={repo_id},
)
```
"""
card = ModelCard(text)
Expand Down

0 comments on commit 95840f3

Please sign in to comment.