From 30c4d8a1400875489d08cdd9f0e83cc04684898b Mon Sep 17 00:00:00 2001 From: wz <605169423@qq.com> Date: Tue, 18 Feb 2020 16:33:57 +0800 Subject: [PATCH] update code --- pytorch_learning/Test4_googlenet/model.py | 5 +++-- pytorch_learning/Test4_googlenet/train.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_learning/Test4_googlenet/model.py b/pytorch_learning/Test4_googlenet/model.py index e0bfbf51a..2282c56e9 100644 --- a/pytorch_learning/Test4_googlenet/model.py +++ b/pytorch_learning/Test4_googlenet/model.py @@ -10,6 +10,7 @@ def __init__(self, num_classes=1000, aux_logits=True, init_weights=False): self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.conv2 = BasicConv2d(64, 64, kernel_size=1) self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) @@ -23,12 +24,12 @@ def __init__(self, num_classes=1000, aux_logits=True, init_weights=False): self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) - self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True) self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) - if aux_logits: + if self.aux_logits: self.aux1 = InceptionAux(512, num_classes) self.aux2 = InceptionAux(528, num_classes) diff --git a/pytorch_learning/Test4_googlenet/train.py b/pytorch_learning/Test4_googlenet/train.py index 54caaa582..cee872097 100644 --- a/pytorch_learning/Test4_googlenet/train.py +++ b/pytorch_learning/Test4_googlenet/train.py @@ -47,8 +47,8 @@ batch_size=batch_size, shuffle=False, num_workers=0) -test_data_iter = iter(validate_loader) -test_image, test_label = test_data_iter.next() +# test_data_iter = iter(validate_loader) +# test_image, test_label = test_data_iter.next() # net = torchvision.models.googlenet(num_classes=5) # model_dict = net.state_dict() @@ -66,7 +66,7 @@ best_acc = 0.0 save_path = './googleNet.pth' -for epoch in range(2): +for epoch in range(30): # train net.train() running_loss = 0.0