Skip to content

Commit

Permalink
Merge branch 'develop-1.5' of https://github.com/FederatedAI/FATE int…
Browse files Browse the repository at this point in the history
…o develop-1.5
  • Loading branch information
talkingwallace committed Oct 10, 2020
2 parents cdc151a + 975e46a commit a559b37
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
10 changes: 6 additions & 4 deletions python/fate_flow/scheduler/dsl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _find_dependencies(self, mode="train", version=1):
self.components[i].set_upstream(self.component_upstream[i])
self.components[i].set_downstream(self.component_downstream[i])

def _init_component_setting(self, setting_conf_prefix, runtime_conf, version=1):
def _init_component_setting(self, setting_conf_prefix, runtime_conf, version=1, redundant_param_check=True):
"""
init top input
"""
Expand All @@ -247,12 +247,14 @@ def _init_component_setting(self, setting_conf_prefix, runtime_conf, version=1):
role_parameters = parameter_util.ParameterUtil.override_parameter(setting_conf_prefix,
runtime_conf,
module,
name)
name,
redundant_param_check=redundant_param_check)
else:
role_parameters = parameter_util.ParameterUtilV2.override_parameter(setting_conf_prefix,
runtime_conf,
module,
name)
name,
redundant_param_check=redundant_param_check)

self.components[idx].set_role_parameters(role_parameters)
else:
Expand Down Expand Up @@ -828,7 +830,7 @@ def run(self, pipeline_dsl=None, pipeline_runtime_conf=None, dsl=None, runtime_c
self._init_component_setting(setting_conf_prefix, self.runtime_conf)
else:
predict_runtime_conf = self.merge_dict(pipeline_runtime_conf, runtime_conf)
self._init_component_setting(setting_conf_prefix, predict_runtime_conf)
self._init_component_setting(setting_conf_prefix, predict_runtime_conf, redundant_param_check=False)

self.args_input, self.args_datakey = parameter_util.ParameterUtil.get_args_input(runtime_conf, module="args")
self._check_args_input()
Expand Down
32 changes: 21 additions & 11 deletions python/fate_flow/utils/parameter_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def change_object_to_dict(cls, obj):

@staticmethod
def _override_parameter(setting_conf_prefix=None, submit_dict=None, module=None,
module_alias=None, version=1):
module_alias=None, version=1, redundant_param_check=True):

_module_setting = ParameterUtil.get_setting_conf(setting_conf_prefix, module, module_alias)

Expand Down Expand Up @@ -90,7 +90,8 @@ def _override_parameter(setting_conf_prefix=None, submit_dict=None, module=None,
role_param_obj,
component=module_alias,
module=module,
version=version)
version=version,
redundant_param_check=redundant_param_check)
runtime_dict[param_class] = merge_dict

if "role_parameters" in submit_dict and role in submit_dict["role_parameters"]:
Expand All @@ -111,7 +112,8 @@ def _override_parameter(setting_conf_prefix=None, submit_dict=None, module=None,
role_num=len(partyid_list),
component=module_alias,
module=module,
version=version)
version=version,
redundant_param_check=redundant_param_check)

runtime_dict[param_class] = merge_dict

Expand All @@ -127,7 +129,8 @@ def _override_parameter(setting_conf_prefix=None, submit_dict=None, module=None,
role_num=len(partyid_list),
component=module_alias,
module=module,
version=version)
version=version,
redundant_param_check=redundant_param_check)
runtime_dict[param_class] = merge_dict

try:
Expand All @@ -149,11 +152,14 @@ def _override_parameter(setting_conf_prefix=None, submit_dict=None, module=None,

@classmethod
def merge_parameters(cls, runtime_dict, role_parameters, param_obj, idx=-1, role=None, role_num=0, component=None,
module=None, version=1):
module=None, version=1, redundant_param_check=True):
param_variables = param_obj.__dict__
for key, val_list in role_parameters.items():
if not redundant_param_check:
if key not in param_variables:
continue

if key not in param_variables:
# continue
raise RedundantParameterError(component=component, module=module, other_info=key)

attr = getattr(param_obj, key)
Expand Down Expand Up @@ -182,7 +188,8 @@ def merge_parameters(cls, runtime_dict, role_parameters, param_obj, idx=-1, role
role_num=role_num,
component=component,
module=module,
version=version)
version=version,
redundant_param_check=redundant_param_check)
setattr(param_obj, key, attr)

return runtime_dict
Expand Down Expand Up @@ -237,16 +244,18 @@ def get_param_object(cls, param_class_path, module, module_alias):

return param_class, param_obj


class ParameterUtil(BaseParameterUtil):
@staticmethod
def override_parameter(setting_conf_prefix=None, submit_dict=None, module=None,
module_alias=None):
module_alias=None, redundant_param_check=True):

return ParameterUtil()._override_parameter(setting_conf_prefix=setting_conf_prefix,
submit_dict=submit_dict,
module=module,
module_alias=module_alias,
version=1)
version=1,
redundant_param_check=redundant_param_check)

@classmethod
def get_args_input(cls, submit_dict, module="args"):
Expand Down Expand Up @@ -295,12 +304,13 @@ def get_args_input(cls, submit_dict, module="args"):
class ParameterUtilV2(BaseParameterUtil):
@classmethod
def override_parameter(cls, setting_conf_prefix=None, submit_dict=None, module=None,
module_alias=None):
module_alias=None, redundant_param_check=True):
return ParameterUtil._override_parameter(setting_conf_prefix=setting_conf_prefix,
submit_dict=submit_dict,
module=module,
module_alias=module_alias,
version=2)
version=2,
redundant_param_check=redundant_param_check)

@classmethod
def get_input_parameters(cls, submit_dict, components=None):
Expand Down

0 comments on commit a559b37

Please sign in to comment.