Skip to content

Commit

Permalink
better metrics logging for plain_train_net
Browse files Browse the repository at this point in the history
Summary:
Fix its spaces, and add a very rough estimate of ETA.
Pull Request resolved: fairinternal/detectron2#392

Reviewed By: rbgirshick

Differential Revision: D20479048

Pulled By: ppwwyyxx

fbshipit-source-id: 7f6d78c172867570b244b4aec7559c44770c0d36
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Mar 17, 2020
1 parent 6c0ff7e commit 45808c0
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 29 deletions.
24 changes: 12 additions & 12 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ For more advanced tutorials, refer to our [documentation](https://detectron2.rea
for example, `mask_rcnn_R_50_FPN_3x.yaml`.
2. We provide `demo.py` that is able to run builtin standard models. Run it with:
```
python demo/demo.py --config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \
cd demo/
python demo.py --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \
--input input1.jpg input2.jpg \
[--other-options]
--opts MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl
Expand All @@ -45,29 +46,28 @@ setup the corresponding datasets following
[datasets/README.md](https://github.com/facebookresearch/detectron2/blob/master/datasets/README.md),
then run:
```
python tools/train_net.py --num-gpus 8 \
--config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
cd tools/
./train_net.py --num-gpus 8 \
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
```

The configs are made for 8-GPU training. To train on 1 GPU, change the batch size with:
The configs are made for 8-GPU training.
To train on 1 GPU, you may need to [change some parameters](https://arxiv.org/abs/1706.02677), e.g.:
```
python tools/train_net.py \
--config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
./train_net.py \
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025
```

For most models, CPU training is not supported.

(Note that we applied the [linear learning rate scaling rule](https://arxiv.org/abs/1706.02677)
when changing the batch size.)

To evaluate a model's performance, use
```
python tools/train_net.py \
--config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
./train_net.py \
--config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
--eval-only MODEL.WEIGHTS /path/to/checkpoint_file
```
For more options, see `python tools/train_net.py -h`.
For more options, see `./train_net.py -h`.

### Use Detectron2 APIs in Your Code

Expand Down
2 changes: 2 additions & 0 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ python -c 'import torch; from torch.utils.cpp_extension import CUDA_HOME; print(
```

print valid outputs at the time you build detectron2.

Most models can run inference (but not training) without GPU support. To use CPUs, set `MODEL.DEVICE='cpu'` in the config.
</details>

<details>
Expand Down
Empty file modified demo/demo.py
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions detectron2/evaluation/cityscapes_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class CityscapesEvaluator(DatasetEvaluator):
Note:
* It does not work in multi-machine distributed training.
* It contains a synchronization, therefore has to be used on all ranks.
* Only the main process runs evaluation.
"""

def __init__(self, dataset_name):
Expand Down
3 changes: 2 additions & 1 deletion detectron2/evaluation/coco_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(self, dataset_name, cfg, distributed, output_dir=None):
Or it must be in detectron2's standard dataset format
so it can be converted to COCO format automatically.
cfg (CfgNode): config instance
distributed (True): if True, will collect results from all ranks for evaluation.
distributed (True): if True, will collect results from all ranks and run evaluation
in the main process.
Otherwise, will evaluate the results in the current process.
output_dir (str): optional, an output directory to dump all
results predicted on the dataset. The dump contains two files:
Expand Down
3 changes: 2 additions & 1 deletion detectron2/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def export_onnx_model(cfg, model, inputs):
"""
Export a detectron2 model to ONNX format.
Note that the exported model contains custom ops only available in caffe2, therefore it
cannot be directly executed by other runtime.
cannot be directly executed by other runtime. Post-processing or transformation passes
may be applied on the model to accommodate different runtimes.
Args:
cfg (CfgNode): a detectron2 config, with extra export-related options
Expand Down
6 changes: 6 additions & 0 deletions detectron2/layers/mask_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def paste_masks_in_image(masks, boxes, image_shape, threshold=0.5):
The location, height, and width for pasting each mask is determined by their
corresponding bounding boxes in boxes.
Note:
This is a complicated but more accurate implementation. In actual deployment, it is
often enough to use a faster but less accurate implementation.
See :func:`paste_mask_in_image_old` in this file for an alternative implementation.
Args:
masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of
detected object instances in the image and Hmask, Wmask are the mask width and mask
Expand All @@ -85,6 +90,7 @@ def paste_masks_in_image(masks, boxes, image_shape, threshold=0.5):
number of detected object instances and Himage, Wimage are the image width
and height. img_masks[i] is a binary mask for object instance i.
"""

assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported"
N = len(masks)
if N == 0:
Expand Down
31 changes: 23 additions & 8 deletions detectron2/utils/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import time
from collections import defaultdict
from contextlib import contextmanager
import torch
Expand Down Expand Up @@ -156,21 +157,35 @@ def __init__(self, max_iter):
"""
self.logger = logging.getLogger(__name__)
self._max_iter = max_iter
self._last_write = None

def write(self):
storage = get_event_storage()
iteration = storage.iter

data_time, time = None, None
eta_string = "N/A"
try:
data_time = storage.history("data_time").avg(20)
time = storage.history("time").global_avg()
except KeyError:
# they may not exist in the first few iterations (due to warmup)
# or when SimpleTrainer is not used
data_time = None

eta_string = "N/A"
try:
iter_time = storage.history("time").global_avg()
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration)
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
except KeyError: # they may not exist in the first few iterations (due to warmup)
pass
except KeyError:
iter_time = None
# estimate eta on our own - more noisy
if self._last_write is not None:
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
iteration - self._last_write[0]
)
eta_seconds = estimate_iter_time * (self._max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
self._last_write = (iteration, time.perf_counter())

try:
lr = "{:.6f}".format(storage.history("lr").latest())
Expand All @@ -184,7 +199,7 @@ def write(self):

# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
" eta: {eta} iter: {iter} {losses} {time} {data_time} lr: {lr} {memory}".format(
" eta: {eta} iter: {iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
eta=eta_string,
iter=iteration,
losses=" ".join(
Expand All @@ -194,8 +209,8 @@ def write(self):
if "loss" in k
]
),
time="time: {:.4f}".format(time) if time is not None else "",
data_time="data_time: {:.4f}".format(data_time) if data_time is not None else "",
time="time: {:.4f} ".format(iter_time) if iter_time is not None else "",
data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "",
lr=lr,
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
)
Expand Down
14 changes: 8 additions & 6 deletions docs/tutorials/deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Caffe2 conversion requires PyTorch ≥ 1.4 and ONNX ≥ 1.6.
### Coverage

