Skip to content

Commit

Permalink
Merge pull request #11 from fgnt/dev
Browse files Browse the repository at this point in the history
Add fast composition sampler and allow meeting sampling without enrollment phase
  • Loading branch information
thequilo authored Jul 13, 2023
2 parents 1ad6b5b + 2a4c7a2 commit 9066466
Show file tree
Hide file tree
Showing 11 changed files with 348 additions and 77 deletions.
32 changes: 23 additions & 9 deletions mms_msg/databases/classical/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
from typing import Callable

import numpy as np
import typing
if typing.TYPE_CHECKING:
from typing import Literal

import lazy_dataset.database
from lazy_dataset import Dataset
from lazy_dataset.database import JsonDatabase, Database
from lazy_dataset.database import Database
from mms_msg.databases.database import MMSMSGDatabase
from mms_msg import keys
from mms_msg.sampling.environment.rir import RIRSampler
from mms_msg.sampling.source_composition import get_composition_dataset
from mms_msg.sampling.source_composition import get_composition_dataset, sample_utterance_composition
from mms_msg.simulation.anechoic import anechoic_scenario_map_fn
from mms_msg.simulation.noise import white_microphone_noise
from mms_msg.simulation.reverberant import reverberant_scenario_map_fn
from mms_msg.simulation.reverberant import reverberant_scenario_map_fn, slice_channel
from mms_msg.simulation.truncation import truncate_min
from mms_msg.simulation.utils import load_audio

Expand All @@ -26,6 +29,7 @@ def __init__(
scaling_sampler: Callable[[dict], dict],
truncate_to_shortest: bool = True,
source_filter: Callable[[dict], bool] = None,
composition_sampler=sample_utterance_composition,
):
"""
Base database class for classical anechoic speech mixtures.
Expand All @@ -36,7 +40,7 @@ def __init__(
- no noise
Args:
source_json_path: Path to the source database json
source_database: Source database object
num_speakers: Number of speakers per mixture
offset_sampler: A sampling module to sample an offset
(key 'offset.original_source') for each utterance
Expand All @@ -45,7 +49,8 @@ def __init__(
truncate_to_shortest: Inspired by WSJ0-2/3mix. If 'min', the mixture is truncated
to the shorter utterance to ensure full overlap. If 'max',
utterances are not truncated
source_filter: A function to filter the source examples
source_filter: A function to filter the source examples. This function is used
to filter all datasets from `source_database`
"""
super().__init__(source_database)
self.num_speakers = num_speakers
Expand All @@ -56,12 +61,14 @@ def __init__(
def source_filter(_):
return True
self.source_filter = source_filter
self.composition_sampler = composition_sampler

def get_mixture_dataset(self, name: str, rng: np.random.Generator) -> Dataset:
ds = get_composition_dataset(
input_dataset=self.source_database.get_dataset(name).filter(self.source_filter),
num_speakers=self.num_speakers,
rng=rng
rng=rng,
composition_sampler=self.composition_sampler
)
ds = ds.map(self.scaling_sampler)
ds = ds.map(self.overlap_sampler)
Expand All @@ -84,7 +91,10 @@ def __init__(self,
rir_database: Database,
snr_sampler: Callable[[dict], dict],
truncate_to_shortest: bool = True,
source_filter: Callable[[dict], bool] = None):
source_filter: Callable[[dict], bool] = None,
channel_slice: 'int | slice | Literal["one_random"] | Literal["all"]' = None,
composition_sampler=sample_utterance_composition,
):
"""
Base database class for classical reverberant speech mixtures.
Expand All @@ -94,7 +104,7 @@ def __init__(self,
- white microphone noise
Args:
source_json_path: Path to the source database json
source_database: Source database object
num_speakers: Number of speakers per mixture
offset_sampler: A sampling module to sample an offset
(key 'offset.original_source') for each utterance
Expand All @@ -107,9 +117,11 @@ def __init__(self,
to the shorter utterance to ensure full overlap. If 'max',
utterances are not truncated
"""
super().__init__(source_database, num_speakers, offset_sampler, scaling_sampler, truncate_to_shortest, source_filter)
super().__init__(source_database, num_speakers, offset_sampler, scaling_sampler, truncate_to_shortest, source_filter,
composition_sampler)
self.rir_database = rir_database
self.snr_sampler = snr_sampler
self.channel_slice = channel_slice

