-
Notifications
You must be signed in to change notification settings - Fork 223
/
Copy pathtest_fedsageplus.py
79 lines (61 loc) · 2.64 KB
/
test_fedsageplus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from federatedscope.core.auxiliaries.data_builder import get_data
from federatedscope.core.auxiliaries.utils import setup_seed
from federatedscope.core.auxiliaries.logging import update_logger
from federatedscope.core.configs.config import global_cfg
from federatedscope.core.auxiliaries.runner_builder import get_runner
from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls
class FedSagePlusTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def set_config_fedsageplus(self, cfg):
backup_cfg = cfg.clone()
import torch
cfg.use_gpu = torch.cuda.is_available()
cfg.federate.mode = 'standalone'
cfg.federate.make_global_eval = True
cfg.federate.client_num = 3
cfg.federate.total_round_num = 10
cfg.federate.method = 'fedsageplus'
cfg.train.batch_or_epoch = 'epoch'
cfg.data.root = 'test_data/'
cfg.data.type = 'cora'
cfg.data.splitter = 'louvain'
cfg.dataloader.type = 'pyg'
cfg.dataloader.batch_size = 1
cfg.model.type = 'sage'
cfg.model.hidden = 64
cfg.model.dropout = 0.5
cfg.model.out_channels = 7
cfg.fedsageplus.num_pred = 5
cfg.fedsageplus.gen_hidden = 64
cfg.fedsageplus.hide_portion = 0.5
cfg.fedsageplus.fedgen_epoch = 2
cfg.fedsageplus.loc_epoch = 1
cfg.fedsageplus.a = 1.0
cfg.fedsageplus.b = 1.0
cfg.fedsageplus.c = 1.0
cfg.criterion.type = 'CrossEntropyLoss'
cfg.trainer.type = 'nodefullbatch_trainer'
cfg.eval.metrics = ['acc', 'correct']
return backup_cfg
def test_fedsageplus_standalone(self):
init_cfg = global_cfg.clone()
backup_cfg = self.set_config_fedsageplus(init_cfg)
setup_seed(init_cfg.seed)
update_logger(init_cfg, True)
data, modified_cfg = get_data(init_cfg.clone())
init_cfg.merge_from_other_cfg(modified_cfg)
self.assertIsNotNone(data)
Fed_runner = get_runner(data=data,
server_class=get_server_cls(init_cfg),
client_class=get_client_cls(init_cfg),
config=init_cfg.clone())
self.assertIsNotNone(Fed_runner)
test_best_results = Fed_runner.run()
init_cfg.merge_from_other_cfg(backup_cfg)
self.assertGreater(test_best_results["server_global_eval"]['test_acc'],
0.7)
if __name__ == '__main__':
unittest.main()