Skip to content

Commit

Permalink
Merge pull request FederatedAI#3067 from FederatedAI/feature-1.6.1-ds…
Browse files Browse the repository at this point in the history
…l_parser_support_single_predict

Feature 1.6.1 dsl parser support single predict
  • Loading branch information
mgqa34 authored Sep 2, 2021
2 parents c1220e1 + 55f6987 commit 4404222
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 5 deletions.
7 changes: 2 additions & 5 deletions python/fate_flow/scheduler/dsl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,10 +479,7 @@ def prepare_graph_dependency_info(self):
"component_module": component_module,
"component_need_run": {}}

if self.mode == "train":
runtime_conf = self.runtime_conf
else:
runtime_conf = self.pipeline_runtime_conf
runtime_conf = self.runtime_conf

self.graph_dependency = {}
for role in runtime_conf["role"]:
Expand Down Expand Up @@ -993,7 +990,7 @@ def run(self, pipeline_runtime_conf=None, dsl=None, runtime_conf=None,
self._init_component_setting(setting_conf_prefix, self.runtime_conf, version=2)
self.job_parameters = parameter_util.ParameterUtilV2.get_job_parameters(self.runtime_conf)
else:
predict_runtime_conf = parameter_util.ParameterUtilV2.merge_dict(pipeline_runtime_conf, runtime_conf)
predict_runtime_conf = parameter_util.ParameterUtilV2.get_predict_runtime_conf(pipeline_runtime_conf, runtime_conf)
self._init_component_setting(setting_conf_prefix, predict_runtime_conf, version=2)
self.job_parameters = parameter_util.ParameterUtilV2.get_job_parameters(predict_runtime_conf)

Expand Down
65 changes: 65 additions & 0 deletions python/fate_flow/utils/parameter_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,68 @@ def get_job_parameters(submit_dict):
ret[role][partyid_list[idx]] = parameters

return ret

@classmethod
def get_predict_runtime_conf(cls, train_conf, predict_conf):
runtime_conf = copy.deepcopy(train_conf)
train_role = train_conf.get("role")
predict_role = predict_conf.get("role")
if len(train_conf) < len(predict_role):
raise ValueError(f"Predict roles is {predict_role}, train roles is {train_conf}, "
"predict roles should be subset of train role")

for role in train_role:
if role not in predict_role:
del runtime_conf["role"][role]

if runtime_conf.get("job_parameters", {}).get("role", {}).get(role):
del runtime_conf["job_parameters"]["role"][role]

if runtime_conf.get("component_parameters", {}).get("role", {}).get(role):
del runtime_conf["component_parameters"]["role"][role]

continue

train_party_ids = train_role[role]
predict_party_ids = predict_role[role]

diff = False
for idx, party_id in enumerate(predict_party_ids):
if party_id not in train_party_ids:
raise ValueError(f"Predict role: {role} party_id: {party_id} not occurs in training")
if train_party_ids[idx] != party_id:
diff = True

if not diff and len(train_party_ids) == len(predict_party_ids):
continue

for p_type in ["job_parameters", "component_parameters"]:
if not runtime_conf.get(p_type, {}).get("role", {}).get(role):
continue

conf = runtime_conf[p_type]["role"][role]
party_keys = conf.keys()
new_conf = {}
for party_key in party_keys:
party_list = party_key.split("|", -1)
new_party_list = []
for party in party_list:
party_id = train_party_ids[int(party)]
if party_id in predict_party_ids:
new_idx = predict_party_ids.index(party_id)
new_party_list.append(str(new_idx))

if not new_party_list:
continue

new_party_key = new_party_list[0] if len(new_party_list) == 1 else "|".join(new_party_list)

if new_party_key not in new_conf:
new_conf[new_party_key] = {}
new_conf[new_party_key].update(conf[party_key])

runtime_conf[p_type]["role"][role] = new_conf

runtime_conf = cls.merge_dict(runtime_conf, predict_conf)

return runtime_conf

0 comments on commit 4404222

Please sign in to comment.