Skip to content

Commit

Permalink
Merge branch 'stanfordnlp:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle authored Mar 18, 2024
2 parents ba550d0 + eb2dd73 commit f181adb
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 127 deletions.
32 changes: 15 additions & 17 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@ default_stages: [commit]
default_install_hook_types: [pre-commit, commit-msg]

repos:
# - repo: https://github.com/astral-sh/ruff-pre-commit
# # Ruff version.
# rev: v0.1.11
# hooks:
# # Run the linter.
# - id: ruff
# args: [--fix]
# # Run the formatter.
# - id: ruff-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
Expand Down Expand Up @@ -50,14 +47,15 @@ repos:
args:
- "--autofix"
- "--indent=2"
# - repo: local
# hooks:
# - id: validate-commit-msg
# name: Commit Message is Valid
# language: pygrep
# entry: ^(break|build|ci|docs|feat|fix|perf|refactor|style|test|ops|hotfix|release|maint|init|enh|revert)\([\w,\.,\-,\(,\),\/]+\)(!?)(:)\s{1}([\w,\W,:]+)
# stages: [commit-msg]
# args: [--negate]

- repo: local
hooks:
- id: validate-commit-msg
name: Commit Message is Valid
language: pygrep
entry: ^(break|build|ci|docs|feat|fix|perf|refactor|style|test|ops|hotfix|release|maint|init|enh|revert)\([\w,\.,\-,\(,\),\/]+\)(!?)(:)\s{1}([\w,\W,:]+)
stages: [commit-msg]
args: [--negate]

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.3
Expand Down
2 changes: 1 addition & 1 deletion dsp/modules/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dsp.utils import dotdict

cache_turn_on = True
cache_turn_on = os.environ.get('DSP_CACHEBOOL', 'True').lower() != 'false'


def noop_decorator(arg=None, *noop_args, **noop_kwargs):
Expand Down
6 changes: 6 additions & 0 deletions dsp/modules/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@ def __init__(
api_provider: Literal["openai"] = "openai",
api_base: Optional[str] = None,
model_type: Literal["chat", "text"] = None,
system_prompt: Optional[str] = None,
**kwargs,
):
super().__init__(model)
self.provider = "openai"
openai.api_type = api_provider

self.system_prompt = system_prompt

