Skip to content

Commit

Permalink
DDC init
Browse files Browse the repository at this point in the history
  • Loading branch information
syoya1997 committed Oct 8, 2018
1 parent c18eaa7 commit 1eeb854
Show file tree
Hide file tree
Showing 4,117 changed files with 450 additions and 0 deletions.
The diff you're trying to view is too large. We only load the first 3000 changed files.
157 changes: 157 additions & 0 deletions DDC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from __future__ import print_function

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import warnings
warnings.filterwarnings('ignore')

import math
import model
import torch
import dataloader

from torch import nn
from torch import optim
from torch.autograd import Variable

cuda = torch.cuda.is_available()

def step_decay(epoch, learning_rate):
"""
learning rate step decay
:param epoch: current training epoch
:param learning_rate: initial learning rate
:return: learning rate after step decay
"""
initial_lrate = learning_rate
drop = 0.8
epochs_drop = 10.0
lrate = initial_lrate * math.pow(drop, math.floor((1 + epoch) / epochs_drop))
return lrate

def train_ddcnet(epoch, model, learning_rate, source_loader, target_loader):
"""
train source and target domain on ddcnet
:param epoch: current training epoch
:param model: defined ddcnet
:param learning_rate: initial learning rate
:param source_loader: source loader
:param target_loader: target train loader
:return:
"""
log_interval = 10
LEARNING_RATE = step_decay(epoch, learning_rate)
print(f'Learning Rate: {LEARNING_RATE}')
optimizer = optim.Adam([
{'params': model.features.parameters()},
{'params': model.classifier.parameters()},
{'params': model.bottleneck.parameters(), 'lr': LEARNING_RATE},
{'params': model.final_classifier.parameters(), 'lr': LEARNING_RATE}
], lr=LEARNING_RATE / 10)

# enter training mode
model.train()

iter_source = iter(source_loader)
iter_target = iter(target_loader)
num_iter = len(source_loader)

correct = 0
total_loss = 0
clf_criterion = nn.CrossEntropyLoss()

for i in range(1, num_iter):
source_data, source_label = iter_source.next()
target_data, _ = iter_target.next()
if i % len(target_loader) == 0:
iter_target = iter(target_loader)
if cuda:
source_data, source_label = source_data.cuda(), source_label.cuda()
target_data = target_data.cuda()

source_data, source_label = Variable(source_data), Variable(source_label)
target_data = Variable(target_data)

optimizer.zero_grad()

source_preds, mmd_loss = model(source_data, target_data)
preds = source_preds.data.max(1, keepdim=True)[1]

correct += preds.eq(source_label.data.view_as(preds)).sum()
clf_loss = clf_criterion(source_preds, source_label)
loss = clf_loss + 0.25 * mmd_loss
total_loss += clf_loss

loss.backward()
optimizer.step()

if i % log_interval == 0:
print('Train Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tsoft_Loss: {:.6f}\tmmd_Loss: {:.6f}'.format(
epoch, i * len(source_data), len(source_loader) * BATCH_SIZE,
100. * i / len(source_loader), loss.data[0], clf_loss.data[0], mmd_loss.data[0]))

total_loss /= len(source_loader)
acc_train = float(correct) * 100. / (len(source_loader) * BATCH_SIZE)

print('{} set: Average classification loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
SOURCE_NAME, total_loss.data[0], correct, len(source_loader.dataset), acc_train))


def test_ddcnet(model, target_loader):
"""
test target data on fine-tuned alexnet
:param model: trained alexnet on source data set
:param target_loader: target dataloader
:return: correct num
"""
# enter evaluation mode
clf_criterion = nn.CrossEntropyLoss()

model.eval()
test_loss = 0
correct = 0

for data, target in target_test_loader:
if cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
target_preds, _ = model(data, data)
test_loss += clf_criterion(target_preds, target) # sum up batch loss
pred = target_preds.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum()

