Skip to content

Commit

Permalink
Add audio_to_image_batch and sample_clips_batch
Browse files Browse the repository at this point in the history
Topic: batch_cli_commands
  • Loading branch information
hmartiro committed Jan 29, 2023
1 parent 1dd4dbe commit 45d36a3
Showing 1 changed file with 139 additions and 1 deletion.
140 changes: 139 additions & 1 deletion riffusion/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
Command line tools for riffusion.
"""

import random
import typing as T
from multiprocessing.pool import ThreadPool
from pathlib import Path

import argh
import numpy as np
import pydub
import tqdm
from PIL import Image

from riffusion.spectrogram_image_converter import SpectrogramImageConverter
Expand Down Expand Up @@ -96,7 +100,7 @@ def sample_clips(
audio: str,
output_dir: str,
num_clips: int = 1,
duration_ms: int = 5000,
duration_ms: int = 5120,
mono: bool = False,
extension: str = "wav",
seed: int = -1,
Expand Down Expand Up @@ -127,12 +131,146 @@ def sample_clips(
print(f"Wrote {clip_path}")


def audio_to_images_batch(
*,
audio_dir: str,
output_dir: str,
image_extension: str = "jpg",
step_size_ms: int = 10,
num_frequencies: int = 512,
min_frequency: int = 0,
max_frequency: int = 10000,
power_for_image: float = 0.25,
mono: bool = False,
sample_rate: int = 44100,
device: str = "cuda",
num_threads: T.Optional[int] = None,
limit: int = -1,
):
"""
Process audio clips into spectrograms in batch, multi-threaded.
"""
audio_paths = list(Path(audio_dir).glob("*"))
audio_paths.sort()

if limit > 0:
audio_paths = audio_paths[:limit]

output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)

params = SpectrogramParams(
step_size_ms=step_size_ms,
num_frequencies=num_frequencies,
min_frequency=min_frequency,
max_frequency=max_frequency,
power_for_image=power_for_image,
stereo=not mono,
sample_rate=sample_rate,
)

converter = SpectrogramImageConverter(params=params, device=device)

def process_one(audio_path: Path) -> None:
# Load
try:
segment = pydub.AudioSegment.from_file(str(audio_path))
except Exception:
return

# TODO(hayk): Sanity checks on clip

if mono and segment.channels != 1:
segment = segment.set_channels(1)
elif not mono and segment.channels != 2:
segment = segment.set_channels(2)

# Frame rate
if segment.frame_rate != params.sample_rate:
segment = segment.set_frame_rate(params.sample_rate)

# Convert
image = converter.spectrogram_image_from_audio(segment)

# Save
image_path = output_path / f"{audio_path.stem}.{image_extension}"
image_format = {"jpg": "JPEG", "jpeg": "JPEG", "png": "PNG"}[image_extension]
image.save(image_path, exif=image.getexif(), format=image_format)

# Create thread pool
pool = ThreadPool(processes=num_threads)
with tqdm.tqdm(total=len(audio_paths)) as pbar:
for i, _ in enumerate(pool.imap_unordered(process_one, audio_paths)):
pbar.update()


def sample_clips_batch(
*,
audio_dir: str,
output_dir: str,
num_clips_per_file: int = 1,
duration_ms: int = 5120,
mono: bool = False,
extension: str = "mp3",
num_threads: T.Optional[int] = None,
limit: int = -1,
seed: int = -1,
):
"""
Sample short clips from a directory of audio files, multi-threaded.
"""
audio_paths = list(Path(audio_dir).glob("*"))
audio_paths.sort()

if limit > 0:
audio_paths = audio_paths[:limit]

output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)

if seed >= 0:
random.seed(seed)

def process_one(audio_path: Path) -> None:
try:
segment = pydub.AudioSegment.from_file(str(audio_path))
except Exception:
return

if mono:
segment = segment.set_channels(1)

segment_duration_ms = int(segment.duration_seconds * 1000)
for i in range(num_clips_per_file):
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
clip = segment[clip_start_ms : clip_start_ms + duration_ms]

clip_name = (
f"{audio_path.stem}_{i}"
"start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
)
clip.export(output_path / clip_name, format=extension)

pool = ThreadPool(processes=num_threads)
with tqdm.tqdm(total=len(audio_paths)) as pbar:
for result in pool.imap_unordered(process_one, audio_paths):
# process_one(audio_path)
pbar.update()

# with tqdm.tqdm(total=len(audio_paths)) as pbar:
# for i, _ in enumerate(pool.imap_unordered(process_one, audio_paths)):
# pass
# pbar.update()


if __name__ == "__main__":
argh.dispatch_commands(
[
audio_to_image,
image_to_audio,
sample_clips,
print_exif,
audio_to_images_batch,
sample_clips_batch,
]
)

0 comments on commit 45d36a3

Please sign in to comment.