-
Notifications
You must be signed in to change notification settings - Fork 208
/
Detic_ViLD_200e.py
155 lines (143 loc) · 5.41 KB
/
Detic_ViLD_200e.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import torch
import detectron2.data.transforms as T
from detectron2.config import LazyCall as L
from detectron2.layers import ShapeSpec
from detectron2.data.samplers import RepeatFactorTrainingSampler
from detectron2.evaluation.lvis_evaluation import LVISEvaluator
from detectron2.layers.batch_norm import NaiveSyncBatchNorm
from detectron2.solver import WarmupParamScheduler
from detectron2.solver.build import get_default_optimizer_params
from detectron2.modeling.matcher import Matcher
from detectron2.modeling.roi_heads import FastRCNNConvFCHead
from detectron2.modeling.box_regression import Box2BoxTransform
from detectron2.model_zoo import get_config
from fvcore.common.param_scheduler import CosineParamScheduler
from detic.modeling.roi_heads.zero_shot_classifier import ZeroShotClassifier
from detic.modeling.roi_heads.detic_roi_heads import DeticCascadeROIHeads
from detic.modeling.roi_heads.detic_fast_rcnn import DeticFastRCNNOutputLayers
from detic.data.custom_dataset_mapper import CustomDatasetMapper
from detic.modeling.meta_arch.custom_rcnn import CustomRCNN
from detic.data.custom_dataset_dataloader import build_custom_train_loader
from detic.data.custom_dataset_dataloader import MultiDatasetSampler
from detic.data.custom_dataset_dataloader import get_detection_dataset_dicts_with_source
default_configs = get_config('new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ.py')
dataloader = default_configs['dataloader']
model = default_configs['model']
train = default_configs['train']
train.init_checkpoint = 'models/BoxSup_ViLD_200e.pth'
[model.roi_heads.pop(k) for k in ["box_head", "box_predictor", "proposal_matcher"]]
model.roi_heads.update(
_target_=DeticCascadeROIHeads,
num_classes=1203,
box_heads=[
L(FastRCNNConvFCHead)(
input_shape=ShapeSpec(channels=256, height=7, width=7),
conv_dims=[256, 256, 256, 256],
fc_dims=[1024],
conv_norm=lambda c: NaiveSyncBatchNorm(c, stats_mode="N")
)
for _ in range(1)
],
box_predictors=[
L(DeticFastRCNNOutputLayers)(
input_shape=ShapeSpec(channels=1024),
test_score_thresh=0.0001,
test_topk_per_image=300,
box2box_transform=L(Box2BoxTransform)(weights=(w1, w1, w2, w2)),
cls_agnostic_bbox_reg=True,
num_classes="${...num_classes}",
cls_score=L(ZeroShotClassifier)(
input_shape=ShapeSpec(channels=1024),
num_classes=1203,
zs_weight_path='datasets/metadata/lvis_v1_clip_a+cname.npy',
norm_weight=True,
# use_bias=-4.6,
),
use_zeroshot_cls=True,
use_sigmoid_ce=True,
ignore_zero_cats=True,
cat_freq_path='datasets/lvis/lvis_v1_train_norare_cat_info.json',
image_label_loss='max_size',
image_loss_weight=0.1,
)
for (w1, w2) in [(10, 5)]
],
proposal_matchers=[
L(Matcher)(thresholds=[th], labels=[0, 1], allow_low_quality_matches=False)
for th in [0.5]
],
with_image_labels=True,
ws_num_props=128,
)
model.update(
_target_=CustomRCNN,
with_image_labels=True,
)
model.roi_heads.mask_head.num_classes = 1
train.ddp.find_unused_parameters = True
num_nodes = 4
image_size = 896
image_size_weak = 448
dataloader.train = L(build_custom_train_loader)(
dataset=L(get_detection_dataset_dicts_with_source)(
dataset_names=['lvis_v1_train_norare', 'imagenet_lvis_v1'],
filter_empty=False,
),
mapper=L(CustomDatasetMapper)(
is_train=True,
augmentations=[],
with_ann_type=True,
dataset_ann=['box', 'image'],
use_diff_bs_size=True,
dataset_augs = [
[L(T.ResizeScale)(
min_scale=0.1, max_scale=2.0, target_height=image_size, target_width=image_size
),
L(T.FixedSizeCrop)(crop_size=(image_size, image_size)),
L(T.RandomFlip)(horizontal=True),
],
[L(T.ResizeScale)(
min_scale=0.5, max_scale=1.5, target_height=image_size_weak, target_width=image_size_weak
),
L(T.FixedSizeCrop)(crop_size=(image_size_weak, image_size_weak)),
L(T.RandomFlip)(horizontal=True),
]
],
image_format="BGR",
use_instance_mask=True,
),
sampler=L(MultiDatasetSampler)(
dataset_dicts="${dataloader.train.dataset}",
dataset_ratio=[1,4],
use_rfs=[True, False],
dataset_ann="${dataloader.train.mapper.dataset_ann}",
repeat_threshold=0.001,
),
total_batch_size=64 * num_nodes,
multi_dataset_grouping=True,
use_diff_bs_size=True,
dataset_bs=[8, 8 * 4],
num_datasets=2,
num_workers=8,
)
dataloader.test.dataset.names="lvis_v1_val"
dataloader.evaluator = L(LVISEvaluator)(
dataset_name="${..test.dataset.names}",
)
train.max_iter = 184375 * 2 // num_nodes
lr_multiplier = L(WarmupParamScheduler)(
scheduler=CosineParamScheduler(1.0, 0.0),
warmup_length=500 / train.max_iter,
warmup_factor=0.067,
)
optimizer = L(torch.optim.AdamW)(
params=L(get_default_optimizer_params)(
weight_decay_norm=0.0
),
lr=0.0002 * num_nodes,
weight_decay=1e-4,
)
train.checkpointer.period=20000 // num_nodes
train.output_dir='./output/Lazy/{}'.format(os.path.basename(__file__)[:-3])