From 171567d818f578429e84bb37de3a5ee27e08c2bd Mon Sep 17 00:00:00 2001 From: Arnav Singhvi Date: Fri, 23 Feb 2024 15:23:31 -0800 Subject: [PATCH] updated vector search response handling --- dspy/retrieve/databricks_rm.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 530fcf02d..2431585e8 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -2,6 +2,8 @@ import os import requests from typing import Union, List, Optional +from collections import defaultdict +from dspy.primitives.prediction import Prediction class DatabricksRM(dspy.Retrieve): """ @@ -113,7 +115,16 @@ def forward(self, query: Union[str, List[float]], query_type: str = 'vector') -> headers=headers ) results = response.json() - docs = [] + + docs = defaultdict(float) + text, score = None, None for data_row in results["result"]["data_array"]: - docs.append({col: val for col, val in zip(results["manifest"]["columns"], data_row)}) - return dspy.Prediction(embeddings=docs) \ No newline at end of file + for col, val in zip(results["manifest"]["columns"], data_row): + if col["name"] == 'text': + text = val + if col["name"] == 'score': + score = val + docs[text] += score + + sorted_docs = sorted(docs.items(), key=lambda x: x[1], reverse=True)[:self.k] + return Prediction(docs=[doc for doc, _ in sorted_docs])