Skip to content

Commit

Permalink
Fixed CNN (again)
Browse files Browse the repository at this point in the history
  • Loading branch information
semjon00 committed May 18, 2024
1 parent 385b78c commit e3bfb70
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,25 @@ def __init__(self, channels, repeats, kernel_size=5):
self.out_channels = channels[-1]

self.seq = nn.ModuleList()
self.res = nn.ModuleList()
for i in range(len(channels) - 1):
layer = nn.Sequential()
for r in range(repeats):
c_in = channels[i + 1] if r != 0 else channels[i]
c_out = channels[i + 1]
layer.append(CNNBlock(c_in, c_out, kernel_size))
if i != len(channels) - 1 - 1:
layer.append(nn.MaxPool2d(kernel_size=2, stride=2))
self.seq.append(layer)
self.res.append(nn.Conv2d(channels[i], channels[i + 1], kernel_size=1))
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

def res_reduction_factor(self):
return 2 ** (max(0, len(self.seq) - 1))

def forward(self, x: Tensor):
x = einops.rearrange(x, '... l w c -> ... c l w')
for layer in self.seq:
x = x + layer(x)
for i in range(len(self.seq)):
x = self.seq[i](x) + self.res[i](x)
if i != len(self.seq) - 1:
x = self.pool(x)
x = einops.rearrange(x, ' ... c l w -> ... l w c')
return x

0 comments on commit e3bfb70

Please sign in to comment.