Skip to content

Commit

Permalink
enh: embed quries with either openai or other embedding models
Browse files Browse the repository at this point in the history
  • Loading branch information
nbqu committed Mar 12, 2024
1 parent 2d845de commit 7921ef8
Showing 1 changed file with 32 additions and 19 deletions.
51 changes: 32 additions & 19 deletions dspy/retrieve/pgvector_rm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import List, Optional

import openai

from typing import List, Optional, Callable
import warnings
import dspy

try:
Expand All @@ -12,6 +10,11 @@
raise ImportError(
"The 'pgvector' extra is required to use PgVectorRM. Install it with `pip install dspy-ai[pgvector]`",
)
try:
import openai
except ImportError:
warnings.warn("`openai` is not installed. Install it with `pip install openai` to use OpenAI embedding models.",
category=ImportWarning)


class PgVectorRM(dspy.Retrieve):
Expand All @@ -26,7 +29,8 @@ class PgVectorRM(dspy.Retrieve):
Args:
db_url (str): A PostgreSQL database URL in psycopg2's DSN format
pg_table_name (Optional[str]): name of the table containing passages
openai_client (openai.OpenAI): OpenAI client to use for computing query embeddings
openai_client (openai.OpenAI): OpenAI client to use for computing query embeddings. Either openai_client or embedding_func must be provided.
embedding_func (Callable): A function to use for computing query embeddings. Either openai_client or embedding_func must be provided.
k (Optional[int]): Default number of top passages to retrieve. Defaults to 20
embedding_field (str = "embedding"): Field containing passage embeddings. Defaults to "embedding"
fields (List[str] = ['text']): Fields to retrieve from the table. Defaults to "text"
Expand All @@ -41,10 +45,10 @@ class PgVectorRM(dspy.Retrieve):
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
openai_client = openai.OpenAI()
llm = dspy.OpenAI(model="gpt-3.5-turbo")
DATABASE_URL should be in the format postgresql://user:password@host/database
DATABASE_URL should be in the format postgresql://user:password@host/database
db_url=os.getenv("DATABASE_URL")
retriever_model = PgVectorRM(conn, openai_client=openai_client, "paragraphs", fields=["text", "document_id"], k=20)
Expand All @@ -60,16 +64,19 @@ def __init__(
self,
db_url: str,
pg_table_name: str,
openai_client: openai.OpenAI,
k: Optional[int]=20,
openai_client: Optional[openai.OpenAI] = None,
embedding_func: Optional[Callable] = None,
k: Optional[int] = 20,
embedding_field: str = "embedding",
fields: List[str] = ['text'],
):
"""
k = 20 is the number of paragraphs to retrieve
"""
assert openai_client or embedding_func, "Either openai_client or embedding_func must be provided."
self.openai_client = openai_client

self.embedding_func = embedding_func

self.conn = psycopg2.connect(db_url)
register_vector(self.conn)
self.pg_table_name = pg_table_name
Expand All @@ -80,19 +87,15 @@ def __init__(

def forward(self, query: str, k: Optional[int]=20):
"""Search with PgVector for self.k top passages for query
Args:
query (str): The query to search for
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k
Returns:
Returns:
dspy.Prediction: an object containing the retrieved passages.
"""
# Embed query
query_embedding = self.openai_client.embeddings.create(
model="text-embedding-ada-002",
input=query,
encoding_format="float",
).data[0].embedding
query_embedding = self._get_embeddings(query)

related_paragraphs = []

Expand All @@ -115,4 +118,14 @@ def forward(self, query: str, k: Optional[int]=20):
for row in rows:
related_paragraphs.append(dspy.Example(long_text=row[0], document_id=row[1]))
# Return Prediction
return related_paragraphs
return related_paragraphs

def _get_embeddings(self, query: str) -> List[float]:
if self.openai_client is not None:
return self.openai_client.embeddings.create(
model="text-embedding-ada-002",
input=query,
encoding_format="float",
).data[0].embedding
else:
return self.embedding_func(query)

0 comments on commit 7921ef8

Please sign in to comment.