Skip to content

Commit

Permalink
Merge pull request WZMIAOMIAO#83 from WZMIAOMIAO/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
WZMIAOMIAO authored Nov 9, 2020
2 parents cf0ff39 + 1c75241 commit c4d44e0
Show file tree
Hide file tree
Showing 11 changed files with 448 additions and 123 deletions.
11 changes: 10 additions & 1 deletion pytorch_classification/train_multi_GPU/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
## 多GPU启动指令

- ```python -m torch.distributed.launch --nproc_per_node=8 --use_env train_multi_GPU.py```
- 其中```nproc_per_node```为并行GPU的数量
- 其中```nproc_per_node```为并行GPU的数量

## 训练时间对比
![training time](./training_time.png)

## 是否使用SyncBatchNorm
![syncbn](./syncbn.png)

## 单GPU与多GPU训练曲线
![accuracy](./accuracy.png)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ def init_distributed_mode(args):
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
dist.barrier()


def cleanup():
dist.destroy_process_group()


def is_dist_avail_and_initialized():
Expand Down Expand Up @@ -55,11 +59,12 @@ def is_main_process():

def reduce_value(value, average=True):
word_size = get_world_size()
if word_size < 2: # 单GPU的情况
return value

with torch.no_grad():
dist.all_reduce(value)
if average:
value /= word_size

return value

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch):

loss = loss_function(pred, labels.to(device))
loss.backward()
loss = reduce_value(loss, average=False)
loss = reduce_value(loss, average=True)
mean_loss = (mean_loss * step + loss.detach()) / (step + 1) # update mean losses

# 在进程0中打印平均loss
Expand Down
54 changes: 54 additions & 0 deletions pytorch_classification/train_multi_GPU/plot_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import matplotlib.pyplot as plt

x = [0, 1, 2, 3]
y = [9, 5.5, 3, 2]

plt.bar(x, y, align='center')
plt.xticks(range(len(x)), ['One-GPU', '2 GPUs', '4 GPUs', '8 GPUs'])
plt.ylim((0, 10))
for i, v in enumerate(y):
plt.text(x=i, y=v + 0.1, s=str(v) + ' s', ha='center')
plt.xlabel('Using number of GPU device')
plt.ylabel('Training time per epoch (second)')
plt.show()
plt.close()

x = list(range(30))
no_SyncBatchNorm = [0.348, 0.495, 0.587, 0.554, 0.637,
0.622, 0.689, 0.673, 0.702, 0.717,
0.717, 0.69, 0.716, 0.696, 0.738,
0.75, 0.75, 0.66, 0.713, 0.758,
0.777, 0.777, 0.769, 0.792, 0.802,
0.807, 0.807, 0.804, 0.812, 0.811]

SyncBatchNorm = [0.283, 0.514, 0.531, 0.654, 0.671,
0.591, 0.621, 0.685, 0.701, 0.732,
0.701, 0.74, 0.667, 0.723, 0.745,
0.679, 0.738, 0.772, 0.764, 0.765,
0.764, 0.791, 0.818, 0.791, 0.807,
0.806, 0.811, 0.821, 0.833, 0.81]

