Skip to content

Commit

Permalink
Fix the "no inputs" case
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 6, 2024
1 parent 1f6e0cb commit a2fb536
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 44 deletions.
73 changes: 32 additions & 41 deletions dsp/templates/template_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def query(self, example: Example, is_demo: bool = False) -> str:
"""Retrieves the input variables from the example and formats them into a query string."""
result: list[str] = []

# If not a demo, find the last field that doesn't have a value set in `example` and set it to ""
# This creates the "Output:" prefix at the end of the prompt.
if not is_demo:
has_value = [
field.input_variable in example
Expand All @@ -78,40 +80,40 @@ def query(self, example: Example, is_demo: bool = False) -> str:
for field in self.fields
]

for i in range(1, len(has_value)):
if has_value[i - 1] and not any(has_value[i:]):
example[self.fields[i].input_variable] = ""
break
# If there are no inputs, set the first field to ""
if not any(has_value):
example[self.fields[0].input_variable] = ""
# Otherwise find the first field without a value.
else:
for i in range(1, len(has_value)):
if has_value[i - 1] and not any(has_value[i:]):
example[self.fields[i].input_variable] = ""
break

for field in self.fields:
if (
field.input_variable in example
and example[field.input_variable] is not None
):
if field.input_variable in example and example[field.input_variable] is not None:
if field.input_variable in self.format_handlers:
format_handler = self.format_handlers[field.input_variable]
else:

def format_handler(x):
assert type(x) == str, f"Need format_handler for {field.input_variable} of type {type(x)}"
return " ".join(x.split())

formatted_value = format_handler(example[field.input_variable])
separator = '\n' if field.separator == ' ' and '\n' in formatted_value else field.separator
separator = "\n" if field.separator == " " and "\n" in formatted_value else field.separator

result.append(
f"{field.name}{separator}{formatted_value}",
)

if self._has_augmented_guidelines() and (example.get('augmented', False)):
if self._has_augmented_guidelines() and (example.get("augmented", False)):
return "\n\n".join([r for r in result if r])
return "\n".join([r for r in result if r])

def guidelines(self, show_guidelines=True) -> str:
"""Returns the task guidelines as described in the lm prompt"""
if (not show_guidelines) or (
hasattr(dsp.settings, "show_guidelines")
and not dsp.settings.show_guidelines
):
if (not show_guidelines) or (hasattr(dsp.settings, "show_guidelines") and not dsp.settings.show_guidelines):
return ""

result = "Follow the following format.\n\n"
Expand All @@ -126,11 +128,13 @@ def guidelines(self, show_guidelines=True) -> str:

def _has_augmented_guidelines(self):
return len(self.fields) > 3 or any(
("\n" in field.separator) or ('\n' in field.description) for field in self.fields
("\n" in field.separator) or ("\n" in field.description) for field in self.fields
)

def extract(
self, example: Union[Example, dict[str, Any]], raw_pred: str,
self,
example: Union[Example, dict[str, Any]],
raw_pred: str,
) -> Example:
"""Extracts the answer from the LM raw prediction using the template structure
Expand All @@ -147,10 +151,7 @@ def extract(

idx = 0
while idx < len(self.fields):
if (
self.fields[idx].input_variable not in example
or example[self.fields[idx].input_variable] is None
):
if self.fields[idx].input_variable not in example or example[self.fields[idx].input_variable] is None:
break
idx += 1

Expand All @@ -164,16 +165,16 @@ def extract(

if offset >= 0:
if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip('---').strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip('---').strip()
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip("---").strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip("---").strip()
else:
example[self.fields[idx].output_variable] = raw_pred[:offset].strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip()

idx += 1
else:
if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip()
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
else:
example[self.fields[idx].output_variable] = raw_pred.strip()

Expand All @@ -185,7 +186,7 @@ def extract(
assert idx == len(self.fields) - 1, (idx, len(self.fields))

if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip()
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
else:
example[self.fields[idx].output_variable] = raw_pred.strip()

Expand All @@ -196,7 +197,7 @@ def extract(
def __call__(self, example, show_guidelines=True) -> str:
example = dsp.Example(example)

if hasattr(dsp.settings, 'query_only') and dsp.settings.query_only:
if hasattr(dsp.settings, "query_only") and dsp.settings.query_only:
return self.query(example)

# The training data should not contain the output variable
Expand All @@ -207,29 +208,20 @@ def __call__(self, example, show_guidelines=True) -> str:
self.query(demo, is_demo=True)
for demo in example.demos
if (
(not demo.get('augmented', False))
(not demo.get("augmented", False))
and ( # validate that the training example has the same primitive input var as the template
self.fields[-1].input_variable in demo
and demo[self.fields[-1].input_variable] is not None
self.fields[-1].input_variable in demo and demo[self.fields[-1].input_variable] is not None
)
)
]

ademos = [
self.query(demo, is_demo=True)
for demo in example.demos
if demo.get('augmented', False)
]
ademos = [self.query(demo, is_demo=True) for demo in example.demos if demo.get("augmented", False)]

# Move the rdemos to ademos if rdemo has all the fields filled in
rdemos_ = []
new_ademos = []
for rdemo in rdemos:
if all(
(field.name in rdemo)
for field in self.fields
if field.input_variable in example
):
if all((field.name in rdemo) for field in self.fields if field.input_variable in example):
import dspy

if dspy.settings.release >= 20230928:
Expand All @@ -242,7 +234,6 @@ def __call__(self, example, show_guidelines=True) -> str:
ademos = new_ademos + ademos
rdemos = rdemos_


long_query = self._has_augmented_guidelines()

if long_query:
Expand All @@ -251,10 +242,10 @@ def __call__(self, example, show_guidelines=True) -> str:
query = self.query(example)

# if it has more lines than fields
if len(query.split('\n')) > len(self.fields):
if len(query.split("\n")) > len(self.fields):
long_query = True

if not example.get('augmented', False):
if not example.get("augmented", False):
example["augmented"] = True
query = self.query(example)

Expand Down
5 changes: 2 additions & 3 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def forward(self, **kwargs):
# print(f"#> Setting temperature to 0.7 since n={num_generations} and prior temperature={temperature}.")

# All of the other kwargs are presumed to fit a prefix of the signature.

# That is, they are input variables for the bottom most generation, so
# we place them inside the input - x - together with the demos.
x = dsp.Example(demos=demos, **kwargs)

if new_signature is not None:
Expand All @@ -86,8 +87,6 @@ def forward(self, **kwargs):

# Switch to legacy format for dsp.generate
template = signature_to_template(signature)
# print("Created template", template)
# print("From Signature", signature)

if self.lm is None:
x, C = dsp.generate(template, **config)(x, stage=self.stage)
Expand Down
25 changes: 25 additions & 0 deletions tests/predict/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dspy import Predict, Signature
from dspy.utils.dummies import DummyLM
import copy
import textwrap


def test_initialization_with_string_signature():
Expand Down Expand Up @@ -119,3 +120,27 @@ def __init__(self):
# Check that it also works the second time.
program2 = copy.deepcopy(program)
assert program2.named_predictors() == [("inner", program2.inner)]


def test_output_only():
class OutputOnlySignature(dspy.Signature):
output = dspy.OutputField()

predictor = Predict(OutputOnlySignature)

lm = DummyLM(["short answer"])
dspy.settings.configure(lm=lm)
assert predictor().output == "short answer"

assert lm.get_convo(-1) == textwrap.dedent("""\
Given the fields , produce the fields `output`.
---
Follow the following format.
Output: ${output}
---
Output: short answer""")

0 comments on commit a2fb536

Please sign in to comment.