test_loss /= len(target_loader)
print('{} set: Average classification loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
TARGET_NAME, test_loss.data[0], correct, len(target_loader.dataset),
100. * correct / len(target_loader.dataset)))
return correct

if __name__ == '__main__':

ROOT_PATH = '../../data/Office31'
SOURCE_NAME = 'amazon'
TARGET_NAME = 'webcam'

BATCH_SIZE = 256
TRAIN_EPOCHS = 200
learning_rate = 1e-3

source_loader = dataloader.load_training(ROOT_PATH, SOURCE_NAME, BATCH_SIZE)
target_train_loader = dataloader.load_training(ROOT_PATH, TARGET_NAME, BATCH_SIZE)
target_test_loader = dataloader.load_testing(ROOT_PATH, TARGET_NAME, BATCH_SIZE)
print('Load data complete')

ddcnet = model.DCCNet(num_classes=31)
print('Construct model complete')

# load pretrained alexnet model
ddcnet = model.load_pretrained_alexnet(ddcnet)
print('Load pretrained alexnet parameters complete\n')

if cuda: ddcnet.cuda()

for epoch in range(1, TRAIN_EPOCHS + 1):
print(f'Train Epoch {epoch}:')
train_ddcnet(epoch, ddcnet, learning_rate, source_loader, target_train_loader)
correct = test_ddcnet(ddcnet, target_test_loader)
148 changes: 148 additions & 0 deletions alextnet_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from __future__ import print_function

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import warnings
warnings.filterwarnings('ignore')

import math
import model
import torch
import dataloader

from torch import nn
from torch import optim
from torch.autograd import Variable

cuda = torch.cuda.is_available()

def step_decay(epoch, learning_rate):
"""
learning rate step decay
:param epoch: current training epoch
:param learning_rate: initial learning rate
:return: learning rate after step decay
"""
initial_lrate = learning_rate
drop = 0.8
epochs_drop = 10.0
lrate = initial_lrate * math.pow(drop, math.floor((1 + epoch) / epochs_drop))
return lrate

def train_alexnet(epoch, model, learning_rate, source_loader):
"""
train source on alexnet
:param epoch: current training epoch
:param model: defined alexnet
:param learning_rate: initial learning rate
:param source_loader: source loader
:return:
"""
log_interval = 10
LEARNING_RATE = step_decay(epoch, learning_rate)
print(f'Learning Rate: {LEARNING_RATE}')
optimizer = optim.Adam([
{'params': model.features.parameters()},
{'params': model.classifier.parameters()},
{'params': model.final_classifier.parameters(), 'lr': LEARNING_RATE}
], lr=LEARNING_RATE / 10)

# enter training mode
model.train()

iter_source = iter(source_loader)
num_iter = len(source_loader)

correct = 0
total_loss = 0
clf_criterion = nn.CrossEntropyLoss()

for i in range(1, num_iter):
source_data, source_label = iter_source.next()
if cuda:
source_data, source_label = source_data.cuda(), source_label.cuda()
source_data, source_label = Variable(source_data), Variable(source_label)

optimizer.zero_grad()

source_preds = model(source_data)
preds = source_preds.data.max(1, keepdim=True)[1]
correct += preds.eq(source_label.data.view_as(preds)).sum()

loss = clf_criterion(source_preds, source_label)
total_loss += loss

loss.backward()
optimizer.step()

if i % log_interval == 0:
print('Train Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, i * len(source_data), len(source_loader) * BATCH_SIZE,
100. * i / len(source_loader), loss.data[0]))

total_loss /= len(source_loader)
acc_train = float(correct) * 100. / (len(source_loader) * BATCH_SIZE)

print('{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
SOURCE_NAME, total_loss.data[0], correct, len(source_loader.dataset), acc_train))