plt.plot(x, no_SyncBatchNorm, label="No SyncBatchNorm")
plt.plot(x, SyncBatchNorm, label="SyncBatchNorm")
plt.xlabel('Training epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
plt.close()


x = list(range(30))
single_gpu = [0.569, 0.576, 0.654, 0.648, 0.609,
0.637, 0.699, 0.709, 0.715, 0.715,
0.717, 0.724, 0.722, 0.731, 0.721,
0.774, 0.751, 0.787, 0.78, 0.77,
0.763, 0.803, 0.754, 0.796, 0.799,
0.815, 0.793, 0.808, 0.811, 0.806]
plt.plot(x, single_gpu, color="black", label="Single GPU")
plt.plot(x, no_SyncBatchNorm, label="No SyncBatchNorm")
plt.plot(x, SyncBatchNorm, label="SyncBatchNorm")
plt.xlabel('Training epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
plt.close()
Binary file not shown.
Binary file added pytorch_classification/train_multi_GPU/syncbn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import math
import tempfile
import argparse

import torch
Expand All @@ -10,12 +12,9 @@
from model import resnet34
from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image
from multi_train_utils.distributed_utils import init_distributed_mode, dist
from multi_train_utils.distributed_utils import init_distributed_mode, dist, cleanup
from multi_train_utils.train_eval_utils import train_one_epoch, evaluate

# http://download.tensorflow.org/example_images/flower_photos.tgz
root = "/home/w180662/my_project/my_github/data_set/flower_data/flower_photos" # 数据集所在根目录


def main(args):
if torch.cuda.is_available() is False:
Expand All @@ -29,7 +28,7 @@ def main(args):
batch_size = args.batch_size
num_classes = args.num_classes
weights_path = args.weights
lr = args.lr
args.lr *= args.world_size # 学习率要根据并行GPU的数量进行倍增

if rank == 0: # 在第一个进程中打印信息,并实例化tensorboard
print(args)
Expand All @@ -38,7 +37,7 @@ def main(args):
if os.path.exists("./weights") is False:
os.makedirs("./weights")

train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
Expand Down Expand Up @@ -93,12 +92,14 @@ def main(args):
if model.state_dict()[k].numel() == v.numel()}
model.load_state_dict(load_weights_dict, strict=False)
else:
checkpoint_path = os.path.join(tempfile.gettempdir(), "initial_weights.pt")
# 如果不存在预训练权重,需要将第一个进程中的权重保存,然后其他进程载入,保持初始化权重一致
if rank == 0:
torch.save(model.state_dict(), "./initial_weights.pt")
torch.save(model.state_dict(), checkpoint_path)

dist.barrier()
model.load_state_dict(torch.load("./initial_weights.pt"))
# 这里注意,一定要指定map_location参数,否则会导致第一块GPU占用更多资源
model.load_state_dict(torch.load(checkpoint_path, map_location=device))

# 是否冻结权重
if args.freeze_layers:
Expand All @@ -107,14 +108,20 @@ def main(args):
if "fc" not in name:
para.requires_grad_(False)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
# 只有训练带有BN结构的网络时使用SyncBatchNorm采用意义
if args.syncBN:
# 使用SyncBatchNorm后训练会更耗时
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)

# 转为DDP模型
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])

# optimizer
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(pg, lr=lr)
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

for epoch in range(args.epochs):
train_sampler.set_epoch(epoch)
Expand All @@ -125,25 +132,46 @@ def main(args):
device=device,
epoch=epoch)

scheduler.step()

sum_num = evaluate(model=model,
data_loader=val_loader,
device=device)
acc = sum_num / val_sampler.total_size

if rank == 0:
tags = ["loss", "accuracy"]
print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
tags = ["loss", "accuracy", "learning_rate"]
tb_writer.add_scalar(tags[0], mean_loss, epoch)
tb_writer.add_scalar(tags[1], acc, epoch)
tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)

torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))

# 删除临时缓存文件
if rank == 0:
if os.path.exists(checkpoint_path) is True:
os.remove(checkpoint_path)

cleanup()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--lr', type=float, default=0.005)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--lrf', type=float, default=0.1)
# 是否启用SyncBatchNorm
parser.add_argument('--syncBN', type=bool, default=True)

# 数据集所在根目录
# http://download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str, default="/home/wz/data_set/flower_data/flower_photos")

# resnet34 官方权重下载地址
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
parser.add_argument('--weights', type=str, default='resNet34.pth',
help='initial weights path')
parser.add_argument('--freeze-layers', type=bool, default=False)
Expand Down
Loading

0 comments on commit c4d44e0

Please sign in to comment.