-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.py
25 lines (23 loc) · 936 Bytes
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""
This code is modified based on Jin-Hwa Kim's repository (Bilinear Attention Networks - https://github.com/jnhwkim/ban-vqa) by Xuan B. Nguyen
"""
import torch.nn as nn
from torch.nn.utils.weight_norm import weight_norm
class SimpleClassifier(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, args):
super(SimpleClassifier, self).__init__()
activation_dict = {'relu': nn.ReLU()}
try:
activation_func = activation_dict[args.activation]
except:
raise AssertionError(args.activation + " is not supported yet!")
layers = [
weight_norm(nn.Linear(in_dim, hid_dim), dim=None),
activation_func,
nn.Dropout(args.dropout, inplace=True),
weight_norm(nn.Linear(hid_dim, out_dim), dim=None)
]
self.main = nn.Sequential(*layers)
def forward(self, x):
logits = self.main(x)
return logits