Skip to content

Commit

Permalink
Added a new method to Module
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Mar 6, 2024
1 parent a2fb536 commit 797f1cd
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
8 changes: 8 additions & 0 deletions dspy/primitives/module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from typing import Generator
import ujson


Expand Down Expand Up @@ -41,6 +42,13 @@ def add_parameter(param_name, param_value):

return named_parameters

def named_sub_modules(self) -> Generator[tuple[str, "BaseModule"], None, None]:
yield "", self
for name, value in self.__dict__.items():
if isinstance(value, BaseModule):
for sub_name, sub_value in value.named_sub_modules():
yield f"{name}.{sub_name}", sub_value

def parameters(self):
return [param for _, param in self.named_parameters()]

Expand Down
61 changes: 49 additions & 12 deletions tests/primitives/test_program.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dspy
from dspy.primitives.module import BaseModule
from dspy.primitives.program import (
Module,
set_attribute_by_name,
Expand All @@ -19,35 +20,27 @@ def forward(self, question):

def test_module_initialization():
module = Module()
assert (
module._compiled is False
), "Module _compiled attribute should be False upon initialization"
assert module._compiled is False, "Module _compiled attribute should be False upon initialization"


def test_named_predictors():
module = HopModule()
named_preds = module.named_predictors()
assert len(named_preds) == 2, "Should identify correct number of Predict instances"
names, preds = zip(*named_preds)
assert (
"predict1" in names and "predict2" in names
), "Named predictors should include 'predict1' and 'predict2'"
assert "predict1" in names and "predict2" in names, "Named predictors should include 'predict1' and 'predict2'"


def test_predictors():
module = HopModule()
preds = module.predictors()
assert len(preds) == 2, "Should return correct number of Predict instances"
assert all(
isinstance(p, dspy.Predict) for p in preds
), "All returned items should be instances of PredictMock"
assert all(isinstance(p, dspy.Predict) for p in preds), "All returned items should be instances of PredictMock"


def test_forward():
program = HopModule()
dspy.settings.configure(
lm=DummyLM({"What is 1+1?": "let me check", "let me check": "2"})
)
dspy.settings.configure(lm=DummyLM({"What is 1+1?": "let me check", "let me check": "2"}))
result = program(question="What is 1+1?").answer
assert result == "2"

Expand All @@ -64,3 +57,47 @@ def __init__(self):
names, _preds = zip(*named_preds)
assert "hop.predict1" in names
assert "hop.predict2" in names


class SubModule(BaseModule):
pass


class AnotherSubModule(BaseModule):
pass


def test_empty_module():
module = BaseModule()
assert list(module.named_sub_modules()) == []


def test_single_level():
module = BaseModule()
module.sub = SubModule()
expected = [("sub", module.sub)]
assert list(module.named_sub_modules()) == expected


def test_multiple_levels():
module = BaseModule()
module.sub = SubModule()
module.sub.subsub = SubModule()
expected = [("sub", module.sub), ("sub.subsub", module.sub.subsub)]
assert list(module.named_sub_modules()) == expected


def test_multiple_sub_modules():
module = BaseModule()
module.sub1 = SubModule()
module.sub2 = SubModule()
expected = [("sub1", module.sub1), ("sub2", module.sub2)]
assert sorted(list(module.named_sub_modules())) == sorted(expected)


def test_non_base_module_attributes():
module = BaseModule()
module.sub = SubModule()
module.not_a_sub = "Not a BaseModule"
expected = [("sub", module.sub)]
assert list(module.named_sub_modules()) == expected

0 comments on commit 797f1cd

Please sign in to comment.