Skip to content

Commit

Permalink
add FedAvg
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoxiongji committed Mar 31, 2018
1 parent 40a88a6 commit d731a2d
Show file tree
Hide file tree
Showing 5 changed files with 355 additions and 0 deletions.
73 changes: 73 additions & 0 deletions FedNets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn
import torch.nn.functional as F


class MLP(nn.Module):
def __init__(self, dim_in, dim_hidden, dim_out):
super(MLP, self).__init__()
self.layer_input = nn.Linear(dim_in, dim_hidden)
self.norm1d = nn.BatchNorm1d(num_features=dim_hidden)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.layer_hidden = nn.Linear(dim_hidden, dim_out)
self.softmax = nn.Softmax(dim=1)

def forward(self, x):
x = x.view(-1, x.shape[-2]*x.shape[-1])
x = self.layer_input(x)
x = self.norm1d(x)
x = self.relu(x)
x = self.dropout(x)
x = self.layer_hidden(x)
return self.softmax(x)


# class CNN(nn.Module):
# def __init__(self, args):
# super(CNN, self).__init__()
# self.num_classes = args.num_classes
#
# self.layer1 = nn.Sequential(
# nn.Conv2d(1, 16, kernel_size=5, padding=2),
# nn.BatchNorm2d(16),
# nn.ReLU(),
# nn.MaxPool2d(2))
# self.layer2 = nn.Sequential(
# nn.Conv2d(16, 32, kernel_size=5, padding=2),
# nn.BatchNorm2d(32),
# nn.ReLU(),
# nn.MaxPool2d(2))
# self.fc = nn.Linear(32, self.num_classes)
# self.softmax = nn.Softmax(dim=1)
#
# def forward(self, x):
# x = self.layer1(x)
# x = self.layer2(x)
# x = torch.mean(x, -1)
# x = torch.mean(x, -1)
# x = self.fc(x)
# return self.softmax(x)


class CNN(nn.Module):
def __init__(self, args):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, args.num_classes)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
94 changes: 94 additions & 0 deletions Update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn import metrics


class DatasetSplit(Dataset):
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = list(idxs)

def __len__(self):
return len(self.idxs)

def __getitem__(self, item):
image, label = self.dataset[self.idxs[item]]
return image, label


class LocalUpdate(object):
def __init__(self, args, dataset, idxs, tb):
self.args = args
self.loss_func = nn.NLLLoss()
self.ldr_train, self.ldr_val, self.ldr_test = self.train_val_test(dataset, list(idxs))
self.tb = tb

def train_val_test(self, dataset, idxs):
# split train, validation, and test
idxs_train = idxs[:420]
idxs_val = idxs[420:480]
idxs_test = idxs[480:]
train = DataLoader(DatasetSplit(dataset, idxs_train), batch_size=self.args.local_bs, shuffle=False)
val = DataLoader(DatasetSplit(dataset, idxs_val), batch_size=len(idxs_val), shuffle=False)
test = DataLoader(DatasetSplit(dataset, idxs_test), batch_size=len(idxs_test), shuffle=False)
return train, val, test

def update_weights(self, net):
# train and update
optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, weight_decay=2)

epoch_loss = []
for iter in range(self.args.local_ep):
batch_loss = []
for batch_idx, (images, labels) in enumerate(self.ldr_train):
if self.args.gpu != -1:
images, labels = images.cuda(), labels.cuda()
images, labels = autograd.Variable(images), autograd.Variable(labels)
net.zero_grad()
log_probs = net(images)
loss = self.loss_func(log_probs, labels)
loss.backward()
optimizer.step()
if self.args.gpu != -1:
loss = loss.cpu()
if batch_idx % 1 == 0:
print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
iter, batch_idx * len(images), len(self.ldr_train.dataset),
100. * batch_idx / len(self.ldr_train), loss.data[0]))
self.tb.add_scalar('loss', loss.data[0])
batch_loss.append(loss.data[0])
epoch_loss.append(sum(batch_loss)/len(batch_loss))
return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

def test(self, net):
optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, weight_decay=2)
for iter in range(self.args.local_ep):
for batch_idx, (images, labels) in enumerate(self.ldr_train):
if self.args.gpu != -1:
images, labels = images.cuda(), labels.cuda()
images, labels = autograd.Variable(images), autograd.Variable(labels)
net.zero_grad()
log_probs = net(images)
loss = self.loss_func(log_probs, labels)
loss.backward()
optimizer.step()

for batch_idx, (images, labels) in enumerate(self.ldr_test):
if self.args.gpu != -1:
images, labels = images.cuda(), labels.cuda()
images, labels = autograd.Variable(images), autograd.Variable(labels)
log_probs = net(images)
loss = self.loss_func(log_probs, labels)
if self.args.gpu != -1:
loss = loss.cpu()
log_probs = log_probs.cpu()
labels = labels.cpu()
y_pred = np.argmax(log_probs.data, axis=1)
acc = metrics.accuracy_score(y_true=labels.data, y_pred=y_pred)
return acc, loss.data[0]
15 changes: 15 additions & 0 deletions averaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
import torch


