Skip to content

Commit

Permalink
suppost eval params config in train/eval/train+val
Browse files Browse the repository at this point in the history
  • Loading branch information
shensheng272 committed Sep 20, 2022
1 parent 5ac5344 commit fc687ca
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
28 changes: 27 additions & 1 deletion tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,27 @@ def get_args_parser(add_help=True):
parser.add_argument('--plot_curve', default=True, type=boolean_string, help='whether to save plots in savedir when do pr metric, set False to close')
parser.add_argument('--plot_confusion_matrix', default=False, action='store_true', help='whether to save confusion matrix plots when do pr metric, might cause no harm warning print')
parser.add_argument('--verbose', default=False, action='store_true', help='whether to print metric on each class')
parser.add_argument('--config-file', default='', type=str, help='experiments description file, lower priority than reproduce_640_eval')
args = parser.parse_args()

if args.config_file:
assert os.path.exists(args.config_file), print("Config file {} does not exist".format(args.config_file))
cfg = Config.fromfile(args.config_file)
if not hasattr(cfg, 'eval_params'):
LOGGER.info("Config file doesn't has eval params config.")
else:
eval_params=cfg.eval_params
for key, value in eval_params.items():
if key not in args.__dict__:
LOGGER.info(f"Unrecognized config {key}, continue")
continue
if isinstance(value, list):
if value[1] is not None:
args.__dict__[key] = value[1]
else:
if value is not None:
args.__dict__[key] = value

# load params for reproduce 640 eval result
if args.reproduce_640_eval:
assert os.path.exists(args.eval_config_file), print("Reproduce config file {} does not exist".format(args.eval_config_file))
Expand All @@ -59,6 +78,12 @@ def get_args_parser(add_help=True):
args.scale_exact = eval_params[eval_model_name]["scale_exact"]
args.force_no_pad = eval_params[eval_model_name]["force_no_pad"]
args.not_infer_on_rect = eval_params[eval_model_name]["not_infer_on_rect"]
#force params
args.img_size = 640
args.conf_thres = 0.03
args.iou_thres = 0.65
args.task = "val"
args.do_coco_metric = True

LOGGER.info(args)
return args
Expand Down Expand Up @@ -89,7 +114,8 @@ def run(data,
do_coco_metric=True,
do_pr_metric=False,
plot_curve=False,
plot_confusion_matrix=False
plot_confusion_matrix=False,
config_file=None,
):
""" Run the evaluation process
Expand Down
7 changes: 5 additions & 2 deletions yolov6/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,11 @@ def eval_model(self):
task='train')
else:
def get_cfg_value(cfg_dict, value_str, default_value):
if value_str in cfg_dict and cfg_dict[value_str] is not None:
return cfg_dict[value_str]
if value_str in cfg_dict:
if isinstance(cfg_dict[value_str], list):
return cfg_dict[value_str][0] if cfg_dict[value_str][0] is not None else default_value
else:
return cfg_dict[value_str] if cfg_dict[value_str] is not None else default_value
else:
return default_value
eval_img_size = get_cfg_value(self.cfg.eval_params, "img_size", self.img_size)
Expand Down

0 comments on commit fc687ca

Please sign in to comment.