Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
WZMIAOMIAO committed Feb 18, 2020
1 parent 0eacea3 commit 30c4d8a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 3 additions & 2 deletions pytorch_learning/Test4_googlenet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):

self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

self.conv2 = BasicConv2d(64, 64, kernel_size=1)
self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
Expand All @@ -23,12 +24,12 @@ def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

if aux_logits:
if self.aux_logits:
self.aux1 = InceptionAux(512, num_classes)
self.aux2 = InceptionAux(528, num_classes)

Expand Down
6 changes: 3 additions & 3 deletions pytorch_learning/Test4_googlenet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
batch_size=batch_size, shuffle=False,
num_workers=0)

test_data_iter = iter(validate_loader)
test_image, test_label = test_data_iter.next()
# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()

# net = torchvision.models.googlenet(num_classes=5)
# model_dict = net.state_dict()
Expand All @@ -66,7 +66,7 @@

best_acc = 0.0
save_path = './googleNet.pth'
for epoch in range(2):
for epoch in range(30):
# train
net.train()
running_loss = 0.0
Expand Down

0 comments on commit 30c4d8a

Please sign in to comment.