assert (
api_provider != "azure"
), "Azure functionality with base OpenAI has been deprecated, please use dspy.AzureOpenAI instead."
Expand Down Expand Up @@ -118,6 +121,9 @@ def basic_request(self, prompt: str, **kwargs):
kwargs = {**self.kwargs, **kwargs}
if self.model_type == "chat":
# caching mechanism requires hashable kwargs
messages = [{"role": "user", "content": prompt}]
if self.system_prompt:
messages.insert(0, {"role": "system", "content": self.system_prompt})
kwargs["messages"] = [{"role": "user", "content": prompt}]
kwargs = {"stringify_request": json.dumps(kwargs)}
response = chat_request(**kwargs)
Expand Down
13 changes: 11 additions & 2 deletions dsp/primitives/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,20 @@ def do_generate(
completion[field_names[last_field_idx]] = ""

# Recurse with greedy decoding and a shorter length.
max_tokens = kwargs.get("max_tokens", dsp.settings.lm.kwargs["max_tokens"])
max_tokens = (kwargs.get("max_tokens") or
kwargs.get("max_output_tokens") or
dsp.settings.lm.kwargs.get("max_tokens") or
dsp.settings.lm.kwargs.get('max_output_tokens'))


if max_tokens is None:
raise ValueError("Required 'max_tokens' or 'max_output_tokens' not specified in settings.")
max_tokens = min(max(75, max_tokens // 2), max_tokens)
keys = list(kwargs.keys()) + list(dsp.settings.lm.kwargs.keys())
max_tokens_key = "max_tokens" if "max_tokens" in keys else "max_output_tokens"
new_kwargs = {
**kwargs,
"max_tokens": max_tokens,
max_tokens_key: max_tokens,
"n": 1,
"temperature": 0.0,
}
Expand Down
2 changes: 1 addition & 1 deletion dsp/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __new__(cls):
force_reuse_cached_compilation=False,
compiling=False,
skip_logprobs=False,
trace=None,
trace=[],
release=0,
log_openai_usage=False,
bypass_assert=False,
Expand Down
4 changes: 3 additions & 1 deletion dspy/experimental/synthesizer/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def generate(
}

if self.config.num_example_for_optim:
kwargs["ground_source"] = random.sample(ground_source, self.config.num_example_for_optim)
if not isinstance(ground_source, list):
raise ValueError("Ground source must be a list of examples when `num_example_for_optim` is provided.")
kwargs["ground_source"] = random.sample(ground_source, k=self.config.num_example_for_optim)

with dspy.context(lm=self.input_lm):
inputs = self.input_predictor(**kwargs)
Expand Down
52 changes: 33 additions & 19 deletions dspy/retrieve/pgvector_rm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Optional

import openai
import warnings
from typing import Callable, Optional

import dspy

Expand All @@ -12,6 +11,11 @@
raise ImportError(
"The 'pgvector' extra is required to use PgVectorRM. Install it with `pip install dspy-ai[pgvector]`",
)
try:
import openai
except ImportError:
warnings.warn("`openai` is not installed. Install it with `pip install openai` to use OpenAI embedding models.",
category=ImportWarning)


class PgVectorRM(dspy.Retrieve):
Expand All @@ -26,7 +30,8 @@ class PgVectorRM(dspy.Retrieve):
Args:
db_url (str): A PostgreSQL database URL in psycopg2's DSN format
pg_table_name (Optional[str]): name of the table containing passages
openai_client (openai.OpenAI): OpenAI client to use for computing query embeddings
openai_client (openai.OpenAI): OpenAI client to use for computing query embeddings. Either openai_client or embedding_func must be provided.
embedding_func (Callable): A function to use for computing query embeddings. Either openai_client or embedding_func must be provided.
k (Optional[int]): Default number of top passages to retrieve. Defaults to 20
embedding_field (str = "embedding"): Field containing passage embeddings. Defaults to "embedding"
fields (List[str] = ['text']): Fields to retrieve from the table. Defaults to "text"
Expand All @@ -41,10 +46,10 @@ class PgVectorRM(dspy.Retrieve):
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
openai_client = openai.OpenAI()
llm = dspy.OpenAI(model="gpt-3.5-turbo")
DATABASE_URL should be in the format postgresql://user:password@host/database
DATABASE_URL should be in the format postgresql://user:password@host/database
db_url=os.getenv("DATABASE_URL")
retriever_model = PgVectorRM(conn, openai_client=openai_client, "paragraphs", fields=["text", "document_id"], k=20)
Expand All @@ -60,16 +65,19 @@ def __init__(
self,
db_url: str,
pg_table_name: str,
openai_client: openai.OpenAI,
k: Optional[int]=20,
openai_client: Optional[openai.OpenAI] = None,
embedding_func: Optional[Callable] = None,
k: Optional[int] = 20,
embedding_field: str = "embedding",
fields: List[str] = ['text'],
fields: list[str] = ['text'],
):
"""
k = 20 is the number of paragraphs to retrieve
"""
assert openai_client or embedding_func, "Either openai_client or embedding_func must be provided."
self.openai_client = openai_client

self.embedding_func = embedding_func

self.conn = psycopg2.connect(db_url)
register_vector(self.conn)
self.pg_table_name = pg_table_name
Expand All @@ -80,19 +88,15 @@ def __init__(

def forward(self, query: str, k: Optional[int]=20):
"""Search with PgVector for self.k top passages for query
Args:
query (str): The query to search for
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k
Returns:
Returns:
dspy.Prediction: an object containing the retrieved passages.
"""
# Embed query
query_embedding = self.openai_client.embeddings.create(
model="text-embedding-ada-002",
input=query,
encoding_format="float",
).data[0].embedding
query_embedding = self._get_embeddings(query)

related_paragraphs = []

Expand All @@ -115,4 +119,14 @@ def forward(self, query: str, k: Optional[int]=20):
for row in rows:
related_paragraphs.append(dspy.Example(long_text=row[0], document_id=row[1]))
# Return Prediction
return related_paragraphs
return related_paragraphs

def _get_embeddings(self, query: str) -> list[float]:
if self.openai_client is not None:
return self.openai_client.embeddings.create(
model="text-embedding-ada-002",
input=query,
encoding_format="float",
).data[0].embedding
else:
return self.embedding_func(query)
Loading

0 comments on commit f181adb

Please sign in to comment.