Skip to content

Commit e02437e

Browse files
committed
rename eggroll to session
1 parent 2c0c275 commit e02437e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+177
-178
lines changed

arch/api/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,4 @@ class NamingPolicy(Enum):
6060

6161

6262
# compatibility
63-
eggroll = session
63+
session = session

arch/api/session.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import os
1818
import typing
1919
import uuid
20-
import warnings
2120
from typing import Iterable
2221

2322
from arch.api import RuntimeInstance

contrib/fate_script/cluster/fate_script.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import uuid
88
from arch.api.cluster.eggroll import _DTable, _EggRoll
9-
from arch.api import eggroll
9+
from arch.api import session
1010
from arch.api.cluster import federation
1111
from arch.api.proto import basic_meta_pb2, federation_pb2, federation_pb2_grpc, storage_basic_pb2
1212
from arch.api.utils import file_utils, eggroll_serdes

contrib/fate_script/fate_script.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pandas as pd
44
import numpy as np
55

6-
from arch.api import eggroll
6+
from arch.api import session
77
from contrib.fate_script import WorkMode
88
from contrib.fate_script import RuntimeInstance
99
from contrib.fate_script.standalone import fate_script as standalone_fate_script
@@ -13,7 +13,7 @@
1313
from federatedml.util.param_checker import AllChecker
1414

1515
def init(job_id, runtime_conf, mode, server_conf_path="arch/conf/server_conf.json"):
16-
eggroll.init(job_id, mode)
16+
session.init(job_id, mode)
1717
print("runtime_conf:{}".format(runtime_conf))
1818
all_checker = AllChecker(runtime_conf)
1919
all_checker.check_all()
@@ -63,7 +63,7 @@ def get_lr_x_table(file_path):
6363
ns = str(uuid.uuid1())
6464
csv_table = pd.read_csv(file_path)
6565
data = pd.read_csv(file_path).values
66-
x = eggroll.table('fata_script_test_data_x_' + str(RuntimeInstance.FEDERATION.role + str(RuntimeInstance.FEDERATION.job_id)), ns, partition=2, persistent=True)
66+
x = session.table('fata_script_test_data_x_' + str(RuntimeInstance.FEDERATION.role + str(RuntimeInstance.FEDERATION.job_id)), ns, partition=2, persistent=True)
6767
if 'y' in list(csv_table.columns.values):
6868
data_index = 2
6969
else:
@@ -77,7 +77,7 @@ def get_lr_y_table(file_path):
7777
ns = str(uuid.uuid1())
7878
csv_table = pd.read_csv(file_path)
7979
data = pd.read_csv(file_path).values
80-
y = eggroll.table('fata_script_test_data_y_' + str(RuntimeInstance.FEDERATION.role) + str(RuntimeInstance.FEDERATION.job_id), ns, partition=2, persistent=True)
80+
y = session.table('fata_script_test_data_y_' + str(RuntimeInstance.FEDERATION.role) + str(RuntimeInstance.FEDERATION.job_id), ns, partition=2, persistent=True)
8181
if 'y' not in list(csv_table.columns.values):
8282
raise RuntimeError("input data must contain y column")
8383
for i in range(np.shape(data)[0]):

contrib/fate_script/standalone/fate_script.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22
import numpy as np
33
import uuid
4-
from arch.api import eggroll
4+
from arch.api import session
55
from arch.api.standalone import federation
66
from arch.api.standalone.federation import FederationRuntime
77
from arch.api.utils import file_utils

contrib/fate_script/test/test_blas.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import json
66
from numbers import Number
77
from arch.api import federation
8-
from arch.api import eggroll
8+
from arch.api import session
99
from arch.api import RuntimeInstance
1010
from arch.api.standalone.federation import FederationRuntime
1111
from arch.api.utils import file_utils
@@ -20,11 +20,11 @@ def test_plain_lr():
2020
from sklearn.datasets import make_moons
2121
import functools
2222
# 修改flow_id 否则内存表可能被覆盖
23-
eggroll.init(mode=0)
23+
session.init(mode=0)
2424
ns = str(uuid.uuid1())
2525

