Skip to content

Commit

Permalink
Merge pull request FederatedAI#4179 from FederatedAI/feature-1.9.0-fa…
Browse files Browse the repository at this point in the history
…te_flow-anonymous

fix anonymous when using tag&svmlight data format
  • Loading branch information
zhihuiwan authored Aug 3, 2022
2 parents 7b8348c + dd03357 commit 3caeedf
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 29 deletions.
28 changes: 18 additions & 10 deletions python/federatedml/util/data_format_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def get_feature_offset(meta):
"""
with_label = meta.get("with_label", False)
with_match_id = meta.get("with_match_id", False)
id_column_num = meta.get("id_column_num", 0)
id_range = meta.get("id_range", 0)

if with_match_id:
if not id_column_num:
id_column_num = 1
if not id_range:
id_range = 1

offset = id_column_num
offset = id_range
if with_label:
offset += 1

Expand Down Expand Up @@ -114,7 +114,7 @@ def generate_header(data, schema):
raise ValueError("Meta not in schema")

meta = schema["meta"]
generated_header = dict(original_index_info=dict())
generated_header = dict(original_index_info=dict(), meta=meta)
input_format = meta.get("input_format")
delimiter = meta.get("delimiter", ",")
if not input_format:
Expand Down Expand Up @@ -178,20 +178,28 @@ def generate_header(data, schema):

with_label = meta.get("with_label", False)
with_match_id = meta.get("with_match_id", False)
id_column_num = meta.get("id_column_num", 0)
id_range = meta.get("id_range", 0)

if id_range and not with_match_id:
raise ValueError(f"id_range {id_range} != 0, with_match_id should be true")

if with_match_id:
if not id_column_num:
id_column_num = 1
if not id_range:
id_range = 1

if id_column_num == 1:
if id_range == 1:
generated_header["match_id_name"] = DEFAULT_MATCH_ID_PREFIX
else:
generated_header["match_id_name"] = [DEFAULT_MATCH_ID_PREFIX + str(i) for i in range(id_column_num)]
generated_header["match_id_name"] = [DEFAULT_MATCH_ID_PREFIX + str(i) for i in range(id_range)]

if with_label:
generated_header["label_name"] = DEFAULT_LABEL_NAME

if id_range:
generated_header["meta"]["id_range"] = id_range

generated_header["is_display"] = False

generated_header["sid"] = schema.get("sid", DEFAULT_SID_NAME)

return generated_header
Expand Down
74 changes: 55 additions & 19 deletions python/federatedml/util/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,23 @@ def __init__(self, data_transform_param):
self.anonymous_generator = None
self.anonymous_header = None

def _update_param(self, meta):
def _update_param(self, schema):
meta = schema["meta"]
self.delimitor = meta.get("delimiter", ",")
self.data_type = meta.get("data_type")
self.with_label = meta.get("with_label", False)
if self.with_label:
self.label_type = meta.get("label_type", "int")
self.label_name = meta.get("label_name", "")
self.with_match_id = meta.get("with_match_id", False)
if self.with_match_id:
match_id_name = schema.get("match_id_name")
if isinstance(match_id_name, list):
self.match_id_name = match_id_name[self.match_id_index]
else:
self.match_id_name = match_id_name

schema["match_id_name"] = self.match_id_name

def read_data(self, input_data, mode="fit"):
LOGGER.info("start to read sparse data and change data to instance")
Expand All @@ -480,13 +489,21 @@ def read_data(self, input_data, mode="fit"):
schema = self.anonymous_generator.generate_anonymous_header(schema)
set_schema(input_data, schema)
else:
self._update_param(schema["meta"])
self._update_param(schema)

if mode == "fit":
self.header = schema["header"]
self.anonymous_header = schema["anonymous_header"]
data_instance = self.fit(input_data)
else:
if not self.anonymous_header:
header_set = set(self.header)
self.anonymous_header = []
for column, anonymous_column in zip(schema["header"], schema["anonymous_header"]):
if column not in header_set:
continue
self.anonymous_header.append(anonymous_column)

schema["header"] = self.header
schema["anonymous_header"] = self.anonymous_header
set_schema(input_data, schema)
Expand All @@ -511,7 +528,7 @@ def transform(self, input_data):
return data_instance

def gen_data_instance(self, input_data, max_feature):
id_range = len(input_data.schema.get("id_list", []))
id_range = input_data.schema["meta"].get("id_range", 0)
params = [self.delimitor, self.data_type,
self.label_type, self.with_match_id,
self.match_id_index, id_range,
Expand Down Expand Up @@ -652,7 +669,8 @@ def __init__(self, data_transform_param):
self.anonymous_generator = None
self.anonymous_header = None

def _update_param(self, meta):
def _update_param(self, schema):
meta = schema["meta"]
self.delimitor = meta.get("delimiter", ",")
self.data_type = meta.get("data_type")
self.tag_with_value = meta.get("tag_with_value")
Expand All @@ -662,6 +680,14 @@ def _update_param(self, meta):
self.label_type = meta.get("label_type", "int")
self.label_name = meta.get("label_name")
self.with_match_id = meta.get("with_match_id", False)
if self.with_match_id:
match_id_name = schema.get("match_id_name")
if isinstance(match_id_name, list):
self.match_id_name = match_id_name[self.match_id_index]
else:
self.match_id_name = match_id_name

schema["match_id_name"] = self.match_id_name

def read_data(self, input_data, mode="fit"):
LOGGER.info("start to read sparse data and change data to instance")
Expand All @@ -682,13 +708,21 @@ def read_data(self, input_data, mode="fit"):
schema = self.anonymous_generator.generate_anonymous_header(schema)
set_schema(input_data, schema)
else:
self._update_param(schema["meta"])
self._update_param(schema)

if mode == "fit":
self.header = schema["header"]
self.anonymous_header = schema["anonymous_header"]
data_instance = self.fit(input_data)
else:
if not self.anonymous_header:
header_set = set(self.header)
self.anonymous_header = []
for column, anonymous_column in zip(schema["header"], schema["anonymous_header"]):
if column not in header_set:
continue
self.anonymous_header.append(anonymous_column)

schema["header"] = self.header
schema["anonymous_header"] = self.anonymous_header
set_schema(input_data, schema)
Expand Down Expand Up @@ -790,7 +824,7 @@ def gen_data_instance(self, input_data, meta, tags_dict):
self.with_label,
self.with_match_id,
self.match_id_index,
meta.get("id_list", []),
meta.get("id_range", 0),
self.label_type,
self.output_format,
tags_dict]
Expand Down Expand Up @@ -820,7 +854,7 @@ def to_instance(param_list, value):
with_label = param_list[4]
with_match_id = param_list[5]
match_id_index = param_list[6]
id_list = param_list[7]
id_range = param_list[7]
label_type = param_list[8]
output_format = param_list[9]
tags_dict = param_list[10]
Expand All @@ -834,7 +868,7 @@ def to_instance(param_list, value):
match_id = None

if with_match_id:
offset = len(id_list)
offset = id_range if id_range else 1
if offset == 0:
offset = 1
match_id = cols[match_id_index]
Expand Down Expand Up @@ -979,39 +1013,41 @@ def load_model(self, model_dict):
if model.endswith("Param"):
self._input_model_param = value[model]

def fit(self, data_inst):
self._load_reader(data_inst.schema)
data_inst = self.transformer.read_data(data_inst, "fit")
def fit(self, data):
self._load_reader(data.schema)
data_inst = self.transformer.read_data(data, "fit")
if isinstance(self.transformer, (DenseFeatureTransformer, SparseTagTransformer)):
summary_buf = self.transformer.get_summary()
if summary_buf:
self.set_summary(summary_buf)

clear_schema(data_inst)
return data_inst

def transform(self, data_inst):
self._load_reader(data_inst.schema)
return self.transformer.read_data(data_inst, "transform")
def transform(self, data):
self._load_reader(data.schema)
data_inst = self.transformer.read_data(data, "transform")
clear_schema(data_inst)
return data_inst

def export_model(self):
model_dict = self.transformer.save_model()
model_dict["DataTransformMeta"].need_run = self.need_run
return model_dict


def clear_schema(schema):
ret_schema = copy.deepcopy(schema)
def clear_schema(data_inst):
ret_schema = copy.deepcopy(data_inst.schema)
key_words = {"sid", "header", "anonymous_header", "label_name",
"anonymous_label", "match_id_name"}
for key in schema:
for key in data_inst.schema:
if key not in key_words:
del ret_schema[key]

return ret_schema
data_inst.schema = ret_schema


def set_schema(data_instance, schema):
schema = clear_schema(schema)
data_instance.schema = schema


Expand Down

0 comments on commit 3caeedf

Please sign in to comment.