diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md deleted file mode 100644 index 101235f7..00000000 --- a/.github/ISSUE_TEMPLATE/bug-report.md +++ /dev/null @@ -1,49 +0,0 @@ ---- -name: "\U0001F41B Bug Report" -about: Submit a bug report to help us improve Mask R-CNN Benchmark - ---- - -## 🐛 Bug - - - -## To Reproduce - -Steps to reproduce the behavior: - -1. -1. -1. - - - -## Expected behavior - - - -## Environment - -Please copy and paste the output from the -[environment collection script from PyTorch](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py) -(or fill out the checklist below manually). - -You can get the script and run it with: -``` -wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py -# For security purposes, please check the contents of collect_env.py before running it. -python collect_env.py -``` - - - PyTorch Version (e.g., 1.0): - - OS (e.g., Linux): - - How you installed PyTorch (`conda`, `pip`, source): - - Build command you used (if compiling from source): - - Python version: - - CUDA/cuDNN version: - - GPU models and configuration: - - Any other relevant information: - -## Additional context - - diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md deleted file mode 100644 index 6c874564..00000000 --- a/.github/ISSUE_TEMPLATE/feature-request.md +++ /dev/null @@ -1,24 +0,0 @@ ---- -name: "\U0001F680Feature Request" -about: Submit a proposal/request for a new Mask R-CNN Benchmark feature - ---- - -## 🚀 Feature - - -## Motivation - - - -## Pitch - - - -## Alternatives - - - -## Additional context - - diff --git a/.github/ISSUE_TEMPLATE/questions-help-support.md b/.github/ISSUE_TEMPLATE/questions-help-support.md deleted file mode 100644 index 992f1b5f..00000000 --- a/.github/ISSUE_TEMPLATE/questions-help-support.md +++ /dev/null @@ -1,7 +0,0 @@ ---- -name: "❓Questions/Help/Support" -about: Do you need support? - ---- - -## ❓ Questions and Help diff --git a/INSTALL.md b/INSTALL.md index 4db4b5bb..d38e6dc4 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -17,13 +17,13 @@ # for that, check that `which conda`, `which pip` and `which python` points to the # right path. From a clean conda env, this is what you need to do -conda create --name maskrcnn_benchmark -conda activate maskrcnn_benchmark +conda create --name FCOS +conda activate FCOS # this installs the right pip and dependencies for the fresh python conda install ipython -# maskrcnn_benchmark and coco api dependencies +# FCOS and coco api dependencies pip install ninja yacs cython matplotlib tqdm # follow PyTorch installation in https://pytorch.org/get-started/locally/ @@ -40,8 +40,8 @@ python setup.py build_ext install # install PyTorch Detection cd $INSTALL_DIR -git clone https://github.com/facebookresearch/maskrcnn-benchmark.git -cd maskrcnn-benchmark +git clone https://github.com/tianzhi0549/FCOS.git +cd FCOS # the following will install the lib with # symbolic links, so that you can modify @@ -57,6 +57,7 @@ unset INSTALL_DIR ``` ### Option 2: Docker Image (Requires CUDA, Linux only) +*The following steps are for original maskrcnn-benchmark. Please change the repository name if needed.* Build image with defaults (`CUDA=9.0`, `CUDNN=7`, `FORCE_CUDA=1`): diff --git a/LICENSE b/LICENSE index 8585e11b..6d4eb2c8 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,25 @@ -MIT License +FCOS for non-commercial purposes -Copyright (c) 2018 Facebook +Copyright (c) 2019 the authors +All rights reserved. -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MASKRCNN_README.md b/MASKRCNN_README.md new file mode 100644 index 00000000..780722ed --- /dev/null +++ b/MASKRCNN_README.md @@ -0,0 +1,251 @@ +# Faster R-CNN and Mask R-CNN in PyTorch 1.0 + +This project aims at providing the necessary building blocks for easily +creating detection and segmentation models using PyTorch 1.0. + +![alt text](demo/demo_e2e_mask_rcnn_X_101_32x8d_FPN_1x.png "from http://cocodataset.org/#explore?id=345434") + +## Highlights +- **PyTorch 1.0:** RPN, Faster R-CNN and Mask R-CNN implementations that matches or exceeds Detectron accuracies +- **Very fast**: up to **2x** faster than [Detectron](https://github.com/facebookresearch/Detectron) and **30%** faster than [mmdetection](https://github.com/open-mmlab/mmdetection) during training. See [MODEL_ZOO.md](MODEL_ZOO.md) for more details. +- **Memory efficient:** uses roughly 500MB less GPU memory than mmdetection during training +- **Multi-GPU training and inference** +- **Batched inference:** can perform inference using multiple images per batch per GPU +- **CPU support for inference:** runs on CPU in inference time. See our [webcam demo](demo) for an example +- Provides pre-trained models for almost all reference Mask R-CNN and Faster R-CNN configurations with 1x schedule. + +## Webcam and Jupyter notebook demo + +We provide a simple webcam demo that illustrates how you can use `maskrcnn_benchmark` for inference: +```bash +cd demo +# by default, it runs on the GPU +# for best results, use min-image-size 800 +python webcam.py --min-image-size 800 +# can also run it on the CPU +python webcam.py --min-image-size 300 MODEL.DEVICE cpu +# or change the model that you want to use +python webcam.py --config-file ../configs/caffe2/e2e_mask_rcnn_R_101_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu +# in order to see the probability heatmaps, pass --show-mask-heatmaps +python webcam.py --min-image-size 300 --show-mask-heatmaps MODEL.DEVICE cpu +# for the keypoint demo +python webcam.py --config-file ../configs/caffe2/e2e_keypoint_rcnn_R_50_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu +``` + +A notebook with the demo can be found in [demo/Mask_R-CNN_demo.ipynb](demo/Mask_R-CNN_demo.ipynb). + +## Installation + +Check [INSTALL.md](INSTALL.md) for installation instructions. + + +## Model Zoo and Baselines + +Pre-trained models, baselines and comparison with Detectron and mmdetection +can be found in [MODEL_ZOO.md](MODEL_ZOO.md) + +## Inference in a few lines +We provide a helper class to simplify writing inference pipelines using pre-trained models. +Here is how we would do it. Run this from the `demo` folder: +```python +from maskrcnn_benchmark.config import cfg +from predictor import COCODemo + +config_file = "../configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml" + +# update the config options with the config file +cfg.merge_from_file(config_file) +# manual override some options +cfg.merge_from_list(["MODEL.DEVICE", "cpu"]) + +coco_demo = COCODemo( + cfg, + min_image_size=800, + confidence_threshold=0.7, +) +# load image and then run prediction +image = ... +predictions = coco_demo.run_on_opencv_image(image) +``` + +## Perform training on COCO dataset + +For the following examples to work, you need to first install `maskrcnn_benchmark`. + +You will also need to download the COCO dataset. +We recommend to symlink the path to the coco dataset to `datasets/` as follows + +We use `minival` and `valminusminival` sets from [Detectron](https://github.com/facebookresearch/Detectron/blob/master/detectron/datasets/data/README.md#coco-minival-annotations) + +```bash +# symlink the coco dataset +cd ~/github/maskrcnn-benchmark +mkdir -p datasets/coco +ln -s /path_to_coco_dataset/annotations datasets/coco/annotations +ln -s /path_to_coco_dataset/train2014 datasets/coco/train2014 +ln -s /path_to_coco_dataset/test2014 datasets/coco/test2014 +ln -s /path_to_coco_dataset/val2014 datasets/coco/val2014 +# or use COCO 2017 version +ln -s /path_to_coco_dataset/annotations datasets/coco/annotations +ln -s /path_to_coco_dataset/train2017 datasets/coco/train2017 +ln -s /path_to_coco_dataset/test2017 datasets/coco/test2017 +ln -s /path_to_coco_dataset/val2017 datasets/coco/val2017 + +# for pascal voc dataset: +ln -s /path_to_VOCdevkit_dir datasets/voc +``` + +P.S. `COCO_2017_train` = `COCO_2014_train` + `valminusminival` , `COCO_2017_val` = `minival` + + +You can also configure your own paths to the datasets. +For that, all you need to do is to modify `maskrcnn_benchmark/config/paths_catalog.py` to +point to the location where your dataset is stored. +You can also create a new `paths_catalog.py` file which implements the same two classes, +and pass it as a config argument `PATHS_CATALOG` during training. + +### Single GPU training + +Most of the configuration files that we provide assume that we are running on 8 GPUs. +In order to be able to run it on fewer GPUs, there are a few possibilities: + +**1. Run the following without modifications** + +```bash +python /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "/path/to/config/file.yaml" +``` +This should work out of the box and is very similar to what we should do for multi-GPU training. +But the drawback is that it will use much more GPU memory. The reason is that we set in the +configuration files a global batch size that is divided over the number of GPUs. So if we only +have a single GPU, this means that the batch size for that GPU will be 8x larger, which might lead +to out-of-memory errors. + +If you have a lot of memory available, this is the easiest solution. + +**2. Modify the cfg parameters** + +If you experience out-of-memory errors, you can reduce the global batch size. But this means that +you'll also need to change the learning rate, the number of iterations and the learning rate schedule. + +Here is an example for Mask R-CNN R-50 FPN with the 1x schedule: +```bash +python tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 SOLVER.MAX_ITER 720000 SOLVER.STEPS "(480000, 640000)" TEST.IMS_PER_BATCH 1 +``` +This follows the [scheduling rules from Detectron.](https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14-L30) +Note that we have multiplied the number of iterations by 8x (as well as the learning rate schedules), +and we have divided the learning rate by 8x. + +We also changed the batch size during testing, but that is generally not necessary because testing +requires much less memory than training. + + +### Multi-GPU training +We use internally `torch.distributed.launch` in order to launch +multi-gpu training. This utility function from PyTorch spawns as many +Python processes as the number of GPUs we want to use, and each Python +process will only use a single GPU. + +```bash +export NGPUS=8 +python -m torch.distributed.launch --nproc_per_node=$NGPUS /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "path/to/config/file.yaml" +``` + +## Abstractions +For more information on some of the main abstractions in our implementation, see [ABSTRACTIONS.md](ABSTRACTIONS.md). + +## Adding your own dataset + +This implementation adds support for COCO-style datasets. +But adding support for training on a new dataset can be done as follows: +```python +from maskrcnn_benchmark.structures.bounding_box import BoxList + +class MyDataset(object): + def __init__(self, ...): + # as you would do normally + + def __getitem__(self, idx): + # load the image as a PIL Image + image = ... + + # load the bounding boxes as a list of list of boxes + # in this case, for illustrative purposes, we use + # x1, y1, x2, y2 order. + boxes = [[0, 0, 10, 10], [10, 20, 50, 50]] + # and labels + labels = torch.tensor([10, 20]) + + # create a BoxList from the boxes + boxlist = BoxList(boxes, image.size, mode="xyxy") + # add the labels to the boxlist + boxlist.add_field("labels", labels) + + if self.transforms: + image, boxlist = self.transforms(image, boxlist) + + # return the image, the boxlist and the idx in your dataset + return image, boxlist, idx + + def get_img_info(self, idx): + # get img_height and img_width. This is used if + # we want to split the batches according to the aspect ratio + # of the image, as it can be more efficient than loading the + # image from disk + return {"height": img_height, "width": img_width} +``` +That's it. You can also add extra fields to the boxlist, such as segmentation masks +(using `structures.segmentation_mask.SegmentationMask`), or even your own instance type. + +For a full example of how the `COCODataset` is implemented, check [`maskrcnn_benchmark/data/datasets/coco.py`](maskrcnn_benchmark/data/datasets/coco.py). + +Once you have created your dataset, it needs to be added in a couple of places: +- [`maskrcnn_benchmark/data/datasets/__init__.py`](maskrcnn_benchmark/data/datasets/__init__.py): add it to `__all__` +- [`maskrcnn_benchmark/config/paths_catalog.py`](maskrcnn_benchmark/config/paths_catalog.py): `DatasetCatalog.DATASETS` and corresponding `if` clause in `DatasetCatalog.get()` + +### Testing +While the aforementioned example should work for training, we leverage the +cocoApi for computing the accuracies during testing. Thus, test datasets +should currently follow the cocoApi for now. + +To enable your dataset for testing, add a corresponding if statement in [`maskrcnn_benchmark/data/datasets/evaluation/__init__.py`](maskrcnn_benchmark/data/datasets/evaluation/__init__.py): +```python +if isinstance(dataset, datasets.MyDataset): + return coco_evaluation(**args) +``` + +## Finetuning from Detectron weights on custom datasets +Create a script `tools/trim_detectron_model.py` like [here](https://gist.github.com/wangg12/aea194aa6ab6a4de088f14ee193fd968). +You can decide which keys to be removed and which keys to be kept by modifying the script. + +Then you can simply point the converted model path in the config file by changing `MODEL.WEIGHT`. + +For further information, please refer to [#15](https://github.com/facebookresearch/maskrcnn-benchmark/issues/15). + +## Troubleshooting +If you have issues running or compiling this code, we have compiled a list of common issues in +[TROUBLESHOOTING.md](TROUBLESHOOTING.md). If your issue is not present there, please feel +free to open a new issue. + +## Citations +Please consider citing this project in your publications if it helps your research. The following is a BibTeX reference. The BibTeX entry requires the `url` LaTeX package. +``` +@misc{massa2018mrcnn, +author = {Massa, Francisco and Girshick, Ross}, +title = {{maskrcnn-benchmark: Fast, modular reference implementation of Instance Segmentation and Object Detection algorithms in PyTorch}}, +year = {2018}, +howpublished = {\url{https://github.com/facebookresearch/maskrcnn-benchmark}}, +note = {Accessed: [Insert date here]} +} +``` + +## Projects using maskrcnn-benchmark + +- [RetinaMask: Learning to predict masks improves state-of-the-art single-shot detection for free](https://arxiv.org/abs/1901.03353). + Cheng-Yang Fu, Mykhailo Shvets, and Alexander C. Berg. + Tech report, arXiv,1901.03353. + + + +## License + +maskrcnn-benchmark is released under the MIT license. See [LICENSE](LICENSE) for additional details. diff --git a/README.md b/README.md index 780722ed..b5e02c88 100644 --- a/README.md +++ b/README.md @@ -1,251 +1,93 @@ -# Faster R-CNN and Mask R-CNN in PyTorch 1.0 +# FCOS: Fully Convolutional One-Stage Object Detection -This project aims at providing the necessary building blocks for easily -creating detection and segmentation models using PyTorch 1.0. +The codes are used for implementing FCOS for object detection, described in: -![alt text](demo/demo_e2e_mask_rcnn_X_101_32x8d_FPN_1x.png "from http://cocodataset.org/#explore?id=345434") + FCOS: Fully Convolutional One-Stage Object Detection, + Tian, Zhi, Chunhua Shen, Hao Chen, and Tong He, + arXiv preprint arXiv:1904.01355 (2019). + +The full paper is available at: [https://arxiv.org/abs/1904.01355](https://arxiv.org/abs/1904.01355). ## Highlights -- **PyTorch 1.0:** RPN, Faster R-CNN and Mask R-CNN implementations that matches or exceeds Detectron accuracies -- **Very fast**: up to **2x** faster than [Detectron](https://github.com/facebookresearch/Detectron) and **30%** faster than [mmdetection](https://github.com/open-mmlab/mmdetection) during training. See [MODEL_ZOO.md](MODEL_ZOO.md) for more details. -- **Memory efficient:** uses roughly 500MB less GPU memory than mmdetection during training -- **Multi-GPU training and inference** -- **Batched inference:** can perform inference using multiple images per batch per GPU -- **CPU support for inference:** runs on CPU in inference time. See our [webcam demo](demo) for an example -- Provides pre-trained models for almost all reference Mask R-CNN and Faster R-CNN configurations with 1x schedule. - -## Webcam and Jupyter notebook demo - -We provide a simple webcam demo that illustrates how you can use `maskrcnn_benchmark` for inference: -```bash -cd demo -# by default, it runs on the GPU -# for best results, use min-image-size 800 -python webcam.py --min-image-size 800 -# can also run it on the CPU -python webcam.py --min-image-size 300 MODEL.DEVICE cpu -# or change the model that you want to use -python webcam.py --config-file ../configs/caffe2/e2e_mask_rcnn_R_101_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu -# in order to see the probability heatmaps, pass --show-mask-heatmaps -python webcam.py --min-image-size 300 --show-mask-heatmaps MODEL.DEVICE cpu -# for the keypoint demo -python webcam.py --config-file ../configs/caffe2/e2e_keypoint_rcnn_R_50_FPN_1x_caffe2.yaml --min-image-size 300 MODEL.DEVICE cpu -``` +- **Totally anchor-free:** FCOS completely avoids the complicated computation related to anchor boxes and all hyper-parameters of anchor boxes. +- **Memory-efficient:** FCOS uses 2x less training memory footprint than its anchor-based counterpart RetinaNet. +- **Better performance:** Compared to RetinaNet, FCOS has better performance under exactly the same training and testing settings. +- **State-of-the-art performance:** Without bells and whistles, FCOS achieves state-of-the-art performances. +It achieves **41.0%** (ResNet-101-FPN) and **42.1%** (ResNeXt-32x8d-101) in AP on coco test-dev. +- **Faster:** FCOS enjoys faster training and inference speed than RetinaNet. -A notebook with the demo can be found in [demo/Mask_R-CNN_demo.ipynb](demo/Mask_R-CNN_demo.ipynb). +## Required hardware +We use 8 Nvidia V100 GPUs. \ +But 4 1080Ti GPUs can also train a fully-fledged ResNet-50-FPN based FCOS since FCOS is memory-efficient. ## Installation -Check [INSTALL.md](INSTALL.md) for installation instructions. - - -## Model Zoo and Baselines - -Pre-trained models, baselines and comparison with Detectron and mmdetection -can be found in [MODEL_ZOO.md](MODEL_ZOO.md) - -## Inference in a few lines -We provide a helper class to simplify writing inference pipelines using pre-trained models. -Here is how we would do it. Run this from the `demo` folder: -```python -from maskrcnn_benchmark.config import cfg -from predictor import COCODemo +This FCOS implementation is based on [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark), so its installation is the same as original maskrcnn-benchmark. -config_file = "../configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml" +Please check [INSTALL.md](INSTALL.md) for installation instructions. +You may also want to see the original [README.md](MASKRCNN_README.md) of maskrcnn-benchmark. -# update the config options with the config file -cfg.merge_from_file(config_file) -# manual override some options -cfg.merge_from_list(["MODEL.DEVICE", "cpu"]) +## Inference +The inference command line on coco minival split: -coco_demo = COCODemo( - cfg, - min_image_size=800, - confidence_threshold=0.7, -) -# load image and then run prediction -image = ... -predictions = coco_demo.run_on_opencv_image(image) -``` - -## Perform training on COCO dataset - -For the following examples to work, you need to first install `maskrcnn_benchmark`. + python tools/test_net.py \ + --config-file configs/fcos/fcos_R_50_FPN_1x.yaml \ + MODEL.WEIGHT models/FCOS_R_50_FPN_1x.pth \ + TEST.IMS_PER_BATCH 4 -You will also need to download the COCO dataset. -We recommend to symlink the path to the coco dataset to `datasets/` as follows +Please note that: +1) If your model has other name, please replace `models/FCOS_R_50_FPN_1x.pth` with the name. +2) If you enounter out-of-memory error, please try to reduce `TEST.IMS_PER_BATCH` to 1. +3) If you want to evaluate another model, please change `--config-file` to its config file (in [configs/fcos](configs/fcos)) and `MODEL.WEIGHT` to its weights file. -We use `minival` and `valminusminival` sets from [Detectron](https://github.com/facebookresearch/Detectron/blob/master/detectron/datasets/data/README.md#coco-minival-annotations) - -```bash -# symlink the coco dataset -cd ~/github/maskrcnn-benchmark -mkdir -p datasets/coco -ln -s /path_to_coco_dataset/annotations datasets/coco/annotations -ln -s /path_to_coco_dataset/train2014 datasets/coco/train2014 -ln -s /path_to_coco_dataset/test2014 datasets/coco/test2014 -ln -s /path_to_coco_dataset/val2014 datasets/coco/val2014 -# or use COCO 2017 version -ln -s /path_to_coco_dataset/annotations datasets/coco/annotations -ln -s /path_to_coco_dataset/train2017 datasets/coco/train2017 -ln -s /path_to_coco_dataset/test2017 datasets/coco/test2017 -ln -s /path_to_coco_dataset/val2017 datasets/coco/val2017 - -# for pascal voc dataset: -ln -s /path_to_VOCdevkit_dir datasets/voc -``` +For your convenience, we provide the following trained models (more models are coming soon). -P.S. `COCO_2017_train` = `COCO_2014_train` + `valminusminival` , `COCO_2017_val` = `minival` - +Model | Total training mem (GB) | Multi-scale training | Testing time / im | AP (minival) | AP (test-dev) | Link +--- |:---:|:---:|:---:|:---:|:--:|:---: +FCOS_R_50_FPN_1x | 29.3 | No | 71ms | 36.6 | 37.0 | [download](https://cloudstor.aarnet.edu.au/plus/s/dDeDPBLEAt19Xrl/download) +FCOS_R_101_FPN_2x | 44.1 | Yes | 74ms | 40.9 | 41.0 | [download](https://cloudstor.aarnet.edu.au/plus/s/vjL3L0AW7vnhRTo/download) +FCOS_X_101_32x8d_FPN_2x | 72.9 | Yes | 122ms | 42.0 | 42.1 | [download](https://cloudstor.aarnet.edu.au/plus/s/U5myBfGF7MviZ97/download) -You can also configure your own paths to the datasets. -For that, all you need to do is to modify `maskrcnn_benchmark/config/paths_catalog.py` to -point to the location where your dataset is stored. -You can also create a new `paths_catalog.py` file which implements the same two classes, -and pass it as a config argument `PATHS_CATALOG` during training. +[1] *1x means the model is trained for 90K iterations.* \ +[2] *2x means the model is trained for 180K iterations.* \ +[3] *We report total training memory footprint on all GPUs instead of the memory footprint per GPU as in maskrcnn-benchmark*. -### Single GPU training +## Training -Most of the configuration files that we provide assume that we are running on 8 GPUs. -In order to be able to run it on fewer GPUs, there are a few possibilities: +The following command line will train FCOS_R_50_FPN_1x on 8 GPUs with Synchronous Stochastic Gradient Descent (SGD): -**1. Run the following without modifications** + python -m torch.distributed.launch \ + --nproc_per_node=8 \ + --master_port=$((RANDOM + 10000)) \ + tools/train_net.py \ + --skip-test \ + --config-file configs/fcos/fcos_R_50_FPN_1x.yaml \ + DATALOADER.NUM_WORKERS 2 \ + OUTPUT_DIR training_dir/fcos_R_50_FPN_1x + +Note that: + +1) If you want to use fewer GPUs, please reduce `--nproc_per_node`. The total batch size does not depends on `nproc_per_node`. If you want to change the total batch size, please change `SOLVER.IMS_PER_BATCH` in [configs/fcos/fcos_R_50_FPN_1x.yaml](configs/fcos/fcos_R_50_FPN_1x.yaml). +2) The models will be saved into `OUTPUT_DIR`. +3) If you want to train FCOS with other backbones, please change `--config-file`. +4) Sometimes you may encounter a deadlock with 100% GPUs' usage, which might be a problem of NCCL. Please try `export NCCL_P2P_DISABLE=1` before running the training command line. -```bash -python /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "/path/to/config/file.yaml" -``` -This should work out of the box and is very similar to what we should do for multi-GPU training. -But the drawback is that it will use much more GPU memory. The reason is that we set in the -configuration files a global batch size that is divided over the number of GPUs. So if we only -have a single GPU, this means that the batch size for that GPU will be 8x larger, which might lead -to out-of-memory errors. - -If you have a lot of memory available, this is the easiest solution. - -**2. Modify the cfg parameters** - -If you experience out-of-memory errors, you can reduce the global batch size. But this means that -you'll also need to change the learning rate, the number of iterations and the learning rate schedule. - -Here is an example for Mask R-CNN R-50 FPN with the 1x schedule: -```bash -python tools/train_net.py --config-file "configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 SOLVER.MAX_ITER 720000 SOLVER.STEPS "(480000, 640000)" TEST.IMS_PER_BATCH 1 -``` -This follows the [scheduling rules from Detectron.](https://github.com/facebookresearch/Detectron/blob/master/configs/getting_started/tutorial_1gpu_e2e_faster_rcnn_R-50-FPN.yaml#L14-L30) -Note that we have multiplied the number of iterations by 8x (as well as the learning rate schedules), -and we have divided the learning rate by 8x. - -We also changed the batch size during testing, but that is generally not necessary because testing -requires much less memory than training. - - -### Multi-GPU training -We use internally `torch.distributed.launch` in order to launch -multi-gpu training. This utility function from PyTorch spawns as many -Python processes as the number of GPUs we want to use, and each Python -process will only use a single GPU. - -```bash -export NGPUS=8 -python -m torch.distributed.launch --nproc_per_node=$NGPUS /path_to_maskrcnn_benchmark/tools/train_net.py --config-file "path/to/config/file.yaml" -``` +## Contributing to the project -## Abstractions -For more information on some of the main abstractions in our implementation, see [ABSTRACTIONS.md](ABSTRACTIONS.md). - -## Adding your own dataset - -This implementation adds support for COCO-style datasets. -But adding support for training on a new dataset can be done as follows: -```python -from maskrcnn_benchmark.structures.bounding_box import BoxList - -class MyDataset(object): - def __init__(self, ...): - # as you would do normally - - def __getitem__(self, idx): - # load the image as a PIL Image - image = ... - - # load the bounding boxes as a list of list of boxes - # in this case, for illustrative purposes, we use - # x1, y1, x2, y2 order. - boxes = [[0, 0, 10, 10], [10, 20, 50, 50]] - # and labels - labels = torch.tensor([10, 20]) - - # create a BoxList from the boxes - boxlist = BoxList(boxes, image.size, mode="xyxy") - # add the labels to the boxlist - boxlist.add_field("labels", labels) - - if self.transforms: - image, boxlist = self.transforms(image, boxlist) - - # return the image, the boxlist and the idx in your dataset - return image, boxlist, idx - - def get_img_info(self, idx): - # get img_height and img_width. This is used if - # we want to split the batches according to the aspect ratio - # of the image, as it can be more efficient than loading the - # image from disk - return {"height": img_height, "width": img_width} -``` -That's it. You can also add extra fields to the boxlist, such as segmentation masks -(using `structures.segmentation_mask.SegmentationMask`), or even your own instance type. - -For a full example of how the `COCODataset` is implemented, check [`maskrcnn_benchmark/data/datasets/coco.py`](maskrcnn_benchmark/data/datasets/coco.py). - -Once you have created your dataset, it needs to be added in a couple of places: -- [`maskrcnn_benchmark/data/datasets/__init__.py`](maskrcnn_benchmark/data/datasets/__init__.py): add it to `__all__` -- [`maskrcnn_benchmark/config/paths_catalog.py`](maskrcnn_benchmark/config/paths_catalog.py): `DatasetCatalog.DATASETS` and corresponding `if` clause in `DatasetCatalog.get()` - -### Testing -While the aforementioned example should work for training, we leverage the -cocoApi for computing the accuracies during testing. Thus, test datasets -should currently follow the cocoApi for now. - -To enable your dataset for testing, add a corresponding if statement in [`maskrcnn_benchmark/data/datasets/evaluation/__init__.py`](maskrcnn_benchmark/data/datasets/evaluation/__init__.py): -```python -if isinstance(dataset, datasets.MyDataset): - return coco_evaluation(**args) -``` - -## Finetuning from Detectron weights on custom datasets -Create a script `tools/trim_detectron_model.py` like [here](https://gist.github.com/wangg12/aea194aa6ab6a4de088f14ee193fd968). -You can decide which keys to be removed and which keys to be kept by modifying the script. - -Then you can simply point the converted model path in the config file by changing `MODEL.WEIGHT`. - -For further information, please refer to [#15](https://github.com/facebookresearch/maskrcnn-benchmark/issues/15). - -## Troubleshooting -If you have issues running or compiling this code, we have compiled a list of common issues in -[TROUBLESHOOTING.md](TROUBLESHOOTING.md). If your issue is not present there, please feel -free to open a new issue. +Any pull requests or issues are weclome. ## Citations -Please consider citing this project in your publications if it helps your research. The following is a BibTeX reference. The BibTeX entry requires the `url` LaTeX package. +Please consider citing our paper in your publications if the project helps your research. The following is a BibTeX reference. ``` -@misc{massa2018mrcnn, -author = {Massa, Francisco and Girshick, Ross}, -title = {{maskrcnn-benchmark: Fast, modular reference implementation of Instance Segmentation and Object Detection algorithms in PyTorch}}, -year = {2018}, -howpublished = {\url{https://github.com/facebookresearch/maskrcnn-benchmark}}, -note = {Accessed: [Insert date here]} +@article{tian2019fcos, + title={FCOS: Fully Convolutional One-Stage Object Detection}, + author={Tian, Zhi and Shen, Chunhua and Chen, Hao and He, Tong}, + journal={arXiv preprint arXiv:1904.01355}, + year={2019} } ``` -## Projects using maskrcnn-benchmark - -- [RetinaMask: Learning to predict masks improves state-of-the-art single-shot detection for free](https://arxiv.org/abs/1901.03353). - Cheng-Yang Fu, Mykhailo Shvets, and Alexander C. Berg. - Tech report, arXiv,1901.03353. - - ## License -maskrcnn-benchmark is released under the MIT license. See [LICENSE](LICENSE) for additional details. +For academic use, this project is licensed under the 2-clause BSD License - see the LICENSE file for details. For commercial use, please contact the authors. diff --git a/configs/fcos/fcos_R_101_FPN_2x.yaml b/configs/fcos/fcos_R_101_FPN_2x.yaml new file mode 100644 index 00000000..cb76eefd --- /dev/null +++ b/configs/fcos/fcos_R_101_FPN_2x.yaml @@ -0,0 +1,28 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + WEIGHT: "catalog://ImageNetPretrained/MSRA/R-101" + RPN_ONLY: True + FCOS_ON: True + BACKBONE: + CONV_BODY: "R-101-FPN-RETINANET" + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RETINANET: + USE_C5: False # FCOS uses P5 instead of C5 +DATASETS: + TRAIN: ("coco_2014_train", "coco_2014_valminusminival") + TEST: ("coco_2014_minival",) +INPUT: + MIN_SIZE_RANGE_TRAIN: (640, 800) + MAX_SIZE_TRAIN: 1333 + MIN_SIZE_TEST: 800 + MAX_SIZE_TEST: 1333 +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.01 + WEIGHT_DECAY: 0.0001 + STEPS: (120000, 160000) + MAX_ITER: 180000 + IMS_PER_BATCH: 16 + WARMUP_METHOD: "constant" \ No newline at end of file diff --git a/configs/fcos/fcos_R_50_FPN_1x.yaml b/configs/fcos/fcos_R_50_FPN_1x.yaml new file mode 100644 index 00000000..ac58571d --- /dev/null +++ b/configs/fcos/fcos_R_50_FPN_1x.yaml @@ -0,0 +1,28 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" + RPN_ONLY: True + FCOS_ON: True + BACKBONE: + CONV_BODY: "R-50-FPN-RETINANET" + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + RETINANET: + USE_C5: False # FCOS uses P5 instead of C5 +DATASETS: + TRAIN: ("coco_2014_train", "coco_2014_valminusminival") + TEST: ("coco_2014_minival",) +INPUT: + MIN_SIZE_TRAIN: (800,) + MAX_SIZE_TRAIN: 1333 + MIN_SIZE_TEST: 800 + MAX_SIZE_TEST: 1333 +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.01 + WEIGHT_DECAY: 0.0001 + STEPS: (60000, 80000) + MAX_ITER: 90000 + IMS_PER_BATCH: 16 + WARMUP_METHOD: "constant" diff --git a/configs/fcos/fcos_X_101_32x8d_FPN_2x.yaml b/configs/fcos/fcos_X_101_32x8d_FPN_2x.yaml new file mode 100644 index 00000000..9b5c54b4 --- /dev/null +++ b/configs/fcos/fcos_X_101_32x8d_FPN_2x.yaml @@ -0,0 +1,30 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + WEIGHT: "catalog://ImageNetPretrained/FAIR/20171220/X-101-32x8d" + RPN_ONLY: True + FCOS_ON: True + BACKBONE: + CONV_BODY: "R-101-FPN-RETINANET" + RESNETS: + BACKBONE_OUT_CHANNELS: 256 + NUM_GROUPS: 32 + WIDTH_PER_GROUP: 8 + RETINANET: + USE_C5: False # FCOS uses P5 instead of C5 +DATASETS: + TRAIN: ("coco_2014_train", "coco_2014_valminusminival") + TEST: ("coco_2014_minival",) +INPUT: + MIN_SIZE_RANGE_TRAIN: (640, 800) + MAX_SIZE_TRAIN: 1333 + MIN_SIZE_TEST: 800 + MAX_SIZE_TEST: 1333 +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.01 + WEIGHT_DECAY: 0.0001 + STEPS: (120000, 160000) + MAX_ITER: 180000 + IMS_PER_BATCH: 16 + WARMUP_METHOD: "constant" diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py index fc750fd4..6b9ab9ae 100644 --- a/maskrcnn_benchmark/config/defaults.py +++ b/maskrcnn_benchmark/config/defaults.py @@ -23,6 +23,7 @@ _C.MODEL = CN() _C.MODEL.RPN_ONLY = False _C.MODEL.MASK_ON = False +_C.MODEL.FCOS_ON = True _C.MODEL.RETINANET_ON = False _C.MODEL.KEYPOINT_ON = False _C.MODEL.DEVICE = "cuda" @@ -41,6 +42,8 @@ _C.INPUT = CN() # Size of the smallest side of the image during training _C.INPUT.MIN_SIZE_TRAIN = (800,) # (800,) +# The range of the smallest side for multi-scale training +_C.INPUT.MIN_SIZE_RANGE_TRAIN = (-1, -1) # -1 means disabled and it will use MIN_SIZE_TRAIN # Maximum size of the side of the image during training _C.INPUT.MAX_SIZE_TRAIN = 1333 # Size of the smallest side of the image during testing @@ -274,6 +277,24 @@ _C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256 _C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64 +# ---------------------------------------------------------------------------- # +# FCOS Options +# ---------------------------------------------------------------------------- # +_C.MODEL.FCOS = CN() +_C.MODEL.FCOS.NUM_CLASSES = 81 # the number of classes including background +_C.MODEL.FCOS.FPN_STRIDES = [8, 16, 32, 64, 128] +_C.MODEL.FCOS.PRIOR_PROB = 0.01 +_C.MODEL.FCOS.INFERENCE_TH = 0.05 +_C.MODEL.FCOS.NMS_TH = 0.4 +_C.MODEL.FCOS.PRE_NMS_TOP_N = 1000 + +# Focal loss parameter: alpha +_C.MODEL.FCOS.LOSS_ALPHA = 0.25 +# Focal loss parameter: gamma +_C.MODEL.FCOS.LOSS_GAMMA = 2.0 + +# the number of convolutions used in the cls and bbox tower +_C.MODEL.FCOS.NUM_CONVS = 4 # ---------------------------------------------------------------------------- # # RetinaNet Options (Follow the Detectron version) diff --git a/maskrcnn_benchmark/data/transforms/build.py b/maskrcnn_benchmark/data/transforms/build.py index 8645d4df..8260f919 100644 --- a/maskrcnn_benchmark/data/transforms/build.py +++ b/maskrcnn_benchmark/data/transforms/build.py @@ -4,7 +4,15 @@ def build_transforms(cfg, is_train=True): if is_train: - min_size = cfg.INPUT.MIN_SIZE_TRAIN + if cfg.INPUT.MIN_SIZE_RANGE_TRAIN[0] == -1: + min_size = cfg.INPUT.MIN_SIZE_TRAIN + else: + assert len(cfg.INPUT.MIN_SIZE_RANGE_TRAIN) == 2, \ + "MIN_SIZE_RANGE_TRAIN must have two elements (lower bound, upper bound)" + min_size = range( + cfg.INPUT.MIN_SIZE_RANGE_TRAIN[0], + cfg.INPUT.MIN_SIZE_RANGE_TRAIN[1] + 1 + ) max_size = cfg.INPUT.MAX_SIZE_TRAIN flip_prob = 0.5 # cfg.INPUT.FLIP_PROB_TRAIN else: diff --git a/maskrcnn_benchmark/layers/__init__.py b/maskrcnn_benchmark/layers/__init__.py index bab50aba..cf787218 100644 --- a/maskrcnn_benchmark/layers/__init__.py +++ b/maskrcnn_benchmark/layers/__init__.py @@ -13,9 +13,12 @@ from .roi_pool import roi_pool from .smooth_l1_loss import smooth_l1_loss from .sigmoid_focal_loss import SigmoidFocalLoss +from .iou_loss import IOULoss +from .scale import Scale + __all__ = ["nms", "roi_align", "ROIAlign", "roi_pool", "ROIPool", "smooth_l1_loss", "Conv2d", "ConvTranspose2d", "interpolate", - "BatchNorm2d", "FrozenBatchNorm2d", "SigmoidFocalLoss" - ] + "BatchNorm2d", "FrozenBatchNorm2d", "SigmoidFocalLoss", "IOULoss", + "Scale"] diff --git a/maskrcnn_benchmark/layers/iou_loss.py b/maskrcnn_benchmark/layers/iou_loss.py new file mode 100644 index 00000000..af398dd6 --- /dev/null +++ b/maskrcnn_benchmark/layers/iou_loss.py @@ -0,0 +1,36 @@ +import torch +from torch import nn + + +class IOULoss(nn.Module): + def forward(self, pred, target, weight=None): + pred_left = pred[:, 0] + pred_top = pred[:, 1] + pred_right = pred[:, 2] + pred_bottom = pred[:, 3] + + target_left = target[:, 0] + target_top = target[:, 1] + target_right = target[:, 2] + target_bottom = target[:, 3] + + target_aera = (target_left + target_right) * \ + (target_top + target_bottom) + pred_aera = (pred_left + pred_right) * \ + (pred_top + pred_bottom) + + w_intersect = torch.min(pred_left, target_left) + \ + torch.min(pred_right, target_right) + h_intersect = torch.min(pred_bottom, target_bottom) + \ + torch.min(pred_top, target_top) + + area_intersect = w_intersect * h_intersect + area_union = target_aera + pred_aera - area_intersect + + losses = -torch.log((area_intersect + 1.0) / (area_union + 1.0)) + + if weight is not None and weight.sum() > 0: + return (losses * weight).sum() / weight.sum() + else: + assert losses.numel() != 0 + return losses.mean() diff --git a/maskrcnn_benchmark/layers/scale.py b/maskrcnn_benchmark/layers/scale.py new file mode 100644 index 00000000..2c25622e --- /dev/null +++ b/maskrcnn_benchmark/layers/scale.py @@ -0,0 +1,11 @@ +import torch +from torch import nn + + +class Scale(nn.Module): + def __init__(self, init_value=1.0): + super(Scale, self).__init__() + self.scale = nn.Parameter(torch.FloatTensor([init_value])) + + def forward(self, input): + return input * self.scale diff --git a/maskrcnn_benchmark/modeling/rpn/fcos/__init__.py b/maskrcnn_benchmark/modeling/rpn/fcos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maskrcnn_benchmark/modeling/rpn/fcos/fcos.py b/maskrcnn_benchmark/modeling/rpn/fcos/fcos.py new file mode 100644 index 00000000..8d7c8c39 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/fcos/fcos.py @@ -0,0 +1,187 @@ +import math +import torch +import torch.nn.functional as F +from torch import nn + +from .inference import make_fcos_postprocessor +from .loss import make_fcos_loss_evaluator + +from maskrcnn_benchmark.layers import Scale + + +class FCOSHead(torch.nn.Module): + def __init__(self, cfg, in_channels): + """ + Arguments: + in_channels (int): number of channels of the input feature + """ + super(FCOSHead, self).__init__() + # TODO: Implement the sigmoid version first. + num_classes = cfg.MODEL.FCOS.NUM_CLASSES - 1 + + cls_tower = [] + bbox_tower = [] + for i in range(cfg.MODEL.FCOS.NUM_CONVS): + cls_tower.append( + nn.Conv2d( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1 + ) + ) + cls_tower.append(nn.GroupNorm(32, in_channels)) + cls_tower.append(nn.ReLU()) + bbox_tower.append( + nn.Conv2d( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1 + ) + ) + bbox_tower.append(nn.GroupNorm(32, in_channels)) + bbox_tower.append(nn.ReLU()) + + self.add_module('cls_tower', nn.Sequential(*cls_tower)) + self.add_module('bbox_tower', nn.Sequential(*bbox_tower)) + self.cls_logits = nn.Conv2d( + in_channels, num_classes, kernel_size=3, stride=1, + padding=1 + ) + self.bbox_pred = nn.Conv2d( + in_channels, 4, kernel_size=3, stride=1, + padding=1 + ) + self.centerness = nn.Conv2d( + in_channels, 1, kernel_size=3, stride=1, + padding=1 + ) + + # initialization + for modules in [self.cls_tower, self.bbox_tower, + self.cls_logits, self.bbox_pred, + self.centerness]: + for l in modules.modules(): + if isinstance(l, nn.Conv2d): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + # initialize the bias for focal loss + prior_prob = cfg.MODEL.FCOS.PRIOR_PROB + bias_value = -math.log((1 - prior_prob) / prior_prob) + torch.nn.init.constant_(self.cls_logits.bias, bias_value) + + self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)]) + + def forward(self, x): + logits = [] + bbox_reg = [] + centerness = [] + for l, feature in enumerate(x): + cls_tower = self.cls_tower(feature) + logits.append(self.cls_logits(cls_tower)) + centerness.append(self.centerness(cls_tower)) + bbox_reg.append(torch.exp(self.scales[l]( + self.bbox_pred(self.bbox_tower(feature)) + ))) + return logits, bbox_reg, centerness + + +class FCOSModule(torch.nn.Module): + """ + Module for FCOS computation. Takes feature maps from the backbone and + FCOS outputs and losses. Only Test on FPN now. + """ + + def __init__(self, cfg, in_channels): + super(FCOSModule, self).__init__() + + head = FCOSHead(cfg, in_channels) + + box_selector_test = make_fcos_postprocessor(cfg) + + loss_evaluator = make_fcos_loss_evaluator(cfg) + self.head = head + self.box_selector_test = box_selector_test + self.loss_evaluator = loss_evaluator + self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES + + def forward(self, images, features, targets=None): + """ + Arguments: + images (ImageList): images for which we want to compute the predictions + features (list[Tensor]): features computed from the images that are + used for computing the predictions. Each tensor in the list + correspond to different feature levels + targets (list[BoxList): ground-truth boxes present in the image (optional) + + Returns: + boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per + image. + losses (dict[Tensor]): the losses for the model during training. During + testing, it is an empty dict. + """ + box_cls, box_regression, centerness = self.head(features) + locations = self.compute_locations(features) + + if self.training: + return self._forward_train( + locations, box_cls, + box_regression, + centerness, targets + ) + else: + return self._forward_test( + locations, box_cls, box_regression, + centerness, images.image_sizes + ) + + def _forward_train(self, locations, box_cls, box_regression, centerness, targets): + loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator( + locations, box_cls, box_regression, centerness, targets + ) + losses = { + "loss_cls": loss_box_cls, + "loss_reg": loss_box_reg, + "loss_centerness": loss_centerness + } + return None, losses + + def _forward_test(self, locations, box_cls, box_regression, centerness, image_sizes): + boxes = self.box_selector_test( + locations, box_cls, box_regression, + centerness, image_sizes + ) + return boxes, {} + + def compute_locations(self, features): + locations = [] + for level, feature in enumerate(features): + h, w = feature.size()[-2:] + locations_per_level = self.compute_locations_per_level( + h, w, self.fpn_strides[level], + feature.device + ) + locations.append(locations_per_level) + return locations + + def compute_locations_per_level(self, h, w, stride, device): + shifts_x = torch.arange( + 0, w * stride, step=stride, + dtype=torch.float32, device=device + ) + shifts_y = torch.arange( + 0, h * stride, step=stride, + dtype=torch.float32, device=device + ) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 + return locations + +def build_fcos(cfg, in_channels): + return FCOSModule(cfg, in_channels) diff --git a/maskrcnn_benchmark/modeling/rpn/fcos/inference.py b/maskrcnn_benchmark/modeling/rpn/fcos/inference.py new file mode 100644 index 00000000..9e2922b7 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/fcos/inference.py @@ -0,0 +1,203 @@ +import torch + +from ..inference import RPNPostProcessor +from ..utils import permute_and_flatten + +from maskrcnn_benchmark.modeling.box_coder import BoxCoder +from maskrcnn_benchmark.modeling.utils import cat +from maskrcnn_benchmark.structures.bounding_box import BoxList +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_nms +from maskrcnn_benchmark.structures.boxlist_ops import remove_small_boxes + + +class FCOSPostProcessor(torch.nn.Module): + """ + Performs post-processing on the outputs of the RetinaNet boxes. + This is only used in the testing. + """ + def __init__( + self, + pre_nms_thresh, + pre_nms_top_n, + nms_thresh, + fpn_post_nms_top_n, + min_size, + num_classes, + ): + """ + Arguments: + pre_nms_thresh (float) + pre_nms_top_n (int) + nms_thresh (float) + fpn_post_nms_top_n (int) + min_size (int) + num_classes (int) + box_coder (BoxCoder) + """ + super(FCOSPostProcessor, self).__init__() + self.pre_nms_thresh = pre_nms_thresh + self.pre_nms_top_n = pre_nms_top_n + self.nms_thresh = nms_thresh + self.fpn_post_nms_top_n = fpn_post_nms_top_n + self.min_size = min_size + self.num_classes = num_classes + + def forward_for_single_feature_map( + self, locations, box_cls, + box_regression, centerness, + image_sizes): + """ + Arguments: + anchors: list[BoxList] + box_cls: tensor of size N, A * C, H, W + box_regression: tensor of size N, A * 4, H, W + """ + N, C, H, W = box_cls.shape + + # put in the same format as locations + box_cls = box_cls.view(N, C, H, W).permute(0, 2, 3, 1) + box_cls = box_cls.reshape(N, -1, C).sigmoid() + box_regression = box_regression.view(N, 4, H, W).permute(0, 2, 3, 1) + box_regression = box_regression.reshape(N, -1, 4) + centerness = centerness.view(N, 1, H, W).permute(0, 2, 3, 1) + centerness = centerness.reshape(N, -1).sigmoid() + + candidate_inds = box_cls > self.pre_nms_thresh + pre_nms_top_n = candidate_inds.view(N, -1).sum(1) + pre_nms_top_n = pre_nms_top_n.clamp(max=self.pre_nms_top_n) + + # multiply the classification scores with centerness scores + box_cls = box_cls * centerness[:, :, None] + + results = [] + for i in range(N): + per_box_cls = box_cls[i] + per_candidate_inds = candidate_inds[i] + per_box_cls = per_box_cls[per_candidate_inds] + + per_candidate_nonzeros = per_candidate_inds.nonzero() + per_box_loc = per_candidate_nonzeros[:, 0] + per_class = per_candidate_nonzeros[:, 1] + 1 + + per_box_regression = box_regression[i] + per_box_regression = per_box_regression[per_box_loc] + per_locations = locations[per_box_loc] + + per_pre_nms_top_n = pre_nms_top_n[i] + + if per_candidate_inds.sum().item() > per_pre_nms_top_n.item(): + per_box_cls, top_k_indices = \ + per_box_cls.topk(per_pre_nms_top_n, sorted=False) + per_class = per_class[top_k_indices] + per_box_regression = per_box_regression[top_k_indices] + per_locations = per_locations[top_k_indices] + + detections = torch.stack([ + per_locations[:, 0] - per_box_regression[:, 0], + per_locations[:, 1] - per_box_regression[:, 1], + per_locations[:, 0] + per_box_regression[:, 2], + per_locations[:, 1] + per_box_regression[:, 3], + ], dim=1) + + h, w = image_sizes[i] + boxlist = BoxList(detections, (int(w), int(h)), mode="xyxy") + boxlist.add_field("labels", per_class) + boxlist.add_field("scores", per_box_cls) + boxlist = boxlist.clip_to_image(remove_empty=False) + boxlist = remove_small_boxes(boxlist, self.min_size) + results.append(boxlist) + + return results + + def forward(self, locations, box_cls, box_regression, centerness, image_sizes): + """ + Arguments: + anchors: list[list[BoxList]] + box_cls: list[tensor] + box_regression: list[tensor] + image_sizes: list[(h, w)] + Returns: + boxlists (list[BoxList]): the post-processed anchors, after + applying box decoding and NMS + """ + sampled_boxes = [] + for _, (l, o, b, c) in enumerate(zip(locations, box_cls, box_regression, centerness)): + sampled_boxes.append( + self.forward_for_single_feature_map( + l, o, b, c, image_sizes + ) + ) + + boxlists = list(zip(*sampled_boxes)) + boxlists = [cat_boxlist(boxlist) for boxlist in boxlists] + boxlists = self.select_over_all_levels(boxlists) + + return boxlists + + # TODO very similar to filter_results from PostProcessor + # but filter_results is per image + # TODO Yang: solve this issue in the future. No good solution + # right now. + def select_over_all_levels(self, boxlists): + num_images = len(boxlists) + results = [] + for i in range(num_images): + scores = boxlists[i].get_field("scores") + labels = boxlists[i].get_field("labels") + boxes = boxlists[i].bbox + boxlist = boxlists[i] + result = [] + # skip the background + for j in range(1, self.num_classes): + inds = (labels == j).nonzero().view(-1) + + scores_j = scores[inds] + boxes_j = boxes[inds, :].view(-1, 4) + boxlist_for_class = BoxList(boxes_j, boxlist.size, mode="xyxy") + boxlist_for_class.add_field("scores", scores_j) + boxlist_for_class = boxlist_nms( + boxlist_for_class, self.nms_thresh, + score_field="scores" + ) + num_labels = len(boxlist_for_class) + boxlist_for_class.add_field( + "labels", torch.full((num_labels,), j, + dtype=torch.int64, + device=scores.device) + ) + result.append(boxlist_for_class) + + result = cat_boxlist(result) + number_of_detections = len(result) + + # Limit to max_per_image detections **over all classes** + if number_of_detections > self.fpn_post_nms_top_n > 0: + cls_scores = result.get_field("scores") + image_thresh, _ = torch.kthvalue( + cls_scores.cpu(), + number_of_detections - self.fpn_post_nms_top_n + 1 + ) + keep = cls_scores >= image_thresh.item() + keep = torch.nonzero(keep).squeeze(1) + result = result[keep] + results.append(result) + return results + + +def make_fcos_postprocessor(config): + pre_nms_thresh = config.MODEL.FCOS.INFERENCE_TH + pre_nms_top_n = config.MODEL.FCOS.PRE_NMS_TOP_N + nms_thresh = config.MODEL.FCOS.NMS_TH + fpn_post_nms_top_n = config.TEST.DETECTIONS_PER_IMG + + box_selector = FCOSPostProcessor( + pre_nms_thresh=pre_nms_thresh, + pre_nms_top_n=pre_nms_top_n, + nms_thresh=nms_thresh, + fpn_post_nms_top_n=fpn_post_nms_top_n, + min_size=0, + num_classes=config.MODEL.FCOS.NUM_CLASSES + ) + + return box_selector diff --git a/maskrcnn_benchmark/modeling/rpn/fcos/loss.py b/maskrcnn_benchmark/modeling/rpn/fcos/loss.py new file mode 100644 index 00000000..ad9d4cf5 --- /dev/null +++ b/maskrcnn_benchmark/modeling/rpn/fcos/loss.py @@ -0,0 +1,192 @@ +""" +This file contains specific functions for computing losses of FCOS +file +""" + +import torch +from torch.nn import functional as F +from torch import nn + +from ..utils import concat_box_prediction_layers +from maskrcnn_benchmark.layers import IOULoss +from maskrcnn_benchmark.layers import SigmoidFocalLoss +from maskrcnn_benchmark.modeling.matcher import Matcher +from maskrcnn_benchmark.modeling.utils import cat +from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou +from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist + + +INF = 100000000 + + +class FCOSLossComputation(object): + """ + This class computes the FCOS losses. + """ + + def __init__(self, cfg): + self.cls_loss_func = SigmoidFocalLoss( + cfg.MODEL.FCOS.LOSS_GAMMA, + cfg.MODEL.FCOS.LOSS_ALPHA + ) + # we make use of IOU Loss for bounding boxes regression, + # but we found that L1 in log scale can yield a similar performance + self.box_reg_loss_func = IOULoss() + self.centerness_loss_func = nn.BCEWithLogitsLoss() + + def prepare_targets(self, points, targets): + object_sizes_of_interest = [ + [-1, 64], + [64, 128], + [128, 256], + [256, 512], + [512, INF], + ] + expanded_object_sizes_of_interest = [] + for l, points_per_level in enumerate(points): + object_sizes_of_interest_per_level = \ + points_per_level.new_tensor(object_sizes_of_interest[l]) + expanded_object_sizes_of_interest.append( + object_sizes_of_interest_per_level[None].expand(len(points_per_level), -1) + ) + + expanded_object_sizes_of_interest = torch.cat(expanded_object_sizes_of_interest, dim=0) + num_points_per_level = [len(points_per_level) for points_per_level in points] + points_all_level = torch.cat(points, dim=0) + labels, reg_targets = self.compute_targets_for_locations( + points_all_level, targets, expanded_object_sizes_of_interest + ) + + for i in range(len(labels)): + labels[i] = torch.split(labels[i], num_points_per_level, dim=0) + reg_targets[i] = torch.split(reg_targets[i], num_points_per_level, dim=0) + + labels_level_first = [] + reg_targets_level_first = [] + for level in range(len(points)): + labels_level_first.append( + torch.cat([labels_per_im[level] for labels_per_im in labels], dim=0) + ) + reg_targets_level_first.append( + torch.cat([reg_targets_per_im[level] for reg_targets_per_im in reg_targets], dim=0) + ) + + return labels_level_first, reg_targets_level_first + + def compute_targets_for_locations(self, locations, targets, object_sizes_of_interest): + labels = [] + reg_targets = [] + xs, ys = locations[:, 0], locations[:, 1] + + for im_i in range(len(targets)): + targets_per_im = targets[im_i] + assert targets_per_im.mode == "xyxy" + bboxes = targets_per_im.bbox + labels_per_im = targets_per_im.get_field("labels") + area = targets_per_im.area() + + l = xs[:, None] - bboxes[:, 0][None] + t = ys[:, None] - bboxes[:, 1][None] + r = bboxes[:, 2][None] - xs[:, None] + b = bboxes[:, 3][None] - ys[:, None] + reg_targets_per_im = torch.stack([l, t, r, b], dim=2) + + is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0 + + max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0] + # limit the regression range for each location + is_cared_in_the_level = \ + (max_reg_targets_per_im >= object_sizes_of_interest[:, [0]]) & \ + (max_reg_targets_per_im <= object_sizes_of_interest[:, [1]]) + + locations_to_gt_area = area[None].repeat(len(locations), 1) + locations_to_gt_area[is_in_boxes == 0] = INF + locations_to_gt_area[is_cared_in_the_level == 0] = INF + + # if there are still more than one objects for a location, + # we choose the one with minimal area + locations_to_min_aera, locations_to_gt_inds = locations_to_gt_area.min(dim=1) + + reg_targets_per_im = reg_targets_per_im[range(len(locations)), locations_to_gt_inds] + labels_per_im = labels_per_im[locations_to_gt_inds] + labels_per_im[locations_to_min_aera == INF] = 0 + + labels.append(labels_per_im) + reg_targets.append(reg_targets_per_im) + + return labels, reg_targets + + def compute_centerness_targets(self, reg_targets): + left_right = reg_targets[:, [0, 2]] + top_bottom = reg_targets[:, [1, 3]] + centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \ + (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) + return torch.sqrt(centerness) + + def __call__(self, locations, box_cls, box_regression, centerness, targets): + """ + Arguments: + locations (list[BoxList]) + box_cls (list[Tensor]) + box_regression (list[Tensor]) + centerness (list[Tensor]) + targets (list[BoxList]) + + Returns: + cls_loss (Tensor) + reg_loss (Tensor) + centerness_loss (Tensor) + """ + N = box_cls[0].size(0) + num_classes = box_cls[0].size(1) + labels, reg_targets = self.prepare_targets(locations, targets) + + box_cls_flatten = [] + box_regression_flatten = [] + centerness_flatten = [] + labels_flatten = [] + reg_targets_flatten = [] + for l in range(len(labels)): + box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes)) + box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4)) + labels_flatten.append(labels[l].reshape(-1)) + reg_targets_flatten.append(reg_targets[l].reshape(-1, 4)) + centerness_flatten.append(centerness[l].reshape(-1)) + + box_cls_flatten = torch.cat(box_cls_flatten, dim=0) + box_regression_flatten = torch.cat(box_regression_flatten, dim=0) + centerness_flatten = torch.cat(centerness_flatten, dim=0) + labels_flatten = torch.cat(labels_flatten, dim=0) + reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0) + + pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1) + cls_loss = self.cls_loss_func( + box_cls_flatten, + labels_flatten.int() + ) / (pos_inds.numel() + N) # add N to avoid dividing by a zero + + box_regression_flatten = box_regression_flatten[pos_inds] + reg_targets_flatten = reg_targets_flatten[pos_inds] + centerness_flatten = centerness_flatten[pos_inds] + + if pos_inds.numel() > 0: + centerness_targets = self.compute_centerness_targets(reg_targets_flatten) + reg_loss = self.box_reg_loss_func( + box_regression_flatten, + reg_targets_flatten, + centerness_targets + ) + centerness_loss = self.centerness_loss_func( + centerness_flatten, + centerness_targets + ) + else: + reg_loss = box_regression_flatten.sum() + centerness_loss = centerness_flatten.sum() + + return cls_loss, reg_loss, centerness_loss + + +def make_fcos_loss_evaluator(cfg): + loss_evaluator = FCOSLossComputation(cfg) + return loss_evaluator diff --git a/maskrcnn_benchmark/modeling/rpn/rpn.py b/maskrcnn_benchmark/modeling/rpn/rpn.py index 07997651..9fc6a010 100644 --- a/maskrcnn_benchmark/modeling/rpn/rpn.py +++ b/maskrcnn_benchmark/modeling/rpn/rpn.py @@ -6,6 +6,7 @@ from maskrcnn_benchmark.modeling import registry from maskrcnn_benchmark.modeling.box_coder import BoxCoder from maskrcnn_benchmark.modeling.rpn.retinanet.retinanet import build_retinanet +from maskrcnn_benchmark.modeling.rpn.fcos.fcos import build_fcos from .loss import make_rpn_loss_evaluator from .anchor_generator import make_anchor_generator from .inference import make_rpn_postprocessor @@ -201,6 +202,8 @@ def build_rpn(cfg, in_channels): """ This gives the gist of it. Not super important because it doesn't change as much """ + if cfg.MODEL.FCOS_ON: + return build_fcos(cfg, in_channels) if cfg.MODEL.RETINANET_ON: return build_retinanet(cfg, in_channels) diff --git a/maskrcnn_benchmark/structures/segmentation_mask.py b/maskrcnn_benchmark/structures/segmentation_mask.py index bf91d556..928433ca 100644 --- a/maskrcnn_benchmark/structures/segmentation_mask.py +++ b/maskrcnn_benchmark/structures/segmentation_mask.py @@ -195,8 +195,7 @@ def __init__(self, polygons, size): polygons = valid_polygons elif isinstance(polygons, PolygonInstance): - polygons = polygons.polygons.copy() - + polygons = [p.clone() for p in polygons.polygons] else: RuntimeError( "Type of argument `polygons` is not allowed:%s" % (type(polygons)) diff --git a/tools/test_net.py b/tools/test_net.py index d0acd283..f957b0dd 100644 --- a/tools/test_net.py +++ b/tools/test_net.py @@ -84,7 +84,7 @@ def main(): data_loader_val, dataset_name=dataset_name, iou_types=iou_types, - box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, + box_only=False if cfg.MODEL.FCOS_ON or cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, device=cfg.MODEL.DEVICE, expected_results=cfg.TEST.EXPECTED_RESULTS, expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,