26-
X = eggroll.table('testX7', ns, partition=2)
27-
Y = eggroll.table('testY7', ns, partition=2)
26+
X = session.table('testX7', ns, partition=2)
27+
Y = session.table('testY7', ns, partition=2)
2828

2929
b = np.array([0])
3030
eta = 1.2
@@ -88,12 +88,12 @@ def test_paillier_lr():
8888
cipher.generate_key()
8989

9090
# 修改flow_id 否则内存表可能被覆盖
91-
eggroll.init(mode=0)
91+
session.init(mode=0)
9292
ns = str(uuid.uuid1())
9393
p = True
94-
X_G = eggroll.table('testX7', ns, partition=2,persistent=p)
95-
X_H = eggroll.table('testX7_2', ns, partition=2,persistent=p)
96-
Y = eggroll.table('testY7', ns, partition=2,persistent=p)
94+
X_G = session.table('testX7', ns, partition=2, persistent=p)
95+
X_H = session.table('testX7_2', ns, partition=2, persistent=p)
96+
Y = session.table('testY7', ns, partition=2, persistent=p)
9797

9898
b = np.array([0])
9999
eta = 1.2

examples/hetero_ftl/run_arbiter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import json
1818
import sys
1919
from workflow.hetero_ftl_workflow.hetero_arbiter_workflow import FTLArbiterWorkFlow
20-
from arch.api import eggroll
20+
from arch.api import session
2121
from arch.api import federation
2222
from arch.api.utils import log_utils
2323
LOGGER = log_utils.getLogger()
@@ -37,7 +37,7 @@ def _init_argument(self):
3737

3838
LOGGER.debug("The Arbiter job id is {}".format(job_id))
3939
LOGGER.debug("The Arbiter work mode id is {}".format(self.workflow_param.work_mode))
40-
eggroll.init(job_id, self.workflow_param.work_mode)
40+
session.init(job_id, self.workflow_param.work_mode)
4141
federation.init(job_id, runtime_json)
4242
LOGGER.debug("Finish eggroll and federation init")
4343

examples/hetero_ftl/run_guest.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sys
1919
from workflow.hetero_ftl_workflow.hetero_guest_workflow import FTLGuestWorkFlow
2020
from federatedml.ftl.data_util.uci_credit_card_util import load_guest_host_dtable_from_UCI_Credit_Card
21-
from arch.api import eggroll
21+
from arch.api import session
2222
from arch.api import federation
2323
from arch.api.utils import log_utils
2424
LOGGER = log_utils.getLogger()
@@ -38,14 +38,14 @@ def _init_argument(self):
3838

3939
LOGGER.debug("The Guest job id is {}".format(job_id))
4040
LOGGER.debug("The Guest work mode id is {}".format(self.workflow_param.work_mode))
41-
eggroll.init(job_id, self.workflow_param.work_mode)
41+
session.init(job_id, self.workflow_param.work_mode)
4242
federation.init(job_id, runtime_json)
4343
LOGGER.debug("Finish eggroll and federation init")
4444

4545
def gen_data_instance(self, table_name, namespace, mode="fit"):
4646
data_model_param = self._get_data_model_param()
4747
if data_model_param.is_read_table:
48-
return eggroll.table(table_name, namespace)
48+
return session.table(table_name, namespace)
4949
else:
5050
data_model_param_dict = dict()
5151
data_model_param_dict["file_path"] = data_model_param.file_path
@@ -61,7 +61,7 @@ def gen_validation_data_instance(self, table_name, namespace):
6161
data_model_param = self._get_data_model_param()
6262
valid_data_model_param = self._get_valid_data_model_param()
6363
if valid_data_model_param.is_read_table:
64-
return eggroll.table(table_name, namespace)
64+
return session.table(table_name, namespace)
6565
else:
6666
data_model_param_dict = dict()
6767
data_model_param_dict["file_path"] = valid_data_model_param.file_path

examples/hetero_ftl/run_host.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from workflow.hetero_ftl_workflow.hetero_host_workflow import FTLHostWorkFlow
2121
from federatedml.ftl.data_util.uci_credit_card_util import load_guest_host_dtable_from_UCI_Credit_Card
22-
from arch.api import eggroll
22+
from arch.api import session
2323
from arch.api import federation
2424
from arch.api.utils import log_utils
2525
LOGGER = log_utils.getLogger()
@@ -39,14 +39,14 @@ def _init_argument(self):
3939

