-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathinference.py
129 lines (116 loc) · 4.97 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import cv2
import torch
import random
import argparse
from glob import glob
from os.path import join
from model.network import Recce
from model.common import freeze_weights
from albumentations import Compose, Normalize, Resize
from albumentations.pytorch.transforms import ToTensorV2
# fix random seed
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
parser = argparse.ArgumentParser(description="This code helps you use a trained model to "
"do inference.")
parser.add_argument("--weight", "-w",
type=str,
default=None,
help="Specify the path to the model weight (the state dict file). "
"Do not use this argument when '--bin' is set.")
parser.add_argument("--bin", "-b",
type=str,
default=None,
help="Specify the path to the model bin which ends up with '.bin' "
"(which is generated by the trainer of this project). "
"Do not use this argument when '--weight' is set.")
parser.add_argument("--image", "-i",
type=str,
default=None,
help="Specify the path to the input image. "
"Do not use this argument when '--image_folder' is set.")
parser.add_argument("--image_folder", "-f",
type=str,
default=None,
help="Specify the directory to evaluate all the images. "
"Do not use this argument when '--image' is set.")
parser.add_argument('--device', '-d', type=str,
default="cpu",
help="Specify the device to load the model. Default: 'cpu'.")
parser.add_argument('--image_size', '-s', type=int,
default=299,
help="Specify the spatial size of the input image(s). Default: 299.")
parser.add_argument('--visualize', '-v', action="store_true",
default=False, help='Visualize images.')
def preprocess(file_path):
img = cv2.imread(file_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
compose = Compose([Resize(height=args.image_size, width=args.image_size),
Normalize(mean=[0.5] * 3, std=[0.5] * 3),
ToTensorV2()])
img = compose(image=img)['image'].unsqueeze(0)
return img
def prepare_data():
paths = list()
images = list()
# check the console arguments
if args.image and args.image_folder:
raise ValueError("Only one of '--image' or '--image_folder' can be set.")
elif args.image:
images.append(preprocess(args.image))
paths.append(args.image)
elif args.image_folder:
image_paths = glob(join(args.image_folder, "*.jpg"))
image_paths.extend(glob(join(args.image_folder, "*.png")))
for _ in image_paths:
images.append(preprocess(_))
paths.append(_)
else:
raise ValueError("Neither of '--image' nor '--image_folder' is set. Please specify either "
"one of these two arguments to load input image(s) properly.")
return paths, images
def inference(model, images, paths, device):
for img, pt in zip(images, paths):
img = img.to(device)
prediction = model(img)
prediction = torch.sigmoid(prediction).cpu()
fake = True if prediction >= 0.5 else False
print(f"path: {pt} \t\t| fake probability: {prediction.item():.4f} \t| "
f"prediction: {'fake' if fake else 'real'}")
if args.visualize:
cvimg = cv2.imread(pt)
cvimg = cv2.putText(cvimg, f'p: {prediction.item():.2f}, ' + f"{'fake' if fake else 'real'}",
(5, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(0, 0, 255) if fake else (255, 0, 0), 2)
cv2.imshow("image", cvimg)
cv2.waitKey(0)
cv2.destroyWindow("image")
def main():
print("Arguments:\n", args, end="\n\n")
# set device
device = torch.device(args.device)
# load model
model = eval("Recce")(num_classes=1)
# check the console arguments
if args.weight and args.bin:
raise ValueError("Only one of '--weight' or '--bin' can be set.")
elif args.weight:
weights = torch.load(args.weight, map_location="cpu")
elif args.bin:
weights = torch.load(args.bin, map_location="cpu")["model"]
else:
raise ValueError("Neither of '--weight' nor '--bin' is set. Please specify either "
"one of these two arguments to load model's weight properly.")
model.load_state_dict(weights)
model = model.to(device)
freeze_weights(model)
model.eval()
paths, images = prepare_data()
print("Inference:")
inference(model, images=images, paths=paths, device=device)
if __name__ == '__main__':
args = parser.parse_args()
main()