通过Pytorch
实现主流的图像分类模型如下:
- AlextNet
- VGGNet
- GoogLeNet
- ResNet
- DenseNet
- MobileNetV2
- MobileNetV3
- ShuffleNetV1
- ShuffleNetV2
- GhostNet
|--models :各个模型pytorch实现代码
| |--alexnet.py
| |--vggnet.py
| |--googlenet.py
| |--resnet.py
| |--densenet.py
| |--mobilenetv2.py
| |--mobilenetv3.py
| |--shufflenet.py
| |--shufflenetv2.py
| |--ghostnet.py
| |--base_model.py : 基模型
|--utils :配置文件
| |--data_utils.py :数据预处理配置
| |--train_val_utils.py : 模型训练配置
|--train.py: 训练脚本
|--predict.py: 预测脚本
花分类数据集下载地址: http://download.tensorflow.org/example_images/flower_photos.tgz
numpy==1.21.2
torch==1.9.1
torchvision==0.11.1
pillow==8.3.1
opencv-python==4.5.4.58
scipy==1.7.2
matplotlib==3.4.3
tqdm==4.62.3
下图为ResNet34的预测结果:
-
数据集的准备
- 每一个文件夹对应于一个类别的图像文件
- 以花分类数据集为例:
|--flower_photos | |--daisy | |--dandelion | |--roses | |--sunflowers | |--tulips
- 以花分类数据集为例:
- 每一个文件夹对应于一个类别的图像文件
-
运行
train.py
开始训练模型:- 必须修改的参数:
num_classes
和data_path
分别对应你的数据集类别数以及路径 - 可选修改的参数:
model_name
对应你要训练哪个图像分类模型
- 必须修改的参数:
-
训练结果预测
model_weight_path
指向训练好的权重文件model_name
表示使用哪一个模型进行预测,对应于上面的权重文件- 修改后就可以运行
predict.py
进行预测了