Skip to content

Commit

Permalink
add vgg11 model
Browse files Browse the repository at this point in the history
  • Loading branch information
ruhyadi committed Mar 10, 2022
1 parent 631b44c commit e8ec778
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions script/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ def __init__(self, model=None, bins=2, w=0.4):
self.bins = bins
self.w = w
# extract all layer until [-2]
self.model = nn.Sequential(*(list(model.children())[:-2]))
self.model = model.features

# orientation head, for orientation estimation
self.orientation = nn.Sequential(
nn.Linear(512, 256),
nn.Linear(512 * 7 * 7, 256),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(256, 256),
Expand All @@ -150,15 +150,18 @@ def __init__(self, model=None, bins=2, w=0.4):

# confident head, for orientation estimation
self.confidence = nn.Sequential(
nn.Linear(512, 256),
nn.Linear(512 * 7 * 7, 256),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(256, bins) # 2 bins
nn.Linear(256, 256),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(256, bins)
)

# dimension head
self.dimension = nn.Sequential(
nn.Linear(512, 256),
nn.Linear(512 * 7 * 7, 256),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(256, 256),
Expand All @@ -169,7 +172,7 @@ def __init__(self, model=None, bins=2, w=0.4):

def forward(self, x):
x = self.model(x)
x = x.view(-1, 512)
x = x.view(-1, 512 * 7 * 7)

orientation = self.orientation(x)
orientation = orientation.view(-1, self.bins, 2)
Expand Down

0 comments on commit e8ec778

Please sign in to comment.