def get_mixture_dataset(self, name: str, rng: np.random.Generator) -> Dataset:
return super().get_mixture_dataset(name, rng).map(
Expand All @@ -118,6 +130,8 @@ def get_mixture_dataset(self, name: str, rng: np.random.Generator) -> Dataset:

def load_example(self, example: dict) -> dict:
example = load_audio(example, keys.ORIGINAL_SOURCE, keys.RIR)
if self.channel_slice is not None:
example = slice_channel(example, channel_slice=self.channel_slice, squeeze=True)
example = reverberant_scenario_map_fn(example)
example = white_microphone_noise(example)
if self.truncate_to_shortest:
Expand Down
3 changes: 3 additions & 0 deletions mms_msg/databases/classical/full_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def SMSWSJ(
source_json_path=database_jsons / 'wsj_8k.json',
scenario_json_path=data_dir.db_dir / 'sms_wsj' / 'rirs' / 'scenarios.json',
num_speakers=2,
channel_slice=None,
):
"""
A database similar to the SMS-WSJ database
Expand All @@ -67,4 +68,6 @@ def SMSWSJ(
snr_sampler=UniformSNRSampler(20, 30),
rir_database=SMSWSJRIRDatabase(scenario_json_path),
source_filter=filter_punctuation_pronunciation,
truncate_to_shortest=False,
channel_slice=channel_slice,
)
23 changes: 16 additions & 7 deletions mms_msg/databases/meeting/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from mms_msg import keys
from mms_msg.databases.database import MMSMSGDatabase
from mms_msg.sampling.environment.rir import RIRSampler
from mms_msg.sampling.source_composition import get_composition_dataset
from mms_msg.sampling.source_composition import get_composition_dataset, sample_utterance_composition
from mms_msg.simulation.anechoic import anechoic_scenario_map_fn
from mms_msg.simulation.noise import white_microphone_noise
from mms_msg.simulation.reverberant import reverberant_scenario_map_fn
from mms_msg.simulation.reverberant import reverberant_scenario_map_fn, slice_channel
from mms_msg.simulation.utils import load_audio
from paderbox.io.data_dir import database_jsons


class AnechoicMeetingDatabase(MMSMSGDatabase):
Expand All @@ -19,6 +18,7 @@ def __init__(
scaling_sampler,
snr_sampler,
source_filter=None,
composition_sampler=sample_utterance_composition,
):
super().__init__(source_database)

Expand All @@ -30,14 +30,16 @@ def __init__(
def source_filter(_):
return True
self.source_filter = source_filter
self.composition_sampler = composition_sampler

def get_mixture_dataset(self, name, rng):
input_ds = self.source_database.get_dataset(name).filter(self.source_filter)

ds = get_composition_dataset(
input_dataset=input_ds,
num_speakers=self.num_speakers,
rng=rng
rng=rng,
composition_sampler=self.composition_sampler,
)
ds = ds.map(self.scaling_sampler)
ds = ds.map(self.meeting_sampler(input_ds))
Expand All @@ -61,25 +63,30 @@ def __init__(
scaling_sampler,
snr_sampler,
rir_database,
source_filter=None
source_filter=None,
composition_sampler=sample_utterance_composition,
channel_slice=None,
):
super().__init__(
source_database,
num_speakers,
meeting_sampler,
scaling_sampler,
snr_sampler,
source_filter
source_filter,
composition_sampler,
)
self.rir_database = rir_database
self.channel_slice = channel_slice

def get_mixture_dataset(self, name, rng):
input_ds = self.source_database.get_dataset(name).filter(self.source_filter)

ds = get_composition_dataset(
input_dataset=input_ds,
num_speakers=self.num_speakers,
rng=rng
rng=rng,
composition_sampler=self.composition_sampler,
)
ds = ds.map(self.scaling_sampler)
ds = ds.map(RIRSampler(self.rir_database.get_dataset(name)))
Expand All @@ -89,6 +96,8 @@ def get_mixture_dataset(self, name, rng):

def load_example(self, example):
example = load_audio(example, keys.ORIGINAL_SOURCE, keys.RIR)
if self.channel_slice is not None:
example = slice_channel(example, channel_slice=self.channel_slice, squeeze=True)
example = reverberant_scenario_map_fn(example)
example = white_microphone_noise(example)
return example
4 changes: 4 additions & 0 deletions mms_msg/sampling/environment/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

from mms_msg.sampling.utils.rng import get_rng_example

__all__ = [
'sample_uniform_snr',
'UniformSNRSampler',
]

def sample_uniform_snr(example, *, min_snr: float = 20, max_snr: float = 30):
example['snr'] = float(
Expand Down
7 changes: 5 additions & 2 deletions mms_msg/sampling/environment/rir.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from dataclasses import dataclass
from pathlib import Path

import lazy_dataset
from lazy_dataset.database import JsonDatabase
from mms_msg.databases.reverberation.sms_wsj import SMSWSJRIRDatabase

__all__ = [
'sample_rirs',
'RIRSampler',
]


def sample_rirs(example: dict, *, rir_dataset: lazy_dataset.Dataset):
# Assume the examples have a running index
Expand Down
12 changes: 11 additions & 1 deletion mms_msg/sampling/pattern/classical/offset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
from mms_msg.sampling.utils.rng import get_rng_example
from mms_msg.sampling.utils.utils import update_num_samples

__all__ = [
'assign_offset',
'sample_offsets_sms_wsj',
'SMSWSJOffsetSampler',
'sample_offsets_constant',
'ConstantOffsetSampler',
'sample_partial_overlap',
'PartialOverlapOffsetSampler',
]


def assign_offset(example, offset):
assert keys.OFFSET not in example
Expand Down Expand Up @@ -56,7 +66,7 @@ def sample_partial_overlap(example, *, minimum_overlap, maximum_overlap):
overlap = rng.uniform(minimum_overlap, maximum_overlap)
num_samples = example[keys.NUM_SAMPLES][keys.ORIGINAL_SOURCE]
assert len(num_samples) == 2, (len(num_samples), num_samples)
overlap_samples = sum(num_samples)*overlap / (1 + overlap)
overlap_samples = sum(num_samples) * overlap / (1 + overlap)
offset = [0, int(max(num_samples[0] - overlap_samples, 0))]
assign_offset(example, offset)
return example
Expand Down
17 changes: 16 additions & 1 deletion mms_msg/sampling/pattern/meeting/meeting_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@

logger = logging.getLogger('meeting')

__all__ = [
'MeetingSampler',
'sample_meeting_from_full_overlap',
]

@dataclass
class _MeetingSampler:
input_dataset: Iterable
duration: int
overlap_sampler: OverlapSampler
scenario_sequence_sampler: callable = sample_balanced
sample_enrollment_phase: bool = True

def __post_init__(self):
if isinstance(self.scenario_sequence_sampler, str):
Expand Down Expand Up @@ -54,6 +59,12 @@ def __call__(self, example):

examples = []

# The enrollment phase makes sure that every speaker is active at least once
# and that the base examples appear in the generated meeting
# If the enrollment phase is disabled, we do not guarantee anything
if not self.sample_enrollment_phase:
base_examples = []

# Add base examples to be sure that each speaker is active at least once
while (
max([
Expand Down Expand Up @@ -178,9 +189,11 @@ def sample_meeting_from_full_overlap(
maximum_overlap=40000
),
scenario_sequence_sampler: callable = sample_balanced,
sample_enrollment_phase: bool = True
):
_MeetingSampler(
input_dataset, duration, overlap_sampler, scenario_sequence_sampler
input_dataset, duration, overlap_sampler, scenario_sequence_sampler,
sample_enrollment_phase,
)(example)


Expand Down Expand Up @@ -213,6 +226,7 @@ class MeetingSampler(pt.Configurable):
maximum_overlap=8 * 8000,
)
)
sample_enrollment_phase: bool = True

def __call__(self, dataset: Dataset):
"""
Expand All @@ -231,4 +245,5 @@ def __call__(self, dataset: Dataset):
duration=self.duration,
scenario_sequence_sampler=self.scenario_sequence_sampler,
overlap_sampler=self.overlap_sampler,
sample_enrollment_phase=self.sample_enrollment_phase,
)
Loading

0 comments on commit 9066466

Please sign in to comment.