-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathfedavg.py
148 lines (132 loc) · 5.5 KB
/
fedavg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# coding: utf-8
import tools
import math
import copy
import torch
from torch import nn
import time
# ---------------------------------------------------------------------------- #
class LocalUpdate_FedAvg(object):
def __init__(self, idx, args, train_set, test_set, model):
self.idx = idx
self.args = args
self.train_data = train_set
self.test_data = test_set
self.device = args.device
self.criterion = nn.CrossEntropyLoss()
self.local_model = model
self.local_model_finetune = copy.deepcopy(model)
self.w_local_keys = self.local_model.classifier_weight_keys
self.agg_weight = self.aggregate_weight()
def aggregate_weight(self):
data_size = len(self.train_data.dataset)
w = torch.tensor(data_size).to(self.device)
return w
def local_test(self, test_loader, test_model=None):
model = self.local_model if test_model is None else test_model
model.eval()
device = self.device
correct = 0
total = len(test_loader.dataset)
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
_, outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
acc = 100.0*correct/total
return acc
def update_local_model(self, global_weight):
self.local_model.load_state_dict(global_weight)
def local_training(self, local_epoch, round=0):
model = self.local_model
model.train()
round_loss = 0
iter_loss = []
model.zero_grad()
grad_accum = []
w0 = tools.get_parameter_values(model)
acc1 = self.local_test(self.test_data)
# Set optimizer for the local updates, default sgd
optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
momentum=0.5, weight_decay=0.0005)
# multiple local epochs
if local_epoch>0:
for ep in range(local_epoch):
data_loader = iter(self.train_data)
iter_num = len(data_loader)
for it in range(iter_num):
images, labels = next(data_loader)
images, labels = images.to(self.device), labels.to(self.device)
model.zero_grad()
_, output = model(images)
loss = self.criterion(output, labels)
loss.backward()
optimizer.step()
iter_loss.append(loss.item())
# multiple local iterations, but less than 1 epoch
else:
data_loader = iter(self.train_data)
iter_num = self.args.local_iter
for it in range(iter_num):
images, labels = next(data_loader)
images, labels = images.to(self.device), labels.to(self.device)
model.zero_grad()
_, output = model(images)
loss = self.criterion(output, labels)
loss.backward()
optimizer.step()
iter_loss.append(loss.item())
# loss value
round_loss1 = iter_loss[0]
round_loss2 = iter_loss[-1]
acc2 = self.local_test(self.test_data)
return model.state_dict(), round_loss1, round_loss2, acc1, acc2
def local_fine_tuning(self, local_epoch, round=0):
model = self.local_model
model.train()
round_loss = 0
iter_loss = []
model.zero_grad()
grad_accum = []
acc1 = self.local_test(self.test_data)
# Set optimizer for the local updates, default sgd
for name, param in model.named_parameters():
if name in self.w_local_keys:
param.requires_grad = True
else:
param.requires_grad = False
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=self.args.lr,
momentum=0.5, weight_decay=0.0005)
# multiple local epochs
if local_epoch>0:
for ep in range(local_epoch):
data_loader = iter(self.train_data)
iter_num = len(data_loader)
for it in range(iter_num):
images, labels = next(data_loader)
images, labels = images.to(self.device), labels.to(self.device)
model.zero_grad()
_, output = model(images)
loss = self.criterion(output, labels)
loss.backward()
optimizer.step()
iter_loss.append(loss.item())
# multiple local iterations, but less than 1 epoch
else:
data_loader = iter(self.train_data)
iter_num = self.args.local_iter
for it in range(iter_num):
images, labels = next(data_loader)
images, labels = images.to(self.device), labels.to(self.device)
model.zero_grad()
_, output = model(images)
loss = self.criterion(output, labels)
loss.backward()
optimizer.step()
iter_loss.append(loss.item())
# loss value
round_loss1 = iter_loss[0]
round_loss2 = iter_loss[-1]
acc2 = self.local_test(self.test_data)
return round_loss1, round_loss2, acc1, acc2