forked from shaoxiongji/federated-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
40a88a6
commit d731a2d
Showing
5 changed files
with
355 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Python version: 3.6 | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, dim_in, dim_hidden, dim_out): | ||
super(MLP, self).__init__() | ||
self.layer_input = nn.Linear(dim_in, dim_hidden) | ||
self.norm1d = nn.BatchNorm1d(num_features=dim_hidden) | ||
self.relu = nn.ReLU() | ||
self.dropout = nn.Dropout(p=0.2) | ||
self.layer_hidden = nn.Linear(dim_hidden, dim_out) | ||
self.softmax = nn.Softmax(dim=1) | ||
|
||
def forward(self, x): | ||
x = x.view(-1, x.shape[-2]*x.shape[-1]) | ||
x = self.layer_input(x) | ||
x = self.norm1d(x) | ||
x = self.relu(x) | ||
x = self.dropout(x) | ||
x = self.layer_hidden(x) | ||
return self.softmax(x) | ||
|
||
|
||
# class CNN(nn.Module): | ||
# def __init__(self, args): | ||
# super(CNN, self).__init__() | ||
# self.num_classes = args.num_classes | ||
# | ||
# self.layer1 = nn.Sequential( | ||
# nn.Conv2d(1, 16, kernel_size=5, padding=2), | ||
# nn.BatchNorm2d(16), | ||
# nn.ReLU(), | ||
# nn.MaxPool2d(2)) | ||
# self.layer2 = nn.Sequential( | ||
# nn.Conv2d(16, 32, kernel_size=5, padding=2), | ||
# nn.BatchNorm2d(32), | ||
# nn.ReLU(), | ||
# nn.MaxPool2d(2)) | ||
# self.fc = nn.Linear(32, self.num_classes) | ||
# self.softmax = nn.Softmax(dim=1) | ||
# | ||
# def forward(self, x): | ||
# x = self.layer1(x) | ||
# x = self.layer2(x) | ||
# x = torch.mean(x, -1) | ||
# x = torch.mean(x, -1) | ||
# x = self.fc(x) | ||
# return self.softmax(x) | ||
|
||
|
||
class CNN(nn.Module): | ||
def __init__(self, args): | ||
super(CNN, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||
self.conv2_drop = nn.Dropout2d() | ||
self.fc1 = nn.Linear(320, 50) | ||
self.fc2 = nn.Linear(50, args.num_classes) | ||
|
||
def forward(self, x): | ||
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | ||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | ||
x = x.view(-1, 320) | ||
x = F.relu(self.fc1(x)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Python version: 3.6 | ||
|
||
import torch | ||
from torch import nn, autograd | ||
from torch.utils.data import DataLoader, Dataset | ||
import numpy as np | ||
from sklearn import metrics | ||
|
||
|
||
class DatasetSplit(Dataset): | ||
def __init__(self, dataset, idxs): | ||
self.dataset = dataset | ||
self.idxs = list(idxs) | ||
|
||
def __len__(self): | ||
return len(self.idxs) | ||
|
||
def __getitem__(self, item): | ||
image, label = self.dataset[self.idxs[item]] | ||
return image, label | ||
|
||
|
||
class LocalUpdate(object): | ||
def __init__(self, args, dataset, idxs, tb): | ||
self.args = args | ||
self.loss_func = nn.NLLLoss() | ||
self.ldr_train, self.ldr_val, self.ldr_test = self.train_val_test(dataset, list(idxs)) | ||
self.tb = tb | ||
|
||
def train_val_test(self, dataset, idxs): | ||
# split train, validation, and test | ||
idxs_train = idxs[:420] | ||
idxs_val = idxs[420:480] | ||
idxs_test = idxs[480:] | ||
train = DataLoader(DatasetSplit(dataset, idxs_train), batch_size=self.args.local_bs, shuffle=False) | ||
val = DataLoader(DatasetSplit(dataset, idxs_val), batch_size=len(idxs_val), shuffle=False) | ||
test = DataLoader(DatasetSplit(dataset, idxs_test), batch_size=len(idxs_test), shuffle=False) | ||
return train, val, test | ||
|
||
def update_weights(self, net): | ||
# train and update | ||
optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, weight_decay=2) | ||
|
||
epoch_loss = [] | ||
for iter in range(self.args.local_ep): | ||
batch_loss = [] | ||
for batch_idx, (images, labels) in enumerate(self.ldr_train): | ||
if self.args.gpu != -1: | ||
images, labels = images.cuda(), labels.cuda() | ||
images, labels = autograd.Variable(images), autograd.Variable(labels) | ||
net.zero_grad() | ||
log_probs = net(images) | ||
loss = self.loss_func(log_probs, labels) | ||
loss.backward() | ||
optimizer.step() | ||
if self.args.gpu != -1: | ||
loss = loss.cpu() | ||
if batch_idx % 1 == 0: | ||
print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||
iter, batch_idx * len(images), len(self.ldr_train.dataset), | ||
100. * batch_idx / len(self.ldr_train), loss.data[0])) | ||
self.tb.add_scalar('loss', loss.data[0]) | ||
batch_loss.append(loss.data[0]) | ||
epoch_loss.append(sum(batch_loss)/len(batch_loss)) | ||
return net.state_dict(), sum(epoch_loss) / len(epoch_loss) | ||
|
||
def test(self, net): | ||
optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, weight_decay=2) | ||
for iter in range(self.args.local_ep): | ||
for batch_idx, (images, labels) in enumerate(self.ldr_train): | ||
if self.args.gpu != -1: | ||
images, labels = images.cuda(), labels.cuda() | ||
images, labels = autograd.Variable(images), autograd.Variable(labels) | ||
net.zero_grad() | ||
log_probs = net(images) | ||
loss = self.loss_func(log_probs, labels) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
for batch_idx, (images, labels) in enumerate(self.ldr_test): | ||
if self.args.gpu != -1: | ||
images, labels = images.cuda(), labels.cuda() | ||
images, labels = autograd.Variable(images), autograd.Variable(labels) | ||
log_probs = net(images) | ||
loss = self.loss_func(log_probs, labels) | ||
if self.args.gpu != -1: | ||
loss = loss.cpu() | ||
log_probs = log_probs.cpu() | ||
labels = labels.cpu() | ||
y_pred = np.argmax(log_probs.data, axis=1) | ||
acc = metrics.accuracy_score(y_true=labels.data, y_pred=y_pred) | ||
return acc, loss.data[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Python version: 3.6 | ||
|
||
import copy | ||
import torch | ||
|
||
|
||
def average_weights(w): | ||
w_avg = copy.deepcopy(w[0]) | ||
for k in w_avg.keys(): | ||
for i in range(1, len(w)): | ||
w_avg[k] += w[i][k] | ||
w_avg[k] = torch.div(w_avg[k], len(w)) | ||
return w_avg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Python version: 3.6 | ||
|
||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
import os | ||
import copy | ||
import numpy as np | ||
from torchvision import datasets, transforms | ||
from tqdm import tqdm | ||
import random | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.utils.data import Dataset, DataLoader | ||
from torch import nn, autograd | ||
from sklearn import metrics | ||
from tensorboardX import SummaryWriter | ||
|
||
from options import args_parser | ||
from Update import LocalUpdate | ||
from FedNets import MLP, CNN | ||
from averaging import average_weights | ||
|
||
|
||
def test(net_g, data_loader, args): | ||
# testing | ||
test_loss = 0 | ||
correct = 0 | ||
l = len(data_loader) | ||
for idx, (data, target) in enumerate(data_loader): | ||
if args.gpu != -1: | ||
data, target = data.cuda(), target.cuda() | ||
data, target = autograd.Variable(data), autograd.Variable(target) | ||
log_probs = net_g(data) | ||
test_loss += F.nll_loss(log_probs, target, size_average=False).data[0] # sum up batch loss | ||
y_pred = log_probs.data.max(1, keepdim=True)[1] # get the index of the max log-probability | ||
correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() | ||
|
||
test_loss /= len(data_loader.dataset) | ||
print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format( | ||
test_loss, correct, len(data_loader.dataset), | ||
100. * correct / len(data_loader.dataset))) | ||
return correct, test_loss | ||
|
||
|
||
if __name__ == '__main__': | ||
# parse args | ||
args = args_parser() | ||
|
||
# define paths | ||
path_project = os.path.abspath('..') | ||
|
||
summary = SummaryWriter('local') | ||
|
||
# load dataset and split users | ||
if args.dataset == 'mnist': | ||
dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
])) | ||
else: | ||
exit('Error: unrecognized dataset') | ||
img_size = dataset_train[0][0].shape[-1] | ||
# sample users | ||
num_users = args.num_users | ||
num_items = int(len(dataset_train)/num_users) | ||
dict_users, all_idxs = {}, [i for i in range(len(dataset_train))] | ||
for i in range(num_users): | ||
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) | ||
all_idxs = list(set(all_idxs) - dict_users[i]) | ||
|
||
# build model | ||
if args.model == 'cnn': | ||
if args.gpu != -1: | ||
torch.cuda.set_device(args.gpu) | ||
net_glob = CNN(args=args).cuda() | ||
else: | ||
net_glob = CNN(args=args) | ||
elif args.model == 'mlp': | ||
if args.gpu != -1: | ||
torch.cuda.set_device(args.gpu) | ||
net_glob = MLP(dim_in=img_size*img_size, dim_hidden=64, dim_out=args.num_classes).cuda() | ||
else: | ||
net_glob = MLP(dim_in=img_size*img_size, dim_hidden=64, dim_out=args.num_classes) | ||
else: | ||
exit('Error: unrecognized model') | ||
print(net_glob) | ||
|
||
# copy weights | ||
w_glob = net_glob.state_dict() | ||
|
||
# training | ||
loss_train = [] | ||
cv_loss, cv_acc = [], [] | ||
val_loss_pre, counter = 0, 0 | ||
net_best = None | ||
val_acc_list, net_list = [], [] | ||
for iter in tqdm(range(args.epochs)): | ||
w_locals, loss_locals = [], [] | ||
m = max(int(args.frac * args.num_users), 1) | ||
idxs_users = np.random.choice(range(args.num_users), m, replace=False) | ||
for idx in idxs_users: | ||
net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary) | ||
w, loss = net_local.update_weights(net=net_glob) | ||
w_locals.append(copy.deepcopy(w)) | ||
loss_locals.append(copy.deepcopy(loss)) | ||
# update global weights | ||
w_glob = average_weights(w_locals) | ||
|
||
# copy weight to net_glob | ||
net_glob.load_state_dict(w_glob) | ||
|
||
# print loss | ||
loss_avg = sum(loss_locals) / len(loss_locals) | ||
if args.epochs % 10 == 0: | ||
print('\nTrain loss:', loss_avg) | ||
loss_train.append(loss_avg) | ||
|
||
# plot loss curve | ||
plt.figure() | ||
plt.plot(range(len(loss_train)), loss_train) | ||
plt.ylabel('train_loss') | ||
plt.savefig('./save/fed_{}_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs, args.frac)) | ||
|
||
# testing | ||
list_acc, list_loss = [], [] | ||
for c in tqdm(range(num_users)): | ||
net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary) | ||
acc, loss = net_local.test(net=net_glob) | ||
list_acc.append(acc) | ||
list_loss.append(loss) | ||
print("average acc:", sum(list_acc)/len(list_acc)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# Python version: 3.6 | ||
|
||
import argparse | ||
|
||
def args_parser(): | ||
parser = argparse.ArgumentParser() | ||
# federated arguments | ||
parser.add_argument('--epochs', type=int, default=1, help="rounds of training") | ||
parser.add_argument('--num_users', type=int, default=100, help="number of users: K") | ||
parser.add_argument('--frac', type=float, default=0.1, help='the fraction of clients: C') | ||
parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E") | ||
parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B") | ||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | ||
|
||
# model arguments | ||
parser.add_argument('--model', type=str, default='mlp', help='model name') | ||
parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel') | ||
parser.add_argument('--kernel_sizes', type=str, default='3,4,5', | ||
help='comma-separated kernel size to use for convolution') | ||
parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None") | ||
parser.add_argument('--num_filters', type=int, default=32, | ||
help="number of filters for conv nets -- 32 for miniimagenet, 64 for omiglot.") | ||
parser.add_argument('--max_pool', type=str, default='True', | ||
help="Whether use max pooling rather than strided convolutions") | ||
|
||
# other arguments | ||
parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset") | ||
parser.add_argument('--iid', type=int, default=1, help='whether i.i.d or not, 1 for iid, 0 for non-iid') | ||
parser.add_argument('--num_classes', type=int, default=10, help="number of classes") | ||
parser.add_argument('--verbose', action='store_true', default=False, help="watch metrics during training and testing") | ||
parser.add_argument('--gpu', type=int, default=1, help="GPU ID") | ||
parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping') | ||
|
||
args = parser.parse_args() | ||
return args |