Skip to content

Commit

Permalink
fix some structure
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith authored Apr 30, 2020
1 parent faffd54 commit 838e698
Showing 1 changed file with 6 additions and 57 deletions.
63 changes: 6 additions & 57 deletions nicolalandro_ntsnet-cub200_ntsnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@ background-class: hub-background
body-class: hub
category: researchers
title: ntsnet
summary: a fined graned model for image classification.
summary: classify birds using this fine-grained image classifier
image: nts-net.png
author: Moreno Carraffini and Nicola Landro
tags: [vision]
github-link: https://github.com/nicolalandro/ntsnet-cub200/archive/master.zip
github-link: https://github.com/nicolalandro/ntsnet-cub200
featured_image_1: Cub200Dataset.png
featured_image_2: no-image
accelerator: "cuda-optional"
---

```python
import torch
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, **{'topN': 6, 'device':'cpu', 'num_classes': 200})
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
**{'topN': 6, 'device':'cpu', 'num_classes': 200})
```

### Example
### Example Usage

```python
from torchvision import transforms
Expand Down Expand Up @@ -52,60 +53,8 @@ with torch.no_grad():
print('bird class:', model.bird_classes[pred_id])
```

### How to train
It is a particular model and if you want to train it use the follow:

```
import torch
PROPOSAL_NUM = 6
LR = 0.001
WD = 1e-4
net = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, **{'topN': 6, 'device':'cpu', 'num_classes': 200})
creterion = torch.nn.CrossEntropyLoss()
# define optimizers
raw_parameters = list(net.pretrained_model.parameters())
part_parameters = list(net.proposal_net.parameters())
concat_parameters = list(net.concat_net.parameters())
partcls_parameters = list(net.partcls_net.parameters())
raw_optimizer = torch.optim.SGD(raw_parameters, lr=LR, momentum=0.9, weight_decay=WD)
concat_optimizer = torch.optim.SGD(concat_parameters, lr=LR, momentum=0.9, weight_decay=WD)
part_optimizer = torch.optim.SGD(part_parameters, lr=LR, momentum=0.9, weight_decay=WD)
partcls_optimizer = torch.optim.SGD(partcls_parameters, lr=LR, momentum=0.9, weight_decay=WD)
...
for i, data in enumerate(trainloader):
img, label = data[0].cuda(), data[1].cuda()
batch_size = img.size(0)
raw_optimizer.zero_grad()
part_optimizer.zero_grad()
concat_optimizer.zero_grad()
partcls_optimizer.zero_grad()
_, _, raw_logits, concat_logits, part_logits, _, top_n_prob = net(img)
part_loss = net.list_loss(part_logits.view(batch_size * PROPOSAL_NUM, -1),
label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1)).view(batch_size, PROPOSAL_NUM)
raw_loss = creterion(raw_logits, label)
concat_loss = creterion(concat_logits, label)
rank_loss = net.ranking_loss(top_n_prob, part_loss)
partcls_loss = creterion(part_logits.view(batch_size * PROPOSAL_NUM, -1),
label.unsqueeze(1).repeat(1, PROPOSAL_NUM).view(-1))
total_loss = raw_loss + rank_loss + concat_loss + partcls_loss
total_loss.backward()
raw_optimizer.step()
part_optimizer.step()
concat_optimizer.step()
partcls_optimizer.step()
```


### Model Description
This is a nts-net pretrained with CUB200 2011 dataset. A fine graned dataset of birds species.
This is an nts-net pretrained with CUB200 2011 dataset, which is a fine grained dataset of birds species.

### References
You can read the full paper at this [link](http://artelab.dista.uninsubria.it/res/research/papers/2019/2019-IVCNZ-Nawaz-Birds.pdf).
Expand Down

0 comments on commit 838e698

Please sign in to comment.