Skip to content

Commit

Permalink
Fix Param Count
Browse files Browse the repository at this point in the history
  • Loading branch information
IMvision12 committed Mar 27, 2023
1 parent 99931a7 commit 36833e7
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 115 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
#[WIP]

- [ ] Add Model
- [ ] Train on custom dataset

# SegFormer-tf

Paper : https://arxiv.org/pdf/2105.15203
Expand Down
86 changes: 49 additions & 37 deletions models/Attention.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,85 @@
import tensorflow as tf
from tensorflow.keras import layers
import math


class Attention(tf.keras.layers.Layer):
def __init__(
self,
dim,
num_heads=8,
num_heads,
sr_ratio,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
sr_ratio=1,
**kwargs,
):
super(Attention, self).__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divisible by num_heads {num_heads}."

super().__init__(**kwargs)
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.head_dim = self.dim // self.num_heads

self.units = self.num_heads * self.head_dim
self.sqrt_of_units = math.sqrt(self.head_dim)

self.q = tf.keras.layers.Dense(self.units, use_bias=qkv_bias)
self.k = tf.keras.layers.Dense(self.units, use_bias=qkv_bias)
self.v = tf.keras.layers.Dense(self.units, use_bias=qkv_bias)

self.attn_drop = tf.keras.layers.Dropout(attn_drop)

self.q = layers.Dense(dim, use_bias=qkv_bias)
self.k = layers.Dense(dim, use_bias=qkv_bias)
self.v = layers.Dense(dim, use_bias=qkv_bias)
self.attn_drop = layers.Dropout(attn_drop)
self.proj = layers.Dense(dim)
self.proj_drop = layers.Dropout(proj_drop)
self.proj = tf.keras.layers.Dense(dim)
self.proj_drop = tf.keras.layers.Dropout(proj_drop)

self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = layers.Conv2D(dim, kernel_size=sr_ratio, strides=sr_ratio)
self.norm = layers.LayerNormalization()
self.sr = tf.keras.layers.Conv2D(
filters=dim, kernel_size=sr_ratio, strides=sr_ratio
)
self.norm = tf.keras.layers.LayerNormalization(epsilon=1e-05)

def call(self, x, H, W):
def call(
self,
x,
H,
W,
):
get_shape = tf.shape(x)
B = get_shape[0]
N = get_shape[1]
C = get_shape[2]

q = self.q(x)
q = tf.reshape(q, [B, N, self.num_heads, C // self.num_heads])
q = tf.reshape(
q, shape=(tf.shape(q)[0], -1, self.num_heads, self.head_dim)
)
q = tf.transpose(q, perm=[0, 2, 1, 3])

if self.sr_ratio > 1:
x = tf.transpose(x, [0, 2, 1])
x = tf.reshape(x, [B, C, H, W])
x = tf.reshape(x, (B, H, W, C))
x = self.sr(x)
x = tf.reshape(x, [B, C, -1])
x = tf.transpose(x, [0, 2, 1])
x = tf.reshape(x, (B, -1, C))
x = self.norm(x)

k = self.k(x)
k = tf.reshape(k, [B, -1, self.num_heads, C // self.num_heads])
k = tf.transpose(k, [0, 2, 1, 3])
k = tf.reshape(
k, shape=(tf.shape(k)[0], -1, self.num_heads, self.head_dim)
)
k = tf.transpose(k, perm=[0, 2, 1, 3])

v = self.v(x)
v = tf.reshape(v, [B, -1, self.num_heads, C // self.num_heads])
v = tf.transpose(v, [0, 2, 1, 3])
v = tf.reshape(
v, shape=(tf.shape(v)[0], -1, self.num_heads, self.head_dim)
)
v = tf.transpose(v, perm=[0, 2, 1, 3])

attn = (q @ tf.transpose(k, [0, 1, 3, 2])) * self.scale
attn = tf.nn.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
attn = tf.matmul(q, k, transpose_b=True)
scale = tf.cast(self.sqrt_of_units, dtype=attn.dtype)
attn = tf.divide(attn, scale)

attn = attn @ v
attn = tf.transpose(attn, [0, 2, 1, 3])
attn = tf.reshape(attn, shape=[B, N, C])
x = self.proj(attn)
attn = tf.nn.softmax(logits=attn, axis=-1)
attn = self.attn_drop(attn)
x = tf.matmul(attn, v)
x = tf.transpose(x, perm=[0, 2, 1, 3])
x = tf.reshape(x, (B, -1, self.units))
x = self.proj(x)
x = self.proj_drop(x)
return x
34 changes: 17 additions & 17 deletions models/Head.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


class MLP(layers.Layer):
def __init__(self, embed_dim=768):
super().__init__()
self.proj = layers.Dense(embed_dim)
class MLP(tf.keras.layers.Layer):
def __init__(self, decode_dim=768, **kwargs):
super().__init__(**kwargs)
self.proj = tf.keras.layers.Dense(decode_dim)

def forward(self, x):
get_shape = tf.shape(x)
Expand All @@ -15,21 +12,24 @@ def forward(self, x):
W = get_shape[2]
dim = get_shape[-1]

x = tf.reshape(x, (B, H*W, dim))
x = tf.reshape(x, (B, H * W, dim))
x = self.proj(x)
return x

class SegFormerHead(layers.Layer):
def __init__(self, num_classes, num_blocks=4, cls_dropout_rate=0.1):
super().__init__()


class SegFormerHead(tf.keras.layers.Layer):
def __init__(self, num_classes, decode_dim, num_blocks=4, cls_dropout_rate=0.1, **kwargs):
super().__init__(**kwargs)

mlps = []
for i in range(num_blocks):
mlp = MLP()
for _ in range(num_blocks):
mlp = MLP(decode_dim)
mlps.append(mlp)
self.mlps = mlps

self.linear_fuse = tf.keras.layers.Conv2D(filters=256, kernel_size=1, use_bias=False)
self.linear_fuse = tf.keras.layers.Conv2D(
filters=decode_dim, kernel_size=1, use_bias=False
)
self.norm = tf.keras.layers.BatchNormalization(epsilon=1e-5)
self.act = tf.keras.layers.Activation("relu")

Expand All @@ -43,11 +43,11 @@ def call(self, x):
outputs = []
for feat, mlp in zip(x, self.mlps):
x = mlp(feat)
x = tf.image.resize(x, size=(H, W), method='bilinear')
x = tf.image.resize(x, size=(H, W), method="bilinear")
outputs.append(x)

x = self.linear_fuse(tf.concat(outputs[::-1], axis=3))
x = self.norm(x)
x = self.act(x)
x = self.cls(x)
return x
return x
Loading

0 comments on commit 36833e7

Please sign in to comment.