Skip to content

Commit

Permalink
resolve conflict for initializing whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
modantailleur authored and jnwnlee committed Feb 13, 2024
1 parent c23b8d1 commit 73b6cac
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 57 deletions.
138 changes: 82 additions & 56 deletions fadtk/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,25 @@
import numpy as np
import soundfile
import os

import torch
import librosa
from torch import nn
from pathlib import Path
from hypy_utils.downloader import download_file
import torch.nn.functional as F
import importlib.util
from .utils import chunk_np_array

from . import panns

log = logging.getLogger(__name__)


class ModelLoader(ABC):
"""
Abstract class for loading a model and getting embeddings from it. The model should be loaded in the `load_model` method.
"""
def __init__(self, name: str, num_features: int, sr: int):
def __init__(self, name: str, num_features: int, sr: int, audio_len : int):
self.audio_len = audio_len
self.model = None
self.sr = sr
self.num_features = num_features
Expand Down Expand Up @@ -56,16 +56,21 @@ def _get_embedding(self, audio: np.ndarray):
def load_wav(self, wav_file: Path):
wav_data, _ = soundfile.read(wav_file, dtype='int16')
wav_data = wav_data / 32768.0 # Convert to [-1.0, +1.0]

if self.audio_len is not None:
if wav_data.shape[0] < self.sr*4:
padding_size = self.sr*4 - len(wav_data)
wav_data = np.pad(wav_data, (0, padding_size), 'constant', constant_values=(0, 0))
elif wav_data.shape[0] > self.sr*4:
wav_data = wav_data[:self.sr*4]
return wav_data


class VGGishModel(ModelLoader):
"""
S. Hershey et al., "CNN Architectures for Large-Scale Audio Classification", ICASSP 2017
"""
def __init__(self, use_pca=False, use_activation=False):
super().__init__("vggish", 128, 16000)
def __init__(self, use_pca=False, use_activation=False, audio_len=None):
super().__init__("vggish", 128, 16000, audio_len)
self.use_pca = use_pca
self.use_activation = use_activation

Expand All @@ -86,10 +91,14 @@ class PANNsModel(ModelLoader):
"""
Kong, Qiuqiang, et al., "Panns: Large-scale pretrained audio neural networks for audio pattern recognition.",
IEEE/ACM Transactions on Audio, Speech, and Language Processing 28 (2020): 2880-2894.
Specify the model to use (cnn14-32k, cnn14-16k, wavegram-logmel).
You can also specify wether to send the full provided audio or 1-s chunks of audio (cnn14-32k-1s). This was shown
to have a very low impact on performances.
"""
def __init__(self, variant: Literal['32k', '16k'] = '32k'):
super().__init__('panns' if variant == '32k' else f"panns-{variant}", 2048,
sr=32000 if variant == '32k' else 16000)
def __init__(self, variant: Literal['cnn14-32k', 'cnn14-32k-1s', 'cnn14-16k', 'wavegram-logmel'], audio_len=None):
super().__init__(f"panns-{variant}", 2048,
sr=16000 if variant == 'cnn14-16k' else 16000, audio_len=audio_len)
self.variant = variant

def load_model(self):
Expand All @@ -106,9 +115,12 @@ def load_model(self):
f"wget -P {ckpt_dir} %s"
% ("https://zenodo.org/record/3987831/files/Cnn14_16k_mAP%3D0.438.pth")
)

os.system(
f"wget -P {ckpt_dir} %s"
% ("https://zenodo.org/records/3987831/files/Wavegram_Logmel_Cnn14_mAP%3D0.439.pth")
)
features_list = ["2048", "logits"]
if self.variant == '16k':
if self.variant == 'cnn14-16k':
self.model = panns.Cnn14(
features_list=features_list,
sample_rate=16000,
Expand All @@ -119,7 +131,7 @@ def load_model(self):
fmax=8000,
classes_num=527,
)
elif self.variant == '32k':
elif self.variant == 'cnn14-32k':
self.model = panns.Cnn14(
features_list=features_list,
sample_rate=32000,
Expand All @@ -130,26 +142,40 @@ def load_model(self):
fmax=14000,
classes_num=527,
)
elif self.variant == 'wavegram-logmel':
self.model = panns.Wavegram_Logmel_Cnn14(
sample_rate=32000,
window_size=1024,
hop_size=320,
mel_bins=64,
fmin=50,
fmax=14000,
classes_num=527,
)
self.model.eval()
self.model.to(self.device)

def _get_embedding(self, audio: np.ndarray) -> np.ndarray:
if '-1s' in self.variant:
audio = chunk_np_array(audio, self.sr)
audio = torch.from_numpy(audio).float().to(self.device)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
emb = self.model.forward(audio)["2048"]
if 'cnn14' in self.variant:
emb = self.model.forward(audio)["2048"]
else:
emb = self.model.forward(audio)["embedding"]
return emb



class EncodecEmbModel(ModelLoader):
"""
Encodec model from https://github.com/facebookresearch/encodec
Thiss version uses the embedding outputs (continuous values of 128 features).
"""
def __init__(self, variant: Literal['48k', '24k'] = '24k'):
def __init__(self, variant: Literal['48k', '24k'] = '24k', audio_len=None):
super().__init__('encodec-emb' if variant == '24k' else f"encodec-emb-{variant}", 128,
sr=24000 if variant == '24k' else 48000)
sr=24000 if variant == '24k' else 48000, audio_len=audio_len)
self.variant = variant

def load_model(self):
Expand Down Expand Up @@ -225,8 +251,8 @@ class DACModel(ModelLoader):
pip install descript-audio-codec
"""
def __init__(self):
super().__init__("dac-44kHz", 1024, 44100)
def __init__(self, audio_len=None):
super().__init__("dac-44kHz", 1024, 44100, audio_len=audio_len)

def load_model(self):
from dac.utils import load_model
Expand Down Expand Up @@ -290,8 +316,8 @@ class MERTModel(ModelLoader):
Please specify the layer to use (1-12).
"""
def __init__(self, size='v1-95M', layer=12, limit_minutes=6):
super().__init__(f"MERT-{size}" + ("" if layer == 12 else f"-{layer}"), 768, 24000)
def __init__(self, size='v1-95M', layer=12, limit_minutes=6, audio_len=None):
super().__init__(f"MERT-{size}" + ("" if layer == 12 else f"-{layer}"), 768, 24000, audio_len=audio_len)
self.huggingface_id = f"m-a-p/MERT-{size}"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr
Expand Down Expand Up @@ -325,8 +351,8 @@ class CLAPLaionModel(ModelLoader):
CLAP model from https://github.com/LAION-AI/CLAP
"""

def __init__(self, type: Literal['audio', 'music']):
super().__init__(f"clap-laion-{type}", 512, 48000)
def __init__(self, type: Literal['audio', 'music'], audio_len=None):
super().__init__(f"clap-laion-{type}", 512, 48000, audio_len=audio_len)
self.type = type

if type == 'audio':
Expand Down Expand Up @@ -425,8 +451,8 @@ class CdpamModel(ModelLoader):
"""
CDPAM model from https://github.com/pranaymanocha/PerceptualAudio/tree/master/cdpam
"""
def __init__(self, mode: Literal['acoustic', 'content']) -> None:
super().__init__(f"cdpam-{mode}", 512, 22050)
def __init__(self, mode: Literal['acoustic', 'content'], audio_len=None) -> None:
super().__init__(f"cdpam-{mode}", 512, 22050, audio_len=audio_len)
self.mode = mode
assert mode in ['acoustic', 'content'], "Mode must be 'acoustic' or 'content'"

Expand Down Expand Up @@ -467,8 +493,8 @@ class CLAPModel(ModelLoader):
"""
CLAP model from https://github.com/microsoft/CLAP
"""
def __init__(self, type: Literal['2023']):
super().__init__(f"clap-{type}", 1024, 44100)
def __init__(self, type: Literal['2023'], audio_len=None):
super().__init__(f"clap-{type}", 1024, 44100, audio_len=audio_len)
self.type = type

if type == '2023':
Expand Down Expand Up @@ -531,11 +557,11 @@ class W2V2Model(ModelLoader):
Please specify the size ('base' or 'large') and the layer to use (1-12 for 'base' or 1-24 for 'large').
"""
def __init__(self, size: Literal['base', 'large'], layer: Literal['12', '24'], limit_minutes=6):
def __init__(self, size: Literal['base', 'large'], layer: Literal['12', '24'], limit_minutes=6, audio_len=None):
model_dim = 768 if size == 'base' else 1024
model_identifier = f"w2v2-{size}" + ("" if (layer == 12 and size == 'base') or (layer == 24 and size == 'large') else f"-{layer}")

super().__init__(model_identifier, model_dim, 16000)
super().__init__(model_identifier, model_dim, 16000, audio_len=audio_len)
self.huggingface_id = f"facebook/wav2vec2-{size}-960h"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr
Expand Down Expand Up @@ -568,11 +594,11 @@ class HuBERTModel(ModelLoader):
Please specify the size ('base' or 'large') and the layer to use (1-12 for 'base' or 1-24 for 'large').
"""
def __init__(self, size: Literal['base', 'large'], layer: Literal['12', '24'], limit_minutes=6):
def __init__(self, size: Literal['base', 'large'], layer: Literal['12', '24'], limit_minutes=6, audio_len=None):
model_dim = 768 if size == 'base' else 1024
model_identifier = f"hubert-{size}" + ("" if (layer == 12 and size == 'base') or (layer == 24 and size == 'large') else f"-{layer}")

super().__init__(model_identifier, model_dim, 16000)
super().__init__(model_identifier, model_dim, 16000, audio_len=audio_len)
self.huggingface_id = f"facebook/hubert-{size}-ls960"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr
Expand Down Expand Up @@ -605,11 +631,11 @@ class WavLMModel(ModelLoader):
Please specify the model size ('base', 'base-plus', or 'large') and the layer to use (1-12 for 'base' or 'base-plus' and 1-24 for 'large').
"""
def __init__(self, size: Literal['base', 'base-plus', 'large'], layer: Literal['12', '24'], limit_minutes=6):
def __init__(self, size: Literal['base', 'base-plus', 'large'], layer: Literal['12', '24'], limit_minutes=6, audio_len=None):
model_dim = 768 if size in ['base', 'base-plus'] else 1024
model_identifier = f"wavlm-{size}" + ("" if (layer == 12 and size in ['base', 'base-plus']) or (layer == 24 and size == 'large') else f"-{layer}")

