Skip to content

Commit

Permalink
Merge pull request stanfordnlp#574 from thomasahle/main
Browse files Browse the repository at this point in the history
New Typed Signature Optimizer
  • Loading branch information
thomasahle authored Mar 6, 2024
2 parents b2816c4 + 26bb358 commit c2d639f
Show file tree
Hide file tree
Showing 17 changed files with 1,374 additions and 189 deletions.
73 changes: 32 additions & 41 deletions dsp/templates/template_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def query(self, example: Example, is_demo: bool = False) -> str:
"""Retrieves the input variables from the example and formats them into a query string."""
result: list[str] = []

# If not a demo, find the last field that doesn't have a value set in `example` and set it to ""
# This creates the "Output:" prefix at the end of the prompt.
if not is_demo:
has_value = [
field.input_variable in example
Expand All @@ -80,40 +82,40 @@ def query(self, example: Example, is_demo: bool = False) -> str:
for field in self.fields
]

for i in range(1, len(has_value)):
if has_value[i - 1] and not any(has_value[i:]):
example[self.fields[i].input_variable] = ""
break
# If there are no inputs, set the first field to ""
if not any(has_value):
example[self.fields[0].input_variable] = ""
# Otherwise find the first field without a value.
else:
for i in range(1, len(has_value)):
if has_value[i - 1] and not any(has_value[i:]):
example[self.fields[i].input_variable] = ""
break

for field in self.fields:
if (
field.input_variable in example
and example[field.input_variable] is not None
):
if field.input_variable in example and example[field.input_variable] is not None:
if field.input_variable in self.format_handlers:
format_handler = self.format_handlers[field.input_variable]
else:

def format_handler(x):
assert type(x) == str, f"Need format_handler for {field.input_variable} of type {type(x)}"
return " ".join(x.split())

formatted_value = format_handler(example[field.input_variable])
separator = '\n' if field.separator == ' ' and '\n' in formatted_value else field.separator
separator = "\n" if field.separator == " " and "\n" in formatted_value else field.separator

result.append(
f"{field.name}{separator}{formatted_value}",
)

if self._has_augmented_guidelines() and (example.get('augmented', False)):
if self._has_augmented_guidelines() and (example.get("augmented", False)):
return "\n\n".join([r for r in result if r])
return "\n".join([r for r in result if r])

def guidelines(self, show_guidelines=True) -> str:
"""Returns the task guidelines as described in the lm prompt"""
if (not show_guidelines) or (
hasattr(dsp.settings, "show_guidelines")
and not dsp.settings.show_guidelines
):
if (not show_guidelines) or (hasattr(dsp.settings, "show_guidelines") and not dsp.settings.show_guidelines):
return ""

result = "Follow the following format.\n\n"
Expand All @@ -128,11 +130,13 @@ def guidelines(self, show_guidelines=True) -> str:

def _has_augmented_guidelines(self):
return len(self.fields) > 3 or any(
("\n" in field.separator) or ('\n' in field.description) for field in self.fields
("\n" in field.separator) or ("\n" in field.description) for field in self.fields
)

def extract(
self, example: Union[Example, dict[str, Any]], raw_pred: str,
self,
example: Union[Example, dict[str, Any]],
raw_pred: str,
) -> Example:
"""Extracts the answer from the LM raw prediction using the template structure
Expand All @@ -149,10 +153,7 @@ def extract(

idx = 0
while idx < len(self.fields):
if (
self.fields[idx].input_variable not in example
or example[self.fields[idx].input_variable] is None
):
if self.fields[idx].input_variable not in example or example[self.fields[idx].input_variable] is None:
break
idx += 1

Expand All @@ -166,16 +167,16 @@ def extract(

if offset >= 0:
if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip('---').strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip('---').strip()
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip("---").strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip("---").strip()
else:
example[self.fields[idx].output_variable] = raw_pred[:offset].strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip()

idx += 1
else:
if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip()
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
else:
example[self.fields[idx].output_variable] = raw_pred.strip()

Expand All @@ -187,7 +188,7 @@ def extract(
assert idx == len(self.fields) - 1, (idx, len(self.fields))

if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip()
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
else:
example[self.fields[idx].output_variable] = raw_pred.strip()

Expand All @@ -198,7 +199,7 @@ def extract(
def __call__(self, example, show_guidelines=True) -> str:
example = dsp.Example(example)

if hasattr(dsp.settings, 'query_only') and dsp.settings.query_only:
if hasattr(dsp.settings, "query_only") and dsp.settings.query_only:
return self.query(example)

# The training data should not contain the output variable
Expand All @@ -209,29 +210,20 @@ def __call__(self, example, show_guidelines=True) -> str:
self.query(demo, is_demo=True)
for demo in example.demos
if (
(not demo.get('augmented', False))
(not demo.get("augmented", False))
and ( # validate that the training example has the same primitive input var as the template
self.fields[-1].input_variable in demo
and demo[self.fields[-1].input_variable] is not None
self.fields[-1].input_variable in demo and demo[self.fields[-1].input_variable] is not None
)
)
]

ademos = [
self.query(demo, is_demo=True)
for demo in example.demos
if demo.get('augmented', False)
]
ademos = [self.query(demo, is_demo=True) for demo in example.demos if demo.get("augmented", False)]

# Move the rdemos to ademos if rdemo has all the fields filled in
rdemos_ = []
new_ademos = []
for rdemo in rdemos:
if all(
(field.name in rdemo)
for field in self.fields
if field.input_variable in example
):
if all((field.name in rdemo) for field in self.fields if field.input_variable in example):
import dspy

if dspy.settings.release >= 20230928:
Expand All @@ -244,7 +236,6 @@ def __call__(self, example, show_guidelines=True) -> str:
ademos = new_ademos + ademos
rdemos = rdemos_


long_query = self._has_augmented_guidelines()

if long_query:
Expand All @@ -253,10 +244,10 @@ def __call__(self, example, show_guidelines=True) -> str:
query = self.query(example)

# if it has more lines than fields
if len(query.split('\n')) > len(self.fields):
if len(query.split("\n")) > len(self.fields):
long_query = True

if not example.get('augmented', False):
if not example.get("augmented", False):
example["augmented"] = True
query = self.query(example)

Expand Down
5 changes: 3 additions & 2 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# from .evaluation import *
# FIXME:
import dsp
from dsp.modules.hf_client import ChatModuleClient, HFClientSGLang, HFClientVLLM, HFServerTGI

Expand All @@ -8,6 +6,9 @@
from .retrieve import *
from .signatures import *

# Functional must be imported after primitives, predict and signatures
from .functional import * # isort: skip

settings = dsp.settings

AzureOpenAI = dsp.AzureOpenAI
Expand Down
Loading

0 comments on commit c2d639f

Please sign in to comment.