Skip to content

Commit

Permalink
fix GPU CI; fix ETA computation in inference; support nightly build
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/detectron2#353

Differential Revision: D19229623

Pulled By: ppwwyyxx

fbshipit-source-id: c370076e7dcbb5d58eed4e580e3603640875c69b
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Dec 26, 2019
1 parent 012ffd3 commit e74a00c
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 51 deletions.
6 changes: 3 additions & 3 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,15 @@ jobs:
name: Build Detectron2
command: |
docker exec -it d2 pip install 'git+https://github.com/facebookresearch/fvcore'
docker exec -it d2 git clone https://github.com/facebookresearch/detectron2
docker copy ~/detectron2 d2:/detectron2
# This will build d2 for the target GPU arch only
docker exec -it d2 pip install -e detectron2
docker exec -it d2 pip install -e /detectron2
docker exec -it d2 python3 -m detectron2.utils.collect_env
- run:
name: Run Unit Tests
command: |
docker exec -it d2 python3 -m unittest discover -v -s detectron2/tests
docker exec -it d2 python3 -m unittest discover -v -s /detectron2/tests
workflows:
version: 2
Expand Down
41 changes: 21 additions & 20 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,34 @@

This file documents a large collection of baselines trained
with detectron2 in Sep-Oct, 2019.
The corresponding configurations for all models can be found under the `configs/` directory.
Unless otherwise noted, the following settings are used for all runs:
All models were trained on [Big Basin](https://engineering.fb.com/data-center-engineering/introducing-big-basin-our-next-generation-ai-hardware/)
servers with 8 NVIDIA V100 GPUs, with data-parallel sync SGD. The softwares in use were PyTorch 1.3, CUDA 9.2, cuDNN 7.4.2 or 7.6.3.
You can programmataically access these models using [detectron2.model_zoo](https://detectron2.readthedocs.io/modules/model_zoo.html) APIs.

#### Common Settings
* All models were trained on [Big Basin](https://engineering.fb.com/data-center-engineering/introducing-big-basin-our-next-generation-ai-hardware/)
servers with 8 NVIDIA V100 GPUs, with data-parallel sync SGD and a total minibatch size of 16 images.
* All models were trained with CUDA 9.2, cuDNN 7.4.2 or 7.6.3 (the difference in speed is found to be negligible).
#### How to Read the Tables
* The "Name" column contains a link to the config file. Running `tools/train_net.py` with this config file
and 8 GPUs will reproduce the model.
* Training speed is averaged across the entire training.
We keep updating the speed with latest version of detectron2/pytorch/etc.,
so they might be different from the `metrics` file.
* Inference speed is measured by `tools/train_net.py --eval-only`, or [inference_on_dataset()](https://detectron2.readthedocs.io/modules/evaluation.html#detectron2.evaluation.inference_on_dataset),
with batch size 1 in detectron2 directly.
Measuring it with your own code will likely introduce other overhead.
Actual deployment in production should in general be faster than the given inference
speed due to more optimizations.
* The *model id* column is provided for ease of reference.
To check downloaded file integrity, any model on this page contains its md5 prefix in its file name.
* Training curves and other statistics can be found in `metrics` for each model.
* The default settings are __not directly comparable__ with Detectron.

#### Common Settings for COCO Models
* All COCO models were trained on `train2017` and evaluated on `val2017`.
* The default settings are __not directly comparable__ with Detectron's standard settings.
For example, our default training data augmentation uses scale jittering in addition to horizontal flipping.

For configs that are comparable to Detectron's settings, see
To make fair comparisons with Detectron's settings, see
[Detectron1-Comparisons](configs/Detectron1-Comparisons/) for accuracy comparison,
and [benchmarks](https://detectron2.readthedocs.io/notes/benchmarks.html)
for speed comparison.
* Inference speed is measured by `tools/train_net.py --eval-only`, or [inference_on_dataset()](https://detectron2.readthedocs.io/modules/evaluation.html#detectron2.evaluation.inference_on_dataset),
with batch size 1 in detectron2 directly.
Measuring it with your own code will likely introduce other overhead.
The actual deployment should in general be faster than the given inference
speed due to more optimizations.
* Training speed is averaged across the entire training.
We keep updating the speed with latest version of detectron2/pytorch/etc.,
so they might be different from the `metrics` file.
* All COCO models were trained on `train2017` and evaluated on `val2017`.
* For Faster/Mask R-CNN, we provide baselines based on __3 different backbone combinations__:
* __FPN__: Use a ResNet+FPN backbone with standard conv and FC heads for mask and box prediction,
respectively. It obtains the best
Expand All @@ -39,9 +43,6 @@ Unless otherwise noted, the following settings are used for all runs:
* Most models are trained with the 3x schedule (~37 COCO epochs).
Although 1x models are heavily under-trained, we provide some ResNet-50 models with the 1x (~12 COCO epochs)
training schedule for comparison when doing quick research iteration.
* The *model id* column is provided for ease of reference.
To check downloaded file integrity, any model on this page contains its md5 prefix in its file name.
Each model also comes with a metrics file with all the training statistics and evaluation curves.

#### ImageNet Pretrained Models

Expand Down
2 changes: 2 additions & 0 deletions detectron2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@
setup_environment()


# This line will be programatically read/write by setup.py.
# Leave them at the bottom of this file and don't touch them.
__version__ = "0.1"
5 changes: 1 addition & 4 deletions detectron2/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,8 @@ def inference_on_dataset(model, data_loader, evaluator):
evaluator.process(inputs, outputs)

if idx >= num_warmup * 2:
duration = time.perf_counter() - start_time
seconds_per_img = total_compute_time / (idx + 1 - num_warmup)
eta = datetime.timedelta(
seconds=int(seconds_per_img * (total - num_warmup) - duration)
)
eta = datetime.timedelta(seconds=int(seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
logging.INFO,
"Inference done {}/{}. {:.4f} s / img. ETA={}".format(
Expand Down
12 changes: 7 additions & 5 deletions detectron2/model_zoo/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ def get_config_file(config_path):
Args:
config_path (str): config file name relative to detectron2's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
Returns:
str: path to the config file.
str: the real path to the config file.
"""
cfg_file = pkg_resources.resource_filename(
"detectron2.model_zoo", os.path.join("configs", config_path)
Expand All @@ -102,11 +102,13 @@ def get_config_file(config_path):

def get(config_path, trained: bool = False):
"""
Get a model specified by relative path under Detectron2's official ``configs`` directory.
Get a model specified by relative path under Detectron2's official ``configs/`` directory.
Args:
trained (bool): Whether to initialize with the trained model zoo weights. If False, the
initialization weights specified in the config file's ``MODEL.WEIGHTS`` key are used
config_path (str): config file name relative to detectron2's "configs/"
directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml"
trained (bool): If True, will initialize the model with the trained model zoo weights.
If False, the checkpoint specified in the config file's ``MODEL.WEIGHTS`` is used
instead; this will typically (though not always) initialize a subset of weights using
an ImageNet pre-trained model, while randomly initializing the other weights.
Expand Down
17 changes: 10 additions & 7 deletions docs/tutorials/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from detectron2.modeling import build_model
model = build_model(cfg) # returns a torch.nn.Module
```

Note that `build_model` only builds the model structure, and fill it with random parameters.
To load an existing checkpoint to the model, use
`DetectionCheckpointer(model).load(file_path)`.
Detectron2 recognizes models in pytorch's `.pth` format, as well as the `.pkl` files
Expand All @@ -33,12 +34,13 @@ The dict may contain the following keys:
+ "gt_keypoints": a [Keypoints](../modules/structures.html#detectron2.structures.Keypoints)
object storing N keypoint sets, one for each instance.
* "proposals": an [Instances](../modules/structures.html#detectron2.structures.Instances)
object used in Fast R-CNN style models, with the following fields:
object used only in Fast R-CNN style models, with the following fields:
+ "proposal_boxes": a [Boxes](../modules/structures.html#detectron2.structures.Boxes) object storing P proposal boxes.
+ "objectness_logits": `Tensor`, a vector of P scores, one for each proposal.
* "height", "width": the **desired** output height and width of the image, not necessarily the same
as the height or width of the `image` when input into the model, which might be after resizing.
For example, it can be the **original** image height and width before resizing.
* "height", "width": the **desired** output height and width, which is not necessarily the same
as the height or width of the `image` input field.
For example, the `image` input field might be a resized image,
but you may want the outputs to be in **original** resolution.

If provided, the model will produce output in this resolution,
rather than in the resolution of the `image` as input into the model. This is more efficient and accurate.
Expand All @@ -57,7 +59,8 @@ After the data loader performs batching, it becomes `list[dict]` which the built

When in training mode, the builtin models output a `dict[str->ScalarTensor]` with all the losses.

When in inference mode, the builtin models output a `list[dict]`, one dict for each image. Each dict may contain:
When in inference mode, the builtin models output a `list[dict]`, one dict for each image.
Based on the tasks the model is doing, each dict may contain the following fields:

* "instances": [Instances](../modules/structures.html#detectron2.structures.Instances)
object with the following fields:
Expand All @@ -83,8 +86,8 @@ When in inference mode, the builtin models output a `list[dict]`, one dict for e

### How to use a model in your code:

Contruct your own `list[dict]`, with the necessary keys.
For example, for inference, provide dicts with "image", and optionally "height" and "width".
Contruct your own `list[dict]` as inputs, with the necessary keys. Then call `outputs = model(inputs)`.
For example, in order to do inference, provide dicts with "image", and optionally "height" and "width".

Note that when in training mode, all models are required to be used under an `EventStorage`.
The training statistics will be put into the storage:
Expand Down
47 changes: 35 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import glob
import os
import shutil
from os import path
from setuptools import find_packages, setup
from typing import List
import torch
Expand All @@ -13,14 +14,36 @@
assert torch_ver >= [1, 3], "Requires PyTorch >= 1.3"


def get_version():
init_py_path = path.join(path.abspath(path.dirname(__file__)), "detectron2", "__init__.py")
init_py = open(init_py_path, "r").readlines()
version_line = [l.strip() for l in init_py if l.startswith("__version__")][0]
version = version_line.split("=")[-1].strip().strip("'\"")

# Used by CI to build nightly packages. Users should never use it.
# To build a nightly wheel, run:
# FORCE_CUDA=1 BUILD_NIGHTLY=1 TORCH_CUDA_ARCH_LIST=All python setup.py bdist_wheel
if os.getenv("BUILD_NIGHTLY", "0") == "1":
from datetime import datetime

date_str = datetime.today().strftime("%y%m%d")
version = version + ".post" + date_str

new_init_py = [l for l in init_py if not l.startswith("__version__")]
new_init_py.append('__version__ = "{}"\n'.format(version))
with open(init_py_path, "w") as f:
f.write("".join(new_init_py))
return version


def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "detectron2", "layers", "csrc")
this_dir = path.dirname(path.abspath(__file__))
extensions_dir = path.join(this_dir, "detectron2", "layers", "csrc")

main_source = os.path.join(extensions_dir, "vision.cpp")
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
os.path.join(extensions_dir, "*.cu")
main_source = path.join(extensions_dir, "vision.cpp")
sources = glob.glob(path.join(extensions_dir, "**", "*.cpp"))
source_cuda = glob.glob(path.join(extensions_dir, "**", "*.cu")) + glob.glob(
path.join(extensions_dir, "*.cu")
)

sources = [main_source] + sources
Expand Down Expand Up @@ -67,14 +90,14 @@ def get_model_zoo_configs() -> List[str]:
"""

# Use absolute paths while symlinking.
source_configs_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs")
destination = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "detectron2", "model_zoo", "configs"
source_configs_dir = path.join(path.dirname(path.realpath(__file__)), "configs")
destination = path.join(
path.dirname(path.realpath(__file__)), "detectron2", "model_zoo", "configs"
)
# Symlink the config directory inside package to have a cleaner pip install.
if os.path.exists(destination):
if path.exists(destination):
# Remove stale symlink/directory from a previous build.
if os.path.islink(destination):
if path.islink(destination):
os.unlink(destination)
else:
shutil.rmtree(destination)
Expand All @@ -91,7 +114,7 @@ def get_model_zoo_configs() -> List[str]:

setup(
name="detectron2",
version="0.1",
version=get_version(),
author="FAIR",
url="https://github.com/facebookresearch/detectron2",
description="Detectron2 is FAIR's next-generation research "
Expand Down

0 comments on commit e74a00c

Please sign in to comment.