Skip to content

Commit

Permalink
Merge pull request WZMIAOMIAO#168 from WZMIAOMIAO/dev
Browse files Browse the repository at this point in the history
fix a err
  • Loading branch information
WZMIAOMIAO authored Mar 4, 2021
2 parents 9838ba7 + efa1eab commit b830989
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
3 changes: 2 additions & 1 deletion tensorflow_classification/Test6_mobilenet/model_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def _inverted_res_block(x,
x = act(name=prefix + 'expand/' + act.__name__)(x)

if stride == 2:
x = layers.ZeroPadding2D(padding=correct_pad(exp_c, kernel_size),
input_size = (x.shape[1], x.shape[2]) # height, width
x = layers.ZeroPadding2D(padding=correct_pad(input_size, kernel_size),
name=prefix + 'depthwise/pad')(x)

x = layers.DepthwiseConv2D(kernel_size=kernel_size,
Expand Down
14 changes: 8 additions & 6 deletions tensorflow_classification/Test6_mobilenet/train_mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def main():
batch_size = 16
epochs = 20
num_classes = 5
freeze_layer = False

# data generator with data augmentation
train_ds, val_ds = generate_ds(data_root, im_height, im_width, batch_size)
Expand All @@ -34,12 +35,13 @@ def main():
assert os.path.exists(pre_weights_path), "cannot find {}".format(pre_weights_path)
model.load_weights(pre_weights_path, by_name=True, skip_mismatch=True)

# freeze layer, only training 2 last layers
for layer in model.layers:
if layer.name not in ["Conv_2", "Logits/Conv2d_1c_1x1"]:
layer.trainable = False
else:
print("training: " + layer.name)
if freeze_layer is True:
# freeze layer, only training 2 last layers
for layer in model.layers:
if layer.name not in ["Conv_2", "Logits/Conv2d_1c_1x1"]:
layer.trainable = False
else:
print("training: " + layer.name)

model.summary()

Expand Down

0 comments on commit b830989

Please sign in to comment.