Skip to content

Commit

Permalink
Fix some randomness in evolutionary pareto search not coming from giv…
Browse files Browse the repository at this point in the history
…en seed. (microsoft#225)

Add unit tests to cover this.
  • Loading branch information
lovettchris authored Apr 21, 2023
1 parent e225885 commit 4fc5a6d
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 31 deletions.
9 changes: 6 additions & 3 deletions archai/discrete_search/algos/evolution_pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def sample_models(self, num_models: int, patience: Optional[int] = 5) -> List[Ar

_, valid_indices = self.so.validate_constraints(sample)
valid_sample += [sample[i] for i in valid_indices]
nb_tries += 1

return valid_sample[:num_models]

Expand All @@ -143,6 +144,7 @@ def mutate_parents(
nb_tries = 0

while len(candidates) < mutations_per_parent and nb_tries < patience:
nb_tries += 1
mutated_model = self.search_space.mutate(p)
mutated_model.metadata["parent"] = p.archid

Expand All @@ -152,7 +154,6 @@ def mutate_parents(
if mutated_model.archid not in self.seen_archs:
mutated_model.metadata["generation"] = self.iter_num
candidates[mutated_model.archid] = mutated_model
nb_tries += 1
mutations.update(candidates)

return list(mutations.values())
Expand All @@ -176,7 +177,7 @@ def crossover_parents(
children, children_ids = [], set()

if len(parents) >= 2:
pairs = [random.sample(parents, 2) for _ in range(num_crossovers)]
pairs = [self.rng.sample(parents, 2) for _ in range(num_crossovers)]
for p1, p2 in pairs:
child = self.search_space.crossover([p1, p2])
nb_tries = 0
Expand Down Expand Up @@ -215,7 +216,7 @@ def select_next_population(self, current_pop: List[ArchaiModel]) -> List[ArchaiM
"""

random.shuffle(current_pop)
self.rng.shuffle(current_pop)
return current_pop[: self.max_unseen_population]

@overrides
Expand All @@ -241,6 +242,8 @@ def search(self) -> SearchResults:
logger.info(f"Calculating search objectives {list(self.so.objective_names)} for {len(unseen_pop)} models ...")

results = self.so.eval_all_objs(unseen_pop)
if len(results) == 0:
raise Exception("Search is finding no valid models")
self.search_state.add_iteration_results(
unseen_pop,
results,
Expand Down
89 changes: 79 additions & 10 deletions tests/discrete_search/algos/test_evolution_pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,96 @@

import os

from typing import Optional
from random import Random
import pytest

from overrides import overrides
from archai.discrete_search.algos.evolution_pareto import EvolutionParetoSearch
from archai.discrete_search.api.search_objectives import SearchObjectives
from archai.discrete_search.api.archai_model import ArchaiModel
from archai.discrete_search.api.model_evaluator import ModelEvaluator
from archai.discrete_search.search_spaces.config import (
ArchParamTree, ConfigSearchSpace, DiscreteChoice,
)


class DummyEvaluator(ModelEvaluator):
def __init__(self, rng: Random):
self.dummy = True
self.rng = rng

@overrides
def evaluate(self, model: ArchaiModel, budget: Optional[float] = None) -> float:
return self.rng.random()


@pytest.fixture(scope="session")
def output_dir(tmp_path_factory):
return tmp_path_factory.mktemp("out")

@pytest.fixture
def tree_c2():
c = {
'p1': DiscreteChoice(list([False, True])),
'p2': DiscreteChoice(list([False, True]))
}

return c


def test_evolution_pareto(output_dir, search_space, search_objectives):
algo = EvolutionParetoSearch(search_space, search_objectives, output_dir, num_iters=3, init_num_models=5)
cache = []
for _ in range(2):
algo = EvolutionParetoSearch(search_space, search_objectives, output_dir, num_iters=3, init_num_models=5, seed=42)
search_space.rng = algo.rng

search_results = algo.search()
assert len(os.listdir(output_dir)) > 0

df = search_results.get_search_state_df()
assert all(0 <= x <= 0.4 for x in df["Random1"].tolist())

all_models = [m for iter_r in search_results.results for m in iter_r["models"]]

# Checks if all registered models satisfy constraints
_, valid_models = search_objectives.validate_constraints(all_models)
assert len(valid_models) == len(all_models)

cache += [[m.archid for m in all_models]]

# make sure the archid's returned are repeatable so that search jobs can be restartable.
assert cache[0] == cache[1]


def test_evolution_pareto_tree_search(output_dir, tree_c2):
tree = ArchParamTree(tree_c2)

def use_arch(c):
if c.pick('p1'):
return

if c.pick('p2'):
return

seed = 42

cache = []
for _ in range(2):
search_objectives = SearchObjectives()
search_objectives.add_objective(
'Dummy',
DummyEvaluator(Random(seed)),
higher_is_better=False,
compute_intensive=False)
search_space = ConfigSearchSpace(use_arch, tree, seed=seed)
algo = EvolutionParetoSearch(search_space, search_objectives, output_dir, num_iters=3, init_num_models=5, seed=seed, save_pareto_model_weights=False)

search_results = algo.search()
assert len(os.listdir(output_dir)) > 0
search_results = algo.search()
assert len(os.listdir(output_dir)) > 0

df = search_results.get_search_state_df()
assert all(0 <= x <= 0.4 for x in df["Random1"].tolist())
all_models = [m for iter_r in search_results.results for m in iter_r["models"]]

all_models = [m for iter_r in search_results.results for m in iter_r["models"]]
cache += [[m.archid for m in all_models]]

# Checks if all registered models satisfy constraints
_, valid_models = search_objectives.validate_constraints(all_models)
assert len(valid_models) == len(all_models)
# make sure the archid's returned are repeatable so that search jobs can be restartable.
assert cache[0] == cache[1]
46 changes: 28 additions & 18 deletions tests/discrete_search/search_spaces/config/test_config_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def tree_c2():

def test_param_sharing(rng, tree_c1):
tree = ArchParamTree(tree_c1)

for _ in range(10):
config = tree.sample_config(rng)
p1 = config.pick('param1')
Expand All @@ -68,10 +68,10 @@ def test_repeat_config_share(rng, tree_c1):

for _ in range(10):
config = tree.sample_config(rng)

for param_block in config.pick('param_list'):
par4 = param_block.pick('param4')

assert len(set(
p.pick('constant') for p in par4
)) == 1
Expand All @@ -93,35 +93,47 @@ def test_ss(rng, tree_c2, tmp_path_factory):
tmp_path = tmp_path_factory.mktemp('test_ss')

tree = ArchParamTree(tree_c2)

def use_arch(c):
if c.pick('p1'):
return

if c.pick('p2'):
return

ss = ConfigSearchSpace(use_arch, tree, seed=1)
m = ss.random_sample()
ss.save_arch(m, tmp_path / 'arch.json')

m2 = ss.load_arch(tmp_path / 'arch.json')
assert m.archid == m2.archid
cache = []
for _ in range(2):
ids = []

ss = ConfigSearchSpace(use_arch, tree, seed=1)
m = ss.random_sample()
ids += [m.archid]
ss.save_arch(m, tmp_path / 'arch.json')

m2 = ss.load_arch(tmp_path / 'arch.json')
assert m.archid == m2.archid

m3 = ss.mutate(m)
m4 = ss.crossover([m3, m2])
m3 = ss.mutate(m)

ids += [m3.archid]
m4 = ss.crossover([m3, m2])
ids += [m4.archid]
cache += [ids]

# make sure the archid's returned are repeatable so that search jobs can be restartable.
assert cache[0] == cache[1]


def test_ss_archid(rng, tree_c2):
tree = ArchParamTree(tree_c2)

def use_arch(c):
if c.pick('p1'):
return

if c.pick('p2'):
return

ss = ConfigSearchSpace(use_arch, tree, seed=1)
archids = set()

Expand All @@ -130,5 +142,3 @@ def use_arch(c):
archids.add(config.archid)

assert len(archids) == 3 # Will fail with probability approx 1/2^100


0 comments on commit 4fc5a6d

Please sign in to comment.