Skip to content

Commit

Permalink
Merge pull request stanfordnlp#634 from nbqu/pgvector-openai-dep
Browse files Browse the repository at this point in the history
Add an option to choose embedding model in pgvector
  • Loading branch information
arnavsinghvi11 authored Mar 18, 2024
2 parents 6295d98 + 0672f69 commit eb2dd73
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions dspy/retrieve/pgvector_rm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Optional

import openai
import warnings
from typing import Callable, Optional

import dspy

Expand All @@ -12,6 +11,11 @@
raise ImportError(
"The 'pgvector' extra is required to use PgVectorRM. Install it with `pip install dspy-ai[pgvector]`",
)
try:
import openai
except ImportError:
warnings.warn("`openai` is not installed. Install it with `pip install openai` to use OpenAI embedding models.",
category=ImportWarning)


class PgVectorRM(dspy.Retrieve):
Expand All @@ -26,7 +30,8 @@ class PgVectorRM(dspy.Retrieve):
Args:
db_url (str): A PostgreSQL database URL in psycopg2's DSN format
pg_table_name (Optional[str]): name of the table containing passages
openai_client (openai.OpenAI): OpenAI client to use for computing query embeddings
openai_client (openai.OpenAI): OpenAI client to use for computing query embeddings. Either openai_client or embedding_func must be provided.
embedding_func (Callable): A function to use for computing query embeddings. Either openai_client or embedding_func must be provided.
k (Optional[int]): Default number of top passages to retrieve. Defaults to 20
embedding_field (str = "embedding"): Field containing passage embeddings. Defaults to "embedding"
fields (List[str] = ['text']): Fields to retrieve from the table. Defaults to "text"
Expand All @@ -41,10 +46,10 @@ class PgVectorRM(dspy.Retrieve):
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
openai_client = openai.OpenAI()
llm = dspy.OpenAI(model="gpt-3.5-turbo")
DATABASE_URL should be in the format postgresql://user:password@host/database
DATABASE_URL should be in the format postgresql://user:password@host/database
db_url=os.getenv("DATABASE_URL")
retriever_model = PgVectorRM(conn, openai_client=openai_client, "paragraphs", fields=["text", "document_id"], k=20)
Expand All @@ -60,16 +65,19 @@ def __init__(
self,
db_url: str,
pg_table_name: str,
openai_client: openai.OpenAI,
k: Optional[int]=20,
openai_client: Optional[openai.OpenAI] = None,
embedding_func: Optional[Callable] = None,
k: Optional[int] = 20,
embedding_field: str = "embedding",
fields: List[str] = ['text'],
fields: list[str] = ['text'],
):
"""
k = 20 is the number of paragraphs to retrieve
"""
assert openai_client or embedding_func, "Either openai_client or embedding_func must be provided."
self.openai_client = openai_client

self.embedding_func = embedding_func

self.conn = psycopg2.connect(db_url)
register_vector(self.conn)
self.pg_table_name = pg_table_name
Expand All @@ -80,19 +88,15 @@ def __init__(

def forward(self, query: str, k: Optional[int]=20):
"""Search with PgVector for self.k top passages for query
Args:
query (str): The query to search for
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k
Returns:
Returns:
dspy.Prediction: an object containing the retrieved passages.
"""
# Embed query
query_embedding = self.openai_client.embeddings.create(
model="text-embedding-ada-002",
input=query,
encoding_format="float",
).data[0].embedding
query_embedding = self._get_embeddings(query)

related_paragraphs = []

Expand All @@ -115,4 +119,14 @@ def forward(self, query: str, k: Optional[int]=20):
for row in rows:
related_paragraphs.append(dspy.Example(long_text=row[0], document_id=row[1]))
# Return Prediction
return related_paragraphs
return related_paragraphs

def _get_embeddings(self, query: str) -> list[float]:
if self.openai_client is not None:
return self.openai_client.embeddings.create(
model="text-embedding-ada-002",
input=query,
encoding_format="float",
).data[0].embedding
else:
return self.embedding_func(query)

0 comments on commit eb2dd73

Please sign in to comment.