Welcome to our PyTorch Image Classification Training Framework. This project provides a robust and flexible framework for training image classification models using PyTorch and Weights & Biases (wandb) for visualization.
Here's a high-level overview of the project's structure:
.
├── .github # Directory for GitHub-specific files, used for commitlint
├── .husky # Directory for Husky hooks, used for commitlint
├── checkpoints # Directory for model weights, loading points, and wandb logs
├── configs # Directory for configuration files
├── dataset # Directory for the dataset
├── submissions # Directory for submission files
├── package-lock.json # File used to set up the conventional commits environment
├── package.json # File used to set up the conventional commits environment
├── run_predict.sh # Script for making predictions on new data
├── run_train.sh # Script for training the model on the dataset
├── LICENSE # MIT License file
├── README.md # This file, a concise description of the project
└── src # Source code directory
└── main
├── data # Code for loading and preprocessing the dataset
├── engine # Code for defining the training and validation loops
├── model # Code for defining the model architecture
├── options # Code for parsing command line arguments
├── predict.py # Python script for making predictions on new data
├── reorganize_paddy_dataset.py # Python script for reorganizing the dataset
└── train.py # Python script for training the model
Ensure you have the following installed on your local machine:
- Python 3.7+
- PyTorch 1.7+
- Weights & Biases
To install the necessary dependencies, run the following commands:
conda create -n pytorch python=3.9
conda activate pytorch
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
conda install pyyaml pandas tqdm wandb scikit-learn -c conda-forge
1.Prepare Your Dataset: Place your dataset in the ./dataset
directory. The dataset should be arranged in the following format:
dataset_root/train_split.csv
dataset_root/val_split.csv
dataset_root/images_train/xxx.jpg
dataset_root/images_train/xxx.jpg
...
dataset_root/images_train/xxx.jpg
dataset_root/images_train/xxx.jpg
...
2.Train the Model: Run the run_train.sh
script to train the model on your dataset.
3.Visualize the Training Process: Log in to your Weights & Biases account to visualize the training process and performance.
If you are using PyCharm, set src/main
as sources root.
Set Git Bash as your default shell in Windows.
This project is licensed under the MIT License.
For any questions or concerns, please open an issue on GitHub.
This project is base on: