Skip to content

Commit

Permalink
refactor: Extract the alignment model loading logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jim60105 committed Aug 27, 2023
1 parent 43ba06f commit 357463c
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 33 deletions.
18 changes: 18 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"cSpell.words": [
"anuragshas",
"comodoro",
"ftspeech",
"imvladikon",
"jonatasgrosman",
"kingabzpro",
"kresnik",
"mpoyraz",
"nguyenvulebinh",
"saattrupdan",
"torchaudio",
"VOXPOPULI",
"xlsr",
"Yehor"
]
}
38 changes: 5 additions & 33 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
ARG LANG=en

# Base image
FROM nvcr.io/nvidia/pytorch:23.07-py3 as base
FROM nvcr.io/nvidia/pytorch:23.07-py3
ENV DEBIAN_FRONTEND=noninteractive

WORKDIR /app
Expand All @@ -11,39 +9,13 @@ COPY ./whisperX/requirements.txt .
RUN python3 -m pip install --no-cache-dir -r ./requirements.txt ujson

# Preload fast-whisper
ARG WHISPER_MODEL=tiny.en
ARG WHISPER_MODEL=base
RUN python3 -c 'import faster_whisper; model = faster_whisper.WhisperModel("'${WHISPER_MODEL}'")'

# Preload align model
FROM base AS align-en
ARG ALIGN_MODEL=WAV2VEC2_ASR_BASE_960H
RUN python3 -c 'import torchaudio; bundle = torchaudio.pipelines.__dict__["'${ALIGN_MODEL}'"]; align_model = bundle.get_model(); labels = bundle.get_labels()'

FROM base AS align-fr
ARG ALIGN_MODEL=VOXPOPULI_ASR_BASE_10K_FR
RUN python3 -c 'import torchaudio; bundle = torchaudio.pipelines.__dict__["'${ALIGN_MODEL}'"]; align_model = bundle.get_model(); labels = bundle.get_labels()'

FROM base AS align-de
ARG ALIGN_MODEL=VOXPOPULI_ASR_BASE_10K_DE
RUN python3 -c 'import torchaudio; bundle = torchaudio.pipelines.__dict__["'${ALIGN_MODEL}'"]; align_model = bundle.get_model(); labels = bundle.get_labels()'

FROM base AS align-es
ARG ALIGN_MODEL=VOXPOPULI_ASR_BASE_10K_ES
RUN python3 -c 'import torchaudio; bundle = torchaudio.pipelines.__dict__["'${ALIGN_MODEL}'"]; align_model = bundle.get_model(); labels = bundle.get_labels()'

FROM base AS align-it
ARG ALIGN_MODEL=VOXPOPULI_ASR_BASE_10K_IT
RUN python3 -c 'import torchaudio; bundle = torchaudio.pipelines.__dict__["'${ALIGN_MODEL}'"]; align_model = bundle.get_model(); labels = bundle.get_labels()'

FROM base AS align-ja
ARG ALIGN_MODEL=jonatasgrosman/wav2vec2-large-xlsr-53-japanese
RUN python3 -c 'from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor; processor = Wav2Vec2Processor.from_pretrained("'${ALIGN_MODEL}'"); align_model = Wav2Vec2ForCTC.from_pretrained("'${ALIGN_MODEL}'")'

FROM base AS align-zh
ARG ALIGN_MODEL=jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn
RUN python3 -c 'from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor; processor = Wav2Vec2Processor.from_pretrained("'${ALIGN_MODEL}'"); align_model = Wav2Vec2ForCTC.from_pretrained("'${ALIGN_MODEL}'")'

FROM align-${LANG} AS final
ARG LANG=en
COPY load_align_model.py .
RUN python load_align_model.py ${LANG}

# Install whisperX
COPY ./whisperX/ .
Expand Down
50 changes: 50 additions & 0 deletions load_align_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import sys
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

lang = sys.argv[1]

# https://github.com/m-bain/whisperX/blob/v3.1.1/whisperx/alignment.py#L21
DEFAULT_ALIGN_MODELS_TORCH = {
"en": "WAV2VEC2_ASR_BASE_960H",
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
"de": "VOXPOPULI_ASR_BASE_10K_DE",
"es": "VOXPOPULI_ASR_BASE_10K_ES",
"it": "VOXPOPULI_ASR_BASE_10K_IT",
}

DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
"fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
"vi": 'nguyenvulebinh/wav2vec2-base-vi',
"ko": "kresnik/wav2vec2-large-xlsr-korean",
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
}

if lang in DEFAULT_ALIGN_MODELS_TORCH:
model_name = DEFAULT_ALIGN_MODELS_TORCH[lang]
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model()
labels = bundle.get_labels()

elif lang in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[lang]
processor = Wav2Vec2Processor.from_pretrained(model_name)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
else:
raise ValueError(f"Unsupported language: {lang}")

0 comments on commit 357463c

Please sign in to comment.