Skip to content

Commit

Permalink
fix bug in smac that configspace cannot deal with complex searchspace (
Browse files Browse the repository at this point in the history
…microsoft#716)

fix bug in smac that configspace cannot deal with complex searchspace. That is converting categorical
  • Loading branch information
QuanluZhang authored Feb 10, 2019
1 parent 60223b2 commit 8e732f2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
13 changes: 10 additions & 3 deletions src/sdk/pynni/nni/smac_tuner/convert_ss_to_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,22 @@ def generate_pcs(nni_search_space_content):
# parameter_name real [min_value, max_value] [default value] log
# https://automl.github.io/SMAC3/stable/options.html
'''
categorical_dict = {}
search_space = nni_search_space_content
with open('param_config_space.pcs', 'w') as pcs_fd:
if isinstance(search_space, dict):
for key in search_space.keys():
if isinstance(search_space[key], dict):
try:
if search_space[key]['_type'] == 'choice':
choice_len = len(search_space[key]['_value'])
pcs_fd.write('%s categorical {%s} [%s]\n' % (
key,
json.dumps(search_space[key]['_value'])[1:-1],
json.dumps(search_space[key]['_value'][0])))
json.dumps(list(range(choice_len)))[1:-1],
json.dumps(0)))
if key in categorical_dict:
raise RuntimeError('%s has already existed, please make sure search space has no duplicate key.' % key)
categorical_dict[key] = search_space[key]['_value']
elif search_space[key]['_type'] == 'randint':
# TODO: support lower bound in randint
pcs_fd.write('%s integer [0, %d] [%d]\n' % (
Expand Down Expand Up @@ -83,6 +88,8 @@ def generate_pcs(nni_search_space_content):
raise RuntimeError('_type or _value error.')
else:
raise RuntimeError('incorrect search space.')
return categorical_dict
return None

def generate_scenario(ss_content):
'''
Expand Down Expand Up @@ -119,7 +126,7 @@ def generate_scenario(ss_content):
sce_fd.write('paramfile = param_config_space.pcs\n')
sce_fd.write('run_obj = quality\n')

generate_pcs(ss_content)
return generate_pcs(ss_content)

if __name__ == '__main__':
generate_scenario('search_space.json')
25 changes: 18 additions & 7 deletions src/sdk/pynni/nni/smac_tuner/smac_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, optimize_mode):
self.first_one = True
self.update_ss_done = False
self.loguniform_key = set()
self.categorical_dict = {}

def _main_cli(self):
'''
Expand Down Expand Up @@ -128,7 +129,9 @@ def update_search_space(self, search_space):
NOTE: updating search space is not supported.
'''
if not self.update_ss_done:
generate_scenario(search_space)
self.categorical_dict = generate_scenario(search_space)
if self.categorical_dict is None:
raise RuntimeError('categorical dict is not correctly returned after parsing search space.')
self.optimizer = self._main_cli()
self.smbo_solver = self.optimizer.solver
self.loguniform_key = {key for key in search_space.keys() if search_space[key]['_type'] == 'loguniform'}
Expand All @@ -152,13 +155,21 @@ def receive_trial_result(self, parameter_id, parameters, value):
else:
self.smbo_solver.nni_smac_receive_runs(self.total_data[parameter_id], reward)

def convert_loguniform(self, challenger_dict):
def convert_loguniform_categorical(self, challenger_dict):
'''
convert the values of type `loguniform` back to their initial range
Convert the values of type `loguniform` back to their initial range
Also, we convert categorical:
categorical values in search space are changed to list of numbers before,
those original values will be changed back in this function
'''
for key, value in challenger_dict.items():
# convert to loguniform
if key in self.loguniform_key:
challenger_dict[key] = np.exp(challenger_dict[key])
# convert categorical back to original value
if key in self.categorical_dict:
idx = challenger_dict[key]
challenger_dict[key] = self.categorical_dict[key][idx]
return challenger_dict

def generate_parameters(self, parameter_id):
Expand All @@ -169,13 +180,13 @@ def generate_parameters(self, parameter_id):
init_challenger = self.smbo_solver.nni_smac_start()
self.total_data[parameter_id] = init_challenger
json_tricks.dumps(init_challenger.get_dictionary())
return self.convert_loguniform(init_challenger.get_dictionary())
return self.convert_loguniform_categorical(init_challenger.get_dictionary())
else:
challengers = self.smbo_solver.nni_smac_request_challengers()
for challenger in challengers:
self.total_data[parameter_id] = challenger
json_tricks.dumps(challenger.get_dictionary())
return self.convert_loguniform(challenger.get_dictionary())
return self.convert_loguniform_categorical(challenger.get_dictionary())

def generate_multiple_parameters(self, parameter_id_list):
'''
Expand All @@ -189,7 +200,7 @@ def generate_multiple_parameters(self, parameter_id_list):
init_challenger = self.smbo_solver.nni_smac_start()
self.total_data[one_id] = init_challenger
json_tricks.dumps(init_challenger.get_dictionary())
params.append(self.convert_loguniform(init_challenger.get_dictionary()))
params.append(self.convert_loguniform_categorical(init_challenger.get_dictionary()))
else:
challengers = self.smbo_solver.nni_smac_request_challengers()
cnt = 0
Expand All @@ -199,6 +210,6 @@ def generate_multiple_parameters(self, parameter_id_list):
break
self.total_data[parameter_id_list[cnt]] = challenger
json_tricks.dumps(challenger.get_dictionary())
params.append(self.convert_loguniform(challenger.get_dictionary()))
params.append(self.convert_loguniform_categorical(challenger.get_dictionary()))
cnt += 1
return params

0 comments on commit 8e732f2

Please sign in to comment.