Skip to content

Commit

Permalink
add support for colqwen
Browse files Browse the repository at this point in the history
  • Loading branch information
jpetrantoni committed Sep 27, 2024
1 parent b26cf34 commit 41284e9
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 36 deletions.
93 changes: 58 additions & 35 deletions byaldi/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import srsly
import torch
from colpali_engine.models import ColPali, ColPaliProcessor
from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor
from pdf2image import convert_from_path
from PIL import Image

Expand All @@ -32,9 +32,12 @@ def __init__(
if isinstance(pretrained_model_name_or_path, Path):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)

if "colpali" not in pretrained_model_name_or_path.lower():
if (
"colpali" not in pretrained_model_name_or_path.lower()
and "colqwen2" not in pretrained_model_name_or_path.lower()
):
raise ValueError(
"This pre-release version of Byaldi only supports ColPali for now. Incorrect model name specified."
"This pre-release version of Byaldi only supports ColPali and ColQwen2 for now. Incorrect model name specified."
)

if verbose > 0:
Expand All @@ -48,9 +51,7 @@ def __init__(
device = (
device or "cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
else "mps" if torch.backends.mps.is_available() else "cpu"
)
self.index_name = index_name
self.verbose = verbose
Expand All @@ -64,26 +65,48 @@ def __init__(
self.doc_ids_to_file_names = {}
self.doc_ids = set()

self.model = ColPali.from_pretrained(
self.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map=(
"cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
)
self.model = self.model.eval()

self.processor = cast(
ColPaliProcessor,
ColPaliProcessor.from_pretrained(
if "colpali" in pretrained_model_name_or_path.lower():
self.model = ColPali.from_pretrained(
self.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map=(
"cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
),
)
)
elif "colqwen2" in pretrained_model_name_or_path.lower():
self.model = ColQwen2.from_pretrained(
self.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map=(
"cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
)
self.model = self.model.eval()

if "colpali" in pretrained_model_name_or_path.lower():
self.processor = cast(
ColPaliProcessor,
ColPaliProcessor.from_pretrained(
self.pretrained_model_name_or_path,
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
),
)
elif "colqwen2" in pretrained_model_name_or_path.lower():
self.processor = cast(
ColQwen2Processor,
ColQwen2Processor.from_pretrained(
self.pretrained_model_name_or_path,
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
),
)

self.device = device
if device != "cuda" and not (
Expand Down Expand Up @@ -240,9 +263,9 @@ def _export_index(self):
"model_name": self.model_name,
"full_document_collection": self.full_document_collection,
"highest_doc_id": self.highest_doc_id,
"resize_stored_images": True
if self.max_image_width and self.max_image_height
else False,
"resize_stored_images": (
True if self.max_image_width and self.max_image_height else False
),
"max_image_width": self.max_image_width,
"max_image_height": self.max_image_height,
"library_version": VERSION,
Expand Down Expand Up @@ -468,9 +491,9 @@ def _process_and_add_to_index(
with tempfile.TemporaryDirectory() as path:
images = convert_from_path(
item,
thread_count=os.cpu_count()-1,
thread_count=os.cpu_count() - 1,
output_folder=path,
paths_only=True
paths_only=True,
)
for i, image_path in enumerate(images):
image = Image.open(image_path)
Expand Down Expand Up @@ -613,9 +636,11 @@ def search(
page_num=int(doc_info["page_id"]),
score=float(scores[0][embed_id]),
metadata=self.doc_id_to_metadata.get(int(doc_info["doc_id"]), {}),
base64=self.collection.get(int(embed_id))
if return_base64_results
else None,
base64=(
self.collection.get(int(embed_id))
if return_base64_results
else None
),
)
query_results.append(result)

Expand Down Expand Up @@ -655,9 +680,7 @@ def encode_image(
# Process PDF
with tempfile.TemporaryDirectory() as path:
pdf_images = convert_from_path(
item,
thread_count=os.cpu_count()-1,
output_folder=path
item, thread_count=os.cpu_count() - 1, output_folder=path
)
images.extend(pdf_images)
elif item.lower().endswith(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ maintainers = [
]

dependencies = [
"colpali-engine>=0.3.0,<0.4.0",
"colpali-engine>=0.3.1,<0.4.0",
"ml-dtypes",
"mteb==1.6.35",
"ninja",
Expand Down
23 changes: 23 additions & 0 deletions tests/test_colqwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Generator

import pytest
from colpali_engine.models import ColQwen2
from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch

from byaldi import RAGMultiModalModel
from byaldi.colpali import ColPaliModel


@pytest.fixture(scope="module")
def colqwen_rag_model() -> Generator[RAGMultiModalModel, None, None]:
device = get_torch_device("auto")
print(f"Using device: {device}")
yield RAGMultiModalModel.from_pretrained("vidore/colqwen2-v0.1", device=device)
tear_down_torch()


@pytest.mark.slow
def test_load_colqwen_from_pretrained(colqwen_rag_model: RAGMultiModalModel):
assert isinstance(colqwen_rag_model, RAGMultiModalModel)
assert isinstance(colqwen_rag_model.model, ColPaliModel)
assert isinstance(colqwen_rag_model.model.model, ColQwen2)

0 comments on commit 41284e9

Please sign in to comment.