forked from CharlesNord/centerloss
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0e5d562
commit 32d85e8
Showing
2 changed files
with
201 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import torch | ||
from torch.autograd import Variable | ||
import torch.utils.data as Data | ||
import torchvision | ||
from torchvision import transforms | ||
import torch.nn.functional as F | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
import numpy as np | ||
import matplotlib.patheffects as PathEffects | ||
from centerLoss import CenterLoss | ||
import torch.optim.lr_scheduler as lr_scheduler | ||
import matplotlib.cm as cm | ||
|
||
trainset = torchvision.datasets.MNIST(root='../mnist', train=True, transform=transforms.Compose( | ||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | ||
)) | ||
|
||
testset = torchvision.datasets.MNIST(root='../mnist', train=True, transform=transforms.Compose( | ||
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | ||
)) | ||
|
||
train_loader = Data.DataLoader(dataset=trainset, batch_size=128, shuffle=True, num_workers=4) | ||
test_loader = Data.DataLoader(dataset=testset, batch_size=128, shuffle=True, num_workers=4) | ||
|
||
|
||
class Net(torch.nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.extract = torch.nn.Sequential( | ||
torch.nn.Linear(784, 512), | ||
torch.nn.PReLU(), | ||
torch.nn.Linear(512, 256), | ||
torch.nn.PReLU(), | ||
torch.nn.Linear(256, 128), | ||
torch.nn.PReLU(), | ||
torch.nn.Linear(128, 64), | ||
torch.nn.PReLU(), | ||
torch.nn.Linear(64, 32), | ||
torch.nn.PReLU(), | ||
torch.nn.Linear(32, 2), | ||
) | ||
self.predict = torch.nn.Sequential( | ||
torch.nn.PReLU(), | ||
torch.nn.Linear(2, 10), | ||
) | ||
|
||
def forward(self, x): | ||
feature = self.extract(x.view(-1, 784)) | ||
pred = F.log_softmax(self.predict(feature)) | ||
return feature, pred | ||
|
||
|
||
class ConvNet(torch.nn.Module): | ||
def __init__(self): | ||
super(ConvNet, self).__init__() | ||
self.extract = torch.nn.Sequential( | ||
torch.nn.Conv2d(1, 32, kernel_size=5, padding=2), | ||
torch.nn.PReLU(), | ||
torch.nn.Conv2d(32, 32, kernel_size=5, padding=2), | ||
torch.nn.PReLU(), | ||
torch.nn.MaxPool2d(2, 2), | ||
torch.nn.Conv2d(32, 64, kernel_size=5, padding=2), | ||
torch.nn.PReLU(), | ||
torch.nn.Conv2d(64, 64, kernel_size=5, padding=2), | ||
torch.nn.PReLU(), | ||
torch.nn.MaxPool2d(2, 2), | ||
torch.nn.Conv2d(64, 128, kernel_size=5, padding=2), | ||
torch.nn.PReLU(), | ||
torch.nn.Conv2d(128, 128, kernel_size=5, padding=2), | ||
torch.nn.PReLU(), | ||
torch.nn.MaxPool2d(2, 2), | ||
) | ||
self.feat = torch.nn.Linear(128*3*3, 2) | ||
self.pred = torch.nn.Sequential( | ||
torch.nn.Linear(2, 10) | ||
) | ||
|
||
def forward(self, x): | ||
x = self.extract(x) | ||
x = x.view(-1, 128*3*3) | ||
feat = self.feat(x) | ||
pred = F.log_softmax(self.pred(feat)) | ||
return feat, pred | ||
|
||
|
||
model = Net().cuda() | ||
# model = ConvNet().cuda() | ||
optimizer4nn = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.0005) | ||
scheduler = lr_scheduler.StepLR(optimizer4nn, 20, gamma=0.8) | ||
|
||
centerloss = CenterLoss(10, 2, 0.1).cuda() | ||
nllloss = torch.nn.NLLLoss().cuda() | ||
#crossentropy = torch.nn.CrossEntropyLoss().cuda() | ||
optimizer4center = torch.optim.SGD(centerloss.parameters(), lr=0.5) | ||
|
||
|
||
def train(train_loader, model, epoch): | ||
print("Training Epoch: {}".format(epoch)) | ||
model.train() | ||
for step, (data, target) in enumerate(train_loader): | ||
data = Variable(data).cuda() | ||
target = Variable(target).cuda() | ||
feat, pred = model(data) | ||
loss = nllloss(pred, target) + centerloss(target, feat) | ||
optimizer4nn.zero_grad() | ||
optimizer4center.zero_grad() | ||
loss.backward() | ||
optimizer4nn.step() | ||
optimizer4center.step() | ||
if step % 100 == 0: | ||
print("Epoch: {} step: {}".format(epoch, step)) | ||
|
||
|
||
def test(test_loader, model, epoch): | ||
print("Predicting Epoch: {}".format(epoch)) | ||
model.eval() | ||
total_pred_label = [] | ||
total_target = [] | ||
total_feature = [] | ||
for step, (data, target) in enumerate(test_loader): | ||
data = Variable(data).cuda() | ||
target = Variable(target).cuda() | ||
feature, pred = model(data) | ||
_, pred_label = pred.max(dim=1) | ||
total_pred_label.append(pred_label.data.cpu()) | ||
total_target.append(target.data.cpu()) | ||
total_feature.append(feature.data.cpu()) | ||
|
||
total_pred_label = torch.cat(total_pred_label, dim=0) | ||
total_target = torch.cat(total_target, dim=0) | ||
total_feature = torch.cat(total_feature, dim=0) | ||
|
||
precision = torch.sum(total_pred_label == total_target) / float(total_target.shape[0]) | ||
print("Validation accuracy: {}%".format(precision * 100)) | ||
scatter(total_feature.numpy(), total_target.numpy(), epoch) | ||
|
||
|
||
def scatter(feat, label, epoch): | ||
plt.ion() | ||
plt.clf() | ||
palette = np.array(sns.color_palette('hls', 10)) | ||
ax = plt.subplot(aspect='equal') | ||
# sc = ax.scatter(feat[:, 0], feat[:, 1], lw=0, s=40, c=palette[label.astype(np.int)]) | ||
for i in range(10): | ||
plt.plot(feat[label == i, 0], feat[label == i, 1], '.', c=palette[i]) | ||
plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right') | ||
ax.axis('tight') | ||
for i in range(10): | ||
xtext, ytext = np.median(feat[label == i, :], axis=0) | ||
txt = ax.text(xtext, ytext, str(i), fontsize=18) | ||
txt.set_path_effects([PathEffects.Stroke(linewidth=5, foreground="w"), PathEffects.Normal()]) | ||
|
||
plt.draw() | ||
plt.savefig('./benchmark/centerloss_{}.png'.format(epoch)) | ||
plt.pause(0.001) | ||
|
||
|
||
for epoch in range(50): | ||
scheduler.step() | ||
train(train_loader, model, epoch) | ||
test(test_loader, model, epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# coding: utf8 | ||
import torch | ||
from torch.autograd import Variable | ||
|
||
|
||
class CenterLoss(torch.nn.Module): | ||
def __init__(self, num_classes, feat_dim, loss_weight=1.0): | ||
super(CenterLoss, self).__init__() | ||
self.num_classes = num_classes | ||
self.feat_dim = feat_dim | ||
self.loss_weight = loss_weight | ||
self.centers = torch.nn.Parameter(torch.randn(num_classes, feat_dim)) | ||
self.use_cuda = False | ||
|
||
def forward(self, y, feat): | ||
if self.use_cuda: | ||
hist = Variable( | ||
torch.histc(y.cpu().data.float(), bins=self.num_classes, min=0, max=self.num_classes) + 1).cuda() | ||
else: | ||
hist = Variable(torch.histc(y.data.float(), bins=self.num_classes, min=0, max=self.num_classes) + 1) | ||
|
||
centers_count = hist.index_select(0, y.long()) # 计算每个类别对应的数目 | ||
|
||
batch_size = feat.size()[0] | ||
feat = feat.view(batch_size, 1, 1, -1).squeeze() | ||
if feat.size()[1] != self.feat_dim: | ||
raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}".format(self.feat_dim, | ||
feat.size()[1])) | ||
centers_pred = self.centers.index_select(0, y.long()) | ||
diff = feat-centers_pred | ||
loss = self.loss_weight * 1/2.0 * (diff.pow(2).sum(1) / centers_count).sum() | ||
return loss | ||
|
||
def cuda(self, device_id=None): | ||
self.use_cuda = True | ||
return self._apply(lambda t: t.cuda(device_id)) | ||
|
||
|
||
|