Skip to content

Commit

Permalink
Merge pull request #22 from Athe-kunal/main
Browse files Browse the repository at this point in the history
Add metadata filtering support and fix multi-document and metadata issue
  • Loading branch information
bclavie authored Nov 13, 2024
2 parents 116e29c + 4de33ed commit 64aa6e0
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 140 deletions.
3 changes: 2 additions & 1 deletion byaldi/RAGModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def search(
self,
query: Union[str, List[str]],
k: int = 10,
filter_metadata: Optional[Dict[str,str]] = None,
return_base64_results: Optional[bool] = None,
) -> Union[List[Result], List[List[Result]]]:
"""Query an index.
Expand All @@ -171,7 +172,7 @@ def search(
Returns:
Union[List[Result], List[List[Result]]]: A list of Result objects or a list of lists of Result objects.
"""
return self.model.search(query, k, return_base64_results)
return self.model.search(query, k, filter_metadata, return_base64_results)

def get_doc_ids_to_file_names(self):
return self.model.get_doc_ids_to_file_names()
Expand Down
44 changes: 30 additions & 14 deletions byaldi/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,11 @@ def add_to_index(
raise ValueError(
f"Number of doc_ids ({len(doc_ids)}) does not match number of input items ({len(input_items)})"
)
if metadata and len(metadata) != len(input_items):
raise ValueError(
f"Number of metadata entries ({len(metadata)}) does not match number of input items ({len(input_items)})"
)

# Process each input item
for i, item in enumerate(input_items):
current_doc_id = doc_ids[i] if doc_ids else self.highest_doc_id + 1 + i
current_metadata = metadata[i] if metadata else None
current_metadata = metadata if metadata else None

if current_doc_id in self.doc_ids:
raise ValueError(
Expand Down Expand Up @@ -593,16 +589,31 @@ def _add_to_index(
def remove_from_index(self):
raise NotImplementedError("This method is not implemented yet.")

def filter_embeddings(self,filter_metadata:Dict[str,str]):
req_doc_ids = []
for idx,metadata_dict in self.doc_id_to_metadata.items():
for metadata_key,metadata_value in metadata_dict.items():
if metadata_key in filter_metadata:
if filter_metadata[metadata_key] == metadata_value:
req_doc_ids.append(idx)

req_embedding_ids = [eid for eid,doc in self.embed_id_to_doc_id.items() if doc['doc_id'] in req_doc_ids]
req_embeddings = [ie for idx,ie in enumerate(self.indexed_embeddings) if idx in req_embedding_ids]

return req_embeddings, req_embedding_ids

def search(
self,
query: Union[str, List[str]],
k: int = 10,
filter_metadata: Optional[Dict[str,str]] = None,
return_base64_results: Optional[bool] = None,
) -> Union[List[Result], List[List[Result]]]:
# Set default value for return_base64_results if not provided
if return_base64_results is None:
return_base64_results = bool(self.collection)

valid_metadata_keys = list(self.doc_id_to_metadata.values())
# Ensure k is not larger than the number of indexed documents
k = min(k, len(self.indexed_embeddings))

Expand All @@ -620,27 +631,32 @@ def search(
batch_query = {k: v.to(self.device) for k, v in batch_query.items()}
embeddings_query = self.model(**batch_query)
qs = list(torch.unbind(embeddings_query.to("cpu")))

if not filter_metadata:
req_embeddings = self.indexed_embeddings
else:
req_embeddings, req_embedding_ids = self.filter_embeddings(filter_metadata=filter_metadata)
# Compute scores
scores = self.processor.score(qs, self.indexed_embeddings).cpu().numpy()
scores = self.processor.score(qs,req_embeddings).cpu().numpy()

# Get top k relevant pages
top_pages = scores.argsort(axis=1)[0][-k:][::-1].tolist()

# Create Result objects
query_results = []
for embed_id in top_pages:
doc_info = self.embed_id_to_doc_id[int(embed_id)]
if filter_metadata:
adjusted_embed_id = req_embedding_ids[embed_id]
else:
adjusted_embed_id = int(embed_id)
doc_info = self.embed_id_to_doc_id[adjusted_embed_id]
result = Result(
doc_id=doc_info["doc_id"],
page_num=int(doc_info["page_id"]),
score=float(scores[0][embed_id]),
score=float(scores[0][int(embed_id)]),
metadata=self.doc_id_to_metadata.get(int(doc_info["doc_id"]), {}),
base64=(
self.collection.get(int(embed_id))
if return_base64_results
else None
),
base64=self.collection.get(adjusted_embed_id)
if return_base64_results
else None,
)
query_results.append(result)

Expand Down
194 changes: 69 additions & 125 deletions examples/quick_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,119 +6,33 @@
"metadata": {},
"outputs": [
{
"name": "stdout",
"name": "stderr",
"output_type": "stream",
"text": [
"Verbosity is set to 1 (active). Pass verbose=0 to make quieter.\n"
"/home/recoverx/.pyenv/versions/3.11.3/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/home/recoverx/.pyenv/versions/3.11.3/lib/python3.11/site-packages/transformers/utils/hub.py:128: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ac796108e8b54de9954ae222b9ba8c0e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"adapter_config.json: 0%| | 0.00/752 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "71705bdd755040c8ba7b58a737403edb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"config.json: 0%| | 0.00/1.02k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6139b3b6f344342b77aa05fdc705f78",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors.index.json: 0%| | 0.00/66.3k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5b991cb1ab044a8e86790ca19e51cc02",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "627f7fc3c940479ba71bea3f021688f6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00001-of-00002.safetensors: 0%| | 0.00/4.99G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5343f9c4b7e64091b5b1f307009c5f0d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00002-of-00002.safetensors: 0%| | 0.00/862M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
"name": "stdout",
"output_type": "stream",
"text": [
"Verbosity is set to 1 (active). Pass verbose=0 to make quieter.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading shards: 100%|██████████| 2/2 [00:00<00:00, 10.68it/s]\n",
"`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n",
"Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n",
"`config.hidden_activation` if you want to override this behaviour.\n",
"See https://github.com/huggingface/transformers/pull/29402 for more details.\n"
"See https://github.com/huggingface/transformers/pull/29402 for more details.\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.12it/s]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "21405845b1ae4cef8dea83e281f768fb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand Down Expand Up @@ -221,11 +135,14 @@
],
"source": [
"# Test indexing\n",
"metadata = [{\"filename\":file_name} for file_name in os.listdir(\"docs\")]\n",
"\n",
"index_name = \"attention_index\"\n",
"model.index(\n",
" input_path=Path(\"docs/\"),\n",
" index_name=index_name,\n",
" store_collection_with_index=False,\n",
" metadata=metadata,\n",
" overwrite=True\n",
")\n",
"\n",
Expand Down Expand Up @@ -260,37 +177,22 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'model_name': 'vidore/colpali-v1.2', 'full_document_collection': False, 'highest_doc_id': 1}\n",
"Verbosity is set to 1 (active). Pass verbose=0 to make quieter.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9d8cfef3127b4fcf96a22ed1ec277dc4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"name": "stderr",
"output_type": "stream",
"text": [
"Loading adapter...\n",
"Adapter name: vidore/colpali-v1.2\n"
"Downloading shards: 100%|██████████| 2/2 [00:00<00:00, 17.65it/s]\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 2.12it/s]\n"
]
}
],
Expand All @@ -303,19 +205,26 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Search results for 'what's the BLEU score of this new strange method?':\n",
"Doc ID: 1, Page: 8, Score: 19.375\n",
"Doc ID: 1, Page: 9, Score: 19.375\n",
"Doc ID: 0, Page: 8, Score: 19.375\n",
"Doc ID: 0, Page: 9, Score: 19.375\n",
"Doc ID: 1, Page: 11, Score: 17.75\n"
"Doc ID: 0, Page: 1, Score: 19.875\n",
"Doc ID: 3, Page: 8, Score: 19.75\n",
"Doc ID: 4, Page: 8, Score: 19.75\n",
"Doc ID: 3, Page: 9, Score: 19.125\n",
"Doc ID: 4, Page: 9, Score: 19.125\n"
]
}
],
Expand All @@ -327,6 +236,41 @@
" print(f\"Doc ID: {result.doc_id}, Page: {result.page_num}, Score: {result.score}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## FILTER BASED ON METADATA"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Metadata information: {0: {'filename': 'attention_table.png'}, 1: {'filename': 'product_c.png'}, 2: {'filename': 'financial_report.pdf'}, 3: {'filename': 'attention_with_a_mustache.pdf'}, 4: {'filename': 'attention.pdf'}}\n",
"Search results for 'what's the BLEU score of this new strange method?':\n",
"Doc ID: 4, Page: 8, Score: 19.75\n",
"Doc ID: 4, Page: 9, Score: 19.125\n",
"Doc ID: 4, Page: 1, Score: 17.125\n",
"Doc ID: 4, Page: 7, Score: 17.0\n",
"Doc ID: 4, Page: 11, Score: 16.75\n"
]
}
],
"source": [
"results = model.search(query, k=5,filter_metadata={\"filename\":\"attention.pdf\"})\n",
"\n",
"print(\"Metadata information: \",model.model.doc_id_to_metadata)\n",
"print(f\"Search results for '{query}':\")\n",
"for result in results:\n",
" print(f\"Doc ID: {result.doc_id}, Page: {result.page_num}, Score: {result.score}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
Expand Down Expand Up @@ -504,7 +448,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.11.3"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 64aa6e0

Please sign in to comment.