forked from IMvision12/SegFormer-tf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
99931a7
commit 36833e7
Showing
5 changed files
with
125 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
#[WIP] | ||
|
||
- [ ] Add Model | ||
- [ ] Train on custom dataset | ||
|
||
# SegFormer-tf | ||
|
||
Paper : https://arxiv.org/pdf/2105.15203 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,85 @@ | ||
import tensorflow as tf | ||
from tensorflow.keras import layers | ||
import math | ||
|
||
|
||
class Attention(tf.keras.layers.Layer): | ||
def __init__( | ||
self, | ||
dim, | ||
num_heads=8, | ||
num_heads, | ||
sr_ratio, | ||
qkv_bias=False, | ||
qk_scale=None, | ||
attn_drop=0.0, | ||
proj_drop=0.0, | ||
sr_ratio=1, | ||
**kwargs, | ||
): | ||
super(Attention, self).__init__() | ||
assert ( | ||
dim % num_heads == 0 | ||
), f"dim {dim} should be divisible by num_heads {num_heads}." | ||
|
||
super().__init__(**kwargs) | ||
self.dim = dim | ||
self.num_heads = num_heads | ||
head_dim = dim // num_heads | ||
self.scale = qk_scale or head_dim**-0.5 | ||
self.head_dim = self.dim // self.num_heads | ||
|
||
self.units = self.num_heads * self.head_dim | ||
self.sqrt_of_units = math.sqrt(self.head_dim) | ||
|
||
self.q = tf.keras.layers.Dense(self.units, use_bias=qkv_bias) | ||
self.k = tf.keras.layers.Dense(self.units, use_bias=qkv_bias) | ||
self.v = tf.keras.layers.Dense(self.units, use_bias=qkv_bias) | ||
|
||
self.attn_drop = tf.keras.layers.Dropout(attn_drop) | ||
|
||
self.q = layers.Dense(dim, use_bias=qkv_bias) | ||
self.k = layers.Dense(dim, use_bias=qkv_bias) | ||
self.v = layers.Dense(dim, use_bias=qkv_bias) | ||
self.attn_drop = layers.Dropout(attn_drop) | ||
self.proj = layers.Dense(dim) | ||
self.proj_drop = layers.Dropout(proj_drop) | ||
self.proj = tf.keras.layers.Dense(dim) | ||
self.proj_drop = tf.keras.layers.Dropout(proj_drop) | ||
|
||
self.sr_ratio = sr_ratio | ||
if sr_ratio > 1: | ||
self.sr = layers.Conv2D(dim, kernel_size=sr_ratio, strides=sr_ratio) | ||
self.norm = layers.LayerNormalization() | ||
self.sr = tf.keras.layers.Conv2D( | ||
filters=dim, kernel_size=sr_ratio, strides=sr_ratio | ||
) | ||
self.norm = tf.keras.layers.LayerNormalization(epsilon=1e-05) | ||
|
||
def call(self, x, H, W): | ||
def call( | ||
self, | ||
x, | ||
H, | ||
W, | ||
): | ||
get_shape = tf.shape(x) | ||
B = get_shape[0] | ||
N = get_shape[1] | ||
C = get_shape[2] | ||
|
||
q = self.q(x) | ||
q = tf.reshape(q, [B, N, self.num_heads, C // self.num_heads]) | ||
q = tf.reshape( | ||
q, shape=(tf.shape(q)[0], -1, self.num_heads, self.head_dim) | ||
) | ||
q = tf.transpose(q, perm=[0, 2, 1, 3]) | ||
|
||
if self.sr_ratio > 1: | ||
x = tf.transpose(x, [0, 2, 1]) | ||
x = tf.reshape(x, [B, C, H, W]) | ||
x = tf.reshape(x, (B, H, W, C)) | ||
x = self.sr(x) | ||
x = tf.reshape(x, [B, C, -1]) | ||
x = tf.transpose(x, [0, 2, 1]) | ||
x = tf.reshape(x, (B, -1, C)) | ||
x = self.norm(x) | ||
|
||
k = self.k(x) | ||
k = tf.reshape(k, [B, -1, self.num_heads, C // self.num_heads]) | ||
k = tf.transpose(k, [0, 2, 1, 3]) | ||
k = tf.reshape( | ||
k, shape=(tf.shape(k)[0], -1, self.num_heads, self.head_dim) | ||
) | ||
k = tf.transpose(k, perm=[0, 2, 1, 3]) | ||
|
||
v = self.v(x) | ||
v = tf.reshape(v, [B, -1, self.num_heads, C // self.num_heads]) | ||
v = tf.transpose(v, [0, 2, 1, 3]) | ||
v = tf.reshape( | ||
v, shape=(tf.shape(v)[0], -1, self.num_heads, self.head_dim) | ||
) | ||
v = tf.transpose(v, perm=[0, 2, 1, 3]) | ||
|
||
attn = (q @ tf.transpose(k, [0, 1, 3, 2])) * self.scale | ||
attn = tf.nn.softmax(attn, axis=-1) | ||
attn = self.attn_drop(attn) | ||
attn = tf.matmul(q, k, transpose_b=True) | ||
scale = tf.cast(self.sqrt_of_units, dtype=attn.dtype) | ||
attn = tf.divide(attn, scale) | ||
|
||
attn = attn @ v | ||
attn = tf.transpose(attn, [0, 2, 1, 3]) | ||
attn = tf.reshape(attn, shape=[B, N, C]) | ||
x = self.proj(attn) | ||
attn = tf.nn.softmax(logits=attn, axis=-1) | ||
attn = self.attn_drop(attn) | ||
x = tf.matmul(attn, v) | ||
x = tf.transpose(x, perm=[0, 2, 1, 3]) | ||
x = tf.reshape(x, (B, -1, self.units)) | ||
x = self.proj(x) | ||
x = self.proj_drop(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.