forked from hmyao22/GLCF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
49 lines (39 loc) · 1.55 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import torch
class DefaultConfig(object):
class_name = 'screw'
data_root = r'D:\IMSN-YHM\dataset\mvtec_loco_anomaly_detection'
# train_raw_data_root = os.path.join(data_root, class_name)
# test_raw_data_root = os.path.join(data_root, "validation")
train_raw_data_root = os.path.join(data_root, class_name, 'train')
validate_raw_data_root = os.path.join(data_root, class_name, 'validation')
test_raw_data_root = os.path.join(data_root, class_name, 'test')
load_model_path = r'./weights/'
training_state_path = './temp/'
measure_save_path = './measure/'
backbone_name = 'Resnet34'
is_STVT = ['ST', 'MAE', 'TR'][2]
use_gpu = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_batch_size = 8
print_freq = 20
max_epoch = 501
lr = 0.0001
lr_decay = 0.90
weight_decay = 1e-5
momentum = 0.9
nz = 100
nc = 3
ngf = 64
def parse_model_root(self, dicts):
for k, v in dicts.items():
if hasattr(self, k):
setattr(self, k, v)
def parse(self, dicts):
for k, v in dicts.items():
if hasattr(self, k):
setattr(self, k, v)
data_root = r'D:\IMSN-YHM\dataset\mvtec_loco_anomaly_detection'
setattr(self, 'train_raw_data_root', os.path.join(data_root, v, 'train'))
setattr(self, 'test_raw_data_root', os.path.join(data_root, v, 'test'))
setattr(self, 'validate_raw_data_root', os.path.join(data_root, v, 'validation'))