Skip to content

Commit

Permalink
Fixes to completions
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 3, 2024
1 parent fe321da commit 16be7a1
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
21 changes: 10 additions & 11 deletions dspy/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import json
from dspy.primitives.prediction import Prediction

from dspy.signatures.signature import ensure_signature
from dspy.signatures.signature import ensure_signature, make_signature


MAX_RETRIES = 3
Expand Down Expand Up @@ -83,7 +83,7 @@ def copy(self) -> "TypedPredictor":
def _make_example(type_) -> str:
# Note: DSPy will cache this call so we only pay the first time TypedPredictor is called.
json_object = dspy.Predict(
dspy.Signature(
make_signature(
"json_schema -> json_object",
"Make a very succinct json object that validates with the following schema",
),
Expand Down Expand Up @@ -153,21 +153,20 @@ def forward(self, **kwargs) -> dspy.Prediction:
for try_i in range(MAX_RETRIES):
result = self.predictor(**modified_kwargs, new_signature=signature)
errors = {}
parsed_results = defaultdict(list)
parsed_results = []
# Parse the outputs
for i, completion in enumerate(result.completions):
try:
parsed = {}
for name, field in signature.output_fields.items():
value = completion[name]
parser = field.json_schema_extra.get("parser", lambda x: x)
completion[name] = parser(value)
parsed_results[name].append(parser(value))
parsed[name] = parser(value)
# Instantiate the actual signature with the parsed values.
# This allow pydantic to validate the fields defined in the signature.
_dummy = self.signature(
**kwargs,
**{key: value[i] for key, value in parsed_results.items()},
)
_dummy = self.signature(**kwargs, **parsed)
parsed_results.append(parsed)
except (pydantic.ValidationError, ValueError) as e:
errors[name] = _format_error(e)
# If we can, we add an example to the error message
Expand Down Expand Up @@ -199,9 +198,9 @@ def forward(self, **kwargs) -> dspy.Prediction:
)
else:
# If there are no errors, we return the parsed results
# for name, value in parsed_results.items():
# setattr(result, name, value)
return Prediction.from_completions(parsed_results)
return Prediction.from_completions(
{key: [r[key] for r in parsed_results] for key in signature.output_fields}
)
raise ValueError(
"Too many retries trying to get the correct output format. " + "Try simplifying the requirements.",
errors,
Expand Down
41 changes: 41 additions & 0 deletions tests/functional/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dspy.functional import predictor, cot, FunctionalModule, TypedPredictor, functional
from dspy.primitives.example import Example
from dspy.teleprompt.bootstrap import BootstrapFewShot
from dspy.teleprompt.vanilla import LabeledFewShot
from dspy.utils.dummies import DummyLM


Expand Down Expand Up @@ -491,3 +492,43 @@ class SimpleOutput(dspy.Signature):
predictor = TypedPredictor(SimpleOutput)

assert predictor().output == 0.5


def test_synthetic_data_gen():
class SyntheticFact(BaseModel):
fact: str = Field(..., description="a statement")
varacity: bool = Field(..., description="is the statement true or false")

class ExampleSignature(dspy.Signature):
"""Generate an example of a synthetic fact."""

fact: SyntheticFact = dspy.OutputField()

lm = DummyLM(
[
'{"fact": "The sky is blue", "varacity": true}',
'{"fact": "The sky is green", "varacity": false}',
'{"fact": "The sky is red", "varacity": true}',
'{"fact": "The earth is flat", "varacity": false}',
'{"fact": "The earth is round", "varacity": true}',
'{"fact": "The earth is a cube", "varacity": false}',
]
)
dspy.settings.configure(lm=lm)

generator = TypedPredictor(ExampleSignature)
examples = generator(config=dict(n=3))
for ex in examples.completions.fact:
assert isinstance(ex, SyntheticFact)
assert examples.completions.fact[0] == SyntheticFact(fact="The sky is blue", varacity=True)

# If you have examples and want more
existing_examples = [
dspy.Example(fact="The sky is blue", varacity=True),
dspy.Example(fact="The sky is green", varacity=False),
]
trained = LabeledFewShot().compile(student=generator, trainset=existing_examples)

augmented_examples = trained(config=dict(n=3))
for ex in augmented_examples.completions.fact:
assert isinstance(ex, SyntheticFact)
2 changes: 2 additions & 0 deletions tests/signatures/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def test_instructions_signature():
def test_signature_instructions():
sig1 = Signature("input1 -> output1", instructions="This is a test")
assert sig1.instructions == "This is a test"
sig2 = Signature("input1 -> output1", "This is a test")
assert sig2.instructions == "This is a test"


def test_signature_instructions_none():
Expand Down

0 comments on commit 16be7a1

Please sign in to comment.