super().__init__(model_identifier, model_dim, 16000)
super().__init__(model_identifier, model_dim, 16000, audio_len=audio_len)
self.huggingface_id = f"patrickvonplaten/wavlm-libri-clean-100h-{size}"
self.layer = layer
self.limit = limit_minutes * 60 * self.sr
Expand Down Expand Up @@ -642,7 +668,7 @@ class WhisperModel(ModelLoader):
Please specify the model size ('tiny', 'base', 'small', 'medium', or 'large').
"""
def __init__(self, size: Literal['tiny', 'base', 'small', 'medium', 'large']):
def __init__(self, size: Literal['tiny', 'base', 'small', 'medium', 'large'], audio_len=None):
dimensions = {
'tiny': 384,
'base': 512,
Expand All @@ -653,7 +679,7 @@ def __init__(self, size: Literal['tiny', 'base', 'small', 'medium', 'large']):
model_dim = dimensions.get(size)
model_identifier = f"whisper-{size}"

super().__init__(model_identifier, model_dim, 16000)
super().__init__(model_identifier, model_dim, 16000, audio_len=audio_len)
self.huggingface_id = f"openai/whisper-{size}"

def load_model(self):
Expand All @@ -674,28 +700,28 @@ def _get_embedding(self, audio: np.ndarray) -> np.ndarray:

return out



def get_all_models() -> list[ModelLoader]:
def get_all_models(audio_len=None) -> list[ModelLoader]:
ms = [
CLAPModel('2023'),
CLAPLaionModel('audio'), CLAPLaionModel('music'),
VGGishModel(),
PANNsModel('32k'), PANNsModel('16k'),
*(MERTModel(layer=v) for v in range(1, 13)),
EncodecEmbModel('24k'), EncodecEmbModel('48k'),
# DACModel(),
# CdpamModel('acoustic'), CdpamModel('content'),
*(W2V2Model('base', layer=v) for v in range(1, 13)),
*(W2V2Model('large', layer=v) for v in range(1, 25)),
*(HuBERTModel('base', layer=v) for v in range(1, 13)),
*(HuBERTModel('large', layer=v) for v in range(1, 25)),
*(WavLMModel('base', layer=v) for v in range(1, 13)),
*(WavLMModel('base-plus', layer=v) for v in range(1, 13)),
*(WavLMModel('large', layer=v) for v in range(1, 25)),
WhisperModel('tiny'), WhisperModel('small'),
WhisperModel('base'), WhisperModel('medium'),
WhisperModel('large'),
CLAPModel('2023', audio_len=audio_len),
CLAPLaionModel('audio', audio_len=audio_len), CLAPLaionModel('music', audio_len=audio_len),
VGGishModel(audio_len=audio_len),
PANNsModel('cnn14-32k',audio_len=audio_len), PANNsModel('cnn14-16k',audio_len=audio_len),
PANNsModel('wavegram-logmel',audio_len=audio_len),
# PANNs1sModel('32k',audio_len=audio_len), PANNs1sModel('16k',audio_len=audio_len),
*(MERTModel(layer=v, audio_len=audio_len) for v in range(1, 13)),
EncodecEmbModel('24k', audio_len=audio_len), EncodecEmbModel('48k', audio_len=audio_len),
DACModel(audio_len=audio_len),
CdpamModel('acoustic', audio_len=audio_len), CdpamModel('content', audio_len=audio_len),
*(W2V2Model('base', layer=v, audio_len=audio_len) for v in range(1, 13)),
*(W2V2Model('large', layer=v, audio_len=audio_len) for v in range(1, 25)),
*(HuBERTModel('base', layer=v, audio_len=audio_len) for v in range(1, 13)),
*(HuBERTModel('large', layer=v, audio_len=audio_len) for v in range(1, 25)),
*(WavLMModel('base', layer=v, audio_len=audio_len) for v in range(1, 13)),
*(WavLMModel('base-plus', layer=v, audio_len=audio_len) for v in range(1, 13)),
*(WavLMModel('large', layer=v, audio_len=audio_len) for v in range(1, 25)),
WhisperModel('tiny', audio_len=audio_len), WhisperModel('small', audio_len=audio_len),
WhisperModel('base', audio_len=audio_len), WhisperModel('medium', audio_len=audio_len),
WhisperModel('large', audio_len=audio_len),
]
if importlib.util.find_spec("dac") is not None:
ms.append(DACModel())
Expand Down
2 changes: 1 addition & 1 deletion fadtk/panns/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .models import Cnn14, Cnn14_16k
from .models import Cnn14, Cnn14_16k, Wavegram_Logmel_Cnn14
3 changes: 3 additions & 0 deletions fadtk/panns/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2887,6 +2887,9 @@ def __init__(
self.fc_audioset = nn.Linear(2048, classes_num, bias=True)

self.init_weight()
current_file_dir = os.path.dirname(os.path.realpath(__file__))
state_dict = torch.load(f"{current_file_dir}/ckpt/Wavegram_Logmel_Cnn14_mAP=0.439.pth")
self.load_state_dict(state_dict["model"])

def init_weight(self):
init_layer(self.pre_conv0)
Expand Down
20 changes: 20 additions & 0 deletions fadtk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,23 @@ def get_cache_embedding_path(model: str, audio_dir: PathLike) -> Path:
"""
audio_dir = Path(audio_dir)
return audio_dir.parent / "embeddings" / model / audio_dir.with_suffix(".npy").name

def chunk_np_array(np_array, chunk_size, discard_remainder=True):
"""
Split a NumPy array into chunks of a specified size.
Parameters:
- np_array: The input NumPy array to be split.
- chunk_size: The size of each chunk.
- discard_remainder: If True, discard any remaining elements that don't fit perfectly into chunks.
Returns:
A NumPy array of dimensions (num_chunks, chunk_size)
"""
if discard_remainder:
num_chunks = len(np_array) // chunk_size
output = np.array([np_array[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)])
else:
output = np.array([np_array[i:i + chunk_size] for i in range(0, len(np_array), chunk_size)])

return output

1 comment on commit 73b6cac

@jnwnlee
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add more panns options + add an audio length to fit as option
8681d6a

Please sign in to comment.