-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathinfer.py
50 lines (43 loc) · 1.5 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import cv2
import random
import argparse
import datetime
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from network import MFF_MoE
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
parser = argparse.ArgumentParser()
parser.add_argument('--local_weight', type=str, default='weights/', help='trained weights path')
args = parser.parse_args()
class NetInference():
def __init__(self):
self.net = MFF_MoE(pretrained=False)
self.net.load(path=args.local_weight)
self.net = nn.DataParallel(self.net).cuda()
self.net.eval()
self.transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((512, 512), antialias=True),
])
def infer(self, input_path=''):
x = cv2.imread(input_path)[..., ::-1]
x = Image.fromarray(np.uint8(x))
x = self.transform_val(x).unsqueeze(0).cuda()
pred = self.net(x)
pred = pred.detach().cpu().numpy()
return pred
if __name__ == '__main__':
model = NetInference()
while True:
print('Please input the image path:')
input_path = input()
try:
res = model.infer(input_path)
print('Prediction of [%s] being Deepfake: %10.9f' % (input_path, res))
except:
print('Error: Image [%s]' % input_path)