def test_alexnet(model, target_loader):
"""
test target data on fine-tuned alexnet
:param model: trained alexnet on source data set
:param target_loader: target dataloader
:return: correct num
"""
# enter evaluation mode
clf_criterion = nn.CrossEntropyLoss()

model.eval()
test_loss = 0
correct = 0

for data, target in target_test_loader:
if cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
target_preds = model(data)
test_loss += clf_criterion(target_preds, target) # sum up batch loss
pred = target_preds.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum()

test_loss /= len(target_loader)
print('{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
TARGET_NAME, test_loss.data[0], correct, len(target_loader.dataset),
100. * correct / len(target_loader.dataset)))
return correct


if __name__ == '__main__':

ROOT_PATH = '../../data/Office31'
SOURCE_NAME = 'amazon'
TARGET_NAME = 'webcam'

BATCH_SIZE = 256
TRAIN_EPOCHS = 200
learning_rate = 1e-3

source_loader = dataloader.load_training(ROOT_PATH, SOURCE_NAME, BATCH_SIZE)
target_train_loader = dataloader.load_training(ROOT_PATH, TARGET_NAME, BATCH_SIZE)
target_test_loader = dataloader.load_testing(ROOT_PATH, TARGET_NAME, BATCH_SIZE)
print('Load data complete')

alexnet = model.Alexnet_finetune(num_classes=31)
print('Construct model complete')

# load pretrained alexnet model
alexnet = model.load_pretrained_alexnet(alexnet)
print('Load pretrained alexnet parameters complete\n')

if cuda: alexnet.cuda()

