Skip to content

Commit

Permalink
Merge pull request xingyizhou#26 from Sundrops/master
Browse files Browse the repository at this point in the history
make the function signature consistent
  • Loading branch information
xingyizhou authored May 1, 2019
2 parents 3a1e9c6 + c9fa20b commit d4700b5
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/lib/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
'hourglass': get_large_hourglass_net,
}

def create_model(arch, head, head_conv):
def create_model(arch, heads, head_conv):
num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0
arch = arch[:arch.find('_')] if '_' in arch else arch
get_model = _model_factory[arch]
model = get_model(num_layers, head, head_conv)
model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)
return model

def load_model(model, model_path, optimizer=None, resume=False,
Expand Down
12 changes: 6 additions & 6 deletions src/lib/models/networks/dlav0.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def fill_fc_weights(layers):

class DLASeg(nn.Module):
def __init__(self, base_name, heads,
pretrained=True, down_ratio=4, add_conv=256):
pretrained=True, down_ratio=4, head_conv=256):
super(DLASeg, self).__init__()
assert down_ratio in [2, 4, 8, 16]
self.heads = heads
Expand All @@ -551,12 +551,12 @@ def __init__(self, base_name, heads,

for head in self.heads:
classes = self.heads[head]
if add_conv > 0:
if head_conv > 0:
fc = nn.Sequential(
nn.Conv2d(channels[self.first_level], add_conv,
nn.Conv2d(channels[self.first_level], head_conv,
kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(add_conv, classes,
nn.Conv2d(head_conv, classes,
kernel_size=1, stride=1,
padding=0, bias=True))
if 'hm' in head:
Expand Down Expand Up @@ -639,9 +639,9 @@ def dla169up(classes, pretrained_base=None, **kwargs):
return model
'''

def get_pose_net(heads, down_ratio=4, add_conv=256):
def get_pose_net(heads, down_ratio=4, head_conv=256):
model = DLASeg('dla34', heads,
pretrained=True,
down_ratio=down_ratio,
add_conv=add_conv)
head_conv=head_conv)
return model
2 changes: 1 addition & 1 deletion src/lib/models/networks/large_hourglass.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,6 @@ def __init__(self, heads, num_stacks=2):
kp_layer=residual, cnv_dim=256
)

def get_large_hourglass_net(_, heads, __):
def get_large_hourglass_net(num_layers, heads, head_conv):
model = HourglassNet(heads, 2)
return model
2 changes: 1 addition & 1 deletion src/lib/models/networks/pose_dla_dcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def forward(self, x):
return [z]


def get_pose_net(num_layers, heads, version, down_ratio=4, head_conv=256):
def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
model = DLASeg('dla{}'.format(num_layers), heads,
pretrained=True,
down_ratio=down_ratio,
Expand Down

0 comments on commit d4700b5

Please sign in to comment.