Skip to content

Commit

Permalink
Merge pull request FederatedAI#527 from FederatedAI/develop
Browse files Browse the repository at this point in the history
1.0.1
  • Loading branch information
dylan-fan authored Sep 3, 2019
2 parents 0032e08 + 4e22c0e commit 787c67f
Show file tree
Hide file tree
Showing 31 changed files with 235 additions and 1,292 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ Then all you need to do is running the following command:
Please note this works only if you have finished the trainning task.

### Obtain Model and Check Out Results
We provided functions such as tracking component output models or logs etc. through a tool called fate-flow. The deployment and usage of fate-flow can be found [here](./fate_flow/README.md)


## Doc
## API doc
Expand Down
3 changes: 2 additions & 1 deletion arch/api/utils/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def generate_table_name(input_file_path):
table_name = _table_name
eggroll.init(job_id=args.job_id, mode=work_mode)
input_data = read_data(input_file_path, table_name, namespace, head)
data_table = storage.save_data(input_data, name=table_name, namespace=namespace, partition=partition)
in_version = job_config.get('in_version', False)
data_table = storage.save_data(input_data, name=table_name, namespace=namespace, partition=partition, in_version=in_version)
print("------------load data finish!-----------------")
print("file: {}".format(input_file_path))
print("total data_count: {}".format(data_table.count()))
Expand Down
4 changes: 0 additions & 4 deletions arch/api/utils/version_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,3 @@ def get_commit_tmp_table(data_table_namespace):
partition=1, create_if_missing=True, error_if_exist=False)
return version_tmp_table


def get_id_library_table_name():
id_library_info = eggroll.table('info', 'id_library', partition=10, create_if_missing=True, error_if_exist=False)
return id_library_info.get("use_data_id")
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,62 @@ private void init() {
}
}

public String getMySiteLocalAddress() {
if (siteLocalAddress == null) {
Enumeration<NetworkInterface> networkInterfaces = null;
try {
networkInterfaces = NetworkInterface.getNetworkInterfaces();

for (NetworkInterface ni : Collections.list(networkInterfaces)) {
Enumeration<InetAddress> inetAddresses = ni.getInetAddresses();
for (InetAddress ia : Collections.list(inetAddresses)) {
if (ia.isSiteLocalAddress()) {
siteLocalAddress = StringUtils.substringAfterLast(ia.toString(), "/");
}
}
}
} catch (SocketException e) {
siteLocalAddress = "127.0.0.1";

/**
* validation of InetAddress
* @param nif Network Interface
* @param adr Internet Protocol (IP) address
* @return valid or in valid
* @throws SocketException
*/
private boolean checkAddrValid(NetworkInterface nif, InetAddress adr)
throws SocketException {
return adr != null && !adr.isLoopbackAddress() && (nif.isPointToPoint() || !adr.isLinkLocalAddress());
}

/**
* get valid local ip address
* @return Internet Protocol (IP) address
* @throws SocketException
* @throws UnknownHostException
*/
private InetAddress getLocalAdr() throws SocketException, UnknownHostException {
String os = System.getProperty("os.name").toLowerCase();
if (os.contains("nix") || os.contains("nux")) {
Enumeration<NetworkInterface> nifs = null;
nifs = NetworkInterface.getNetworkInterfaces();

if (null == nifs) {
return null;
}
while (nifs.hasMoreElements()) {
NetworkInterface nif = nifs.nextElement();
Enumeration<InetAddress> adrs = nif.getInetAddresses();

while (adrs.hasMoreElements()) {
InetAddress adr = adrs.nextElement();
if (checkAddrValid(nif, adr)) {
return adr;
}
}
}
} else {
return InetAddress.getLocalHost();
}
return null;
}


return siteLocalAddress;
public String getMySiteLocalAddress() {
if (siteLocalAddress == null) {
try {
InetAddress inetAddress = getLocalAdr();
siteLocalAddress = null == inetAddress ? "127.0.0.1" : inetAddress.getHostAddress();
} catch (SocketException | UnknownHostException e) {
siteLocalAddress = "127.0.0.1";
}
}
return siteLocalAddress;
}

public String getMySiteLocalIpAndPort() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ coordinator=webank
ip=127.0.0.1
port=8000
workMode=0
serviceRoleName=serving
inferenceWorkerThreadNum=10
#storage
# maybe python/data/
Expand Down
15 changes: 6 additions & 9 deletions doc/upload_data_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,12 @@ sorting the occurred tags by lexicographic order, then fill the occur item with
Here is an example showing how to create a upload config file:
```
{
"file": "examples/data/breast_b.csv",
"head": 1,
"partition": 10,
"local": {
"party_id": 10000,
"role": "guest"
},
"table_name": "hetero_feature_selection_host",
"namespace": "hetero_feature_selection"
"file": "examples/data/breast_b.csv",
"head": 1,
"partition": 10,
"work_mode": 0,
"table_name": "hetero_breast_b",
"namespace": "hetero_guest_breast"
}
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ protected ReturnResult getFederatedPredictFromRemote(Context context,FederatedP

metaDataBuilder.setSrc(
topicBuilder.setPartyId(String.valueOf(srcParty.getPartyId())).
setRole("serving")
setRole(Configuration.getProperty("serviceRoleName", "serving"))
.setName("partnerPartyName")
.build());
metaDataBuilder.setDst(
topicBuilder.setPartyId(String.valueOf(dstParty.getPartyId()))
.setRole("serving")
.setRole(Configuration.getProperty("serviceRoleName", "serving"))
.setName("partyName")
.build());
metaDataBuilder.setCommand(Proxy.Command.newBuilder().setName("federatedInference").build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ coordinator=webank
ip=127.0.0.1
port=8000
workMode=0
serviceRoleName=serving
inferenceWorkerThreadNum=10
#storage
# maybe python/data/
Expand Down
100 changes: 54 additions & 46 deletions federatedml/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,57 +107,65 @@ def _run_data(self, data_sets=None, stage=None):
else:
LOGGER.warning("Evaluation has not transform, return")

def split_data_with_type(self, data: list) -> dict:
split_result = defaultdict(list)
for value in data:
mode = value[1][4]
split_result[mode].append(value)

return split_result

def evaluate_metircs(self, mode: str, data: list) -> dict:
labels = []
pred_scores = []
pred_labels = []

for d in data:
labels.append(d[1][0])
pred_labels.append(d[1][1])
pred_scores.append(d[1][2])

if self.eval_type == consts.BINARY or self.eval_type == consts.REGRESSION:
if self.pos_label and self.eval_type == consts.BINARY:
new_labels = []
for label in labels:
if self.pos_label == label:
new_labels.append(1)
else:
new_labels.append(0)
labels = new_labels

pred_results = pred_scores
else:
pred_results = pred_labels

eval_result = defaultdict(list)

if self.eval_type in self.metrics:
metrics = self.metrics[self.eval_type]
else:
LOGGER.warning("Unknown eval_type of {}".format(self.eval_type))
metrics = []

for eval_metric in metrics:
res = getattr(self, eval_metric)(labels, pred_results)
if res:
eval_result[eval_metric].append(mode)
eval_result[eval_metric].append(res)

return eval_result

def fit(self, data):
if len(data) <= 0:
return

self.eval_results.clear()
for (key, eval_data) in data.items():
eval_data_local = list(eval_data.collect())

labels = []
pred_scores = []
pred_labels = []

data_type = key
mode = "eval"
if len(eval_data_local[0][1]) >= 5:
mode = eval_data_local[0][1][4]

for d in eval_data_local:
labels.append(d[1][0])
pred_labels.append(d[1][1])
pred_scores.append(d[1][2])

if self.eval_type == consts.BINARY or self.eval_type == consts.REGRESSION:
if self.pos_label and self.eval_type == consts.BINARY:
new_labels = []
for label in labels:
if self.pos_label == label:
new_labels.append(1)
else:
new_labels.append(0)
labels = new_labels

pred_results = pred_scores
else:
pred_results = pred_labels

eval_result = defaultdict(list)

if self.eval_type in self.metrics:
metrics = self.metrics[self.eval_type]
else:
LOGGER.warning("Unknown eval_type of {}".format(self.eval_type))
metrics = []

for eval_metric in metrics:
res = getattr(self, eval_metric)(labels, pred_results)
if res:
eval_result[eval_metric].append(mode)
eval_result[eval_metric].append(res)

self.eval_results[data_type] = eval_result
split_data_with_label = self.split_data_with_type(eval_data_local)
for mode, data in split_data_with_label.items():
eval_result = self.evaluate_metircs(mode, data)
self.eval_results[key] = eval_result

self.callback_metric_data()

Expand Down Expand Up @@ -220,7 +228,7 @@ def __save_roc(self, data_type, metric_name, metric_namespace, metric_res):
# set roc edge value
fpr.append(1.0)
tpr.append(1.0)

fpr, tpr, idx_list = self.__filt_override_unit_ordinate_coordinate(fpr, tpr)
edge_idx = idx_list[-1]
if edge_idx == len(thresholds):
Expand Down Expand Up @@ -498,7 +506,6 @@ def roc(self, labels, pred_scores):
fpr, tpr, thresholds = roc_curve(np.array(labels), np.array(pred_scores), drop_intermediate=1)
fpr, tpr, thresholds = list(map(float, fpr)), list(map(float, tpr)), list(map(float, thresholds))


filt_thresholds, cuts = self.__filt_threshold(thresholds=thresholds, step=0.01)
new_thresholds = []
new_tpr = []
Expand Down Expand Up @@ -839,6 +846,7 @@ class BiClassPrecision(object):
"""
Compute binary classification precision
"""

def __init__(self):
self.total_positives = 0

Expand Down
24 changes: 15 additions & 9 deletions federatedml/feature/binning/base_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,20 @@ def convert_feature_to_bin(self, data_instances, transform_cols_idx=-1, split_po
split_points = self.split_points

is_sparse = data_overview.is_sparse_data(data_instances)

if is_sparse:
f = functools.partial(self._convert_sparse_data,
transform_cols_idx=transform_cols_idx,
split_points_dict=split_points,
header=self.header)
header=self.header,
abnormal_list=self.abnormal_list)
new_data = data_instances.mapValues(f)
else:
f = functools.partial(self._convert_dense_data,
transform_cols_idx=transform_cols_idx,
split_points_dict=split_points,
header=self.header)
header=self.header,
abnormal_list=self.abnormal_list)
new_data = data_instances.mapValues(f)
new_data.schema = {"header": self.header}
bin_sparse = self.get_sparse_bin(transform_cols_idx, split_points)
Expand All @@ -263,16 +266,17 @@ def convert_feature_to_bin(self, data_instances, transform_cols_idx=-1, split_po
return new_data, split_points_result, bin_sparse

@staticmethod
def _convert_sparse_data(instances, transform_cols_idx, split_points_dict, header):
def _convert_sparse_data(instances, transform_cols_idx, split_points_dict, header, abnormal_list):
all_data = instances.features.get_all_data()
data_shape = instances.features.get_shape()
indice = []
sparse_value = []
# print("In _convert_sparse_data, transform_cols_idx: {}, header: {}, split_points_dict: {}".format(
# transform_cols_idx, header, split_points_dict
# ))

for col_idx, col_value in all_data:
if col_idx in transform_cols_idx:
if col_value in abnormal_list:
indice.append(col_idx)
sparse_value.append(col_value)
elif col_idx in transform_cols_idx:
col_name = header[col_idx]
split_points = split_points_dict[col_name]
bin_num = Binning.get_bin_num(col_value, split_points)
Expand All @@ -299,10 +303,12 @@ def get_sparse_bin(self, transform_cols_idx, split_points_dict):
return result

@staticmethod
def _convert_dense_data(instances, transform_cols_idx, split_points_dict, header):
def _convert_dense_data(instances, transform_cols_idx, split_points_dict, header, abnormal_list):
features = instances.features
for col_idx, col_value in enumerate(features):
if col_idx in transform_cols_idx:
if col_value in abnormal_list:
features[col_idx] = col_value
elif col_idx in transform_cols_idx:
col_name = header[col_idx]
split_points = split_points_dict[col_name]
bin_num = Binning.get_bin_num(col_value, split_points)
Expand Down
2 changes: 2 additions & 0 deletions federatedml/feature/quantile_summaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def set_total_count(self, total_count):
self._total_count = total_count

def insert(self, x):
if x in self.abnormal_list:
return
if x < consts.FLOAT_ZERO:
self.smaller_num += 1
elif x >= consts.FLOAT_ZERO:
Expand Down
4 changes: 2 additions & 2 deletions federatedml/feature/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,15 +510,15 @@ def run_sample(self, data_inst, task_type, task_role):
return sample_data_inst

def fit(self, data_inst):
return self.run_sample(data_inst, self.task_type, self,task_role)
return self.run_sample(data_inst, self.task_type, self.task_role)

def transform(self, data_inst):
return self.run_sample(data_inst, self.task_type, self.task_role)

def run(self, component_parameters, args=None):
self._init_runtime_parameters(component_parameters)
self._init_role(component_parameters)
stage = None
stage = "fit"
if args.get("data", None) is None:
return

Expand Down
Loading

0 comments on commit 787c67f

Please sign in to comment.