-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_layer_decay_optimizer_constructor.py
149 lines (135 loc) · 6.06 KB
/
custom_layer_decay_optimizer_constructor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
"""
Mostly copy-paste from BEiT library:
https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/layer_decay_optimizer_constructor.py
"""
import json
from mmengine.optim import OPTIM_WRAPPER_CONSTRUCTORS, DefaultOptimWrapperConstructor
from mmengine.dist import get_dist_info
from mmengine import MMLogger
def get_num_layer_for_swin(var_name, num_max_layer, depths):
if var_name.startswith("backbone.patch_embed"):
return 0
elif "level_embeds" in var_name:
return 0
elif var_name.startswith("backbone.layers") or var_name.startswith(
"backbone.levels"
):
if var_name.split(".")[3] not in ["downsample", "norm"]:
stage_id = int(var_name.split(".")[2])
layer_id = int(var_name.split(".")[4])
# layers for Swin-Large: [2, 2, 18, 2]
if stage_id == 0:
return layer_id + 1
elif stage_id == 1:
return layer_id + 1 + depths[0]
elif stage_id == 2:
return layer_id + 1 + depths[0] + depths[1]
else:
return layer_id + 1 + depths[0] + depths[1] + depths[2]
else:
stage_id = int(var_name.split(".")[2])
if stage_id == 0:
return 1 + depths[0]
elif stage_id == 1:
return 1 + depths[0] + depths[1]
elif stage_id == 2:
return 1 + depths[0] + depths[1] + depths[2]
else:
return 1 + depths[0] + depths[1] + depths[2]
else:
return num_max_layer - 1
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
class CustomLayerDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
def add_params(self, params, module, prefix="", is_dcn_module=None):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
is_dcn_module (int|float|None): If the current module is a
submodule of DCN, `is_dcn_module` will be passed to
control conv_offset layer's learning rate. Defaults to None.
"""
parameter_groups = {}
logger = MMLogger.get_current_instance()
logger.info(self.paramwise_cfg)
backbone_small_lr = self.paramwise_cfg.get("backbone_small_lr", False)
dino_head = self.paramwise_cfg.get("dino_head", False)
num_layers = self.paramwise_cfg.get("num_layers") + 2
layer_decay_rate = self.paramwise_cfg.get("layer_decay_rate")
depths = self.paramwise_cfg.get("depths")
offset_lr_scale = self.paramwise_cfg.get("offset_lr_scale", 1.0)
logger.info(
"Build CustomLayerDecayOptimizerConstructor %f - %d"
% (layer_decay_rate, num_layers)
)
weight_decay = self.base_wd
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
if (
len(param.shape) == 1
or name.endswith(".bias")
or "relative_position" in name
or "norm" in name
or "sampling_offsets" in name
):
group_name = "no_decay"
this_weight_decay = 0.0
else:
group_name = "decay"
this_weight_decay = weight_decay
layer_id = get_num_layer_for_swin(name, num_layers, depths)
if (
layer_id == num_layers - 1
and dino_head
and ("sampling_offsets" in name or "reference_points" in name)
):
group_name = "layer_%d_%s_0.1x" % (layer_id, group_name)
elif "sampling_offsets" in name or "reference_points" in name:
group_name = "layer_%d_%s_offset_lr_scale" % (layer_id, group_name)
else:
group_name = "layer_%d_%s" % (layer_id, group_name)
if group_name not in parameter_groups:
scale = layer_decay_rate ** (num_layers - layer_id - 1)
if scale < 1 and backbone_small_lr == True:
scale = scale * 0.1
if "0.1x" in group_name:
scale = scale * 0.1
if "offset_lr_scale" in group_name:
scale = scale * offset_lr_scale
parameter_groups[group_name] = {
"weight_decay": this_weight_decay,
"params": [],
"param_names": [],
"lr_scale": scale,
"group_name": group_name,
"lr": scale * self.base_lr,
}
parameter_groups[group_name]["params"].append(param)
parameter_groups[group_name]["param_names"].append(name)
rank, _ = get_dist_info()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
"param_names": parameter_groups[key]["param_names"],
"lr_scale": parameter_groups[key]["lr_scale"],
"lr": parameter_groups[key]["lr"],
"weight_decay": parameter_groups[key]["weight_decay"],
}
logger.info("Param groups = %s" % json.dumps(to_display, indent=2))
# state_dict = module.state_dict()
# for group_name in parameter_groups:
# group = parameter_groups[group_name]
# for name in group["param_names"]:
# group["params"].append(state_dict[name])
params.extend(parameter_groups.values())