-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
47ed341
commit 17dd288
Showing
1 changed file
with
354 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,354 @@ | ||
import copy | ||
from typing import Optional, List | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn, Tensor | ||
|
||
|
||
class Transformer_vis(nn.Module): | ||
|
||
def __init__(self, d_model=256, nhead=8, num_encoder_layers=6,dim_feedforward=2048, | ||
dropout=0.1, activation="relu", normalize_before=False): | ||
super().__init__() | ||
|
||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, | ||
dropout, activation, normalize_before) | ||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None | ||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) | ||
|
||
self._reset_parameters() | ||
|
||
self.d_model = d_model | ||
self.nhead = nhead | ||
|
||
def _reset_parameters(self): | ||
for p in self.parameters(): | ||
if p.dim() > 1: | ||
nn.init.xavier_uniform_(p) | ||
|
||
def forward(self, src, mask, pos_embed): | ||
# flatten NxCxHxW to HWxNxC | ||
bs, c, h, w = src.shape | ||
src = src.flatten(2).permute(2, 0, 1) | ||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1) | ||
mask = mask.flatten(1) | ||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) | ||
return memory.permute(1, 2, 0).view(bs, c, h, w) | ||
|
||
class Transformer_Decoder(nn.Module): | ||
def __init__(self, d_model=512, nhead=8, | ||
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, | ||
activation="relu", normalize_before=False, | ||
return_intermediate_dec=False): | ||
super().__init__() | ||
|
||
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, | ||
dropout, activation, normalize_before) | ||
decoder_norm = nn.LayerNorm(d_model) | ||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, | ||
return_intermediate=return_intermediate_dec) | ||
|
||
self._reset_parameters() | ||
|
||
self.d_model = d_model | ||
self.nhead = nhead | ||
|
||
def _reset_parameters(self): | ||
for p in self.parameters(): | ||
if p.dim() > 1: | ||
nn.init.xavier_uniform_(p) | ||
|
||
def forward(self, tgt, memory, mask,pos_embed, query_embed): | ||
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,pos=pos_embed, query_pos=query_embed) | ||
return hs | ||
|
||
class Transformer(nn.Module): | ||
|
||
def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, | ||
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, | ||
activation="relu", normalize_before=False, | ||
return_intermediate_dec=False): | ||
super().__init__() | ||
|
||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, | ||
dropout, activation, normalize_before) | ||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None | ||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) | ||
|
||
self._reset_parameters() | ||
|
||
self.d_model = d_model | ||
self.nhead = nhead | ||
|
||
def _reset_parameters(self): | ||
for p in self.parameters(): | ||
if p.dim() > 1: | ||
nn.init.xavier_uniform_(p) | ||
|
||
def forward(self, src, mask, pos_embed): | ||
# flatten NxCxHxW to HWxNxC | ||
# permute NxCxW to WxNxC | ||
src = src.permute(2, 0, 1) | ||
pos_embed = pos_embed.permute(1, 0, 2) | ||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) | ||
return memory | ||
|
||
|
||
class TransformerEncoder(nn.Module): | ||
def __init__(self, encoder_layer, num_layers, norm=None): | ||
super().__init__() | ||
self.layers = _get_clones(encoder_layer, num_layers) | ||
self.num_layers = num_layers | ||
self.norm = norm | ||
|
||
def forward(self, src, | ||
mask: Optional[Tensor] = None, # 没有用mask | ||
src_key_padding_mask: Optional[Tensor] = None, | ||
pos: Optional[Tensor] = None): | ||
output = src | ||
|
||
for layer in self.layers: | ||
output = layer(output, src_mask=mask, | ||
src_key_padding_mask=src_key_padding_mask, pos=pos) | ||
|
||
if self.norm is not None: | ||
output = self.norm(output) | ||
|
||
return output | ||
|
||
|
||
class TransformerDecoder(nn.Module): | ||
|
||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): | ||
super().__init__() | ||
self.layers = _get_clones(decoder_layer, num_layers) | ||
self.num_layers = num_layers | ||
self.norm = norm | ||
self.return_intermediate = return_intermediate | ||
|
||
def forward(self, tgt, memory, | ||
tgt_mask: Optional[Tensor] = None, | ||
memory_mask: Optional[Tensor] = None, | ||
tgt_key_padding_mask: Optional[Tensor] = None, | ||
memory_key_padding_mask: Optional[Tensor] = None, | ||
pos: Optional[Tensor] = None, | ||
query_pos: Optional[Tensor] = None): | ||
|
||
output = tgt | ||
|
||
intermediate = [] | ||
|
||
for layer in self.layers: | ||
output = layer(output, memory, tgt_mask=tgt_mask, | ||
memory_mask=memory_mask, | ||
tgt_key_padding_mask=tgt_key_padding_mask, | ||
memory_key_padding_mask=memory_key_padding_mask, | ||
pos=pos, query_pos=query_pos) | ||
if self.return_intermediate: | ||
intermediate.append(self.norm(output)) | ||
|
||
if self.norm is not None: | ||
output = self.norm(output) | ||
if self.return_intermediate: | ||
intermediate.pop() | ||
intermediate.append(output) | ||
|
||
if self.return_intermediate: | ||
return torch.stack(intermediate) | ||
|
||
return output.unsqueeze(0) | ||
|
||
|
||
class TransformerEncoderLayer(nn.Module): | ||
|
||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, | ||
activation="relu", normalize_before=False): | ||
super().__init__() | ||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | ||
# Implementation of Feedforward model | ||
self.linear1 = nn.Linear(d_model, dim_feedforward) | ||
self.dropout = nn.Dropout(dropout) | ||
self.linear2 = nn.Linear(dim_feedforward, d_model) | ||
|
||
self.norm1 = nn.LayerNorm(d_model) | ||
self.norm2 = nn.LayerNorm(d_model) | ||
self.dropout1 = nn.Dropout(dropout) | ||
self.dropout2 = nn.Dropout(dropout) | ||
|
||
self.activation = _get_activation_fn(activation) | ||
self.normalize_before = normalize_before | ||
|
||
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | ||
return tensor if pos is None else tensor + pos | ||
|
||
def forward_post(self, | ||
src, | ||
src_mask: Optional[Tensor] = None, | ||
src_key_padding_mask: Optional[Tensor] = None, | ||
pos: Optional[Tensor] = None): | ||
q = k = self.with_pos_embed(src, pos) | ||
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, | ||
key_padding_mask=src_key_padding_mask)[0] | ||
src = src + self.dropout1(src2) | ||
src = self.norm1(src) | ||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) | ||
src = src + self.dropout2(src2) | ||
src = self.norm2(src) | ||
return src | ||
|
||
def forward_pre(self, src, | ||
src_mask: Optional[Tensor] = None, | ||
src_key_padding_mask: Optional[Tensor] = None, | ||
pos: Optional[Tensor] = None): | ||
src2 = self.norm1(src) | ||
q = k = self.with_pos_embed(src2, pos) | ||
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, | ||
key_padding_mask=src_key_padding_mask)[0] | ||
src = src + self.dropout1(src2) | ||
src2 = self.norm2(src) | ||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) | ||
src = src + self.dropout2(src2) | ||
return src | ||
|
||
def forward(self, src, | ||
src_mask: Optional[Tensor] = None, | ||
src_key_padding_mask: Optional[Tensor] = None, | ||
pos: Optional[Tensor] = None): | ||
if self.normalize_before: | ||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos) | ||
return self.forward_post(src, src_mask, src_key_padding_mask, pos) | ||
|
||
|
||
class TransformerDecoderLayer(nn.Module): | ||
|
||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, | ||
activation="relu", normalize_before=False): | ||
super().__init__() | ||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | ||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) | ||
# Implementation of Feedforward model | ||
self.linear1 = nn.Linear(d_model, dim_feedforward) | ||
self.dropout = nn.Dropout(dropout) | ||
self.linear2 = nn.Linear(dim_feedforward, d_model) | ||
|
||
self.norm1 = nn.LayerNorm(d_model) | ||
self.norm2 = nn.LayerNorm(d_model) | ||
self.norm3 = nn.LayerNorm(d_model) | ||
self.dropout1 = nn.Dropout(dropout) | ||
self.dropout2 = nn.Dropout(dropout) | ||
self.dropout3 = nn.Dropout(dropout) | ||
|
||
self.activation = _get_activation_fn(activation) | ||
self.normalize_before = normalize_before | ||
|
||
def with_pos_embed(self, tensor, pos: Optional[Tensor]): | ||
return tensor if pos is None else tensor + pos | ||
|
||
def forward_post(self, tgt, memory, | ||
tgt_mask: Optional[Tensor] = None, | ||
memory_mask: Optional[Tensor] = None, | ||
tgt_key_padding_mask: Optional[Tensor] = None, | ||
memory_key_padding_mask: Optional[Tensor] = None, | ||
pos: Optional[Tensor] = None, | ||
query_pos: Optional[Tensor] = None): | ||
q = k = self.with_pos_embed(tgt, query_pos) | ||
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, | ||
key_padding_mask=tgt_key_padding_mask)[0] | ||
tgt = tgt + self.dropout1(tgt2) | ||
tgt = self.norm1(tgt) | ||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), | ||
key=self.with_pos_embed(memory, pos), | ||
value=memory, attn_mask=memory_mask, | ||
key_padding_mask=memory_key_padding_mask)[0] | ||
tgt = tgt + self.dropout2(tgt2) | ||
tgt = self.norm2(tgt) | ||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) | ||
tgt = tgt + self.dropout3(tgt2) | ||
tgt = self.norm3(tgt) | ||
return tgt | ||
|
||
def forward_pre(self, tgt, memory, | ||
tgt_mask: Optional[Tensor] = None, | ||
memory_mask: Optional[Tensor] = None, | ||
tgt_key_padding_mask: Optional[Tensor] = None, | ||
memory_key_padding_mask: Optional[Tensor] = None, | ||
pos: Optional[Tensor] = None, | ||
query_pos: Optional[Tensor] = None): | ||
tgt2 = self.norm1(tgt) | ||
q = k = self.with_pos_embed(tgt2, query_pos) | ||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, | ||
key_padding_mask=tgt_key_padding_mask)[0] | ||
tgt = tgt + self.dropout1(tgt2) | ||
tgt2 = self.norm2(tgt) | ||
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), | ||
key=self.with_pos_embed(memory, pos), | ||
value=memory, attn_mask=memory_mask, | ||
key_padding_mask=memory_key_padding_mask)[0] | ||
tgt = tgt + self.dropout2(tgt2) | ||
tgt2 = self.norm3(tgt) | ||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) | ||
tgt = tgt + self.dropout3(tgt2) | ||
return tgt | ||
|
||
def forward(self, tgt, memory, | ||
tgt_mask: Optional[Tensor] = None, | ||
memory_mask: Optional[Tensor] = None, | ||
tgt_key_padding_mask: Optional[Tensor] = None, | ||
memory_key_padding_mask: Optional[Tensor] = None, | ||
pos: Optional[Tensor] = None, | ||
query_pos: Optional[Tensor] = None): | ||
if self.normalize_before: | ||
return self.forward_pre(tgt, memory, tgt_mask, memory_mask, | ||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) | ||
return self.forward_post(tgt, memory, tgt_mask, memory_mask, | ||
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) | ||
|
||
def _get_clones(module, N): | ||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | ||
|
||
|
||
def build_vis_transformer(args): | ||
return Transformer_vis( | ||
d_model=args.hidden_dim, | ||
dropout=args.dropout, | ||
nhead=args.nheads, | ||
dim_feedforward=args.dim_feedforward, | ||
num_encoder_layers=args.enc_layers, | ||
normalize_before=args.pre_norm, | ||
) | ||
|
||
def build_de(args): | ||
return Transformer_Decoder( | ||
d_model=args.hidden_dim, | ||
dropout=args.dropout, | ||
nhead=args.nheads, | ||
dim_feedforward=args.dim_feedforward, | ||
num_decoder_layers=1, | ||
normalize_before=args.pre_norm, | ||
return_intermediate_dec=True | ||
) | ||
|
||
|
||
def build_transformer(args): | ||
return Transformer( | ||
d_model=args.hidden_dim, | ||
dropout=args.dropout, | ||
nhead=args.nheads, | ||
dim_feedforward=args.dim_feedforward, | ||
num_encoder_layers=args.enc_layers, | ||
num_decoder_layers=args.dec_layers, | ||
normalize_before=args.pre_norm, | ||
# TODO: return_intermediate_dec | ||
return_intermediate_dec=True, | ||
) | ||
|
||
def _get_activation_fn(activation): | ||
"""Return an activation function given a string""" | ||
if activation == "relu": | ||
return F.relu | ||
if activation == "gelu": | ||
return F.gelu | ||
if activation == "glu": | ||
return F.glu | ||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |