Skip to content

Commit

Permalink
edit intersect param
Browse files Browse the repository at this point in the history
Signed-off-by: nemirorox <[email protected]>
  • Loading branch information
nemirorox committed Mar 17, 2021
1 parent fe60f09 commit 8fe2ae7
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion examples/dsl/v2/intersect/test_intersect_job_rsa_conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"hash_method": "sha256",
"final_hash_method": "sha256",
"split_calculation": false,
"key_bit": 2048
"key_length": 2048
}
}
},
Expand Down
2 changes: 1 addition & 1 deletion examples/pipeline/intersect/pipeline-intersect-rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def main(config="../../config.yaml", namespace=""):
"rsa_params": {
"hash_method": "sha256",
"final_hash_method": "sha256",
"key_bit": 2048
"key_length": 2048
}
}
intersect_0 = Intersection(name="intersect_0", **param)
Expand Down
10 changes: 5 additions & 5 deletions python/fate_client/pipeline/param/intersect_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,19 @@ class RSAParam(BaseParam):
random_base_fraction: positive float, if not None, generate (fraction * public key id count) of r for encryption and reuse generated r;
note that value greater than 0.99 will be taken as 1, and value less than 0.01 will be rounded up to 0.01
key_bit: positive int, bit count of rsa key, default 1024
key_length: positive int, bit count of rsa key, default 1024
"""

def __init__(self, salt='', hash_method='sha256', final_hash_method='sha256',
split_calculation=False, random_base_fraction=None, key_bit=1024):
split_calculation=False, random_base_fraction=None, key_length=1024):
super().__init__()
self.salt = salt
self.hash_method = hash_method
self.final_hash_method = final_hash_method
self.split_calculation = split_calculation
self.random_base_fraction = random_base_fraction
self.key_bit = key_bit
self.key_length = key_length

def check(self):
if type(self.salt).__name__ != "str":
Expand All @@ -117,8 +117,8 @@ def check(self):
self.check_positive_number(self.random_base_fraction, descr)
self.check_decimal_float(self.random_base_fraction, descr)

descr = "rsa param's key_bit"
self.check_positive_integer(self.key_bit, descr)
descr = "rsa param's key_length"
self.check_positive_integer(self.key_length, descr)

return True

Expand Down
10 changes: 5 additions & 5 deletions python/federatedml/param/intersect_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,19 @@ class RSAParam(BaseParam):
random_base_fraction: positive float, if not None, generate (fraction * public key id count) of r for encryption and reuse generated r;
note that value greater than 0.99 will be taken as 1, and value less than 0.01 will be rounded up to 0.01
key_bit: positive int, bit count of rsa key, default 1024
key_length: positive int, bit count of rsa key, default 1024
"""

def __init__(self, salt='', hash_method='sha256', final_hash_method='sha256',
split_calculation=False, random_base_fraction=None, key_bit=1024):
split_calculation=False, random_base_fraction=None, key_length=1024):
super().__init__()
self.salt = salt
self.hash_method = hash_method
self.final_hash_method = final_hash_method
self.split_calculation = split_calculation
self.random_base_fraction = random_base_fraction
self.key_bit=key_bit
self.key_length=key_length

def check(self):
if type(self.salt).__name__ != "str":
Expand Down Expand Up @@ -122,8 +122,8 @@ def check(self):
self.check_positive_number(self.random_base_fraction, descr)
self.check_decimal_float(self.random_base_fraction, descr)

descr = "rsa param's key_bit"
self.check_positive_integer(self.key_bit, descr)
descr = "rsa param's key_length"
self.check_positive_integer(self.key_length, descr)

LOGGER.debug("Finish RSAParam parameter check!")
return True
Expand Down
4 changes: 2 additions & 2 deletions python/federatedml/statistic/intersect/intersect.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ def generate_rsa_key(rsa_bit=1024):

def generate_protocol_key(self):
if self.role == consts.HOST:
e, d, n = self.generate_rsa_key(self.rsa_params.key_bit)
e, d, n = self.generate_rsa_key(self.rsa_params.key_length)
else:
e, d, n = [], [], []
for i in range(len(self.host_party_id_list)):
e_i, d_i, n_i = self.generate_rsa_key(self.rsa_params.key_bit)
e_i, d_i, n_i = self.generate_rsa_key(self.rsa_params.key_length)
e.append(e_i)
d.append(d_i)
n.append(n_i)
Expand Down

0 comments on commit 8fe2ae7

Please sign in to comment.