-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Ren Peng
committed
Sep 10, 2021
1 parent
71cbf8f
commit 75ae7e4
Showing
26 changed files
with
1,641 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,4 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import torch | ||
from torchvision import transforms | ||
|
||
from utils.transforms import * | ||
|
||
|
||
class Config: | ||
def __init__(self, task_id=1): | ||
self.task_id = task_id | ||
self.output_dir = "output" | ||
|
||
self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
self.num_device = torch.cuda.device_count() | ||
|
||
self.num_classes = 19 | ||
self.dataset = "ImageFolder" | ||
self.data_dir = "/data/face/parsing/dataset/CelebAMask-HQ_processed" | ||
self.sample_dir = "/data/face/parsing/dataset/testset_210720_aligned" | ||
self.image_size = (512, 512) | ||
self.crop_size = (448, 448) | ||
self.do_val = True | ||
|
||
self.train_transform = Compose([ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), | ||
RandomScale((0.75, 1.25)), | ||
RandomRotation(), | ||
RandomCrop(self.crop_size), | ||
ToTensor(), | ||
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) | ||
self.val_transform = Compose([ToTensor(), Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) | ||
self.test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) | ||
|
||
self.model_name = "U2NET" | ||
self.model_args = Dict() | ||
|
||
self.loss_name = "Loss" | ||
self.loss_args = Dict(score_thresh=0.7, ignore_idx=255) | ||
|
||
self.optimizer_name = "SGD" | ||
self.optimizer_args = Dict( | ||
momentum=0.9, | ||
weight_decay=5e-4, | ||
warmup_start_lr=1e-5, | ||
power=0.9, | ||
) | ||
|
||
self.lr = 0.01 | ||
self.batch_size = 8 | ||
self.milestones = Dict() | ||
self.epochs = 30 | ||
|
||
def build(self, steps=None, num_classes=None): | ||
if "lr0" not in self.optimizer_args: | ||
self.optimizer_args["lr0"] = self.lr | ||
|
||
if "max_iter" not in self.optimizer_args and steps is not None: | ||
self.optimizer_args["max_iter"] = self.epochs * steps | ||
|
||
if "warmup_steps" not in self.optimizer_args and steps is not None: | ||
self.optimizer_args["warmup_steps"] = steps | ||
|
||
if num_classes is not None: | ||
self.num_classes = num_classes | ||
|
||
self.model_args["out_ch"] = self.num_classes | ||
return self | ||
|
||
|
||
class Dict(dict): | ||
def __getattr__(self, item): | ||
return self.get(item, None) | ||
|
||
def __setattr__(self, key, value): | ||
self[key] = value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import numpy as np | ||
import cv2 | ||
|
||
import os | ||
|
||
|
||
def process(src_folder, dst_folder): | ||
os.makedirs(dst_folder, exist_ok=True) | ||
os.makedirs(os.path.join(dst_folder, "images"), exist_ok=True) | ||
os.makedirs(os.path.join(dst_folder, "labels"), exist_ok=True) | ||
|
||
# 1.prepare | ||
id_map = open(os.path.join(src_folder, "CelebA-HQ-to-CelebA-mapping.txt")).readlines() | ||
id_map = [line.split() for line in id_map[1:] if line.strip()] | ||
id_map = {int(idx): {"origin_file": img_file} for (idx, _, img_file) in id_map} | ||
|
||
lines = open(os.path.join(src_folder, "list_eval_partition.txt")).readlines() | ||
lines = [line.split() for line in lines if line.strip()] | ||
flags = {line[0]: int(line[1]) for line in lines} | ||
|
||
for k in id_map: | ||
id_map[k]["flag"] = flags[id_map[k]["origin_file"]] | ||
|
||
mask_map = {} | ||
mask_folder = os.path.join(src_folder, "CelebAMask-HQ-mask-anno") | ||
for folder_name in os.listdir(mask_folder): | ||
folder = os.path.join(mask_folder, folder_name) | ||
if not os.path.isdir(folder): | ||
continue | ||
|
||
for file_name in os.listdir(folder): | ||
if not file_name.endswith(".png"): | ||
continue | ||
|
||
idx = int(file_name[:5]) | ||
if idx not in mask_map: | ||
mask_map[idx] = {} | ||
|
||
mask_map[idx][file_name[6:-4]] = os.path.join(folder_name, file_name) | ||
|
||
label_names = ['background', 'skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r', 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat'] | ||
label_map = {name: idx for idx, name in enumerate(label_names)} | ||
|
||
# 2.copy | ||
train, val, test = [], [], [] | ||
for idx, item in id_map.items(): | ||
src_file = os.path.join(src_folder, "CelebA-HQ-img", f"{idx}.jpg") | ||
src_img = cv2.imread(src_file) | ||
|
||
dst_img = cv2.resize(src_img, (512, 512)) | ||
|
||
dst_file = os.path.join(dst_folder, "images", f"{idx}.jpg") | ||
cv2.imwrite(dst_file, dst_img) | ||
|
||
mask = None | ||
for label_name in label_names: | ||
if label_name in mask_map[idx]: | ||
m = cv2.imread(os.path.join(mask_folder, mask_map[idx][label_name]), cv2.IMREAD_GRAYSCALE) | ||
if mask is None: | ||
mask = np.zeros_like(m) | ||
mask[m != 0] = label_map[label_name] | ||
|
||
if mask is not None: | ||
line = f"{os.path.join('images', f'{idx}.jpg')},{os.path.join('labels', f'{idx}.png')}" | ||
label_file = os.path.join(dst_folder, "labels", f"{idx}.png") | ||
cv2.imwrite(label_file, mask) | ||
else: | ||
line = f"{os.path.join('images', f'{idx}.jpg')}" | ||
|
||
if item["flag"] == 0: | ||
train.append(line) | ||
elif item["flag"] == 1: | ||
val.append(line) | ||
else: | ||
test.append(line) | ||
|
||
print(f"{idx + 1}/{len(id_map)}", end="\r", flush=True) | ||
|
||
# 3.write | ||
open(os.path.join(dst_folder, "train.txt"), "w").write("\n".join(train)) | ||
open(os.path.join(dst_folder, "val.txt"), "w").write("\n".join(val)) | ||
open(os.path.join(dst_folder, "test.txt"), "w").write("\n".join(test)) | ||
open(os.path.join(dst_folder, "label.txt"), "w").write(",".join(label_names)) | ||
|
||
print("Complete!") | ||
|
||
|
||
if __name__ == '__main__': | ||
process("/data/face/parsing/dataset/CelebAMask-HQ", "/data/face/parsing/dataset/CelebAMask-HQ_processed2") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import PIL.Image | ||
from torch.utils.data import Dataset | ||
from PIL import Image | ||
import os | ||
|
||
|
||
class ImageFolder(Dataset): | ||
def __init__(self, root, file_list="train.txt", label_file="label.txt", image_size=None, transform=None): | ||
self.root = root | ||
self.image_size = (image_size, image_size) if isinstance(image_size, int) else image_size | ||
self.transform = transform | ||
|
||
lines = open(os.path.join(root, file_list)).readlines() | ||
self.files = [line.strip().split(",") for line in lines if line.strip()] | ||
|
||
self.label_names = open(os.path.join(root, label_file)).read().strip().split(",") | ||
|
||
def __len__(self): | ||
return len(self.files) | ||
|
||
def __getitem__(self, i): | ||
img_file, mask_file = self.files[i] | ||
|
||
img = Image.open(os.path.join(self.root, img_file)).convert("RGB") | ||
if self.image_size is not None and img.size != self.image_size: | ||
img = img.resize(self.image_size) | ||
|
||
mask = Image.open(os.path.join(self.root, mask_file)).convert("I") | ||
if self.image_size is not None and mask.size != self.image_size: | ||
mask = mask.resize(self.image_size, resample=PIL.Image.NEAREST) | ||
|
||
if self.transform is not None: | ||
img, mask = self.transform(img, mask) | ||
|
||
return img, mask | ||
|
||
@property | ||
def num_classes(self): | ||
return len(self.label_names) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .ImageFolder import ImageFolder | ||
|
||
DATASETS = { | ||
"ImageFolder": ImageFolder | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import os | ||
|
||
|
||
def process_folder(src_folder, folder_name, target_name): | ||
lines = [] | ||
for file_name in os.listdir(os.path.join(src_folder, folder_name)): | ||
if not file_name.endswith(".jpg"): | ||
continue | ||
|
||
file_name = file_name[:-4] | ||
lines.append(f"{folder_name}/{file_name}.jpg,{folder_name}/{file_name}.png") | ||
open(os.path.join(src_folder, target_name), "w").write("\n".join(lines)) | ||
|
||
|
||
def process(src_folder): | ||
process_folder(src_folder, "train", "train.txt") | ||
process_folder(src_folder, "test", "val.txt") | ||
open(os.path.join(src_folder, "label.txt"), "w").write("background,skin,left_eyebrow,right_eyebrow,left_eye,right_eye,nose,upper_lip,inner_mouth,lower_lip,hair") | ||
|
||
|
||
if __name__ == '__main__': | ||
process("/data/face/parsing/dataset/ibugmask_release") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class FocalLoss(nn.Module): | ||
def __init__(self, gamma, ignore_lb=255, *args, **kwargs): | ||
super(FocalLoss, self).__init__() | ||
self.gamma = gamma | ||
self.nll = nn.NLLLoss(ignore_index=ignore_lb) | ||
|
||
def forward(self, logits, labels): | ||
scores = F.softmax(logits, dim=1) | ||
factor = torch.pow(1. - scores, self.gamma) | ||
log_score = F.log_softmax(logits, dim=1) | ||
log_score = factor * log_score | ||
loss = self.nll(log_score, labels) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
import numpy as np | ||
|
||
|
||
class OhemCELoss(nn.Module): | ||
def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs): | ||
super(OhemCELoss, self).__init__() | ||
# self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() | ||
self.thresh = 0. - np.log(thresh) | ||
self.n_min = n_min | ||
self.ignore_lb = ignore_lb | ||
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') | ||
|
||
def forward(self, logits, labels): | ||
loss = self.criteria(logits, labels).view(-1) | ||
loss, _ = torch.sort(loss, descending=True) | ||
if loss[self.n_min] > self.thresh: | ||
loss = loss[loss > self.thresh] | ||
else: | ||
loss = loss[:self.n_min] | ||
return torch.mean(loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from utils import enet_weighing | ||
|
||
|
||
class WeightedOhemCELoss(nn.Module): | ||
def __init__(self, thresh, n_min, num_classes, ignore_lb=255, *args, **kwargs): | ||
super(WeightedOhemCELoss, self).__init__() | ||
self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() | ||
self.n_min = n_min | ||
self.ignore_lb = ignore_lb | ||
self.num_classes = num_classes | ||
|
||
def forward(self, logits, labels): | ||
criteria = nn.CrossEntropyLoss(weight=enet_weighing(labels, self.num_classes).cuda(), ignore_index=self.ignore_lb, reduction='none') | ||
loss = criteria(logits, labels).view(-1) | ||
loss, _ = torch.sort(loss, descending=True) | ||
if loss[self.n_min] > self.thresh: | ||
loss = loss[loss > self.thresh] | ||
else: | ||
loss = loss[:self.n_min] | ||
return torch.mean(loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .loss import Loss | ||
|
||
LOSSES = { | ||
"Loss": Loss | ||
} |
Oops, something went wrong.