for epoch in range(1, TRAIN_EPOCHS + 1):
print(f'Train Epoch {epoch}:')
train_alexnet(epoch, alexnet, learning_rate, source_loader)
correct = test_alexnet(alexnet, target_test_loader)
Binary file added data/.DS_Store
Binary file not shown.
Binary file added data/Office31.tar
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/Office31/amazon/images/bike/frame_0001.jpg
Binary file added data/Office31/amazon/images/bike/frame_0002.jpg
Binary file added data/Office31/amazon/images/bike/frame_0003.jpg
Binary file added data/Office31/amazon/images/bike/frame_0004.jpg
Binary file added data/Office31/amazon/images/bike/frame_0005.jpg
Binary file added data/Office31/amazon/images/bike/frame_0006.jpg
Binary file added data/Office31/amazon/images/bike/frame_0007.jpg
Binary file added data/Office31/amazon/images/bike/frame_0008.jpg
Binary file added data/Office31/amazon/images/bike/frame_0009.jpg
Binary file added data/Office31/amazon/images/bike/frame_0010.jpg
Binary file added data/Office31/amazon/images/bike/frame_0011.jpg
Binary file added data/Office31/amazon/images/bike/frame_0012.jpg
Binary file added data/Office31/amazon/images/bike/frame_0013.jpg
Binary file added data/Office31/amazon/images/bike/frame_0014.jpg
Binary file added data/Office31/amazon/images/bike/frame_0015.jpg
Binary file added data/Office31/amazon/images/bike/frame_0016.jpg
Binary file added data/Office31/amazon/images/bike/frame_0017.jpg
Binary file added data/Office31/amazon/images/bike/frame_0018.jpg
Binary file added data/Office31/amazon/images/bike/frame_0019.jpg
Binary file added data/Office31/amazon/images/bike/frame_0020.jpg
Binary file added data/Office31/amazon/images/bike/frame_0021.jpg
Binary file added data/Office31/amazon/images/bike/frame_0022.jpg
Binary file added data/Office31/amazon/images/bike/frame_0023.jpg
Binary file added data/Office31/amazon/images/bike/frame_0024.jpg
Binary file added data/Office31/amazon/images/bike/frame_0025.jpg
Binary file added data/Office31/amazon/images/bike/frame_0026.jpg
Binary file added data/Office31/amazon/images/bike/frame_0027.jpg
Binary file added data/Office31/amazon/images/bike/frame_0028.jpg
Binary file added data/Office31/amazon/images/bike/frame_0029.jpg
Binary file added data/Office31/amazon/images/bike/frame_0030.jpg
Binary file added data/Office31/amazon/images/bike/frame_0031.jpg
Binary file added data/Office31/amazon/images/bike/frame_0032.jpg
Binary file added data/Office31/amazon/images/bike/frame_0034.jpg
Binary file added data/Office31/amazon/images/bike/frame_0035.jpg
Binary file added data/Office31/amazon/images/bike/frame_0036.jpg
Binary file added data/Office31/amazon/images/bike/frame_0037.jpg
Binary file added data/Office31/amazon/images/bike/frame_0038.jpg
Binary file added data/Office31/amazon/images/bike/frame_0039.jpg
Binary file added data/Office31/amazon/images/bike/frame_0040.jpg
Binary file added data/Office31/amazon/images/bike/frame_0041.jpg
Binary file added data/Office31/amazon/images/bike/frame_0042.jpg
Binary file added data/Office31/amazon/images/bike/frame_0043.jpg
Binary file added data/Office31/amazon/images/bike/frame_0044.jpg
Binary file added data/Office31/amazon/images/bike/frame_0046.jpg
Binary file added data/Office31/amazon/images/bike/frame_0047.jpg
Binary file added data/Office31/amazon/images/bike/frame_0048.jpg
Binary file added data/Office31/amazon/images/bike/frame_0049.jpg
Binary file added data/Office31/amazon/images/bike/frame_0050.jpg
Binary file added data/Office31/amazon/images/bike/frame_0051.jpg
Binary file added data/Office31/amazon/images/bike/frame_0052.jpg
Binary file added data/Office31/amazon/images/bike/frame_0053.jpg
Binary file added data/Office31/amazon/images/bike/frame_0054.jpg
Binary file added data/Office31/amazon/images/bike/frame_0055.jpg
Binary file added data/Office31/amazon/images/bike/frame_0056.jpg
Binary file added data/Office31/amazon/images/bike/frame_0057.jpg
Binary file added data/Office31/amazon/images/bike/frame_0058.jpg
Binary file added data/Office31/amazon/images/bike/frame_0059.jpg
Binary file added data/Office31/amazon/images/bike/frame_0061.jpg
Binary file added data/Office31/amazon/images/bike/frame_0062.jpg
Binary file added data/Office31/amazon/images/bike/frame_0064.jpg
Binary file added data/Office31/amazon/images/bike/frame_0065.jpg
Binary file added data/Office31/amazon/images/bike/frame_0066.jpg
Binary file added data/Office31/amazon/images/bike/frame_0067.jpg
Binary file added data/Office31/amazon/images/bike/frame_0068.jpg
Binary file added data/Office31/amazon/images/bike/frame_0069.jpg
Binary file added data/Office31/amazon/images/bike/frame_0070.jpg
Binary file added data/Office31/amazon/images/bike/frame_0071.jpg
Binary file added data/Office31/amazon/images/bike/frame_0072.jpg
Binary file added data/Office31/amazon/images/bike/frame_0073.jpg
Binary file added data/Office31/amazon/images/bike/frame_0074.jpg
Binary file added data/Office31/amazon/images/bike/frame_0075.jpg
Binary file added data/Office31/amazon/images/bike/frame_0076.jpg
Binary file added data/Office31/amazon/images/bike/frame_0077.jpg
Binary file added data/Office31/amazon/images/bike/frame_0078.jpg
Binary file added data/Office31/amazon/images/bike/frame_0079.jpg
Binary file added data/Office31/amazon/images/bike/frame_0080.jpg
Binary file added data/Office31/amazon/images/bike/frame_0081.jpg
Binary file added data/Office31/amazon/images/bike/frame_0082.jpg
Loading

0 comments on commit 1eeb854

Please sign in to comment.