Skip to content

Commit

Permalink
Merge pull request stanfordnlp#654 from thomasahle/main
Browse files Browse the repository at this point in the history
Formatting for new optimizers
  • Loading branch information
thomasahle authored Mar 15, 2024
2 parents cc8b193 + aee6baf commit 649ba32
Show file tree
Hide file tree
Showing 12 changed files with 659 additions and 333 deletions.
69 changes: 43 additions & 26 deletions dspy/teleprompt/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,24 @@


class BootstrapFewShot(Teleprompter):
def __init__(self, metric=None, metric_threshold=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=5):
def __init__(
self,
metric=None,
metric_threshold=None,
teacher_settings={},
max_bootstrapped_demos=4,
max_labeled_demos=16,
max_rounds=1,
max_errors=5,
):
self.metric = metric
self.metric_threshold = metric_threshold
self.teacher_settings = teacher_settings

self.max_bootstrapped_demos = max_bootstrapped_demos
self.max_labeled_demos = max_labeled_demos
self.max_rounds = max_rounds
self.max_errors= max_errors
self.max_errors = max_errors
self.error_count = 0
self.error_lock = threading.Lock()

Expand All @@ -59,37 +68,41 @@ def compile(self, student, *, teacher=None, trainset, valset=None):
self.student._suggest_failures = 0

return self.student

def _prepare_student_and_teacher(self, student, teacher):
self.student = student.reset_copy()
self.teacher = teacher.deepcopy() if teacher is not None else student.reset_copy()

assert getattr(self.student, '_compiled', False) is False, "Student must be uncompiled."
assert getattr(self.student, "_compiled", False) is False, "Student must be uncompiled."

if self.max_labeled_demos and getattr(self.teacher, '_compiled', False) is False:
if self.max_labeled_demos and getattr(self.teacher, "_compiled", False) is False:
teleprompter = LabeledFewShot(k=self.max_labeled_demos)
self.teacher = teleprompter.compile(self.teacher.reset_copy(), trainset=self.trainset)

def _prepare_predictor_mappings(self):
name2predictor, predictor2name = {}, {}
student, teacher = self.student, self.teacher

assert len(student.predictors()) == len(teacher.predictors()), "Student and teacher must have the same number of predictors."
assert len(student.predictors()) == len(
teacher.predictors(),
), "Student and teacher must have the same number of predictors."

for (name1, predictor1), (name2, predictor2) in zip(student.named_predictors(), teacher.named_predictors()):
assert name1 == name2, "Student and teacher must have the same program structure."
assert predictor1.signature.equals(predictor2.signature), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}"
assert predictor1.signature.equals(
predictor2.signature,
), f"Student and teacher must have the same signatures. {type(predictor1.signature)} != {type(predictor2.signature)}"
assert id(predictor1) != id(predictor2), "Student and teacher must be different objects."

name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2)
name2predictor[name1] = None # dict(student=predictor1, teacher=predictor2)
predictor2name[id(predictor1)] = name1

# FIXME(shangyint): This is an ugly hack to bind traces of
# retry.module to retry
# if isinstance(predictor1, Retry):
# predictor2name[id(predictor1.module)] = name1

predictor2name[id(predictor2)] = name2
predictor2name[id(predictor2)] = name2

self.name2predictor = name2predictor
self.predictor2name = predictor2name
Expand All @@ -111,8 +124,8 @@ def _bootstrap(self, *, max_bootstraps=None):
if success:
bootstrapped[example_idx] = True

print(f'Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.')
print(f"Bootstrapped {len(bootstrapped)} full traces after {example_idx+1} examples in round {round_idx}.")

# Unbootstrapped training examples

self.validation = [x for idx, x in enumerate(self.trainset) if idx not in bootstrapped]
Expand All @@ -123,10 +136,10 @@ def _bootstrap(self, *, max_bootstraps=None):
# NOTE: Can't yet use evaluate because we need to trace *per example*
# evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12)
# score = evaluate(self.metric, display_table=False, display_progress=True)

def _bootstrap_one_example(self, example, round_idx=0):
name2traces = self.name2traces
teacher = self.teacher #.deepcopy()
teacher = self.teacher # .deepcopy()
predictor_cache = {}

try:
Expand All @@ -145,7 +158,7 @@ def _bootstrap_one_example(self, example, round_idx=0):

for name, predictor in teacher.named_predictors():
predictor.demos = predictor_cache[name]

if self.metric:
metric_val = self.metric(example, prediction, trace)
if self.metric_threshold:
Expand All @@ -162,13 +175,13 @@ def _bootstrap_one_example(self, example, round_idx=0):
current_error_count = self.error_count
if current_error_count >= self.max_errors:
raise e
print(f'Failed to run or to evaluate example {example} with {self.metric} due to {e}.')
print(f"Failed to run or to evaluate example {example} with {self.metric} due to {e}.")

if success:
for step in trace:
predictor, inputs, outputs = step

if 'dspy_uuid' in example:
if "dspy_uuid" in example:
demo = Example(augmented=True, dspy_uuid=example.dspy_uuid, **inputs, **outputs)
else:
# TODO: FIXME: This is a hack. RandomSearch will complain for now in this edge case.
Expand All @@ -177,30 +190,34 @@ def _bootstrap_one_example(self, example, round_idx=0):
try:
predictor_name = self.predictor2name[id(predictor)]
except KeyError as e:
continue # FIXME: !
continue # FIXME: !

# TODO: Look closer into this. It's a bit tricky to reproduce.
print(f'Failed to find predictor {predictor} in {self.predictor2name}.')
print('Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.')
print('Try restarting the notebook, or open an issue.')
raise KeyError(f'Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.') from e
print(f"Failed to find predictor {predictor} in {self.predictor2name}.")
print(
"Are you doing this in a notebook (Jupyter)? This might be caused by redefining values by rerunning cells.",
)
print("Try restarting the notebook, or open an issue.")
raise KeyError(
f"Failed to find predictor {id(predictor)} {predictor} in {self.predictor2name}.",
) from e

name2traces[predictor_name].append(demo)

return success

def _train(self):
rng = random.Random(0)
raw_demos = self.validation

for name, predictor in self.student.named_predictors():
augmented_demos = self.name2traces[name][:self.max_bootstrapped_demos]
augmented_demos = self.name2traces[name][: self.max_bootstrapped_demos]

sample_size = min(self.max_labeled_demos - len(augmented_demos), len(raw_demos))
sample_size = max(0, sample_size)

raw_demos = rng.sample(raw_demos, sample_size)

if dspy.settings.release >= 20230928:
predictor.demos = raw_demos + augmented_demos
else:
Expand Down
Loading

0 comments on commit 649ba32

Please sign in to comment.