4040
LOGGER.debug("The Host job id is {}".format(job_id))
4141
LOGGER.debug("The Host work mode id is {}".format(self.workflow_param.work_mode))
42-
eggroll.init(job_id, self.workflow_param.work_mode)
42+
session.init(job_id, self.workflow_param.work_mode)
4343
federation.init(job_id, runtime_json)
4444
LOGGER.debug("Finish eggroll and federation init")
4545

4646
def gen_data_instance(self, table_name, namespace, mode="fit"):
4747
data_model_param = self._get_data_model_param()
4848
if data_model_param.is_read_table:
49-
return eggroll.table(table_name, namespace)
49+
return session.table(table_name, namespace)
5050
else:
5151
data_model_param_dict = dict()
5252
data_model_param_dict["file_path"] = data_model_param.file_path
@@ -62,7 +62,7 @@ def gen_validation_data_instance(self, table_name, namespace):
6262
data_model_param = self._get_data_model_param()
6363
valid_data_model_param = self._get_valid_data_model_param()
6464
if valid_data_model_param.is_read_table:
65-
return eggroll.table(table_name, namespace)
65+
return session.table(table_name, namespace)
6666
else:
6767
data_model_param_dict = dict()
6868
data_model_param_dict["file_path"] = valid_data_model_param.file_path

fate_flow/utils/download.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#
1616
import os
1717

18-
from arch.api import eggroll,storage
18+
from arch.api import session,storage
1919

2020
from arch.api.utils import log_utils, dtable_utils
2121

@@ -35,7 +35,7 @@ def run(self, component_parameters=None, args=None):
3535
table_name, namespace = dtable_utils.get_table_info(config=self.parameters,
3636
create=False)
3737
job_id = "_".join(self.taskid.split("_")[:2])
38-
eggroll.init(job_id, self.parameters["work_mode"])
38+
session.init(job_id, self.parameters["work_mode"])
3939
with open(os.path.abspath(self.parameters["output_path"]), "w") as fout:
4040
data_table = storage.get_data_table(name=table_name, namespace=namespace)
4141
print('===== begin to export data =====')

federatedml/feature/feature_scale/test/min_max_scale_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from sklearn.preprocessing import MinMaxScaler as MMS
88

9-
from arch.api import eggroll
9+
from arch.api import session
1010
from federatedml.feature.feature_scale.min_max_scale import MinMaxScale
1111
from federatedml.feature.instance import Instance
1212
from federatedml.param.scale_param import ScaleParam
@@ -41,8 +41,8 @@ def print_table(self, table):
4141
print("id:{}, value:{}".format(v[0], v[1].features))
4242

4343
def data_to_eggroll_table(self, data, jobid, partition=1, work_mode=0):
44-
eggroll.init(jobid, mode=work_mode)
45-
data_table = eggroll.parallelize(data, include_key=False)
44+
session.init(jobid, mode=work_mode)
45+
data_table = session.parallelize(data, include_key=False)
4646
return data_table
4747

4848
def sklearn_attribute_format(self, scaler, feature_range):

federatedml/feature/feature_scale/test/standard_scale_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from sklearn.preprocessing import StandardScaler as SSL
88

9-
from arch.api import eggroll
9+
from arch.api import session
1010
from federatedml.feature.feature_scale.standard_scale import StandardScale
1111
from federatedml.feature.instance import Instance
1212
from federatedml.param.scale_param import ScaleParam
@@ -40,8 +40,8 @@ def print_table(self, table):
4040
print(v[1].features)
4141

4242
def data_to_eggroll_table(self, data, jobid, partition=1, work_mode=0):
43-
eggroll.init(jobid, mode=work_mode)
44-
data_table = eggroll.parallelize(data, include_key=False, partition=10)
43+
session.init(jobid, mode=work_mode)
44+
data_table = session.parallelize(data, include_key=False, partition=10)
4545
return data_table
4646

4747
def get_table_instance_feature(self, table_instance):

