Skip to content

Commit

Permalink
Fixes for Module-level ChainOfThought
Browse files Browse the repository at this point in the history
  • Loading branch information
okhat committed Jul 8, 2024
1 parent 730fd13 commit 812ddf9
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions dspy/predict/chain_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,25 @@ def __init__(self, signature, rationale_type=None, activated=True, **config):

rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc)

self.extended_signature = signature.prepend("rationale", rationale_type, type_=str)
self._predict = dspy.Predict(self.extended_signature, **config)
self._predict.extended_signature = self.extended_signature
extended_signature = signature.prepend("rationale", rationale_type, type_=str)
self._predict = dspy.Predict(extended_signature, **config)
self._predict.extended_signature = extended_signature

def forward(self, **kwargs):
assert self.activated in [True, False]

signature = kwargs.pop("new_signature", self.extended_signature if self.activated else self.signature)
signature = kwargs.pop("new_signature", self._predict.extended_signature if self.activated else self.signature)
return self._predict(signature=signature, **kwargs)
# return super().forward(signature=signature, **kwargs)

@property
def demos(self):
return self._predict.demos

@property
def extended_signature(self):
return self._predict.extended_signature

"""
TODO: In principle, we can update the field's prefix during forward too to fill any thing based on the input args.
Expand Down

0 comments on commit 812ddf9

Please sign in to comment.