Skip to content

Commit

Permalink
add ks,auc
Browse files Browse the repository at this point in the history
  • Loading branch information
Sprate committed Dec 7, 2021
1 parent 524c3dd commit c77425b
Show file tree
Hide file tree
Showing 9 changed files with 840 additions and 16 deletions.
41 changes: 41 additions & 0 deletions core/he/he_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ PYBIND11_MODULE(he_utils, m) {
}
});

// choose scaling_factor 64-bit to meet the float precision requirement
// suppose millions data 20(data size) + 23(float fraction) < 64
m.def("cal_blind_iv", [](const mpz_class & a, const mpz_class &b,
const int64_t &total_pos, const int64_t &total_neg){
double woe = 0.0;
Expand Down Expand Up @@ -162,6 +164,45 @@ PYBIND11_MODULE(he_utils, m) {
mpf_class result = mpf_class(a) / (mpf_class(1) << 128);
return result.get_d();
});

m.def("cal_blind_ks", [](const mpz_class & a, const mpz_class &b,
const int64_t &total_pos, const int64_t &total_neg){

mpz_class scaling_factor = mpz_class(1) << 64;
mpz_class blind_ks = abs(a * mpz_class(scaling_factor / total_pos)
- b * mpz_class(scaling_factor / total_neg));
return blind_ks;
});

m.def("cal_max_ks", [](const std::vector<mpz_class> & ks) {
mpz_class max_ks(-1);
for (auto _ks : ks) {
max_ks = max_ks > _ks ? max_ks : _ks;
}
mpf_class result = mpf_class(max_ks) / (mpf_class(1) << 64);
return result.get_d();
});

m.def("cal_blind_auc", [](const std::vector<mpz_class> &pos,
const std::vector<mpz_class> &neg){
mpz_class auc(0);
for(uint32_t i = 0; i < pos.size(); ++i) {
auc += pos[i] * neg[i];
}
return auc;
});

m.def("cal_unblind_auc", [](const std::vector<mpz_class> &blind_auc,
const int64_t &total_pos,
const int64_t &total_neg){
std::vector<double> auc;
auc.reserve(blind_auc.size());
for(uint32_t i = 0; i < blind_auc.size(); ++i) {
mpf_class temp = mpf_class(blind_auc[i]) / mpf_class(2 * total_pos * total_neg);
auc.emplace_back(temp.get_d());
}
return auc;
});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def connect(self, channel):

def get_positive_ratio(self, labels):
"""
reutrn postive ratio to client
return postive ratio to client
params:
labels: a list in the shape of (sample_size, 1)
labels[i] is either 0 or 1, represents negative and positive resp.
Expand All @@ -51,7 +51,7 @@ def get_positive_ratio(self, labels):

def get_woe(self, labels):
"""
reutrn woe to client
return woe to client
params:
labels: a list in the shape of (sample_size, 1)
labels[i] is either 0 or 1, represents negative and positive resp.
Expand All @@ -65,14 +65,51 @@ def get_woe(self, labels):

def get_iv(self, labels):
"""
reutrn iv to client
return iv to client
params:
labels: a list in the shape of (sample_size, 1)
labels[i] is either 0 or 1, represents negative and positive resp.
e.g. [[1], [0], [1],...,[1]]
return:
an list corresponding to the iv of each feature
a list corresponding to the iv of each feature
e.g. [0.56653, 0.56653]
"""
return mc.get_mpc_iv_alice(self._channel, labels, self._paillier)

def get_woe_iv(self, labels):
"""
return woe, iv to client
params:
labels: a list in the shape of (sample_size, 1)
labels[i] is either 0 or 1, represents negative and positive resp.
e.g. [[1], [0], [1],...,[1]]
return:
a tuple of woe and iv
"""
return mc.get_mpc_iv_alice(self._channel, labels, self._paillier, True)

def get_ks(self, labels):
"""
reutrn ks to client
params:
labels: a list in the shape of (sample_size, 1)
labels[i] is either 0 or 1, represents negative and positive resp.
e.g. [[1], [0], [1],...,[1]]
return:
a list corresponding to the ks of each feature
e.g. [0.3, 0.3]
"""
return mc.get_mpc_ks_alice(self._channel, labels, self._paillier)

def get_auc(self, labels):
"""
reutrn auc to client
params:
labels: a list in the shape of (sample_size, 1)
labels[i] is either 0 or 1, represents negative and positive resp.
e.g. [[1], [0], [1],...,[1]]
return:
a list corresponding to the auc of each feature
e.g. [0.33, 0.33]
"""
return mc.get_mpc_auc_alice(self._channel, labels, self._paillier)
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_positive_ratio(self, features):

def get_woe(self, features):
"""
reutrn woe to server
return woe to server
params:
features: a feature list in the shape of (sample_size, features_size)
e.g. [[4, 3, 1], [1, 2, 5],...,[2, 3 ,2]] (feature_size = 3)
Expand All @@ -68,12 +68,12 @@ def get_woe(self, features):

def get_iv(self, features):
"""
reutrn iv to server
return iv to server
params:
features: a feature list in the shape of (sample_size, features_size)
e.g. [[4, 3, 1], [1, 2, 5],...,[2, 3 ,2]] (feature_size = 3)
return:
an list corresponding to the iv of each feature
a list corresponding to the iv of each feature
e.g. [0.56653, 0.56653]
"""
iv_list = []
Expand All @@ -84,4 +84,64 @@ def get_iv(self, features):
self._server.start()
stop_event.wait()
self._server.stop(90)
return iv_list
return iv_list

def get_woe_iv(self, features):
"""
return woe, iv to server
params:
features: a feature list in the shape of (sample_size, features_size)
e.g. [[4, 3, 1], [1, 2, 5],...,[2, 3 ,2]] (feature_size = 3)
return:
a tuple of woe and iv
"""
woe_list = []
iv_list = []
stop_event = threading.Event()
ms.metrics_pb2_grpc.add_MpcIVServicer_to_server(
ms.MpcIVServicer(features, stop_event, iv_list, woe_list),
self._server)
self._server.start()
stop_event.wait()
self._server.stop(90)
return woe_list, iv_list

def get_ks(self, features):
"""
return ks to server
params:
features: a feature list in the shape of (sample_size, features_size)
e.g. [[4, 3, 1], [1, 2, 5],...,[2, 3 ,2]] (feature_size = 3)
return:
a list corresponding to the ks of each feature
e.g. [0.3, 0.3]
"""
ks_list = []
stop_event = threading.Event()
ms.metrics_pb2_grpc.add_MpcKSServicer_to_server(
ms.MpcKSServicer(features, stop_event, ks_list),
self._server)
self._server.start()
stop_event.wait()
self._server.stop(90)
return ks_list

def get_auc(self, features):
"""
return auc to server
params:
features: a feature list in the shape of (sample_size, features_size)
e.g. [[4, 3, 1], [1, 2, 5],...,[2, 3 ,2]] (feature_size = 3)
return:
a list corresponding to the auc of each feature
e.g. [0.33, 0.33]
"""
auc_list = []
stop_event = threading.Event()
ms.metrics_pb2_grpc.add_MpcAUCServicer_to_server(
ms.MpcAUCServicer(features, stop_event, auc_list),
self._server)
self._server.start()
stop_event.wait()
self._server.stop(90)
return auc_list
Loading

0 comments on commit c77425b

Please sign in to comment.