federatedml/feature/sampler.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17-
from arch.api import eggroll
17+
from arch.api import session
1818
from arch.api import federation
1919
from sklearn.utils import resample
2020
from fate_flow.entity.metric import Metric
@@ -119,7 +119,7 @@ def __sample(self, data_inst, sample_ids=None):
119119
n_samples=sample_num,
120120
random_state=self.random_state)
121121

122-
sample_dtable = eggroll.parallelize(zip(sample_ids, range(len(sample_ids))),
122+
sample_dtable = session.parallelize(zip(sample_ids, range(len(sample_ids))),
123123
include_key=True,
124124
partition=data_inst._partitions)
125125
new_data_inst = data_inst.join(sample_dtable, lambda v1, v2: v1)
@@ -152,7 +152,7 @@ def __sample(self, data_inst, sample_ids=None):
152152
index = id_maps[sample_ids[i]]
153153
new_data.append((i, data_set[index][1]))
154154

155-
new_data_inst = eggroll.parallelize(new_data,
155+
new_data_inst = session.parallelize(new_data,
156156
include_key=True,
157157
partition=data_inst._partitions)
158158

@@ -299,7 +299,7 @@ def __sample(self, data_inst, sample_ids=None):
299299

300300
callback(self.tracker, "stratified", callback_metrics)
301301

302-
sample_dtable = eggroll.parallelize(zip(sample_ids, range(len(sample_ids))),
302+
sample_dtable = session.parallelize(zip(sample_ids, range(len(sample_ids))),
303303
include_key=True,
304304
partition=data_inst._partitions)
305305
new_data_inst = data_inst.join(sample_dtable, lambda v1, v2: v1)
@@ -358,7 +358,7 @@ def __sample(self, data_inst, sample_ids=None):
358358
index = id_maps[sample_ids[i]]
359359
new_data.append((i, data_set[index][1]))
360360

361-
new_data_inst = eggroll.parallelize(new_data,
361+
new_data_inst = session.parallelize(new_data,
362362
include_key=True,
363363
partition=data_inst._partitions)
364364

federatedml/feature/test/bucket_binning_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
import numpy as np
2020

21-
from arch.api import eggroll
21+
from arch.api import session
2222

23-
eggroll.init("123")
23+
session.init("123")
2424

2525
from federatedml.feature.binning.bucket_binning import BucketBinning
2626
from federatedml.feature.instance import Instance
@@ -43,7 +43,7 @@ def setUp(self):
4343
tmp_pair = (str(i), inst)
4444
final_result.append(tmp_pair)
4545
numpy_array.append(tmp)
46-
table = eggroll.parallelize(final_result,
46+
table = session.parallelize(final_result,
4747
include_key=True,
4848
partition=10)
4949

federatedml/feature/test/imputer_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
import unittest
44

5-
from arch.api import eggroll
5+
from arch.api import session
66
from federatedml.feature.imputer import Imputer
77

88

@@ -42,8 +42,8 @@ def print_table(self, table):
4242
print(v[1].features)
4343

4444
def data_to_eggroll_table(self, data, jobid, partition=10, work_mode=0):
45-
eggroll.init(jobid, mode=work_mode)
46-
data_table = eggroll.parallelize(data, include_key=False, partition=partition)
45+
session.init(jobid, mode=work_mode)
46+
data_table = session.parallelize(data, include_key=False, partition=partition)
4747
return data_table
4848

4949
def table_to_list(self, table_instance):

federatedml/feature/test/one_hot_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
import unittest
2121

22-
from arch.api import eggroll
22+
from arch.api import session
2323

24-
eggroll.init("123")
24+
session.init("123")
2525
from federatedml.feature.one_hot_encoder import OneHotEncoder
2626
from federatedml.feature.instance import Instance
2727
import numpy as np
@@ -44,7 +44,7 @@ def setUp(self):
4444
tmp_pair = (str(i), inst)
4545
final_result.append(tmp_pair)
4646

47-
table = eggroll.parallelize(final_result,
47+
table = session.parallelize(final_result,
4848
include_key=True,
4949
partition=10)
5050
table.schema = {"header": self.header}

0 commit comments

Comments
 (0)