Skip to content

Commit

Permalink
Add options to freeze encoder and unfreeze batch normalization/squeez…
Browse files Browse the repository at this point in the history
…e-and-excitation, save trained model
  • Loading branch information
vdng9338 committed Jan 25, 2022
1 parent 83c2ba2 commit 3b9af93
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
30 changes: 30 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import nemo
import nemo.collections.asr as nemo_asr
import torch.nn as nn
import pytorch_lightning as pl
import yaml
from datetime import datetime
from omegaconf import DictConfig

FREEZE_ENCODER = True
UNFREEZE_SQUEEZE_EXCITATION = True
UNFREEZE_BATCH_NORM = True

DATE = datetime.strftime(datetime.now(), "%Y-%m-%d_%H-%M-%S")

CONFIG_PATH = 'model/config.yaml'

Expand All @@ -17,11 +24,34 @@
model.setup_validation_data(val_data_config=params['model']['validation_ds'])
model.setup_optimization(optim_config=params['model']['optim'])

# Freeze the encoder: from https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb
def unfreeze_squeeze_excitation(m):
if "SqueezeExcite" in type(m).__name__:
m.train()
for param in m.parameters():
param.requires_grad_(True)

def unfreeze_batch_norm(m):
if type(m) == nn.BatchNorm1d:
m.train()
for param in m.parameters():
param.requires_grad_(True)

if FREEZE_ENCODER:
model.encoder.freeze()
if UNFREEZE_SQUEEZE_EXCITATION:
model.encoder.apply(unfreeze_squeeze_excitation)
if UNFREEZE_BATCH_NORM:
model.encoder.apply(unfreeze_batch_norm)

trainer = pl.Trainer(**params['trainer'])

print("\n========== Start FIT")
trainer.fit(model)

print("\n========== Done fitting")
model.save_to(f"{params['name']}_fr_{DATE}.nemo")

# p = '/home/jovyan/projet-ml/data/libri-dataset/dev-clean/1272/128104/1272-128104-0000.flac'
# txt = model.transcribe([p])
# print(txt)
2 changes: 1 addition & 1 deletion model/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ model:

trainer:
gpus: 1 # number of gpus
max_epochs: 1
max_epochs: 20
max_steps: null # computed at runtime if not set
num_nodes: 1 # Should be set via SLURM variable `SLURM_JOB_NUM_NODES`
accelerator: ddp
Expand Down

0 comments on commit 3b9af93

Please sign in to comment.