Skip to content

Commit

Permalink
feedback driven generation
Browse files Browse the repository at this point in the history
  • Loading branch information
krypticmouse committed Mar 12, 2024
1 parent 5a2c714 commit c26d5ba
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 49 deletions.
25 changes: 10 additions & 15 deletions dspy/experimental/synthesizer/config.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
from typing import Optional, Union

from pydantic import BaseModel, field_validator

import dspy

from typing import Optional, Any
from pydantic import BaseModel, model_validator

class SynthesizerArguments(BaseModel):
# [TODO]
feedback_mode: Optional[str] = None
num_example_for_feedback: Optional[int] = None

input_lm_model: Optional[dspy.LM] = None
output_lm_model: Optional[dspy.LM] = None
output_teacher_module: Optional[Union[dspy.Module, dspy.Predict]] = None
input_lm_model: Optional[Any] = None
output_lm_model: Optional[Any] = None
output_teacher_module: Optional[Any] = None

num_example_for_optim: Optional[int] = None

@field_validator(fields=["feedback_mode", "num_example_for_feedback"])
def validate_feedback_mode(cls, value):
if value and value not in ["human", "llm"]:
@model_validator(mode='after')
def validate_feedback_mode(self):
if self.feedback_mode and self.feedback_mode not in ["human", "llm"]:
raise ValueError("Feedback mode should be either 'human' or 'llm'.")

if value and not cls.num_example_for_feedback:
if self.feedback_mode and not self.num_example_for_feedback:
raise ValueError("Number of examples for feedback is required when feedback mode is provided.")

return value
return self
3 changes: 3 additions & 0 deletions dspy/experimental/synthesizer/instruction_suffixes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
INPUT_GENERATION_TASK_WITH_EXAMPLES_SUFFIX = """\n\nI'll also be providing you some data I generated before hand, make sure the data you generate if consistent with task I provided but different from the data I provided in every way possible."""

INPUT_GENERATION_TASK_WITH_FEEDBACK_SUFFIX = "\n\nAdditionally, I'll be providing you with feedback on the data you generate, while generating the data make sure to take into account the feedback I provide and try to improve the data you generate based on the feedback I provide."
5 changes: 0 additions & 5 deletions dspy/experimental/synthesizer/instructions.py

This file was deleted.

16 changes: 16 additions & 0 deletions dspy/experimental/synthesizer/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ class ExplainTask(dspy.Signature):
desc="Explanation of the task.",
)

class UpdateTaskDescriptionBasedOnFeedback(dspy.Signature):
"""Update the task description based on the feedback provided. Ensure that the revised task description incorporates the feedback to improve its overall clarity and effectiveness. Focus on enhancing the task's goal and basic premise, without delving into specific data points, models, examples, algorithms, or technical intricacies. Your explanation should aim to clarify the task's fundamental objective and purpose."""

task_description = dspy.InputField(
prefix="Task Description:",
desc="Description of the task.",
)
feedback = dspy.InputField(
prefix="Feedback:",
desc="Feedback on the task description.",
)
updated_task_description = dspy.OutputField(
prefix="Task Description:",
desc="Updated description of the task.",
)

class GenerateFieldDescription(dspy.Signature):
"""Generate a concise and informative description for a given field based on the provided name and task description. This description should be no longer than 10 words and should be in simple english."""

Expand Down
83 changes: 61 additions & 22 deletions dspy/experimental/synthesizer/synthesizer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
import dspy
import random
from collections.abc import Mapping
from typing import List, Optional, Union

from datasets import Dataset
from tqdm import tqdm, trange

import dspy
from rich import print as rprint
from collections.abc import Mapping
from typing import List, Optional, Union

from .config import SynthesizerArguments
from .instructions import INPUT_GENERATION_TASK_WITH_EXAMPLES
from .instruction_suffixes import (
INPUT_GENERATION_TASK_WITH_EXAMPLES_SUFFIX,
INPUT_GENERATION_TASK_WITH_FEEDBACK_SUFFIX,
)
from .signatures import (
ExplainTask,
GenerateFieldDescription,
GenerateInputFieldsData,
GenerateOutputFieldsData,
UnderstandTask,
UpdateTaskDescriptionBasedOnFeedback,
)
from .utils import format_examples

__all__ = ["Synthesizer"]
__all__ = [
"Synthesizer",
"SynthesizerArguments",
]

