forked from WZMIAOMIAO/deep-learning-for-image-processing
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
wz
authored and
wz
committed
Dec 2, 2020
1 parent
fcb94a6
commit e25b186
Showing
5 changed files
with
495 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"0": "daisy", | ||
"1": "dandelion", | ||
"2": "roses", | ||
"3": "sunflowers", | ||
"4": "tulips" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
def channel_shuffle(x, groups): | ||
# type: (torch.Tensor, int) -> torch.Tensor | ||
batch_size, num_channels, height, width = x.data.size() | ||
channels_per_group = num_channels // groups | ||
|
||
# reshape | ||
# [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width] | ||
x = x.view(batch_size, groups, channels_per_group, height, width) | ||
|
||
x = torch.transpose(x, 1, 2).contiguous() | ||
|
||
# flatten | ||
x = x.view(batch_size, -1, height, width) | ||
|
||
return x | ||
|
||
|
||
class InvertedResidual(nn.Module): | ||
def __init__(self, input_c, output_c, stride): | ||
super(InvertedResidual, self).__init__() | ||
|
||
if stride not in [1, 2]: | ||
raise ValueError("illegal stride value.") | ||
self.stride = stride | ||
|
||
assert output_c % 2 == 0 | ||
branch_features = output_c // 2 | ||
# 当stride为1时,input_channel应该是branch_features的两倍 | ||
# python中 '<<' 是位运算,可理解为计算×2的快速方法 | ||
assert (self.stride != 1) or (input_c == branch_features << 1) | ||
|
||
if self.stride == 2: | ||
self.branch1 = nn.Sequential( | ||
self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1), | ||
nn.BatchNorm2d(input_c), | ||
nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False), | ||
nn.BatchNorm2d(branch_features), | ||
nn.ReLU(inplace=True) | ||
) | ||
else: | ||
self.branch1 = nn.Sequential() | ||
|
||
self.branch2 = nn.Sequential( | ||
nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1, | ||
stride=1, padding=0, bias=False), | ||
nn.BatchNorm2d(branch_features), | ||
nn.ReLU(inplace=True), | ||
self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1), | ||
nn.BatchNorm2d(branch_features), | ||
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), | ||
nn.BatchNorm2d(branch_features), | ||
nn.ReLU(inplace=True) | ||
) | ||
|
||
@staticmethod | ||
def depthwise_conv(input_c, output_c, kernel_s, stride=1, padding=0, bias=False): | ||
return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s, | ||
stride=stride, padding=padding, bias=bias, groups=input_c) | ||
|
||
def forward(self, x): | ||
if self.stride == 1: | ||
x1, x2 = x.chunk(2, dim=1) | ||
out = torch.cat((x1, self.branch2(x2)), dim=1) | ||
else: | ||
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) | ||
|
||
out = channel_shuffle(out, 2) | ||
|
||
return out | ||
|
||
|
||
class ShuffleNetV2(nn.Module): | ||
def __init__(self, stages_repeats, stages_out_channels, | ||
num_classes=1000, inverted_residual=InvertedResidual): | ||
super(ShuffleNetV2, self).__init__() | ||
|
||
if len(stages_repeats) != 3: | ||
raise ValueError("expected stages_repeats as list of 3 positive ints") | ||
if len(stages_out_channels) != 5: | ||
raise ValueError("expected stages_out_channels as list of 5 positive ints") | ||
self._stage_out_channels = stages_out_channels | ||
|
||
# input RGB image | ||
input_channels = 3 | ||
output_channels = self._stage_out_channels[0] | ||
|
||
self.conv1 = nn.Sequential( | ||
nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False), | ||
nn.BatchNorm2d(output_channels), | ||
nn.ReLU(inplace=True) | ||
) | ||
input_channels = output_channels | ||
|
||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
|
||
stage_names = ["stage{}".format(i) for i in [2, 3, 4]] | ||
for name, repeats, output_channels in zip(stage_names, stages_repeats, | ||
self._stage_out_channels[1:]): | ||
seq = [inverted_residual(input_channels, output_channels, 2)] | ||
for i in range(repeats-1): | ||
seq.append(inverted_residual(output_channels, output_channels, 1)) | ||
setattr(self, name, nn.Sequential(*seq)) | ||
input_channels = output_channels | ||
|
||
output_channels = self._stage_out_channels[-1] | ||
self.conv5 = nn.Sequential( | ||
nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False), | ||
nn.BatchNorm2d(output_channels), | ||
nn.ReLU(inplace=True) | ||
) | ||
|
||
self.fc = nn.Linear(output_channels, num_classes) | ||
|
||
def _forward_impl(self, x): | ||
# See note [TorchScript super()] | ||
x = self.conv1(x) | ||
x = self.maxpool(x) | ||
x = self.stage2(x) | ||
x = self.stage3(x) | ||
x = self.stage4(x) | ||
x = self.conv5(x) | ||
x = x.mean([2, 3]) # global pool | ||
x = self.fc(x) | ||
return x | ||
|
||
def forward(self, x): | ||
return self._forward_impl(x) | ||
|
||
|
||
def shufflenet_v2_x1_0(num_classes=1000): | ||
""" | ||
Constructs a ShuffleNetV2 with 1.0x output channels, as described in | ||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" | ||
<https://arxiv.org/abs/1807.11164>`. | ||
weight: https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth | ||
:param num_classes: | ||
:return: | ||
""" | ||
model = ShuffleNetV2(stages_repeats=[4, 8, 4], | ||
stages_out_channels=[24, 116, 232, 464, 1024], | ||
num_classes=num_classes) | ||
|
||
return model | ||
|
||
|
||
def shufflenet_v2_x0_5(num_classes=1000): | ||
""" | ||
Constructs a ShuffleNetV2 with 0.5x output channels, as described in | ||
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" | ||
<https://arxiv.org/abs/1807.11164>`. | ||
weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth | ||
:param num_classes: | ||
:return: | ||
""" | ||
model = ShuffleNetV2(stages_repeats=[4, 8, 4], | ||
stages_out_channels=[24, 48, 96, 192, 1024], | ||
num_classes=num_classes) | ||
|
||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from PIL import Image | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class MyDataSet(Dataset): | ||
"""自定义数据集""" | ||
|
||
def __init__(self, images_path: list, images_class: list, transform=None): | ||
self.images_path = images_path | ||
self.images_class = images_class | ||
self.transform = transform | ||
|
||
def __len__(self): | ||
return len(self.images_path) | ||
|
||
def __getitem__(self, item): | ||
img = Image.open(self.images_path[item]) | ||
# RGB为彩色图片,L为灰度图片 | ||
if img.mode != 'RGB': | ||
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) | ||
label = self.images_class[item] | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
return img, label | ||
|
||
@staticmethod | ||
def collate_fn(batch): | ||
# 官方实现的default_collate可以参考 | ||
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py | ||
images, labels = tuple(zip(*batch)) | ||
|
||
images = torch.stack(images, dim=0) | ||
labels = torch.as_tensor(labels) | ||
return images, labels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import os | ||
import math | ||
import argparse | ||
|
||
import torch | ||
import torch.optim as optim | ||
from torch.utils.tensorboard import SummaryWriter | ||
from torchvision import transforms | ||
import torch.optim.lr_scheduler as lr_scheduler | ||
|
||
from model import shufflenet_v2_x1_0 | ||
from my_dataset import MyDataSet | ||
from utils import read_split_data, train_one_epoch, evaluate | ||
|
||
|
||
def main(args): | ||
device = torch.device(args.device if torch.cuda.is_available() else "cpu") | ||
|
||
print(args) | ||
print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/') | ||
tb_writer = SummaryWriter() | ||
if os.path.exists("./weights") is False: | ||
os.makedirs("./weights") | ||
|
||
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path) | ||
|
||
data_transform = { | ||
"train": transforms.Compose([transforms.RandomResizedCrop(224), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), | ||
"val": transforms.Compose([transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} | ||
|
||
# 实例化训练数据集 | ||
train_data_set = MyDataSet(images_path=train_images_path, | ||
images_class=train_images_label, | ||
transform=data_transform["train"]) | ||
|
||
# 实例化验证数据集 | ||
val_data_set = MyDataSet(images_path=val_images_path, | ||
images_class=val_images_label, | ||
transform=data_transform["val"]) | ||
|
||
batch_size = args.batch_size | ||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers | ||
print('Using {} dataloader workers every process'.format(nw)) | ||
train_loader = torch.utils.data.DataLoader(train_data_set, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
pin_memory=True, | ||
num_workers=nw, | ||
collate_fn=train_data_set.collate_fn) | ||
|
||
val_loader = torch.utils.data.DataLoader(val_data_set, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
pin_memory=True, | ||
num_workers=nw, | ||
collate_fn=val_data_set.collate_fn) | ||
|
||
# 如果存在预训练权重则载入 | ||
model = shufflenet_v2_x1_0(num_classes=args.num_classes).to(device) | ||
if os.path.exists(args.weights): | ||
weights_dict = torch.load(args.weights, map_location=device) | ||
load_weights_dict = {k: v for k, v in weights_dict.items() | ||
if model.state_dict()[k].numel() == v.numel()} | ||
print(model.load_state_dict(load_weights_dict, strict=False)) | ||
|
||
# 是否冻结权重 | ||
if args.freeze_layers: | ||
for name, para in model.named_parameters(): | ||
# 除最后的全连接层外,其他权重全部冻结 | ||
if "fc" not in name: | ||
para.requires_grad_(False) | ||
|
||
pg = [p for p in model.parameters() if p.requires_grad] | ||
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.0001) | ||
# Scheduler https://arxiv.org/pdf/1812.01187.pdf | ||
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine | ||
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) | ||
|
||
for epoch in range(args.epochs): | ||
# train | ||
mean_loss = train_one_epoch(model=model, | ||
optimizer=optimizer, | ||
data_loader=train_loader, | ||
device=device, | ||
epoch=epoch) | ||
|
||
scheduler.step() | ||
|
||
# validate | ||
sum_num = evaluate(model=model, | ||
data_loader=val_loader, | ||
device=device) | ||
acc = sum_num / len(val_data_set) | ||
print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3))) | ||
tags = ["loss", "accuracy", "learning_rate"] | ||
tb_writer.add_scalar(tags[0], mean_loss, epoch) | ||
tb_writer.add_scalar(tags[1], acc, epoch) | ||
tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch) | ||
|
||
torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch)) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--num_classes', type=int, default=5) | ||
parser.add_argument('--epochs', type=int, default=30) | ||
parser.add_argument('--batch-size', type=int, default=16) | ||
parser.add_argument('--lr', type=float, default=0.1) | ||
parser.add_argument('--lrf', type=float, default=0.1) | ||
|
||
# 数据集所在根目录 | ||
# http://download.tensorflow.org/example_images/flower_photos.tgz | ||
parser.add_argument('--data-path', type=str, default="/home/wz/data_set/flower_data/flower_photos") | ||
|
||
# resnet34 官方权重下载地址 | ||
# https://download.pytorch.org/models/resnet34-333f7ec4.pth | ||
parser.add_argument('--weights', type=str, default='shufflenetv2_x1.pth', | ||
help='initial weights path') | ||
parser.add_argument('--freeze-layers', type=bool, default=True) | ||
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') | ||
|
||
opt = parser.parse_args() | ||
|
||
main(opt) |
Oops, something went wrong.