Skip to content

Commit

Permalink
Fix load compilation (ml-explore#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath authored Dec 27, 2023
1 parent 1f6ab6a commit 79c95b6
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions python/src/load.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
py::object file,
StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .safetensors file path string
return {load_safetensors(py::cast<std::string>(file), s)};
return load_safetensors(py::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load_safetensors(std::make_shared<PyFileReader>(file), s);
Expand All @@ -174,7 +174,7 @@ std::unordered_map<std::string, array> mlx_load_safetensor_helper(
arr.eval();
}
}
return {arr};
return arr;
}

throw std::invalid_argument(
Expand Down Expand Up @@ -217,20 +217,20 @@ std::unordered_map<std::string, array> mlx_load_npz_helper(
arr.eval();
}

return {array_dict};
return array_dict;
}

array mlx_load_npy_helper(py::object file, StreamOrDevice s) {
if (py::isinstance<py::str>(file)) { // Assume .npy file path string
return {load(py::cast<std::string>(file), s)};
return load(py::cast<std::string>(file), s);
} else if (is_istream_object(file)) {
// If we don't own the stream and it was passed to us, eval immediately
auto arr = load(std::make_shared<PyFileReader>(file), s);
{
py::gil_scoped_release gil;
arr.eval();
}
return {arr};
return arr;
}
throw std::invalid_argument(
"[load_npy] Input must be a file-like object, or string");
Expand Down

0 comments on commit 79c95b6

Please sign in to comment.