Skip to content

Commit

Permalink
Add a repeat padding option for MergedDataset.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571362625
  • Loading branch information
sdenton4 authored and copybara-github committed Oct 6, 2023
1 parent 0635efe commit b57a326
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions chirp/projects/multicluster/data_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def from_folder_of_folders(
load_audio: bool = True,
target_sample_rate: int = -2,
audio_file_pattern: str = '*',
pad_type: str = 'zeros',
) -> 'MergedDataset':
"""Generating MergedDataset via folder-of-folders method.
Expand All @@ -82,6 +83,7 @@ def from_folder_of_folders(
sample rate.
audio_file_pattern: The glob pattern to use for finding audio files within
the sub-folders.
pad_type: Padding strategy for short audio.
Returns:
MergedDataset
Expand All @@ -97,6 +99,7 @@ def from_folder_of_folders(
load_audio=load_audio,
target_sample_rate=target_sample_rate,
audio_file_pattern=audio_file_pattern,
pad_type=pad_type,
)
elapsed = time.time() - st
print(f'\n...embedded dataset in {elapsed:5.2f}s...')
Expand Down Expand Up @@ -257,15 +260,27 @@ def pool_time_axis(embeddings, pool_method, axis=1):
raise ValueError('Unrecognized reduction method.')


def _pad_audio(audio: np.ndarray, target_length: int) -> np.ndarray:
def _pad_audio(
audio: np.ndarray, target_length: int, pad_type: str = 'zeros'
) -> np.ndarray:
"""Pad audio to target_length."""
if len(audio.shape) > 1:
raise ValueError('audio should be a flat array.')
if audio.shape[0] > target_length:
return audio
pad_amount = target_length - audio.shape[0]
front = pad_amount // 2
back = pad_amount - front
return np.pad(audio, [(front, back)], 'constant')
if pad_type == 'zeros':
pad_amount = target_length - audio.shape[0]
front = pad_amount // 2
back = pad_amount - front
return np.pad(audio, [(front, back)], 'constant')
elif pad_type == 'repeat':
# repeat audio until longer than target_length.
num_repeats = target_length // audio.shape[0] + 1
repeated_audio = np.repeat(audio, num_repeats, axis=0)
start = repeated_audio.shape[0] - target_length // 2
padded = repeated_audio[start : start + target_length]
return padded
raise ValueError('Unrecognized padding method.')


def embed_dataset(
Expand All @@ -276,6 +291,7 @@ def embed_dataset(
load_audio: bool = True,
target_sample_rate: int = -1,
audio_file_pattern: str = '*',
pad_type: str = 'zeros',
) -> Tuple[Sequence[str], Dict[str, np.ndarray]]:
"""Add embeddings to an eval dataset.
Expand All @@ -295,6 +311,7 @@ def embed_dataset(
raw audio with no resampling. If -2, uses the embedding_model sample rate.
audio_file_pattern: The glob pattern to use for finding audio files within
the sub-folders.
pad_type: Padding style for short audio.
Returns:
Ordered labels and a Dict contianing the entire embedded dataset.
Expand Down Expand Up @@ -344,7 +361,7 @@ def embed_dataset(
):
audio_size = audio.shape[0]
if window_size > audio_size:
audio = _pad_audio(audio, window_size)
audio = _pad_audio(audio, window_size, pad_type)
audio = audio.astype(np.float32)
outputs = embedding_model.embed(audio)

Expand Down

0 comments on commit b57a326

Please sign in to comment.