Skip to content

Commit

Permalink
Merge pull request stanfordnlp#1370 from stanfordnlp/extend-save-info
Browse files Browse the repository at this point in the history
Save field metadata in `module.save()` with `add_field_meta` feature flag
  • Loading branch information
arnavsinghvi11 authored Aug 8, 2024
2 parents 780203e + f844059 commit 877ff28
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 11 deletions.
2 changes: 2 additions & 0 deletions docs/docs/building-blocks/6-optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ optimized_program.save(YOUR_SAVE_PATH)

The resulting file is in plain-text JSON format. It contains all the parameters and steps in the source program. You can always read it and see what the optimizer generated.

You can add `save_field_meta` to additionally save the list of fields with the keys, `name`, `field_type`, `description`, and `prefix` with: `optimized_program.save(YOUR_SAVE_PATH, save_field_meta=True).

### Loading a program

To load a program from a file, you can instantiate an object from that class and then call the load method on it.
Expand Down
19 changes: 15 additions & 4 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def reset(self):
self.train = []
self.demos = []

def dump_state(self):
def dump_state(self, save_field_meta=False):
state_keys = ["lm", "traces", "train"]
state = {k: getattr(self, k) for k in state_keys}

Expand All @@ -37,6 +37,19 @@ def dump_state(self):

state["demos"].append(demo)

# If `save_field_meta` save all field metadata as well.
if save_field_meta:
fields = []
for field_key in self.signature.fields.keys():
field_metadata = self.signature.fields[field_key]
fields.append({
"name": field_key,
"field_type": field_metadata.json_schema_extra["__dspy_field_type"],
"description": field_metadata.json_schema_extra["desc"],
"prefix": field_metadata.json_schema_extra["prefix"]
})
state["fields"] = fields

# Cache the signature instructions and the last field's name.
*_, last_key = self.signature.fields.keys()
state["signature_instructions"] = self.signature.instructions
Expand Down Expand Up @@ -196,10 +209,8 @@ def new_generate(lm, signature, example, max_depth=6, **kwargs):

return completions



# TODO: get some defaults during init from the context window?
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.
# Generally, unless overwritten, we'd see n=None, temperature=None.
# That will eventually mean we have to learn them.
# That will eventually mean we have to learn them.
10 changes: 5 additions & 5 deletions dspy/primitives/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def reset_copy(self):

return obj

def dump_state(self):
def dump_state(self, save_field_meta):
print(self.named_parameters())
return {name: param.dump_state() for name, param in self.named_parameters()}
return {name: param.dump_state(save_field_meta) for name, param in self.named_parameters()}

def load_state(self, state):
for name, param in self.named_parameters():
Expand All @@ -127,9 +127,9 @@ def load_state(self, state):
# else:
# raise

def save(self, path):
def save(self, path, save_field_meta=False):
with open(path, "w") as f:
f.write(ujson.dumps(self.dump_state(), indent=2))
f.write(ujson.dumps(self.dump_state(save_field_meta), indent=2))

def load(self, path):
with open(path) as f:
Expand All @@ -147,4 +147,4 @@ def postprocess_parameter_name(name, value):
if name == "_predict":
return "self"

return name
return name
2 changes: 1 addition & 1 deletion dspy/primitives/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def activate_assertions(self, handler=backtrack_handler, **handler_args):
def set_attribute_by_name(obj, name, value):
magicattr.set(obj, name, value)

Program = Module
Program = Module
2 changes: 1 addition & 1 deletion dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,4 +423,4 @@ def infer_prefix(attribute_name: str) -> str:
else:
title_cased_words.append(word.capitalize())

return " ".join(title_cased_words)
return " ".join(title_cased_words)

0 comments on commit 877ff28

Please sign in to comment.