-
Notifications
You must be signed in to change notification settings - Fork 11
/
main_hegedus_2021.py
71 lines (64 loc) · 2.61 KB
/
main_hegedus_2021.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from torch.nn.modules.loss import CrossEntropyLoss
from networkx import to_numpy_array
from networkx.generators.random_graphs import random_regular_graph
from gossipy import set_seed
from gossipy.core import UniformDelay, AntiEntropyProtocol, CreateModelMode, StaticP2PNetwork
from gossipy.node import GossipNode, PartitioningBasedNode, SamplingBasedNode
from gossipy.model.handler import PartitionedTMH, SamplingTMH, TorchModelHandler
from gossipy.model.sampling import TorchModelPartition
from gossipy.model.nn import LogisticRegression
from gossipy.data import load_classification_dataset, DataDispatcher
from gossipy.data.handler import ClassificationDataHandler
from gossipy.simul import GossipSimulator, SimulationReport, TokenizedGossipSimulator
from gossipy.flow_control import RandomizedTokenAccount
from gossipy.utils import plot_evaluation
# AUTHORSHIP
__version__ = "0.0.1"
__author__ = "Mirko Polato"
__copyright__ = "Copyright 2022, gossipy"
__license__ = "MIT"
__maintainer__ = "Mirko Polato, PhD"
__email__ = "[email protected]"
__status__ = "Development"
#
set_seed(98765)
X, y = load_classification_dataset("spambase", as_tensor=True)
data_handler = ClassificationDataHandler(X, y, test_size=.1)
dispatcher = DataDispatcher(data_handler, n=100, eval_on_user=False, auto_assign=True)
topology = StaticP2PNetwork(100, to_numpy_array(random_regular_graph(20, 100, seed=42)))
net = LogisticRegression(data_handler.Xtr.shape[1], 2)
nodes = PartitioningBasedNode.generate(
data_dispatcher=dispatcher,
p2p_net=topology,
round_len=100,
model_proto=PartitionedTMH(
net=net,
tm_partition=TorchModelPartition(net, 4),
optimizer=torch.optim.SGD,
optimizer_params={
"lr": 1,
"weight_decay": .001
},
criterion=CrossEntropyLoss(),
create_model_mode=CreateModelMode.UPDATE #CreateModelMode.MERGE_UPDATE
),
sync=True
)
simulator = TokenizedGossipSimulator(
nodes=nodes,
data_dispatcher=dispatcher,
token_account=RandomizedTokenAccount(C=20, A=10),
utility_fun=lambda mh1, mh2, msg: 1, #The utility function is always = 1 (i.e., utility is not used)
delta=100,
protocol=AntiEntropyProtocol.PUSH,
delay=UniformDelay(0, 10),
#online_prob=.2, #Approximates the average online rate of the STUNner's smartphone traces
#drop_prob=.1, #Simulates the possibility of message dropping
sampling_eval=.1
)
report = SimulationReport()
simulator.add_receiver(report)
simulator.init_nodes(seed=42)
simulator.start(n_rounds=1000)
plot_evaluation([[ev for _, ev in report.get_evaluation(False)]], "Overall test results")