forked from WZMIAOMIAO/deep-learning-for-image-processing
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
wz
authored and
wz
committed
Nov 5, 2020
1 parent
fbdaf0f
commit d1761df
Showing
3 changed files
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import os | ||
|
||
import torch | ||
from torchvision import transforms | ||
|
||
from my_dataset import MyDataSet | ||
from utils import read_split_data | ||
|
||
root = "/home/w180662/my_project/my_github/data_set/flower_data/flower_photos" # 数据集所在根目录 | ||
|
||
|
||
def main(): | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
print("using {} device.".format(device)) | ||
|
||
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root) | ||
|
||
data_transform = { | ||
"train": transforms.Compose([transforms.RandomResizedCrop(224), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), | ||
"val": transforms.Compose([transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} | ||
|
||
train_data_set = MyDataSet(images_path=train_images_path, | ||
images_class=train_images_label, | ||
transform=data_transform["train"]) | ||
|
||
batch_size = 8 | ||
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers | ||
print('Using {} dataloader workers'.format(nw)) | ||
train_loader = torch.utils.data.DataLoader(train_data_set, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
num_workers=nw, | ||
collate_fn=train_data_set.collate_fn) | ||
|
||
for step, data in enumerate(train_loader): | ||
images, labels = data | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from PIL import Image | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class MyDataSet(Dataset): | ||
"""自定义数据集""" | ||
|
||
def __init__(self, images_path: list, images_class: list, transform=None): | ||
self.images_path = images_path | ||
self.images_class = images_class | ||
self.transform = transform | ||
|
||
def __len__(self): | ||
return len(self.images_path) | ||
|
||
def __getitem__(self, item): | ||
img = Image.open(self.images_path[item]) | ||
# RGB为彩色图片,L为灰度图片 | ||
if img.mode != 'RGB': | ||
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) | ||
label = self.images_class[item] | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
return img, label | ||
|
||
@staticmethod | ||
def collate_fn(batch): | ||
# 官方实现的default_collate可以参考 | ||
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py | ||
images, labels = tuple(zip(*batch)) | ||
|
||
images = torch.stack(images, dim=0) | ||
labels = torch.as_tensor(labels) | ||
return images, labels | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
import json | ||
import pickle | ||
import random | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
|
||
def read_split_data(root: str, val_rate: float = 0.2): | ||
random.seed(0) # 保证随机结果可复现 | ||
assert os.path.exists(root), "dataset root: {} does not exist.".format(root) | ||
|
||
# 遍历文件夹,一个文件夹对应一个类别 | ||
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] | ||
# 排序,保证顺序一致 | ||
flower_class.sort() | ||
# 生成类别名称以及对应的数字索引 | ||
class_indices = dict((k, v) for v, k in enumerate(flower_class)) | ||
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) | ||
with open('class_indices.json', 'w') as json_file: | ||
json_file.write(json_str) | ||
|
||
train_images_path = [] # 存储训练集的所有图片路径 | ||
train_images_label = [] # 存储训练集图片对应索引信息 | ||
val_images_path = [] # 存储验证集的所有图片路径 | ||
val_images_label = [] # 存储验证集图片对应索引信息 | ||
every_class_num = [] # 存储每个类别的样本总数 | ||
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型 | ||
# 遍历每个文件夹下的文件 | ||
for cla in flower_class: | ||
cla_path = os.path.join(root, cla) | ||
# 遍历获取supported支持的所有文件路径 | ||
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) | ||
if os.path.splitext(i)[-1] in supported] | ||
# 获取该类别对应的索引 | ||
image_class = class_indices[cla] | ||
# 记录该类别的样本数量 | ||
every_class_num.append(len(images)) | ||
# 按比例随机采样验证样本 | ||
eval_path = random.sample(images, k=int(len(images) * val_rate)) | ||
|
||
for img_path in images: | ||
if img_path in eval_path: # 如果该路径在采样的验证集样本中则存入验证集 | ||
val_images_path.append(img_path) | ||
val_images_label.append(image_class) | ||
else: # 否则存入训练集 | ||
train_images_path.append(img_path) | ||
train_images_label.append(image_class) | ||
|
||
print("{} images were found in the dataset.".format(sum(every_class_num))) | ||
|
||
plot_image = False | ||
if plot_image: | ||
# 绘制每种类别个数柱状图 | ||
plt.bar(range(len(flower_class)), every_class_num, align='center') | ||
# 将横坐标0,1,2,3,4替换为相应的类别名称 | ||
plt.xticks(range(len(flower_class)), flower_class) | ||
# 在柱状图上添加数值标签 | ||
for i, v in enumerate(every_class_num): | ||
plt.text(x=i, y=v + 5, s=str(v), ha='center') | ||
# 设置x坐标 | ||
plt.xlabel('image class') | ||
# 设置y坐标 | ||
plt.ylabel('number of images') | ||
# 设置柱状图的标题 | ||
plt.title('flower class distribution') | ||
plt.show() | ||
|
||
return train_images_path, train_images_label, val_images_path, val_images_label | ||
|
||
|
||
def write_pickle(list_info: list, file_name: str): | ||
with open(file_name, 'wb') as f: | ||
pickle.dump(list_info, f) | ||
|
||
|
||
def read_pickle(file_name: str) -> list: | ||
with open(file_name, 'rb') as f: | ||
info_list = pickle.load(f) | ||
return info_list |