ABCP [Paper]
Automatic Block-wise and Channel-wise Network Pruning (ABCP) jointly search the block pruning policy and the channel pruning policy of the network with deep reinforcement learning (DRL). A joint sample algorithm is proposed to simultaneously generate the pruning choice of each residual block and the channel pruning ratio of each convolutional layer from the discrete and continuous search space respectively. These codes are for YOLOv3 pruning through ABCP.
These codes refer to enas and tensorflow-yolov3.
- Clone this repository into your local folder, install some dependencies according to requirement.txt.
- Prepare dataset, please refer to Detection Datasets for ABCP for detailed instructions.
- Open
/code/src/yolo/config.py
and modify the variable calledYOLO.CLASSES
,YOLO.ANCHORS
,TRAIN.ANNOT_PATH
, andTEST.ANNOT_PATH
.
- Open
/code/src/yolo/config.py
and modify the variable calledTRAIN.LEARN_RATE_INIT
andTRAIN.LEARN_RATE_END
according to the dataset choice. - Open
/code/src/yolo/config.py
and modify the variable calledTRAIN.FISRT_STAGE_EPOCHS
as 20 andTRAIN.SECOND_STAGE_EPOCHS
greater than 30. - Open
/code/src/yolo/pretraining.py
and modify the variable calledoutput_dir
, which is the pretrained weight path during the searching process. python2 pretraining.py
- Open
/code/src/yolo/config.py
and modify the variable calledTRAIN.LEARN_RATE_INIT
andTRAIN.LEARN_RATE_END
according to the dataset choice. - Open
/code/src/yolo/config.py
and modify the variable calledTRAIN.FISRT_STAGE_EPOCHS
as 1 andTRAIN.SECOND_STAGE_EPOCHS
as 0. - Open
/code/src/yolo/main_cal_yolo_multitask_continuous.py
and modify the variable calledoutput_dir
,weight_path
, andnum_epochs
. nohup python2 main_cal_yolo_multitask_continuous.py
- When the searching process is finished, use the log
nohup.out
to get the pruning action with the best reward:python read_nohup_yolo.py
.
For the mAP calculating for the YOLOv3 models, we use the codes from PyTorch-YOLOv3.
YOLOv3 is adopted to illustrate the performance of our proposed ABCP framework. These three datasets are collected for the evaluation of ABCP.
The UCSD dataset is a small dataset captured from the freeway surveillance videos collected by UCSD. This dataset involves three different traffic densities each making up about one-third: the sparse traffic, the medium-density traffic, and the dense traffic. We define three classes in this dataset: truck, car, and bus. The vehicles in the images are labeled for the detection task. The resolutions of the images are all 320×240. The training and testing sets contain 683 and 76 images respectively.
The mobile robot detection dataset is collected by the robot-mounted cameras to meet the requirements of the fast and lightweight detection algorithms for the mobile robots, which is inspired RoboMaster Univeristy AI Challenge. There are two kinds of ordinary color camera with different resolutions which are 1024×512 and 640×480 respectively. Five classes have been defined: red robot, red armor, blue robot, blue armor, dead robot. The training and testing sets contain 13,914 and 5,969 images respectively. During collecting, we change series of exposure and various distances and angles of the robots to improve the robustness.
The sim2real detection dataset is divided into two sub-datasets: the real-world dataset and the simulation dataset. We search and train the model on the simulation dataset and test it on the real-world dataset. Firstly, we collect the real-world dataset by the surveillance-view ordinary color cameras in the field. The field and the mobile robots are the same as those in the mobile robot detection dataset. Secondly, we leverage Gazebo to simulate the robots and the field from the surveillance view. Then we capture the images of the simulation environment to collect the simulation dataset. The resolutions of images in the sim2real dataset are all 640×480. There is only one object class in these two datasets: robot. The training and testing sets of the simulation dataset contain 5,760 and 1,440 respectively, and the testing set of the real-world dataset contains 3,019 images.
The format of the labels is relative xywh coordinates. The documents named train.txt and test.txt list the image paths of the training dataset and the testing dataset respectively, and are used for the YOLOv3 training on Darknet. The documents named search_train.txt and search_test.txt list the image paths and the labels of the training dataset and the testing dataset respectively, and are used for the pruning policy search. It is worth noting that the format of the labels is absolute xxyy coordinates.
The data could be downloaded from Baidu Netdisk (Pwd: redc) and OneDrive.
Please download the compressed models for the three datasets from compressed_models.
We search the pruning policy of YOLOv3 on the UCSD dataset and re-train the pruned model.
Models | mAP (%) | FLOPs (G) | Params (M) | Inference Time (s) |
---|---|---|---|---|
YOLOv3 | 61.4 | 65.496 | 61.535 | 0.110 |
YOLOv4 | 63.1 | 59.659 | 63.948 | 0.132 |
YOLO-tiny | 57.4 | 5.475 | 8.674 | 0.014 |
RBCP | 66.5 | 17.973 | 4.844 | 0.042 |
ABCP (Ours) | 69.6 | 4.485 | 4.685 | 0.016 |
The detection results of the pruned YOLOv3:
We search the pruning policy of YOLOv3 on the mobile robot detection dataset and re-train the pruned model.
Models | mAP (%) | FLOPs (G) | Params (M) | Inference Time (s) |
---|---|---|---|---|
YOLOv3 | 94.9 | 65.510 | 61.545 | 0.227 |
YOLOv4 | 92.1 | 59.673 | 63.959 | 0.141 |
YOLO-tiny | 85.3 | 5.478 | 8.679 | 0.014 |
RBCP | 89.9 | 2.842 | 1.879 | 0.012 |
ABCP (Ours) | 92.1 | 0.327 | 0.299 | 0.003 |
The detection results of the pruned YOLOv3:
We search the pruning policy of YOLOv3 on the simulation dataset and test the pruned model on the real-world dataset.
Models | mAP (%) | FLOPs (G) | Params (M) | Inference Time (s) | |
sim dataset | real dataset | ||||
YOLOv3 | 95.6 | 66.5 | 65.481 | 61.524 | 0.117 |
YOLOv4 | 98.3 | 28.8 | 59.644 | 63.938 | 0.141 |
YOLO-tiny | 98.3 | 42.3 | 5.472 | 8.670 | 0.014 |
RBCP | 97.9 | 71.2 | 2.321 | 1.237 | 0.009 |
ABCP (ours) | 98.0 | 76.1 | 1.581 | 2.545 | 0.008 |
The detection results of the pruned YOLOv3 on the real-world dataset: