forked from kyutai-labs/moshi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
quantize_mlx.py
38 lines (27 loc) · 1.05 KB
/
quantize_mlx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import mlx.core as mx
import mlx.nn as nn
import moshi_mlx
def main():
parser = argparse.ArgumentParser()
parser.add_argument("original_weights", type=str)
parser.add_argument("--out", type=str)
parser.add_argument("--bits", type=int, default=8)
parser.add_argument("--group-size", type=int, default=64)
args = parser.parse_args()
model_file = args.original_weights
lm_config = moshi_mlx.models.config_v0_1()
print(f"model config:\n{lm_config}")
model = moshi_mlx.models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
print(f"loading weights {model_file}")
model.load_weights(model_file, strict=True)
print("weights loaded")
nn.quantize(model, bits=args.bits, group_size=args.group_size)
print(f"saving the quantized q{args.bits} weights in {args.out}")
model.save_weights(args.out)
if __name__ == "__main__":
main()