Skip to content

Commit

Permalink
add support of match-id for homo-lr & homo-nn
Browse files Browse the repository at this point in the history
Signed-off-by: weijingchen <[email protected]>
  • Loading branch information
talkingwallace committed Dec 16, 2022
1 parent 27d7a2b commit 8a91fce
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
from federatedml.protobuf.generated import lr_model_param_pb2
from federatedml.model_base import MetricMeta
from fate_arch.session import computing_session
from federatedml.nn.backend.utils.data import get_ret_predict_table
from federatedml.nn.backend.utils.data import get_ret_predict_table, add_match_id
from federatedml.nn.loss.weighted_loss import WeightedBCE
from federatedml.statistic.data_overview import check_with_inst_id


def linear_weight_to_torch(model_weights):
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(self, header, homo_lr_meta, model_weights, param_name, meta_name, *

def export_model_dict(
self,
model,
model=None,
optimizer=None,
model_define=None,
optimizer_define=None,
Expand Down Expand Up @@ -291,6 +292,7 @@ def predict(self, data_instances):
self.init_schema(data_instances)

data_instances = self.align_data_header(data_instances, self.header)
with_inst_id = check_with_inst_id(data_instances)

dataset = self.get_dataset(data_instances)

Expand All @@ -299,13 +301,15 @@ def predict(self, data_instances):
return predict_result

dataset.set_type('predict')

if self.trainer is None:
self.trainer, torch_model, wrap_optimizer, loss = self.init(
dataset, data_instances.partitions)

trainer_ret = self.trainer.predict(dataset)
id_table, pred_table, classes = trainer_ret()

if with_inst_id:
add_match_id(id_table=id_table, dataset_inst=dataset)

id_dtable, pred_dtable = get_ret_predict_table(
id_table, pred_table, classes, data_instances.partitions, computing_session)
ret_table = self.predict_score_to_output(
Expand Down
6 changes: 6 additions & 0 deletions python/federatedml/nn/backend/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,9 @@ def get_ret_predict_table(id_table, pred_table, classes, partitions, computing_s
pred_table, partition=partitions, include_key=True)

return id_dtable, pred_dtable


def add_match_id(id_table: list, dataset_inst: TableDataset):
assert isinstance(dataset_inst, TableDataset), 'when using match id your dataset must be a Table Dataset'
for id_inst in id_table:
id_inst[1].inst_id = dataset_inst.match_ids[id_inst[0]]
8 changes: 5 additions & 3 deletions python/federatedml/nn/homo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fate_arch.session import computing_session
from federatedml.nn.backend.utils.data import get_ret_predict_table
from federatedml.nn.dataset.table import TableDataset
from federatedml.nn.backend.utils.data import add_match_id
from federatedml.protobuf.generated.homo_nn_model_param_pb2 import HomoNNParam as HomoNNParamPB
from federatedml.protobuf.generated.homo_nn_model_meta_pb2 import HomoNNMeta as HomoNNMetaPB

Expand Down Expand Up @@ -337,16 +338,17 @@ def predict(self, cpn_input):
return None

id_table, pred_table, classes = trainer_ret()

if with_inst_id: # set match id
assert isinstance(dataset_inst, TableDataset), 'when using match id your dataset must be a Table Dataset'
for id_inst in id_table:
id_inst[1].inst_id = dataset_inst.match_ids[id_inst[0]]
add_match_id(id_table=id_table, dataset_inst=dataset_inst)

id_dtable, pred_dtable = get_ret_predict_table(
id_table, pred_table, classes, self.partitions, computing_session)
ret_table = self.predict_score_to_output(
id_dtable, pred_dtable, classes)
if schema is not None:
self.set_predict_data_schema(ret_table, schema)

return ret_table

def export_model(self):
Expand Down

0 comments on commit 8a91fce

Please sign in to comment.