Skip to content

Commit

Permalink
Merging train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
amdegroot committed Aug 3, 2017
2 parents c1ff164 + 80e3e70 commit 06d503e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
5 changes: 3 additions & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from data import VOC_CLASSES as labelmap
import torch.utils.data as data

from data import AnnotationTransform, VOCDetection, BaseTransform
from data import AnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES
from ssd import build_ssd

import sys
Expand Down Expand Up @@ -408,7 +408,8 @@ def evaluate_detections(box_list, output_dir, dataset):

if __name__ == '__main__':
# load net
net = build_ssd('test', 300, 21) # initialize SSD
num_classes = len(VOC_CLASSES) + 1 # +1 background
net = build_ssd('test', 300, num_classes) # initialize SSD
net.load_state_dict(torch.load(args.trained_model))
net.eval()
print('Finished loading model!')
Expand Down
5 changes: 3 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.autograd import Variable
from data import VOCroot, VOC_CLASSES as labelmap
from PIL import Image
from data import AnnotationTransform, VOCDetection, BaseTransform
from data import AnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES
import torch.utils.data as data
from ssd import build_ssd

Expand Down Expand Up @@ -73,7 +73,8 @@ def test_net(save_folder, net, cuda, testset, transform, thresh):

if __name__ == '__main__':
# load net
net = build_ssd('test', 300, 21) # initialize SSD
num_classes = len(VOC_CLASSES) + 1 # +1 background
net = build_ssd('test', 300, num_classes) # initialize SSD
net.load_state_dict(torch.load(args.trained_model))
net.eval()
print('Finished loading model!')
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import argparse
from torch.autograd import Variable
import torch.utils.data as data
from data import v2, v1, AnnotationTransform, VOCDetection, detection_collate, VOCroot
from data import v2, v1, AnnotationTransform, VOCDetection, detection_collate, VOCroot, VOC_CLASSES
from utils.augmentations import SSDAugmentation
from layers.modules import MultiBoxLoss
from ssd import build_ssd
Expand Down Expand Up @@ -54,7 +54,7 @@ def str2bool(v):
# train_sets = 'train'
ssd_dim = 300 # only support 300 now
means = (104, 117, 123) # only support voc now
num_classes = 21
num_classes = len(VOC_CLASSES) + 1
batch_size = args.batch_size
accum_batch_size = 32
iter_size = accum_batch_size / batch_size
Expand All @@ -68,7 +68,7 @@ def str2bool(v):
import visdom
viz = visdom.Visdom()

ssd_net = build_ssd('train', 300, 21)
ssd_net = build_ssd('train', 300, num_classes)
net = ssd_net

if args.cuda:
Expand Down

0 comments on commit 06d503e

Please sign in to comment.