class Synthesizer:
def __init__(self, config: SynthesizerArguments):
Expand All @@ -29,10 +36,33 @@ def __init__(self, config: SynthesizerArguments):
self.explain_task = dspy.Predict(ExplainTask)
self.understand_task = dspy.Predict(UnderstandTask)
self.generate_field_description = dspy.Predict(GenerateFieldDescription)
self.update_task_description = dspy.Predict(UpdateTaskDescriptionBasedOnFeedback)

self.generate_input_data = GenerateInputFieldsData
self.generate_output_data = GenerateOutputFieldsData

def _gather_feedback(self, examples: dspy.Example) -> str:
if self.config.feedback_mode == "human":
input_keys = examples.inputs().keys()

print("-"*75)
print_text = "[bold blue]Generated Data:[bold blue]\n[bold red]Inputs:[bold red]\n"

for key in input_keys:
print_text += f"\t[bold yellow]{key}[bold yellow]: [green]{examples[key]}[green]\n"

rprint(print_text)
feedback = input("Provide feedback on the generated data: ")
print("-"*75)

return feedback

elif self.config.feedback_mode == "llm":
raise NotImplementedError("Feedback mode 'llm' is not implemented yet.")

else:
raise ValueError("Feedback mode should be either 'human' or 'llm'.")

def _get_field_data(self, key: str, keys_dict: Mapping[str, str]):
if key.startswith("$"):
field_details = self.generate_field_description(
Expand Down Expand Up @@ -137,7 +167,11 @@ def generate(
task_description, input_keys, output_keys = self._get_dataset_metadata(ground_source)

if self.config.num_example_for_optim:
self.generate_input_data.__doc__ = INPUT_GENERATION_TASK_WITH_EXAMPLES
self.generate_input_data.__doc__ += INPUT_GENERATION_TASK_WITH_EXAMPLES_SUFFIX

if self.config.feedback_mode:
self.generate_input_data.__doc__ += INPUT_GENERATION_TASK_WITH_FEEDBACK_SUFFIX

self.generate_output_data.__doc__ = task_description

self.input_predictor, self.output_predictor = self._prepare_synthetic_data_predictors(
Expand All @@ -147,28 +181,23 @@ def generate(
)

data = []
feedback = ""

for idx in trange(0, num_data, batch_size, desc="Generating Synthetic Data"):
iter_temperature = 0.7+0.01*idx
iter_seed = random.randint(0, 1000000)

inputs = None
kwargs = {
"task_description": task_description,
"knowledge_seed": iter_seed,
"config": dict(temperature=iter_temperature, n=batch_size),
}

if self.config.num_example_for_optim:
kwargs["ground_source"] = random.sample(ground_source, self.config.num_example_for_optim)

with dspy.context(lm=self.input_lm):
if self.config.num_example_for_optim:
example_for_optimization = random.sample(ground_source, self.config.num_example_for_optim)
inputs = self.input_predictor(
task_description=task_description,
knowledge_seed=iter_seed,
ground_source=example_for_optimization,
config=dict(temperature=iter_temperature, n=batch_size),
)
else:
inputs = self.input_predictor(
task_description=task_description,
knowledge_seed=iter_seed,
config=dict(temperature=iter_temperature, n=batch_size),
)
inputs = self.input_predictor(**kwargs)

input_kwargs = [{
key: getattr(completions, key)
Expand All @@ -191,6 +220,16 @@ def generate(
}

data.append(dspy.Example(**kwargs, **output_kwargs).with_inputs(*input_keys))

if self.config.feedback_mode and idx < self.config.num_example_for_feedback:
feedback = self._gather_feedback(data[-1])

task_description = self.update_task_description(
task_description=task_description,
feedback=feedback,
).updated_task_description

self.output_predictor.signature.__doc__ = task_description

return data

Expand Down
4 changes: 1 addition & 3 deletions dspy/experimental/synthesizer/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import List

import dspy

from typing import List

def format_examples(examples: List[dspy.Example]) -> str:
if isinstance(examples, str):
Expand Down
25 changes: 21 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ sphinx_rtd_theme = { version = "*", optional = true }
autodoc_pydantic = { version = "*", optional = true }
sphinx-reredirects = { version = "^0.1.2", optional = true }
sphinx-automodapi = { version = "0.16.0", optional = true }
rich = "^13.7.1"


[tool.poetry.group.test.dependencies]
Expand Down

0 comments on commit c26d5ba

Please sign in to comment.