From dbcbc865b9f190df42946942066e1b56881c3cae Mon Sep 17 00:00:00 2001 From: truskovskiyk Date: Sun, 1 Dec 2024 09:46:43 -0500 Subject: [PATCH 1/8] Eval & OOM issue --- ai-search-demo/ai_search_demo/ui.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ai-search-demo/ai_search_demo/ui.py b/ai-search-demo/ai_search_demo/ui.py index 4fdfd4e..8181e97 100644 --- a/ai-search-demo/ai_search_demo/ui.py +++ b/ai-search-demo/ai_search_demo/ui.py @@ -101,8 +101,9 @@ def process_and_ingest(): with open(os.path.join(collection_dir, COLLECTION_INFO_FILENAME), "w") as json_file: json.dump(collection_info, json_file) - # Run the processing and ingestion in a separate thread - threading.Thread(target=process_and_ingest).start() + # Run the processing and ingestion in the current function with a spinner + with st.spinner('Processing and ingesting PDFs...'): + process_and_ingest() def display_all_collections(): st.header("Previously Uploaded Collections") From fddfb903ac57f01eb93650aab9376b3e6b2ca645 Mon Sep 17 00:00:00 2001 From: truskovskiyk Date: Sun, 1 Dec 2024 15:03:15 -0500 Subject: [PATCH 2/8] fix oom --- ai-search-demo/ai_search_demo/qdrant_inexing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ai-search-demo/ai_search_demo/qdrant_inexing.py b/ai-search-demo/ai_search_demo/qdrant_inexing.py index ec23b8f..cca0036 100644 --- a/ai-search-demo/ai_search_demo/qdrant_inexing.py +++ b/ai-search-demo/ai_search_demo/qdrant_inexing.py @@ -165,6 +165,9 @@ def pdfs_to_hf_dataset(path_to_folder): images, page_texts = get_pdf_images(str(pdf_file)) for page_number, (image, text) in enumerate(zip(images, page_texts)): + print(f"page_number = {page_number}") + print(f"image = {image}") + print(f"text = {text}") data.append({ "image": image, "index": global_index, From f8fd8bed7f46a72dd3cc6c80831ea9c6be976765 Mon Sep 17 00:00:00 2001 From: truskovskiyk Date: Sun, 1 Dec 2024 15:08:09 -0500 Subject: [PATCH 3/8] add profile --- ai-search-demo/ai_search_demo/qdrant_inexing.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/ai-search-demo/ai_search_demo/qdrant_inexing.py b/ai-search-demo/ai_search_demo/qdrant_inexing.py index cca0036..29025dc 100644 --- a/ai-search-demo/ai_search_demo/qdrant_inexing.py +++ b/ai-search-demo/ai_search_demo/qdrant_inexing.py @@ -142,6 +142,8 @@ def search_images_by_text(self, query_text, collection_name: str, top_k=TOP_K): return search_result +import tracemalloc + def get_pdf_images(pdf_path): reader = PdfReader(pdf_path) page_texts = [] @@ -155,6 +157,7 @@ def get_pdf_images(pdf_path): return images, page_texts def pdfs_to_hf_dataset(path_to_folder): + tracemalloc.start() # Start tracing memory allocations data = [] global_index = 0 @@ -167,7 +170,6 @@ def pdfs_to_hf_dataset(path_to_folder): for page_number, (image, text) in enumerate(zip(images, page_texts)): print(f"page_number = {page_number}") print(f"image = {image}") - print(f"text = {text}") data.append({ "image": image, "index": global_index, @@ -176,6 +178,15 @@ def pdfs_to_hf_dataset(path_to_folder): "page_text": text }) global_index += 1 + + # Print memory usage after processing each PDF + current, peak = tracemalloc.get_traced_memory() + print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") + + current, peak = tracemalloc.get_traced_memory() + print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") + tracemalloc.stop() # Stop tracing memory allocations + print("Done processing") dataset = Dataset.from_list(data) print("Done converting to dataset") From ce22862e72d6741f53d2342bd7cf5e75d596883d Mon Sep 17 00:00:00 2001 From: truskovskiyk Date: Sun, 1 Dec 2024 15:16:54 -0500 Subject: [PATCH 4/8] Deploy --- ai-search-demo/ai_search_demo/qdrant_inexing.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ai-search-demo/ai_search_demo/qdrant_inexing.py b/ai-search-demo/ai_search_demo/qdrant_inexing.py index 29025dc..f431a8f 100644 --- a/ai-search-demo/ai_search_demo/qdrant_inexing.py +++ b/ai-search-demo/ai_search_demo/qdrant_inexing.py @@ -178,13 +178,16 @@ def pdfs_to_hf_dataset(path_to_folder): "page_text": text }) global_index += 1 + # Print memory usage after processing each image + current, peak = tracemalloc.get_traced_memory() + print(f"IMAGE: Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") # Print memory usage after processing each PDF current, peak = tracemalloc.get_traced_memory() - print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") + print(f"PDF: Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") current, peak = tracemalloc.get_traced_memory() - print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") + print(f"TOTAL: Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") tracemalloc.stop() # Stop tracing memory allocations print("Done processing") From 2a88ab68d5a28287f598146cc33ed7dd630489d6 Mon Sep 17 00:00:00 2001 From: truskovskiyk Date: Sun, 1 Dec 2024 15:46:21 -0500 Subject: [PATCH 5/8] Deploy --- ai-search-demo/ai_search_demo/qdrant_inexing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ai-search-demo/ai_search_demo/qdrant_inexing.py b/ai-search-demo/ai_search_demo/qdrant_inexing.py index f431a8f..fbe04db 100644 --- a/ai-search-demo/ai_search_demo/qdrant_inexing.py +++ b/ai-search-demo/ai_search_demo/qdrant_inexing.py @@ -152,7 +152,7 @@ def get_pdf_images(pdf_path): text = page.extract_text() page_texts.append(text) # Convert to PIL images - images = convert_from_path(pdf_path) + images = convert_from_path(pdf_path, dpi=100, fmt="jpeg", jpegopt={"quality": 75, "progressive": True, "optimize": True}) assert len(images) == len(page_texts) return images, page_texts @@ -180,7 +180,6 @@ def pdfs_to_hf_dataset(path_to_folder): global_index += 1 # Print memory usage after processing each image current, peak = tracemalloc.get_traced_memory() - print(f"IMAGE: Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB") # Print memory usage after processing each PDF current, peak = tracemalloc.get_traced_memory() From 58ea277572aa0452de5450337084f333827d984e Mon Sep 17 00:00:00 2001 From: truskovskiyk Date: Sun, 1 Dec 2024 15:57:28 -0500 Subject: [PATCH 6/8] Deploy --- ai-search-demo/ai_search_demo/qdrant_inexing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ai-search-demo/ai_search_demo/qdrant_inexing.py b/ai-search-demo/ai_search_demo/qdrant_inexing.py index fbe04db..bf82020 100644 --- a/ai-search-demo/ai_search_demo/qdrant_inexing.py +++ b/ai-search-demo/ai_search_demo/qdrant_inexing.py @@ -152,7 +152,7 @@ def get_pdf_images(pdf_path): text = page.extract_text() page_texts.append(text) # Convert to PIL images - images = convert_from_path(pdf_path, dpi=100, fmt="jpeg", jpegopt={"quality": 75, "progressive": True, "optimize": True}) + images = convert_from_path(pdf_path, dpi=150, fmt="jpeg", jpegopt={"quality": 75, "progressive": True, "optimize": True}) assert len(images) == len(page_texts) return images, page_texts From 5e1b037f5ba4d46478607139cbc7f836afdaef92 Mon Sep 17 00:00:00 2001 From: truskovskiyk Date: Sun, 1 Dec 2024 19:37:34 -0500 Subject: [PATCH 7/8] Eval table --- ai-search-demo/README.md | 50 ++++- .../ai_search_demo/evaluate_synthetic_data.py | 159 ++++++++++++++++ ai-search-demo/ai_search_demo/llm_serving.py | 172 ------------------ .../ai_search_demo/llm_serving_colpali.py | 109 ----------- .../ai_search_demo/llm_serving_load_models.py | 62 ------- .../ai_search_demo/qdrant_inexing.py | 19 +- ai-search-demo/ai_search_demo/ui.py | 10 +- 7 files changed, 214 insertions(+), 367 deletions(-) create mode 100644 ai-search-demo/ai_search_demo/evaluate_synthetic_data.py delete mode 100644 ai-search-demo/ai_search_demo/llm_serving.py delete mode 100644 ai-search-demo/ai_search_demo/llm_serving_colpali.py delete mode 100644 ai-search-demo/ai_search_demo/llm_serving_load_models.py diff --git a/ai-search-demo/README.md b/ai-search-demo/README.md index 74694d7..8418788 100644 --- a/ai-search-demo/README.md +++ b/ai-search-demo/README.md @@ -5,9 +5,6 @@ This is a small demo showing how to build AI search on top of visual reach data (PDFs, Images, etc) -## Evaluation - -Before developing this we want to understand how the system performs in general, for this we are going to generate synthetic data based on SmartHR data and evaluate. This is not a real estimate, but a starting point to automate some evaluation. In real life - data from actual use should be used for this. ## Architecture @@ -38,22 +35,57 @@ graph TD; H --> Q ``` + +## Evaluation + +Before developing this we want to understand how the system performs in general, for this we are going to generate synthetic data based on SmartHR data and evaluate. This is not a real estimate, but a starting point to automate some evaluation. In real life - data from actual use should be used for this. + +### Results: + +| Dataset | Language | NDCG@1 | NDCG@3 | Recall@1 | Recall@3 | Precision@1 | Precision@4 | +|---------|----------|--------|--------|----------|----------|-------------|-------------| +| [synthetic-data-single-image-single-query](https://huggingface.co/datasets/koml/smart-hr-synthetic-data-single-image-single-query) | English | 0.85 | 0.78 | 0.90 | 0.82 | 0.88 | 0.75 | +| [synthetic-data-single-image-single-query](https://huggingface.co/datasets/koml/smart-hr-synthetic-data-single-image-single-query) | Japanese | 0.80 | 0.76 | 0.85 | 0.80 | 0.83 | 0.70 | +| [Dataset 3](#) | English | 0.82 | 0.79 | 0.87 | 0.81 | 0.85 | 0.72 | +| [Dataset 4](#) | Japanese | 0.78 | 0.74 | 0.82 | 0.78 | 0.80 | 0.68 | + + +### Process: + +Evaluation process had 2 stage: we generate sytbeht data based on existng SmartHR PDFs and evaluate our visual retravala. To run small test: + +``` +python ai_search_demo/evaluate_synthetic_data.py create-synthetic-dataset ./example_data/smart-hr ./example_data/smart-hr-dataset-test koml/smart-hr-synthetic-data-test +python ai_search_demo/evaluate_synthetic_data.py evaluate-on-synthetic-dataset koml/smart-hr-synthetic-data-test --collection-name small-eval +``` + +To run large evaluation: + +``` +python ai_search_demo/evaluate_synthetic_data.py create-synthetic-dataset ./example_data/smart-hr ./example_data/smart-hr-synthetic-data-single-image-single-query koml/smart-hr-synthetic-data-single-image-single-query --num-samples 79 +python ai_search_demo/evaluate_synthetic_data.py evaluate-on-synthetic-dataset koml/smart-hr-synthetic-data-single-image-single-query --collection-name smart-hr-synthetic-data-single-image-single-query + + +python ai_search_demo/evaluate_synthetic_data.py create-synthetic-dataset ./example_data/smart-hr ./example_data/smart-hr-synthetic-data-single-image-multiple-queries koml/smart-hr-synthetic-data-single-image-multiple-queries --num-samples 1000 +``` + + ## LLM inference Download models ``` -modal run llm_serving_load_models.py --model-name Qwen/Qwen2.5-7B-Instruct --model-revision bb46c15ee4bb56c5b63245ef50fd7637234d6f75 -modal run llm_serving_load_models.py --model-name Qwen/Qwen2-VL-7B-Instruct --model-revision 51c47430f97dd7c74aa1fa6825e68a813478097f -modal run llm_serving_load_models.py --model-name Qwen/Qwen2-VL-72B-Instruct --model-revision bb46c15ee4bb56c5b63245ef50fd7637234d6f75 -modal run llm_serving_load_models.py --model-name vidore/colqwen2-v1.0-merged --model-revision 364a4f5df97231e233e15cbbaf0b9dbe352ba92c +modal run llm-inference/llm_serving_load_models.py --model-name Qwen/Qwen2.5-7B-Instruct --model-revision bb46c15ee4bb56c5b63245ef50fd7637234d6f75 +modal run llm-inference/llm_serving_load_models.py --model-name Qwen/Qwen2-VL-7B-Instruct --model-revision 51c47430f97dd7c74aa1fa6825e68a813478097f +modal run llm-inference/llm_serving_load_models.py --model-name Qwen/Qwen2-VL-72B-Instruct --model-revision bb46c15ee4bb56c5b63245ef50fd7637234d6f75 +modal run llm-inference/llm_serving_load_models.py --model-name vidore/colqwen2-v1.0-merged --model-revision 364a4f5df97231e233e15cbbaf0b9dbe352ba92c ``` Deploy models ``` -modal deploy llm_serving.py -modal deploy llm_serving_colpali.py +modal deploy llm-inference/llm_serving.py +modal deploy llm-inference/llm_serving_colpali.py ``` ## DB diff --git a/ai-search-demo/ai_search_demo/evaluate_synthetic_data.py b/ai-search-demo/ai_search_demo/evaluate_synthetic_data.py new file mode 100644 index 0000000..ee72c5d --- /dev/null +++ b/ai-search-demo/ai_search_demo/evaluate_synthetic_data.py @@ -0,0 +1,159 @@ +import base64 +import os +from io import BytesIO +from typing import Dict, List +import random +import PIL +import typer +from colpali_engine.trainer.eval_utils import CustomRetrievalEvaluator +from datasets import Dataset, load_from_disk, load_dataset +from openai import OpenAI +from pydantic import BaseModel +from tqdm import tqdm +from rich import print +from rich.table import Table + +from ai_search_demo.qdrant_inexing import IngestClient, SearchClient, pdfs_to_hf_dataset + +# Initialize OpenAI client +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + +class DataSample(BaseModel): + japanese_query: str + english_query: str + +def generate_synthetic_question(image: PIL.Image.Image) -> DataSample: + # Convert PIL image to base64 string + buffered = BytesIO() + image.save(buffered, format="JPEG") + image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") + + prompt = """ + I am developing a visual retrieval dataset to evaluate my system. + Based on the image I provided, I want you to generate a query that this image will satisfy. + For example, if a user types this query into the search box, this image would be extremely relevant. + Generate the query in Japanese and English. + """ + # Generate synthetic question using OpenAI + chat_completion = client.beta.chat.completions.parse( + model="gpt-4o", + response_format=DataSample, + temperature=1, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, + }, + ], + } + ], + ) + + sample = chat_completion.choices[0].message.parsed + return sample + +def create_synthetic_dataset(input_folder: str, output_folder: str, hub_repo: str, num_samples: int = 10) -> None: + # Step 1: Read all PDFs and extract info + dataset = pdfs_to_hf_dataset(input_folder) + + # Step 2: Randomly sample data points + if num_samples > len(dataset): + indices = random.choices(range(len(dataset)), k=num_samples) + sampled_data = dataset.select(indices) + else: + sampled_data = dataset.shuffle().select(range(num_samples)) + + synthetic_data: List[Dict] = [] + + for index, data_point in enumerate(tqdm(sampled_data, desc="Generating synthetic questions")): + image = data_point['image'] + pdf_name = data_point['pdf_name'] + pdf_page = data_point['pdf_page'] + + # Step 3: Generate synthetic question + sample = generate_synthetic_question(image) + + # Step 4: Store samples in a new dataset + synthetic_data.append({ + "index": index, + "image": image, + "question_en": sample.english_query, + "question_jp": sample.japanese_query, + "pdf_name": pdf_name, + "pdf_page": pdf_page + }) + + # Create a new dataset from synthetic data + synthetic_dataset = Dataset.from_list(synthetic_data) + synthetic_dataset.save_to_disk(output_folder) + + # Save the dataset card + synthetic_dataset.push_to_hub(hub_repo, private=False) + +def evaluate_on_synthetic_dataset(hub_repo: str, collection_name: str = "synthetic-dataset-evaluate-full") -> None: + # Ingest collection with IngestClient + print("Load data") + synthetic_dataset = load_dataset(hub_repo)['train'] + + print("Ingest data to qdrant") + ingest_client = IngestClient() + ingest_client.ingest(collection_name, synthetic_dataset) + + run_evaluation(synthetic_dataset=synthetic_dataset, collection_name=collection_name, query_text_key='question_en') + run_evaluation(synthetic_dataset=synthetic_dataset, collection_name=collection_name, query_text_key='question_jp') + +def run_evaluation(synthetic_dataset: Dataset, collection_name: str, query_text_key: str) -> None: + search_client = SearchClient() + relevant_docs: Dict[str, Dict[str, int]] = {} + results: Dict[str, Dict[str, float]] = {} + + for x in synthetic_dataset: + query_id = f"{x['pdf_name']}_{x['pdf_page']}" + relevant_docs[query_id] = {query_id: 1} # The most relevant document is itself + + response = search_client.search_images_by_text(query_text=x['question_en'], collection_name=collection_name, top_k=10) + + results[query_id] = {} + for point in response.points: + doc_id = f"{point.payload['pdf_name']}_{point.payload['pdf_page']}" + results[query_id][doc_id] = point.score + + mteb_evaluator = CustomRetrievalEvaluator() + + ndcg, _map, recall, precision, naucs = mteb_evaluator.evaluate( + relevant_docs, + results, + mteb_evaluator.k_values, + ) + + mrr = mteb_evaluator.evaluate_custom(relevant_docs, results, mteb_evaluator.k_values, "mrr") + + scores = { + **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, + **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, + **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, + **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, + **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()}, + **{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()}, + } + + # Use rich to print scores beautifully + table = Table(title=f"Evaluation Scores for {query_text_key}") + table.add_column("Metric", justify="right", style="cyan", no_wrap=True) + table.add_column("Score", style="magenta") + + for metric, score in scores.items(): + table.add_row(metric, f"{score:.4f}") + + print(table) + + +if __name__ == '__main__': + app = typer.Typer() + app.command()(create_synthetic_dataset) + app.command()(evaluate_on_synthetic_dataset) + app() \ No newline at end of file diff --git a/ai-search-demo/ai_search_demo/llm_serving.py b/ai-search-demo/ai_search_demo/llm_serving.py deleted file mode 100644 index 833ebfd..0000000 --- a/ai-search-demo/ai_search_demo/llm_serving.py +++ /dev/null @@ -1,172 +0,0 @@ -import modal - -vllm_image = modal.Image.debian_slim(python_version="3.12").pip_install( - "vllm==0.6.3post1", "fastapi[standard]==0.115.4" -) - - -MODELS_DIR = "/models" -MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct" - - -try: - volume = modal.Volume.lookup("models", create_if_missing=False) -except modal.exception.NotFoundError: - raise Exception("Download models first with modal run download_llama.py") - - - -app = modal.App("qwen2-vllm") - -N_GPU = 1 # tip: for best results, first upgrade to more powerful GPUs, and only then increase GPU count -TOKEN = "super-secret-token" # auth token. for production use, replace with a modal.Secret - -MINUTES = 60 # seconds -HOURS = 60 * MINUTES - - -@app.function( - image=vllm_image, - gpu=modal.gpu.H100(count=N_GPU), - container_idle_timeout=5 * MINUTES, - timeout=24 * HOURS, - allow_concurrent_inputs=1000, - volumes={MODELS_DIR: volume}, -) -@modal.asgi_app() -def serve(): - import fastapi - import vllm.entrypoints.openai.api_server as api_server - from vllm.engine.arg_utils import AsyncEngineArgs - from vllm.engine.async_llm_engine import AsyncLLMEngine - from vllm.entrypoints.logger import RequestLogger - from vllm.entrypoints.openai.serving_chat import OpenAIServingChat - from vllm.entrypoints.openai.serving_completion import ( - OpenAIServingCompletion, - ) - from vllm.entrypoints.openai.serving_engine import BaseModelPath - from vllm.usage.usage_lib import UsageContext - - volume.reload() # ensure we have the latest version of the weights - - # create a fastAPI app that uses vLLM's OpenAI-compatible router - web_app = fastapi.FastAPI( - title=f"OpenAI-compatible {MODEL_NAME} server", - description="Run an OpenAI-compatible LLM server with vLLM on modal.com 🚀", - version="0.0.1", - docs_url="/docs", - ) - - # security: CORS middleware for external requests - http_bearer = fastapi.security.HTTPBearer( - scheme_name="Bearer Token", - description="See code for authentication details.", - ) - web_app.add_middleware( - fastapi.middleware.cors.CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # security: inject dependency on authed routes - async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): - if api_key.credentials != TOKEN: - raise fastapi.HTTPException( - status_code=fastapi.status.HTTP_401_UNAUTHORIZED, - detail="Invalid authentication credentials", - ) - return {"username": "authenticated_user"} - - router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) - - # wrap vllm's router in auth router - router.include_router(api_server.router) - # add authed vllm to our fastAPI app - web_app.include_router(router) - - engine_args = AsyncEngineArgs( - model=MODELS_DIR + "/" + MODEL_NAME, - tensor_parallel_size=N_GPU, - gpu_memory_utilization=0.90, - max_model_len=8096, - enforce_eager=False, # capture the graph for faster inference, but slower cold starts (30s > 20s) - ) - - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER - ) - - model_config = get_model_config(engine) - - request_logger = RequestLogger(max_log_len=2048) - - base_model_paths = [ - BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) - ] - - api_server.chat = lambda s: OpenAIServingChat( - engine, - model_config=model_config, - base_model_paths=base_model_paths, - chat_template=None, - response_role="assistant", - lora_modules=[], - prompt_adapters=[], - request_logger=request_logger, - ) - api_server.completion = lambda s: OpenAIServingCompletion( - engine, - model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=[], - prompt_adapters=[], - request_logger=request_logger, - ) - - return web_app - - -def get_model_config(engine): - import asyncio - - try: # adapted from vLLM source -- https://github.com/vllm-project/vllm/blob/507ef787d85dec24490069ffceacbd6b161f4f72/vllm/entrypoints/openai/api_server.py#L235C1-L247C1 - event_loop = asyncio.get_running_loop() - except RuntimeError: - event_loop = None - - if event_loop is not None and event_loop.is_running(): - # If the current is instanced by Ray Serve, - # there is already a running event loop - model_config = event_loop.run_until_complete(engine.get_model_config()) - else: - # When using single vLLM without engine_use_ray - model_config = asyncio.run(engine.get_model_config()) - - return model_config - - -def client(): - - model = "Qwen2-VL-7B-Instruct" - img_url = "https://www.google.com/images/branding/googlelogo/1x/googlelogo_color_272x92dp.png" - prompt = "Translate text on image to Japanese" - - response = client.chat.completions.create( - model=model, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": {"url": f"{img_url}"}, - }, - ], - } - ], - ) - - print(response.choices[0].message) \ No newline at end of file diff --git a/ai-search-demo/ai_search_demo/llm_serving_colpali.py b/ai-search-demo/ai_search_demo/llm_serving_colpali.py deleted file mode 100644 index dac8e7c..0000000 --- a/ai-search-demo/ai_search_demo/llm_serving_colpali.py +++ /dev/null @@ -1,109 +0,0 @@ -import modal - -vllm_image = modal.Image.debian_slim(python_version="3.12").pip_install( - "vllm==0.6.3post1", "fastapi[standard]==0.115.4" -).pip_install("colpali-engine") - -MODELS_DIR = "/models" -MODEL_NAME = "vidore/colqwen2-v1.0-merged" - - -try: - volume = modal.Volume.lookup("models", create_if_missing=False) -except modal.exception.NotFoundError: - raise Exception("Download models first with modal run download_llama.py") - - - -app = modal.App("colpali-embedding") - -N_GPU = 1 -TOKEN = "super-secret-token" - -MINUTES = 60 # seconds -HOURS = 60 * MINUTES - - -@app.function( - image=vllm_image, - gpu=modal.gpu.H100(count=N_GPU), - container_idle_timeout=5 * MINUTES, - timeout=24 * HOURS, - allow_concurrent_inputs=1000, - volumes={MODELS_DIR: volume}, -) -@modal.asgi_app() -def serve(): - import fastapi - from colpali_engine.models import ColQwen2, ColQwen2Processor - from fastapi.middleware.cors import CORSMiddleware - from fastapi.security import HTTPBearer - from fastapi import HTTPException, Security, APIRouter, Depends - import torch - - volume.reload() # ensure we have the latest version of the weights - - # create a fastAPI app for serving the ColPali model - web_app = fastapi.FastAPI( - title=f"ColPali {MODEL_NAME} server", - description="Run a ColPali model server with fastAPI on modal.com 🚀", - version="0.0.1", - docs_url="/docs", - ) - - # security: CORS middleware for external requests - http_bearer = HTTPBearer( - scheme_name="Bearer Token", - description="See code for authentication details.", - ) - web_app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - # security: inject dependency on authed routes - async def is_authenticated(api_key: str = Security(http_bearer)): - if api_key.credentials != TOKEN: - raise HTTPException( - status_code=fastapi.status.HTTP_401_UNAUTHORIZED, - detail="Invalid authentication credentials", - ) - return {"username": "authenticated_user"} - - router = APIRouter(dependencies=[Depends(is_authenticated)]) - - # Define the model and processor - model_name = "/models/vidore/colqwen2-v1.0-merged" - colpali_model = ColQwen2.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, - device_map="cuda:0", - ).eval() - - colpali_processor = ColQwen2Processor.from_pretrained(model_name) - - # Define a simple endpoint to process text queries - @router.post("/query") - async def query_model(query_text: str): - with torch.no_grad(): - batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device) - query_embedding = colpali_model(**batch_query) - return {"embedding": query_embedding[0].cpu().float().numpy().tolist()} - - @router.post("/process_image") - async def process_image(image: fastapi.UploadFile): - from PIL import Image - pil_image = Image.open(image.file) - with torch.no_grad(): - batch_image = colpali_processor.process_images([pil_image]).to(colpali_model.device) - image_embedding = colpali_model(**batch_image) - return {"embedding": image_embedding[0].cpu().float().numpy().tolist()} - - - # add authed router to our fastAPI app - web_app.include_router(router) - - return web_app diff --git a/ai-search-demo/ai_search_demo/llm_serving_load_models.py b/ai-search-demo/ai_search_demo/llm_serving_load_models.py deleted file mode 100644 index 84c610b..0000000 --- a/ai-search-demo/ai_search_demo/llm_serving_load_models.py +++ /dev/null @@ -1,62 +0,0 @@ -import modal - -MODELS_DIR = "/models" - -DEFAULT_NAME = "Qwen/Qwen2.5-7B-Instruct" -DEFAULT_REVISION = "bb46c15ee4bb56c5b63245ef50fd7637234d6f75" - -# Qwen/Qwen2-VL-7B-Instruct - - -volume = modal.Volume.from_name("models", create_if_missing=True) - -image = ( - modal.Image.debian_slim(python_version="3.10") - .pip_install( - [ - "huggingface_hub", - "hf-transfer", - ] - ) - .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) -) - - -MINUTES = 60 -HOURS = 60 * MINUTES - - -app = modal.App( - image=image, secrets=[modal.Secret.from_name("huggingface-secret")] -) - - -@app.function(volumes={MODELS_DIR: volume}, timeout=4 * HOURS) -def download_model(model_name, model_revision, force_download=False): - from huggingface_hub import snapshot_download - - volume.reload() - - snapshot_download( - model_name, - local_dir=MODELS_DIR + "/" + model_name, - ignore_patterns=[ - "*.pt", - "*.bin", - "*.pth", - "original/*", - ], # Ensure safetensors - revision=model_revision, - force_download=force_download, - ) - - volume.commit() - - -@app.local_entrypoint() -def main( - model_name: str = DEFAULT_NAME, - model_revision: str = DEFAULT_REVISION, - force_download: bool = False, -): - download_model.remote(model_name, model_revision, force_download) \ No newline at end of file diff --git a/ai-search-demo/ai_search_demo/qdrant_inexing.py b/ai-search-demo/ai_search_demo/qdrant_inexing.py index bf82020..3cd15ee 100644 --- a/ai-search-demo/ai_search_demo/qdrant_inexing.py +++ b/ai-search-demo/ai_search_demo/qdrant_inexing.py @@ -1,12 +1,13 @@ -from qdrant_client import QdrantClient -from qdrant_client.http import models -from tqdm import tqdm -from pdf2image import convert_from_path -from pypdf import PdfReader import io -import requests +import tracemalloc from pathlib import Path + +import requests from datasets import Dataset +from pdf2image import convert_from_path +from pypdf import PdfReader +from qdrant_client import QdrantClient +from qdrant_client.http import models from tqdm import tqdm # Constants @@ -142,7 +143,7 @@ def search_images_by_text(self, query_text, collection_name: str, top_k=TOP_K): return search_result -import tracemalloc + def get_pdf_images(pdf_path): reader = PdfReader(pdf_path) @@ -152,7 +153,7 @@ def get_pdf_images(pdf_path): text = page.extract_text() page_texts.append(text) # Convert to PIL images - images = convert_from_path(pdf_path, dpi=150, fmt="jpeg", jpegopt={"quality": 75, "progressive": True, "optimize": True}) + images = convert_from_path(pdf_path, dpi=150, fmt="jpeg", jpegopt={"quality": 100, "progressive": True, "optimize": True}) assert len(images) == len(page_texts) return images, page_texts @@ -168,8 +169,6 @@ def pdfs_to_hf_dataset(path_to_folder): images, page_texts = get_pdf_images(str(pdf_file)) for page_number, (image, text) in enumerate(zip(images, page_texts)): - print(f"page_number = {page_number}") - print(f"image = {image}") data.append({ "image": image, "index": global_index, diff --git a/ai-search-demo/ai_search_demo/ui.py b/ai-search-demo/ai_search_demo/ui.py index 8181e97..23aff9d 100644 --- a/ai-search-demo/ai_search_demo/ui.py +++ b/ai-search-demo/ai_search_demo/ui.py @@ -1,12 +1,12 @@ -import streamlit as st -import os import json +import os + import pandas as pd -import threading -from ai_search_demo.qdrant_inexing import IngestClient, pdfs_to_hf_dataset -from ai_search_demo.qdrant_inexing import SearchClient +import streamlit as st from datasets import load_from_disk +from ai_search_demo.qdrant_inexing import IngestClient, SearchClient, pdfs_to_hf_dataset + STORAGE_DIR = "storage" COLLECTION_INFO_FILENAME = "collection_info.json" HF_DATASET_DIRNAME = "hf_dataset" From 4940196ecb399f0e3c319e95f6c221e98fe4bc72 Mon Sep 17 00:00:00 2001 From: truskovskiyk Date: Sun, 1 Dec 2024 19:54:33 -0500 Subject: [PATCH 8/8] Eval table --- ai-search-demo/README.md | 12 +- .../ai_search_demo/evaluate_synthetic_data.py | 15 +- ai-search-demo/llm-inference/llm_serving.py | 172 ++++++++++++++++++ .../llm-inference/llm_serving_colpali.py | 109 +++++++++++ .../llm-inference/llm_serving_load_models.py | 62 +++++++ 5 files changed, 356 insertions(+), 14 deletions(-) create mode 100644 ai-search-demo/llm-inference/llm_serving.py create mode 100644 ai-search-demo/llm-inference/llm_serving_colpali.py create mode 100644 ai-search-demo/llm-inference/llm_serving_load_models.py diff --git a/ai-search-demo/README.md b/ai-search-demo/README.md index 8418788..99df6d5 100644 --- a/ai-search-demo/README.md +++ b/ai-search-demo/README.md @@ -8,7 +8,7 @@ This is a small demo showing how to build AI search on top of visual reach data ## Architecture -Hitht leve diagram of the systen +High-level diagram of the system ```mermaid @@ -42,17 +42,15 @@ Before developing this we want to understand how the system performs in general, ### Results: -| Dataset | Language | NDCG@1 | NDCG@3 | Recall@1 | Recall@3 | Precision@1 | Precision@4 | +| Dataset | Language | NDCG@1 | NDCG@5 | Recall@1 | Recall@5 | Precision@1 | Precision@5 | |---------|----------|--------|--------|----------|----------|-------------|-------------| -| [synthetic-data-single-image-single-query](https://huggingface.co/datasets/koml/smart-hr-synthetic-data-single-image-single-query) | English | 0.85 | 0.78 | 0.90 | 0.82 | 0.88 | 0.75 | -| [synthetic-data-single-image-single-query](https://huggingface.co/datasets/koml/smart-hr-synthetic-data-single-image-single-query) | Japanese | 0.80 | 0.76 | 0.85 | 0.80 | 0.83 | 0.70 | -| [Dataset 3](#) | English | 0.82 | 0.79 | 0.87 | 0.81 | 0.85 | 0.72 | -| [Dataset 4](#) | Japanese | 0.78 | 0.74 | 0.82 | 0.78 | 0.80 | 0.68 | +| [synthetic-data-single-image-single-query](https://huggingface.co/datasets/koml/smart-hr-synthetic-data-single-image-single-query) | English | 0.5190 | 0.7021 | 0.5190 | 0.8354 | 0.5190 | 0.1671 | +| [synthetic-data-single-image-single-query](https://huggingface.co/datasets/koml/smart-hr-synthetic-data-single-image-single-query) | Japanese | 0.7215 | 0.8342 | 0.7215 | 0.9241 | 0.7215 | 0.1848 | ### Process: -Evaluation process had 2 stage: we generate sytbeht data based on existng SmartHR PDFs and evaluate our visual retravala. To run small test: +The evaluation process had two stages: we generated synthetic data based on existing SmartHR PDFs and evaluated our visual retrieval. To run a small test: ``` python ai_search_demo/evaluate_synthetic_data.py create-synthetic-dataset ./example_data/smart-hr ./example_data/smart-hr-dataset-test koml/smart-hr-synthetic-data-test diff --git a/ai-search-demo/ai_search_demo/evaluate_synthetic_data.py b/ai-search-demo/ai_search_demo/evaluate_synthetic_data.py index ee72c5d..93689ea 100644 --- a/ai-search-demo/ai_search_demo/evaluate_synthetic_data.py +++ b/ai-search-demo/ai_search_demo/evaluate_synthetic_data.py @@ -1,19 +1,20 @@ import base64 import os +import random from io import BytesIO from typing import Dict, List -import random + import PIL import typer from colpali_engine.trainer.eval_utils import CustomRetrievalEvaluator -from datasets import Dataset, load_from_disk, load_dataset +from datasets import Dataset, load_dataset from openai import OpenAI from pydantic import BaseModel -from tqdm import tqdm from rich import print from rich.table import Table +from tqdm import tqdm -from ai_search_demo.qdrant_inexing import IngestClient, SearchClient, pdfs_to_hf_dataset +from ai_search_demo.qdrant_inexing import SearchClient, pdfs_to_hf_dataset # Initialize OpenAI client client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) @@ -100,8 +101,8 @@ def evaluate_on_synthetic_dataset(hub_repo: str, collection_name: str = "synthet synthetic_dataset = load_dataset(hub_repo)['train'] print("Ingest data to qdrant") - ingest_client = IngestClient() - ingest_client.ingest(collection_name, synthetic_dataset) + # ingest_client = IngestClient() + # ingest_client.ingest(collection_name, synthetic_dataset) run_evaluation(synthetic_dataset=synthetic_dataset, collection_name=collection_name, query_text_key='question_en') run_evaluation(synthetic_dataset=synthetic_dataset, collection_name=collection_name, query_text_key='question_jp') @@ -115,7 +116,7 @@ def run_evaluation(synthetic_dataset: Dataset, collection_name: str, query_text_ query_id = f"{x['pdf_name']}_{x['pdf_page']}" relevant_docs[query_id] = {query_id: 1} # The most relevant document is itself - response = search_client.search_images_by_text(query_text=x['question_en'], collection_name=collection_name, top_k=10) + response = search_client.search_images_by_text(query_text=x[query_text_key], collection_name=collection_name, top_k=10) results[query_id] = {} for point in response.points: diff --git a/ai-search-demo/llm-inference/llm_serving.py b/ai-search-demo/llm-inference/llm_serving.py new file mode 100644 index 0000000..833ebfd --- /dev/null +++ b/ai-search-demo/llm-inference/llm_serving.py @@ -0,0 +1,172 @@ +import modal + +vllm_image = modal.Image.debian_slim(python_version="3.12").pip_install( + "vllm==0.6.3post1", "fastapi[standard]==0.115.4" +) + + +MODELS_DIR = "/models" +MODEL_NAME = "Qwen/Qwen2-VL-7B-Instruct" + + +try: + volume = modal.Volume.lookup("models", create_if_missing=False) +except modal.exception.NotFoundError: + raise Exception("Download models first with modal run download_llama.py") + + + +app = modal.App("qwen2-vllm") + +N_GPU = 1 # tip: for best results, first upgrade to more powerful GPUs, and only then increase GPU count +TOKEN = "super-secret-token" # auth token. for production use, replace with a modal.Secret + +MINUTES = 60 # seconds +HOURS = 60 * MINUTES + + +@app.function( + image=vllm_image, + gpu=modal.gpu.H100(count=N_GPU), + container_idle_timeout=5 * MINUTES, + timeout=24 * HOURS, + allow_concurrent_inputs=1000, + volumes={MODELS_DIR: volume}, +) +@modal.asgi_app() +def serve(): + import fastapi + import vllm.entrypoints.openai.api_server as api_server + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.entrypoints.logger import RequestLogger + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + from vllm.entrypoints.openai.serving_completion import ( + OpenAIServingCompletion, + ) + from vllm.entrypoints.openai.serving_engine import BaseModelPath + from vllm.usage.usage_lib import UsageContext + + volume.reload() # ensure we have the latest version of the weights + + # create a fastAPI app that uses vLLM's OpenAI-compatible router + web_app = fastapi.FastAPI( + title=f"OpenAI-compatible {MODEL_NAME} server", + description="Run an OpenAI-compatible LLM server with vLLM on modal.com 🚀", + version="0.0.1", + docs_url="/docs", + ) + + # security: CORS middleware for external requests + http_bearer = fastapi.security.HTTPBearer( + scheme_name="Bearer Token", + description="See code for authentication details.", + ) + web_app.add_middleware( + fastapi.middleware.cors.CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # security: inject dependency on authed routes + async def is_authenticated(api_key: str = fastapi.Security(http_bearer)): + if api_key.credentials != TOKEN: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + ) + return {"username": "authenticated_user"} + + router = fastapi.APIRouter(dependencies=[fastapi.Depends(is_authenticated)]) + + # wrap vllm's router in auth router + router.include_router(api_server.router) + # add authed vllm to our fastAPI app + web_app.include_router(router) + + engine_args = AsyncEngineArgs( + model=MODELS_DIR + "/" + MODEL_NAME, + tensor_parallel_size=N_GPU, + gpu_memory_utilization=0.90, + max_model_len=8096, + enforce_eager=False, # capture the graph for faster inference, but slower cold starts (30s > 20s) + ) + + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_API_SERVER + ) + + model_config = get_model_config(engine) + + request_logger = RequestLogger(max_log_len=2048) + + base_model_paths = [ + BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) + ] + + api_server.chat = lambda s: OpenAIServingChat( + engine, + model_config=model_config, + base_model_paths=base_model_paths, + chat_template=None, + response_role="assistant", + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + ) + api_server.completion = lambda s: OpenAIServingCompletion( + engine, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=[], + prompt_adapters=[], + request_logger=request_logger, + ) + + return web_app + + +def get_model_config(engine): + import asyncio + + try: # adapted from vLLM source -- https://github.com/vllm-project/vllm/blob/507ef787d85dec24490069ffceacbd6b161f4f72/vllm/entrypoints/openai/api_server.py#L235C1-L247C1 + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running(): + # If the current is instanced by Ray Serve, + # there is already a running event loop + model_config = event_loop.run_until_complete(engine.get_model_config()) + else: + # When using single vLLM without engine_use_ray + model_config = asyncio.run(engine.get_model_config()) + + return model_config + + +def client(): + + model = "Qwen2-VL-7B-Instruct" + img_url = "https://www.google.com/images/branding/googlelogo/1x/googlelogo_color_272x92dp.png" + prompt = "Translate text on image to Japanese" + + response = client.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": {"url": f"{img_url}"}, + }, + ], + } + ], + ) + + print(response.choices[0].message) \ No newline at end of file diff --git a/ai-search-demo/llm-inference/llm_serving_colpali.py b/ai-search-demo/llm-inference/llm_serving_colpali.py new file mode 100644 index 0000000..16045f6 --- /dev/null +++ b/ai-search-demo/llm-inference/llm_serving_colpali.py @@ -0,0 +1,109 @@ +import modal + +vllm_image = modal.Image.debian_slim(python_version="3.12").pip_install( + "vllm==0.6.3post1", "fastapi[standard]==0.115.4" +).pip_install("colpali-engine") + +MODELS_DIR = "/models" +MODEL_NAME = "vidore/colqwen2-v1.0-merged" + + +try: + volume = modal.Volume.lookup("models", create_if_missing=False) +except modal.exception.NotFoundError: + raise Exception("Download models first with modal run download_llama.py") + + + +app = modal.App("colpali-embedding") + +N_GPU = 1 +TOKEN = "super-secret-token" + +MINUTES = 60 # seconds +HOURS = 60 * MINUTES + + +@app.function( + image=vllm_image, + gpu=modal.gpu.H100(count=N_GPU), + container_idle_timeout=5 * MINUTES, + timeout=24 * HOURS, + allow_concurrent_inputs=1000, + volumes={MODELS_DIR: volume}, +) +@modal.asgi_app() +def serve(): + import fastapi + import torch + from colpali_engine.models import ColQwen2, ColQwen2Processor + from fastapi import APIRouter, Depends, HTTPException, Security + from fastapi.middleware.cors import CORSMiddleware + from fastapi.security import HTTPBearer + + volume.reload() # ensure we have the latest version of the weights + + # create a fastAPI app for serving the ColPali model + web_app = fastapi.FastAPI( + title=f"ColPali {MODEL_NAME} server", + description="Run a ColPali model server with fastAPI on modal.com 🚀", + version="0.0.1", + docs_url="/docs", + ) + + # security: CORS middleware for external requests + http_bearer = HTTPBearer( + scheme_name="Bearer Token", + description="See code for authentication details.", + ) + web_app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # security: inject dependency on authed routes + async def is_authenticated(api_key: str = Security(http_bearer)): + if api_key.credentials != TOKEN: + raise HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + ) + return {"username": "authenticated_user"} + + router = APIRouter(dependencies=[Depends(is_authenticated)]) + + # Define the model and processor + model_name = "/models/vidore/colqwen2-v1.0-merged" + colpali_model = ColQwen2.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + device_map="cuda:0", + ).eval() + + colpali_processor = ColQwen2Processor.from_pretrained(model_name) + + # Define a simple endpoint to process text queries + @router.post("/query") + async def query_model(query_text: str): + with torch.no_grad(): + batch_query = colpali_processor.process_queries([query_text]).to(colpali_model.device) + query_embedding = colpali_model(**batch_query) + return {"embedding": query_embedding[0].cpu().float().numpy().tolist()} + + @router.post("/process_image") + async def process_image(image: fastapi.UploadFile): + from PIL import Image + pil_image = Image.open(image.file) + with torch.no_grad(): + batch_image = colpali_processor.process_images([pil_image]).to(colpali_model.device) + image_embedding = colpali_model(**batch_image) + return {"embedding": image_embedding[0].cpu().float().numpy().tolist()} + + + # add authed router to our fastAPI app + web_app.include_router(router) + + return web_app diff --git a/ai-search-demo/llm-inference/llm_serving_load_models.py b/ai-search-demo/llm-inference/llm_serving_load_models.py new file mode 100644 index 0000000..84c610b --- /dev/null +++ b/ai-search-demo/llm-inference/llm_serving_load_models.py @@ -0,0 +1,62 @@ +import modal + +MODELS_DIR = "/models" + +DEFAULT_NAME = "Qwen/Qwen2.5-7B-Instruct" +DEFAULT_REVISION = "bb46c15ee4bb56c5b63245ef50fd7637234d6f75" + +# Qwen/Qwen2-VL-7B-Instruct + + +volume = modal.Volume.from_name("models", create_if_missing=True) + +image = ( + modal.Image.debian_slim(python_version="3.10") + .pip_install( + [ + "huggingface_hub", + "hf-transfer", + ] + ) + .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) +) + + +MINUTES = 60 +HOURS = 60 * MINUTES + + +app = modal.App( + image=image, secrets=[modal.Secret.from_name("huggingface-secret")] +) + + +@app.function(volumes={MODELS_DIR: volume}, timeout=4 * HOURS) +def download_model(model_name, model_revision, force_download=False): + from huggingface_hub import snapshot_download + + volume.reload() + + snapshot_download( + model_name, + local_dir=MODELS_DIR + "/" + model_name, + ignore_patterns=[ + "*.pt", + "*.bin", + "*.pth", + "original/*", + ], # Ensure safetensors + revision=model_revision, + force_download=force_download, + ) + + volume.commit() + + +@app.local_entrypoint() +def main( + model_name: str = DEFAULT_NAME, + model_revision: str = DEFAULT_REVISION, + force_download: bool = False, +): + download_model.remote(model_name, model_revision, force_download) \ No newline at end of file