The goal of this project is to provide some neural network examples and a simple training codebase for begginners.
Train Models: Open the notebook to train the models from scratch on CIFAR10/100. It will takes several hours depend on the complexity of the model and the allocated GPU type.
Test Models: Open the notebook to measure the validation accuracy on CIFAR10/100 with pretrained models. It will only take about few seconds.
You can simply use the pretrained models in your project with torch.hub
API.
It will automatically load the code and the pretrained weights from GitHub.
import torch
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
To list all available model entry, you can run:
import torch
from pprint import pprint
pprint(torch.hub.list("chenyaofo/pytorch-cifar-models", force_reload=True))
Model | Top-1 Acc.(%) | Top-5 Acc.(%) | #Params.(M) | #MAdds(M) | |
---|---|---|---|---|---|
resnet20 | 92.60 | 99.81 | 0.27 | 40.81 | model | log |
resnet32 | 93.53 | 99.77 | 0.47 | 69.12 | model | log |
resnet44 | 94.01 | 99.77 | 0.66 | 97.44 | model | log |
resnet56 | 94.37 | 99.83 | 0.86 | 125.75 | model | log |
vgg11_bn | 92.79 | 99.72 | 9.76 | 153.29 | model | log |
vgg13_bn | 94.00 | 99.77 | 9.94 | 228.79 | model | log |
vgg16_bn | 94.16 | 99.71 | 15.25 | 313.73 | model | log |
vgg19_bn | 93.91 | 99.64 | 20.57 | 398.66 | model | log |
mobilenetv2_x0_5 | 92.88 | 99.86 | 0.70 | 27.97 | model | log |
mobilenetv2_x0_75 | 93.72 | 99.79 | 1.37 | 59.31 | model | log |
mobilenetv2_x1_0 | 93.79 | 99.73 | 2.24 | 87.98 | model | log |
mobilenetv2_x1_4 | 94.22 | 99.80 | 4.33 | 170.07 | model | log |
shufflenetv2_x0_5 | 90.13 | 99.70 | 0.35 | 10.90 | model | log |
shufflenetv2_x1_0 | 92.98 | 99.73 | 1.26 | 45.00 | model | log |
shufflenetv2_x1_5 | 93.55 | 99.77 | 2.49 | 94.26 | model | log |
shufflenetv2_x2_0 | 93.81 | 99.79 | 5.37 | 187.81 | model | log |
repvgg_a0 | 94.39 | 99.82 | 7.84 | 489.08 | model | log |
repvgg_a1 | 94.89 | 99.83 | 12.82 | 851.33 | model | log |
repvgg_a2 | 94.98 | 99.82 | 26.82 | 1850.10 | model | log |
Model | Top-1 Acc.(%) | Top-5 Acc.(%) | #Params.(M) | #MAdds(M) | |
---|---|---|---|---|---|
resnet20 | 68.83 | 91.01 | 0.28 | 40.82 | model | log |
resnet32 | 70.16 | 90.89 | 0.47 | 69.13 | model | log |
resnet44 | 71.63 | 91.58 | 0.67 | 97.44 | model | log |
resnet56 | 72.63 | 91.94 | 0.86 | 125.75 | model | log |
vgg11_bn | 70.78 | 88.87 | 9.80 | 153.34 | model | log |
vgg13_bn | 74.63 | 91.09 | 9.99 | 228.84 | model | log |
vgg16_bn | 74.00 | 90.56 | 15.30 | 313.77 | model | log |
vgg19_bn | 73.87 | 90.13 | 20.61 | 398.71 | model | log |
mobilenetv2_x0_5 | 70.88 | 91.72 | 0.82 | 28.08 | model | log |
mobilenetv2_x0_75 | 73.61 | 92.61 | 1.48 | 59.43 | model | log |
mobilenetv2_x1_0 | 74.20 | 92.82 | 2.35 | 88.09 | model | log |
mobilenetv2_x1_4 | 75.98 | 93.44 | 4.50 | 170.23 | model | log |
shufflenetv2_x0_5 | 67.82 | 89.93 | 0.44 | 10.99 | model | log |
shufflenetv2_x1_0 | 72.39 | 91.46 | 1.36 | 45.09 | model | log |
shufflenetv2_x1_5 | 73.91 | 92.13 | 2.58 | 94.35 | model | log |
shufflenetv2_x2_0 | 75.35 | 92.62 | 5.55 | 188.00 | model | log |
repvgg_a0 | 75.22 | 92.93 | 7.96 | 489.19 | model | log |
repvgg_a1 | 76.12 | 92.71 | 12.94 | 851.44 | model | log |
repvgg_a2 | 77.18 | 93.51 | 26.94 | 1850.22 | model | log |