From c89cb50aeaf43c78e82392d3467b737b7761eee0 Mon Sep 17 00:00:00 2001 From: chigkim <22120994+chigkim@users.noreply.github.com> Date: Fri, 3 Jan 2025 14:37:01 -0500 Subject: [PATCH 1/3] Chat in CLI. (#168) * Chat with CLI. * pre-commit run --all --------- Co-authored-by: Chi Kim Co-authored-by: Prince Canuma Co-authored-by: chigkim --- mlx_vlm/generate.py | 67 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 13 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 16fbde4..1527c2b 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -2,7 +2,14 @@ import codecs from .prompt_utils import apply_chat_template -from .utils import generate, get_model_path, load, load_config, load_image_processor +from .utils import ( + generate, + get_model_path, + load, + load_config, + load_image_processor, + stream_generate, +) DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit" DEFAULT_IMAGE = [] @@ -49,6 +56,12 @@ def parse_arguments(): default=DEFAULT_PROMPT, help="Message to be processed by the model.", ) + parser.add_argument( + "--system", + type=str, + default=None, + help="System message for the model.", + ) parser.add_argument( "--max-tokens", type=int, @@ -58,6 +71,7 @@ def parse_arguments(): parser.add_argument( "--temp", type=float, default=DEFAULT_TEMP, help="Temperature for sampling." ) + parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.") parser.add_argument("--verbose", action="store_false", help="Detailed output.") return parser.parse_args() @@ -93,18 +107,45 @@ def main(): else resize_shape ) - output = generate( - model, - processor, - prompt, - image=args.image, - temp=args.temp, - max_tokens=args.max_tokens, - verbose=args.verbose, - **kwargs, - ) - if not args.verbose: - print(output) + if args.chat: + chat = [] + if args.system: + chat.append({"role": "system", "content": args.system}) + while user := input("User:"): + chat.append({"role": "user", "content": user}) + prompt = apply_chat_template( + processor, config, chat, num_images=len(args.image) + ) + response = "" + print("Assistant:", end="") + for chunk in stream_generate( + model, + processor, + prompt, + args.image, + max_tokens=args.max_tokens, + temp=args.temp, + **kwargs, + ): + response += chunk.text + print(chunk.text, end="") + + chat.append({"role": "assistant", "content": response}) + print() + + else: + output = generate( + model, + processor, + prompt, + image=args.image, + temp=args.temp, + max_tokens=args.max_tokens, + verbose=args.verbose, + **kwargs, + ) + if not args.verbose: + print(output) if __name__ == "__main__": From 71a3f666d708733e8b7fb8ec7fad10084b07c3c2 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 4 Jan 2025 17:36:48 +0100 Subject: [PATCH 2/3] Fix skip-vision predicate and add utils unit test (quantize and inputs) (#172) * fix skip vision and remove unused import * add test utils (quantization) * add prepare inputs test * add resize shape and load time to smoke test * add load time, resize shape and skip showUI --- mlx_vlm/__init__.py | 9 +- mlx_vlm/tests/test_smoke.py | 19 +++- mlx_vlm/tests/test_utils.py | 191 ++++++++++++++++++++++++++++++++++++ mlx_vlm/utils.py | 6 +- 4 files changed, 218 insertions(+), 7 deletions(-) create mode 100644 mlx_vlm/tests/test_utils.py diff --git a/mlx_vlm/__init__.py b/mlx_vlm/__init__.py index 7a22a6e..2d716f6 100644 --- a/mlx_vlm/__init__.py +++ b/mlx_vlm/__init__.py @@ -1,3 +1,10 @@ from .prompt_utils import apply_chat_template, get_message_json -from .utils import convert, generate, load, prepare_inputs, process_image +from .utils import ( + convert, + generate, + load, + prepare_inputs, + process_image, + quantize_model, +) from .version import __version__ diff --git a/mlx_vlm/tests/test_smoke.py b/mlx_vlm/tests/test_smoke.py index f74d894..70877a0 100644 --- a/mlx_vlm/tests/test_smoke.py +++ b/mlx_vlm/tests/test_smoke.py @@ -4,6 +4,7 @@ import subprocess import sys import textwrap +import time import mlx.core as mx import psutil @@ -50,6 +51,7 @@ def parse_args(): parser.add_argument( "--max-tokens", type=int, default=100, help="Maximum tokens to generate" ) + parser.add_argument("--resize-shape", type=int, default=None, help="Resize shape") return parser.parse_args() @@ -73,9 +75,13 @@ def get_device_info(): def test_model_loading(model_path): try: console.print("[bold green]Loading model...") + start_time = time.time() model, processor = load(model_path, trust_remote_code=True) config = load_config(model_path, trust_remote_code=True) - console.print("[bold green]✓[/] Model loaded successfully") + end_time = time.time() + console.print( + f"[bold green]✓[/] Model loaded successfully in {end_time - start_time:.2f} seconds" + ) return model, processor, config, False except Exception as e: console.print(f"[bold red]✗[/] Failed to load model: {str(e)}") @@ -112,11 +118,13 @@ def test_generation( output = generate(**generate_args) - # Deepseek-vl2-tiny outputs are empty on VLM generation + # Deepseek-vl2-tiny and ShowUI outputs are empty on VLM generation # Paligemma outputs are empty on language-only generation # So we skip the assertion for these models - if ("deepseek-vl2-tiny" not in model_path and vision_language) or ( - "paligemma" not in model_path and not vision_language + if ( + not any(x in model_path for x in ["deepseek-vl2-tiny", "ShowUI"]) + and vision_language + or ("paligemma" not in model_path and not vision_language) ): assert isinstance(output, str) and len(output) > 0 @@ -142,6 +150,9 @@ def main(): "kwargs": { "temp": args.temperature, "max_tokens": args.max_tokens, + "resize_shape": ( + (args.resize_shape, args.resize_shape) if args.resize_shape else None + ), }, } diff --git a/mlx_vlm/tests/test_utils.py b/mlx_vlm/tests/test_utils.py new file mode 100644 index 0000000..3d398f4 --- /dev/null +++ b/mlx_vlm/tests/test_utils.py @@ -0,0 +1,191 @@ +import mlx.core as mx +import mlx.nn as nn + +from mlx_vlm.utils import ( + get_class_predicate, + prepare_inputs, + quantize_model, + sanitize_weights, + update_module_configs, +) + + +def test_sanitize_weights(): + class DummyModel: + def __init__(self, config=None): + self.config = config + + def sanitize(self, weights): + weights["sanitized"] = True + return weights + + weights = {"test": mx.array([1, 2, 3])} + # Need to instantiate DummyModel first since sanitize is an instance method + model = DummyModel() + sanitized = sanitize_weights(model, weights) + assert sanitized["sanitized"] is True + + # Test with config + config = {"test": "config"} + sanitized = sanitize_weights(DummyModel, weights, config) + assert sanitized["sanitized"] is True + + +def test_update_module_configs(): + class ModelConfig: + def __init__(self): + self.text_config = None + self.vision_config = None + + class TextConfig: + @classmethod + def from_dict(cls, d): + return "text_config" + + class VisionConfig: + @classmethod + def from_dict(cls, d): + return "vision_config" + + # Define DummyModel after the other classes + class DummyModel: + pass + + # Set the classes as attributes after DummyModel is defined + DummyModel.ModelConfig = ModelConfig + DummyModel.TextConfig = TextConfig + DummyModel.VisionConfig = VisionConfig + + config = { + "text_config": {"test": "text"}, + "vision_config": {"test": "vision"}, + } + model_config = ModelConfig() + updated = update_module_configs( + model_config, DummyModel, config, ["text", "vision"] + ) + + assert updated.text_config == "text_config" + assert updated.vision_config == "vision_config" + + +def test_get_class_predicate(): + class DummyModule: + def __init__(self, shape): + self.weight = mx.zeros(shape) + self.to_quantized = True + + # Test skip_vision=True + pred = get_class_predicate(skip_vision=True) + module = DummyModule((10, 64)) + assert pred("language_model", module) is True + assert pred("vision_model", module) is False + + # Test skip_vision=True with weights + weights = { + "language_model.scales": mx.array([1, 2, 3]), + "vision_model.scales": mx.array([4, 5, 6]), + } + pred = get_class_predicate(skip_vision=True, weights=weights) + assert pred("language_model", module) is True + assert pred("vision_model", module) is False + + # Test skip_vision=False without weights + pred = get_class_predicate(skip_vision=False) + assert pred("", module) is True + module = DummyModule((10, 63)) # Not divisible by 64 + assert pred("", module) is False + + # Test skip_vision=False with weights + weights = { + "language_model.scales": mx.array([1, 2, 3]), + "vision_model.scales": mx.array([4, 5, 6, 7]), # Not divisible by 64 + } + pred = get_class_predicate(skip_vision=False, weights=weights) + assert pred("language_model", DummyModule((10, 64))) is True + assert pred("vision_model", DummyModule((10, 63))) is False + + +def test_quantize_module(): + class DummyModule(nn.Module): + def __init__(self, shape): + super().__init__() + self.language_model = nn.Linear(shape[1], shape[1]) + self.vision_model = nn.Linear(shape[1], shape[1]) + + # Test basic quantization + module = DummyModule((10, 64)) + config = {} + _, updated_config = quantize_model( + module, config, q_group_size=64, q_bits=4, skip_vision=False + ) + + # Check quantization parameters + assert hasattr(module.language_model, "scales") + assert hasattr(module.vision_model, "scales") + assert module.language_model.scales.shape == (64, 1) + assert module.language_model.bits == 4 + assert module.language_model.group_size == 64 + assert module.vision_model.scales.shape == (64, 1) + assert module.vision_model.bits == 4 + assert module.vision_model.group_size == 64 + + # Check config is updated correctly + assert updated_config["quantization"] == {"group_size": 64, "bits": 4} + + # Test skip_vision=True + module = DummyModule((10, 64)) + config = {} + _, updated_config = quantize_model( + module, config, q_group_size=64, q_bits=4, skip_vision=True + ) + + # Vision module should not be quantized + assert hasattr(module.language_model, "scales") + assert not hasattr(module.vision_model, "scales") + + # Check config is updated correctly + assert updated_config["vision_config"]["skip_vision"] is True + + +def test_prepare_inputs(): + """Test prepare_inputs function.""" + + # Mock processor + class MockProcessor: + def __init__(self): + self.tokenizer = type( + "DummyTokenizer", (), {"pad_token": None, "eos_token": "[EOS]"} + )() + + def __call__(self, text=None, images=None, padding=None, return_tensors=None): + return { + "input_ids": mx.array([1, 2, 3]), + "pixel_values": mx.array([4, 5, 6]), + "attention_mask": mx.array([7, 8, 9]), + } + + processor = MockProcessor() + + # Test text-only input + inputs = prepare_inputs( + processor, prompts="test", images=None, image_token_index=None + ) + assert "input_ids" in inputs + assert mx.array_equal(inputs["input_ids"], mx.array([1, 2, 3])) + + # Test image-only input + image = mx.zeros((3, 224, 224)) + inputs = prepare_inputs( + processor, prompts=None, images=image, image_token_index=None + ) + assert "input_ids" in inputs + assert mx.array_equal(inputs["input_ids"], mx.array([1, 2, 3])) + + # Test both text and image + image = mx.zeros((3, 224, 224)) + inputs = prepare_inputs( + processor, prompts="test", images=image, image_token_index=None + ) + assert "input_ids" in inputs + assert mx.array_equal(inputs["input_ids"], mx.array([1, 2, 3])) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 0e262f5..e1d0c37 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -9,7 +9,7 @@ from io import BytesIO from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -226,7 +226,9 @@ def update_module_configs(model_config, model_class, config, modules): def get_class_predicate(skip_vision, weights=None): if skip_vision: - return lambda _, m: not ("vision_model" in m.name or "vision_tower" in m.name) + return lambda p, m: hasattr(m, "to_quantized") and not ( + "vision_model" in p or "vision_tower" in p + ) else: if weights: return lambda p, m: ( From 1e4579ad0c99b41c71f7309a4cb59d0c4e49fca1 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 8 Jan 2025 14:48:25 +0100 Subject: [PATCH 3/3] refactor topk to use mlx (#175) --- mlx_vlm/models/deepseek_vl_v2/language.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlx_vlm/models/deepseek_vl_v2/language.py b/mlx_vlm/models/deepseek_vl_v2/language.py index 0d904f3..774b6ab 100644 --- a/mlx_vlm/models/deepseek_vl_v2/language.py +++ b/mlx_vlm/models/deepseek_vl_v2/language.py @@ -408,9 +408,9 @@ def __call__(self, x): # Calculate group scores using top-2 sum per group scores_reshaped = scores_for_choice.reshape(bsz * seq_len, self.n_group, -1) - k = 2 - group_scores_topk = mx.sort(scores_reshaped, axis=-1)[..., -k:] - group_scores = group_scores_topk.sum(axis=-1) + + # Get top 2 scores per group + group_scores = mx.topk(scores_reshaped, 2, axis=-1).sum(axis=-1) # Get top groups k = self.n_group - self.topk_group