Skip to content

Commit

Permalink
update.
Browse files Browse the repository at this point in the history
  • Loading branch information
sijin-dm committed Jul 12, 2021
1 parent 2ae96ae commit 74038ae
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 7 deletions.
18 changes: 11 additions & 7 deletions pytorch2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,15 @@
import time
import torch


from runx.logx import logx
from config import assert_and_infer_cfg, update_epoch, cfg
from loss.optimizer import get_optimizer, restore_opt, restore_net

import datasets
import network
from torch2trt import trt, torch2trt
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))

sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))

# Argument Parser
parser = argparse.ArgumentParser(description='Semantic Segmentation')
Expand Down Expand Up @@ -407,9 +406,11 @@
type=int,
default=None,
help='mapillary evaluation size.')
parser.add_argument('--ddrnet_augment', action='store_true', default=False,
parser.add_argument('--ddrnet_augment',
action='store_true',
default=False,
help='use multi output for ddrnet.')

args = parser.parse_args()
args.best_record = {
'epoch': -1,
Expand Down Expand Up @@ -470,7 +471,6 @@ def main():

auto_resume_details = None


if auto_resume_details:
checkpoint_fn = auto_resume_details.get("RESUME_FILE", None)
checkpoint = torch.load(checkpoint_fn,
Expand Down Expand Up @@ -533,7 +533,7 @@ def main():

def save_pred(y, out_name="color_mask.png"):
colorize_mask_fn = cfg.DATASET_INST.colorize_mask
output = torch.nn.functional.softmax(y_trt, dim=1)
output = torch.nn.functional.softmax(y, dim=1)
prob_mask, predictions = output.data.max(1)
# Image.fromarray(predictions[0].cpu().numpy().astype(np.uint8)).convert('P').save("label_id.png")
color_mask = colorize_mask_fn(predictions[0].cpu().numpy())
Expand All @@ -544,11 +544,15 @@ def save_pred(y, out_name="color_mask.png"):
model_trt = torch2trt(
model,
[x],
input_names=["input"],
output_names=["output"],
fp16_mode=True,
log_level=trt.Logger.ERROR # VERBOSE
)
save_engine(model_trt, '{}_trt.engine'.format(args.arch))
torch.save(model_trt.state_dict(), '{}_trt.pth'.format(args.arch))

y = model(x)
save_engine(model_trt)
time_list = []
for i in range(10):
start_time = time.time()
Expand Down
77 changes: 77 additions & 0 deletions utils/select_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
import argparse
from glob import glob
import os
import cv2
import tqdm

def parse_args():
# Argument Parser
parser = argparse.ArgumentParser(description='Select Samples for Annotation')
parser.add_argument('--mode', type=str, default="folder")
parser.add_argument('--image_extensions', type=str, default=".jpg")
parser.add_argument('--prob_extensions', type=str, default="_prob.png")
parser.add_argument('--prob_path', type=str, default="", help="path for output prob image.")
parser.add_argument('--image_path', type=str, default="",help="path for output prob image.")
parser.add_argument('--selected_num', type=int, default=10, help="How many images that we wanna select.")
parser.add_argument('--prob', type=float, default=0.9, help="Probability value")
parser.add_argument('--output_path', type=str, default=None, help="Output_path for selected images.")

return parser.parse_args()



def cal_lower_prob_ratio(preds, prob_threshold=0.7):
masks = np.where(preds < prob_threshold)
ratio = masks[0].size/preds.size
return ratio

def main():
args = parse_args()
assert os.path.exists(args.image_path)
if args.mode == "folder":
os.path.exists(args.prob_path)
else:
raise ValueError("Unsupported mode.")

hist = []
prob_filenames = []
image_names = sorted(glob(os.path.join(args.image_path,'*'+ args.image_extensions)))
for image_fn in tqdm.tqdm(image_names):
if args.mode == "folder":
filename = os.path.basename(image_fn)
prob_filename = filename.replace(args.image_extensions, args.prob_extensions)
prob_filename = os.path.join(args.prob_path, prob_filename)
prob_image = cv2.imread(prob_filename)
prob_filenames.append(prob_filename)
ratio = cal_lower_prob_ratio(prob_image, args.prob*255)
hist.append(ratio)

selected_num = args.selected_num
if len(hist)<selected_num:
selected_num = len(hist)
hist_arr = np.asarray(hist)
ind = list(np.argpartition(hist_arr, -selected_num)[-selected_num:])
print('prob_threshold={} ratio_bound={},{} filter_num={}'.format(
args.prob, hist[ind[-1]], hist[ind[0]], selected_num))


for idx in ind[:selected_num]:
if args.output_path is not None:
os.system("cp {} {}".format(image_names[idx], args.output_path))
os.system("cp {} {}".format(prob_filenames[idx], args.output_path))
print("Copy: {}".format(image_names[idx]))
else:
print(image_names[idx])



if __name__ == '__main__':
main()







0 comments on commit 74038ae

Please sign in to comment.