It supports 3 most common meta architectures: `GeneralizedRCNN`, `RetinaNet`, `PanopticFPN`,
and almost all official models under these 3 meta architectures.
and most official models under these 3 meta architectures.

Users' custom extensions under these architectures (added through registration) are supported
as long as they do not contain control flow or operators not available in Caffe2 (e.g. deformable convolution).
Expand All @@ -25,7 +25,7 @@ these APIs to convert a standard model.
To convert an official Mask R-CNN trained on COCO, first
[prepare the COCO dataset](../../datasets/), then pick the model from [Model Zoo](../../MODEL_ZOO.md), and run:
```
python tools/caffe2_converter.py --config-file configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \
cd tools/ && ./caffe2_converter.py --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \
--output ./caffe2_model --run-eval \
MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl \
MODEL.DEVICE cpu
Expand All @@ -50,7 +50,7 @@ You can also load `model.pb` to tools such as [netron](https://github.com/lutzro

### Inputs & Outputs

All converted models take two input tensors:
All converted models (the .pb file) take two input tensors:
"data" which is an NCHW image, and "im_info" which is a Nx3 tensor of (height, width, unused legacy parameter) for
each image (the shape of "data" might be larger than that in "im_info" due to padding).

Expand All @@ -60,6 +60,8 @@ The models only produce raw outputs from the final
layers that are not post-processed, because in actual deployment, an application often needs
its custom lightweight post-processing (e.g. full-image masks for every detected object is often not necessary).

Due to different inputs & outputs formats, the `Caffe2Model.__call__` method includes
pre/post-processing code in order to match the formats of original detectron2 models.
They can serve as a reference for pre/post-processing in actual deployment.
Due to different inputs & outputs formats,
we provide a wrapper around the converted model, in the [Caffe2Model.__call__](../modules/export.html#detectron2.export.Caffe2Model.__call__) method.
It has an interface that's identical to the [format of pytorch versions of models](models.html),
and it internally applies pre/post-processing code to match the formats.
They can serve as a reference for pre/post-processing in actual deployment.
3 changes: 2 additions & 1 deletion tools/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def benchmark_train(args):
dummy_data = list(itertools.islice(data_loader, 100))

def f():
data = DatasetFromList(dummy_data, copy=False)
while True:
yield from DatasetFromList(dummy_data, copy=False)
yield from data

max_iter = 400
trainer = SimpleTrainer(model, f(), optimizer)
Expand Down

0 comments on commit 45808c0

Please sign in to comment.