-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaltxt.py
66 lines (47 loc) · 2.31 KB
/
evaltxt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
os.chdir("..")
ROOT = os.getcwd()+"/Cowsformer/"
print(ROOT)
#os.environ['TORCH_HOME'] = '/home/mautushid/.torch'
from ultralytics import NAS
from models.nas import *
import argparse
import supervision as sv
from API import*
from evaluate import from_sv
import pandas as pd
def main(args):
# parse arguments
yolo_base = args.yolo_base
config = args.config
exp_name = args.exp_name
n = args.n
iteration = args.iteration
config_short = config.split("_")[-1]
dir_train = ROOT+ "/data/"+config+"/tv/"+ exp_name+"_"+ yolo_base + "_" + \
str(n) + "_" + str(iteration) + "_" + config_short+ "_" + yolo_base +"_" + str(n) + "_" + str(iteration) + "/"+ "train"
dir_val = ROOT+ "/data/"+config+"/tv/"+ exp_name+"_"+ yolo_base + "_" + \
str(n) + "_" + str(iteration) + "_" + config_short+ "_" + yolo_base +"_" + str(n) + "_" + str(iteration) + "/"+ "val"
dir_test = ROOT+ "/data/"+config + "/test"
data_yaml_path = ROOT+ "/data/"+config+"/tv/"+ exp_name+"_"+ yolo_base + "_" + \
str(n) + "_" + str(iteration) + "_" + config_short+ "_" + yolo_base +"_" + str(n) + "_" + str(iteration) + "/"+ "data.yaml"
base_dir = ROOT + "/checkpoints/n" + str(n) + "_" + yolo_base + "_i" + str(iteration) + "_" + config_short
items_under_base = os.listdir(base_dir)[0]
finetuned_model_path = base_dir + "/"+ items_under_base + "/ckpt_best.pth"
output_dir = dir_test +"/"+ yolo_base +"_"+ str(n)+"_"+str(iteration)+ "_labelsPred"
### Creating instance of Niche_YOLO_NAS class
my_nas = Niche_YOLO_NAS(yolo_base, dir_train, dir_val, dir_test, "cow200")
predictions = my_nas.prediction(data_yaml_path,finetuned_model_path)
my_nas.write_predictions(predictions,output_dir)
if __name__ == "__main__":
# parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--yolo_base", type=str, help="yolo_nas_l,yolo_nas_m")
parser.add_argument("--config", type=str,
help="1a_angle_t2s, 1b_angle_s2t, 2_light, 3_breed, 4_all")
parser.add_argument("--exp_name", type=str,
help="exp")
parser.add_argument("--n", type=int, help="16, 32,64, 128...")
parser.add_argument("--iteration", type=int, help="1,2,3...")
args = parser.parse_args()
main(args)