def average_weights(w):
w_avg = copy.deepcopy(w[0])
for k in w_avg.keys():
for i in range(1, len(w)):
w_avg[k] += w[i][k]
w_avg[k] = torch.div(w_avg[k], len(w))
return w_avg
136 changes: 136 additions & 0 deletions main_fed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
import copy
import numpy as np
from torchvision import datasets, transforms
from tqdm import tqdm
import random
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import nn, autograd
from sklearn import metrics
from tensorboardX import SummaryWriter

from options import args_parser
from Update import LocalUpdate
from FedNets import MLP, CNN
from averaging import average_weights


def test(net_g, data_loader, args):
# testing
test_loss = 0
correct = 0
l = len(data_loader)
for idx, (data, target) in enumerate(data_loader):
if args.gpu != -1:
data, target = data.cuda(), target.cuda()
data, target = autograd.Variable(data), autograd.Variable(target)
log_probs = net_g(data)
test_loss += F.nll_loss(log_probs, target, size_average=False).data[0] # sum up batch loss
y_pred = log_probs.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

test_loss /= len(data_loader.dataset)
print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, correct, len(data_loader.dataset),
100. * correct / len(data_loader.dataset)))
return correct, test_loss


if __name__ == '__main__':
# parse args
args = args_parser()

# define paths
path_project = os.path.abspath('..')

summary = SummaryWriter('local')

# load dataset and split users
if args.dataset == 'mnist':
dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
else:
exit('Error: unrecognized dataset')
img_size = dataset_train[0][0].shape[-1]
# sample users
num_users = args.num_users
num_items = int(len(dataset_train)/num_users)
dict_users, all_idxs = {}, [i for i in range(len(dataset_train))]
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
all_idxs = list(set(all_idxs) - dict_users[i])

# build model
if args.model == 'cnn':
if args.gpu != -1:
torch.cuda.set_device(args.gpu)
net_glob = CNN(args=args).cuda()
else:
net_glob = CNN(args=args)
elif args.model == 'mlp':
if args.gpu != -1:
torch.cuda.set_device(args.gpu)
net_glob = MLP(dim_in=img_size*img_size, dim_hidden=64, dim_out=args.num_classes).cuda()
else:
net_glob = MLP(dim_in=img_size*img_size, dim_hidden=64, dim_out=args.num_classes)
else:
exit('Error: unrecognized model')
print(net_glob)

# copy weights
w_glob = net_glob.state_dict()

# training
loss_train = []
cv_loss, cv_acc = [], []
val_loss_pre, counter = 0, 0
net_best = None
val_acc_list, net_list = [], []
for iter in tqdm(range(args.epochs)):
w_locals, loss_locals = [], []
m = max(int(args.frac * args.num_users), 1)
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
for idx in idxs_users:
net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx], tb=summary)
w, loss = net_local.update_weights(net=net_glob)
w_locals.append(copy.deepcopy(w))
loss_locals.append(copy.deepcopy(loss))
# update global weights
w_glob = average_weights(w_locals)

# copy weight to net_glob
net_glob.load_state_dict(w_glob)

# print loss
loss_avg = sum(loss_locals) / len(loss_locals)
if args.epochs % 10 == 0:
print('\nTrain loss:', loss_avg)
loss_train.append(loss_avg)

# plot loss curve
plt.figure()
plt.plot(range(len(loss_train)), loss_train)
plt.ylabel('train_loss')
plt.savefig('./save/fed_{}_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs, args.frac))

# testing
list_acc, list_loss = [], []
for c in tqdm(range(num_users)):
net_local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[c], tb=summary)
acc, loss = net_local.test(net=net_glob)
list_acc.append(acc)
list_loss.append(loss)
print("average acc:", sum(list_acc)/len(list_acc))

37 changes: 37 additions & 0 deletions options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import argparse

def args_parser():
parser = argparse.ArgumentParser()
# federated arguments
parser.add_argument('--epochs', type=int, default=1, help="rounds of training")
parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
parser.add_argument('--frac', type=float, default=0.1, help='the fraction of clients: C')
parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')

# model arguments
parser.add_argument('--model', type=str, default='mlp', help='model name')
parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
help='comma-separated kernel size to use for convolution')
parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
parser.add_argument('--num_filters', type=int, default=32,
help="number of filters for conv nets -- 32 for miniimagenet, 64 for omiglot.")
parser.add_argument('--max_pool', type=str, default='True',
help="Whether use max pooling rather than strided convolutions")

# other arguments
parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
parser.add_argument('--iid', type=int, default=1, help='whether i.i.d or not, 1 for iid, 0 for non-iid')
parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
parser.add_argument('--verbose', action='store_true', default=False, help="watch metrics during training and testing")
parser.add_argument('--gpu', type=int, default=1, help="GPU ID")
parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')

args = parser.parse_args()
return args

0 comments on commit d731a2d

Please sign in to comment.