forked from khanhha/crack_segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate_unet.py
71 lines (57 loc) · 2.19 KB
/
evaluate_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from pathlib import Path
import argparse
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
def dice(y_true, y_pred):
return (2 * (y_true * y_pred).sum() + 1e-15) / (y_true.sum() + y_pred.sum() + 1e-15)
def general_dice(y_true, y_pred):
if y_true.sum() == 0:
if y_pred.sum() == 0:
return 1
else:
return 0
return dice(y_true, y_pred)
def jaccard(y_true, y_pred):
intersection = (y_true * y_pred).sum()
union = y_true.sum() + y_pred.sum() - intersection
return (intersection + 1e-15) / (union + 1e-15)
def general_jaccard(y_true, y_pred):
if y_true.sum() == 0:
if y_pred.sum() == 0:
return 1
else:
return 0
return jaccard(y_true, y_pred)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg('-ground_truth_dir', type=str, required=True, help='path where ground truth images are located')
arg('-pred_dir', type=str, required=True, help='path with predictions')
arg('-threshold', type=float, default=0.2, required=False, help='crack threshold detection')
args = parser.parse_args()
result_dice = []
result_jaccard = []
paths = [path for path in Path(args.ground_truth_dir).glob('*')]
for file_name in tqdm(paths):
y_true = (cv2.imread(str(file_name), 0) > 0).astype(np.uint8)
pred_file_name = Path(args.pred_dir) / file_name.name
if not pred_file_name.exists():
print(f'missing prediction for file {file_name.name}')
continue
pred_image = (cv2.imread(str(pred_file_name), 0) > 255 * args.threshold).astype(np.uint8)
y_pred = pred_image
# print(y_true.max(), y_true.min())
# plt.subplot(131)
# plt.imshow(y_true)
# plt.subplot(132)
# plt.imshow(y_pred)
# plt.subplot(133)
# plt.imshow(y_true)
# plt.imshow(y_pred, alpha=0.5)
# plt.show()
result_dice += [dice(y_true, y_pred)]
result_jaccard += [jaccard(y_true, y_pred)]
print('Dice = ', np.mean(result_dice), np.std(result_dice))
print('Jaccard = ', np.mean(result_jaccard), np.std(result_jaccard))