Skip to content

Commit

Permalink
Added files
Browse files Browse the repository at this point in the history
  • Loading branch information
jarvis2324 committed Dec 28, 2023
1 parent 7bdd9eb commit 5642ee5
Show file tree
Hide file tree
Showing 9 changed files with 403 additions and 60 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@
data/
log/
pretrain_model/
logs/*
new/*
pretrain/*
3 changes: 2 additions & 1 deletion datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def __getitem__(self, index):

im_path = os.path.join(self.root, self.dataset + '_images', train_item + '.jpg')
parsing_anno_path = os.path.join(self.root, self.dataset + '_segmentations', train_item + '.png')

#print(im_path)
#print(parsing_anno_path)
im = cv2.imread(im_path, cv2.IMREAD_COLOR)
h, w, _ = im.shape
parsing_anno = np.zeros((h, w), dtype=np.long)
Expand Down
9 changes: 5 additions & 4 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,11 @@ def main():
state_dict = torch.load(args.model_restore)['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
#for k, v in state_dict.items():
# name = k[7:] # remove `module.`
# new_state_dict[name] = v
#model.load_state_dict(new_state_dict)
model.load_state_dict(state_dict)
model.cuda()
model.eval()

Expand Down
2 changes: 1 addition & 1 deletion evaluate_swin.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
name='lip_solider_swin_tiny'
python evaluate.py --arch swin_small --data-dir /home/xianzhe.xxz/datasets/HumanParsing/LIP --model-restore ./logs/${name}/schp_4_checkpoint.pth.tar --input-size 572,384 --multi-scales 0.5,0.75,1.0,1.25,1.5 --flip
python evaluate.py --arch swin_tiny --data-dir /home/ubuntu/PaddleSeg/ --model-restore ./logs/${name}/schp_4_checkpoint.pth.tar --input-size 572,384 --multi-scales 0.5,0.75,1.0,1.25,1.5 --flip --batch-size 1 --save-results
2 changes: 2 additions & 0 deletions evaluate_swin.sh.save
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
name='lip_solider_swin_tiny'
python evaluate.py --arch swin_tiny --data-dir /home/ubuntu/PaddleSeg/ --model-restore ./logs/${name}/schp_2_checkpoint.pth.tar --input-size 572,384 --multi-scales 0.5,0.75,1.0,1.25,1.5 --flip --batch-size 1 --save-results
87 changes: 87 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import argparse
import os
import torch
from PIL import Image as PILImage
import torchvision.transforms as transforms
import networks
from utils.transforms import BGR2RGB_transform, transform_parsing



def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")

# Network Structure
parser.add_argument("--arch", type=str, default='resnet101')
# Data Preference
parser.add_argument("--data-dir", type=str, default='./data/LIP')
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--input-size", type=str, default='473,473')
parser.add_argument("--num-classes", type=int, default=20)
parser.add_argument("--ignore-label", type=int, default=255)
parser.add_argument("--random-mirror", action="store_true")
parser.add_argument("--random-scale", action="store_true")
# Evaluation Preference
parser.add_argument("--log-dir", type=str, default='./log')
parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar')
parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
parser.add_argument("--save-results", action="store_true", help="whether to save the results.")
parser.add_argument("--flip", action="store_true", help="random flip during the test.")
parser.add_argument("--multi-scales", type=str, default='1', help="multiple scales during the test")
return parser.parse_args()
def load_model(model_path):
args = get_arguments()
# Create an instance of the model
model = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=None)
# Load the pre-trained weights
state_dict = torch.load(model_path)['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)

# Move the model to GPU
model.cuda()
model.eval()
return model

def inference(input_image_path, model, output_dir):
# Load and preprocess the input image
input_image = PILImage.open(input_image_path).convert("RGB")
transform = transforms.Compose([
transforms.ToTensor(),
BGR2RGB_transform(),
transforms.Normalize(mean=model.mean, std=model.std),
])
input_tensor = transform(input_image).unsqueeze(0).cuda()

# Perform inference
parsing, _ = multi_scale_testing(model, input_tensor, flip=False, multi_scales=[1])

# Save the predicted mask
parsing_result = transform_parsing(parsing, [0, 0], 1.0, input_image.width, input_image.height, [473, 473])
output_image_path = os.path.join(output_dir, "predicted_mask.png")
output_im = PILImage.fromarray(parsing_result.astype('uint8'))
output_im.save(output_image_path)

print(f"Predicted mask saved at: {output_image_path}")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Inference script for semantic segmentation")
parser.add_argument("--image-path", type=str, help="Path to the input image", required=True)
parser.add_argument("--model-path", type=str, help="Path to the trained model", required=True)
parser.add_argument("--output-dir", type=str, help="Directory to save the predicted mask", default="./output")
args = parser.parse_args()

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

model = load_model(args.model_path)
inference(args.image_path, model, args.output_dir)
74 changes: 21 additions & 53 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#!/usr/bin/env python
# ... (other imports)

#!/usr/bin/env python
# -*- encoding: utf-8 -*-

Expand Down Expand Up @@ -83,29 +86,21 @@ def get_arguments():

def main():
args = get_arguments()
local_rank = args.local_rank

start_epoch = 0
cycle_n = 0

if not os.path.exists(args.log_dir):
if local_rank == 0:
os.makedirs(args.log_dir)
if local_rank == 0:
os.makedirs(args.log_dir)

if args.local_rank == 0:
with open(os.path.join(args.log_dir, 'args.json'), 'w') as opt_file:
json.dump(vars(args), opt_file)
print(args)
#gpus = [int(i) for i in args.gpu.split(',')]
#if not args.gpu == 'None':
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

dist.init_process_group(backend='nccl')

device = torch.device("cuda", local_rank)

torch.cuda.set_device(device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size = list(map(int, args.input_size.split(',')))

print("Device",device)
cudnn.enabled = True
cudnn.benchmark = True

Expand All @@ -116,9 +111,7 @@ def main():
convert_weights = False
model = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain, convert_weights=convert_weights)
for name, param in model.named_parameters():
#if name.startswith("backbone.patch_embed"):
if "patch_embed" in name:
print(name)
param.requires_grad = False

IMAGE_MEAN = model.mean
Expand All @@ -132,15 +125,8 @@ def main():
model.load_state_dict(checkpoint['state_dict'])
start_epoch = checkpoint['epoch']
model.to(device)
if args.syncbn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)

schp_model = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain, convert_weights=convert_weights)
#for name, param in schp_model.named_parameters():
#if name.startswith("backbone.patch_embed"):
# if "patch_embed" in name:
# param.requires_grad = False

if os.path.exists(args.schp_restore):
print('Resuming schp checkpoint from {}'.format(args.schp_restore))
Expand All @@ -150,16 +136,10 @@ def main():
schp_model.load_state_dict(schp_model_state_dict)

schp_model.to(device)
if args.syncbn:
print('----use syncBN in model!----')
schp_model = nn.SyncBatchNorm.convert_sync_batchnorm(schp_model)
schp_model = DDP(schp_model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)

# Loss Function
criterion = CriterionAll(lambda_1=args.lambda_s, lambda_2=args.lambda_e, lambda_3=args.lambda_c,
num_classes=args.num_classes)
#criterion = DataParallelCriterion(criterion)
#criterion.to(device)

# Data Loader
if INPUT_SPACE == 'BGR':
Expand All @@ -180,20 +160,18 @@ def main():
])

train_dataset = LIPDataSet(args.data_dir, 'train', crop_size=input_size, transform=transform)
dist_sampler = data.distributed.DistributedSampler(train_dataset, shuffle=True)

train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=dist_sampler,
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=8, pin_memory=False, drop_last=True)
print('Total training samples: {}'.format(len(train_dataset)))

# Optimizer Initialization
if args.optimizer == 'sgd':
print("using SGD optimizer")
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.weight_decay)
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.weight_decay)
elif args.optimizer == 'adam':
print("using Adam optimizer")
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay)
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

# Original warmup_epoch=10, changed to 3 for fix backbone finetune
lr_scheduler = SGDRScheduler(optimizer, total_epoch=args.epochs,
Expand All @@ -207,28 +185,20 @@ def main():

model.train()
for epoch in range(start_epoch, args.epochs):
dist_sampler.set_epoch(epoch)
lr = lr_scheduler.get_lr()[0]

for i_iter, batch in enumerate(train_loader):
i_iter += len(train_loader) * epoch
images, labels, _ = batch
#labels = labels.cuda(non_blocking=True)
labels = labels.to(device)
images = images.to(device)

edges = generate_edge_tensor(labels)
labels = labels.type(torch.cuda.LongTensor)
edges = edges.type(torch.cuda.LongTensor)
# for name, param in model.named_parameters():
# print(name,': ', param.requires_grad)

#print('fixed', model.state_dict()['module.conv1.weight'][0,0,0,0])
#print('update', model.state_dict()['module.decoder.conv4.weight'][0,0,0,0])


preds = model(images)

# Online Self Correction Cycle with Label Refinement
if cycle_n >= 1:
with torch.no_grad():
soft_preds = schp_model(images)
Expand All @@ -245,28 +215,27 @@ def main():
loss.backward()
optimizer.step()

if local_rank == 0 and i_iter % 100 == 0:
if i_iter % 100 == 0:
print('iter = {} of {} completed, lr = {}, loss = {}, time = {}'.format(i_iter, total_iters, lr,
loss.data.cpu().numpy(), (timeit.default_timer()-iter_start)/100))
loss.data.cpu().numpy(), (timeit.default_timer() - iter_start) / 100))
iter_start = timeit.default_timer()

lr_scheduler.step()
if local_rank == 0 and (epoch + 1) % (args.eval_epochs) == 0:
if (epoch + 1) % (args.eval_epochs) == 0:
schp.save_schp_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
}, False, args.log_dir, filename='checkpoint_{}.pth.tar'.format(epoch + 1))

# Self Correction Cycle with Model Aggregation
if (epoch + 1) >= args.schp_start and (epoch + 1 - args.schp_start) % args.cycle_epochs == 0:
print('Self-correction cycle number {}'.format(cycle_n))
schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1))
cycle_n += 1
schp.bn_re_estimate(train_loader, schp_model)
if local_rank == 0:
schp.save_schp_checkpoint({
'state_dict': schp_model.state_dict(),
'cycle_n': cycle_n,
}, False, args.log_dir, filename='schp_{}_checkpoint.pth.tar'.format(cycle_n))
schp.save_schp_checkpoint({
'state_dict': schp_model.state_dict(),
'cycle_n': cycle_n,
}, False, args.log_dir, filename='schp_{}_checkpoint.pth.tar'.format(cycle_n))

torch.cuda.empty_cache()
end = timeit.default_timer()
Expand All @@ -276,6 +245,5 @@ def main():
end = timeit.default_timer()
print('Training Finished in {} seconds'.format(end - start))


if __name__ == '__main__':
main()
Loading

0 comments on commit 5642ee5

Please sign in to comment.