Skip to content

Commit

Permalink
Merge pull request #14 from sibange/main
Browse files Browse the repository at this point in the history
Addition of the MeetingGenerator for generation of meeting data based on a given source dataset. Also includes small fix for the TransitionModel.
  • Loading branch information
thequilo authored Oct 1, 2024
2 parents 4dc1f2a + b7523b7 commit 7f7b8ea
Show file tree
Hide file tree
Showing 10 changed files with 1,807 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v2
Expand Down
12 changes: 9 additions & 3 deletions mms_msg/sampling/pattern/meeting/overlap_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,25 @@
import padertorch as pt


def _get_valid_overlap_region(examples, max_concurrent_spk, current_source):
def _get_valid_overlap_region(examples, max_concurrent_spk, current_source, use_vad=False):
"""
Compute maximum overlap that guarantees that no more than max_concurrent_spk are active at the same time.
Note: This function underestimates the maximal overlap to ensure regions sampled as silence or
repetitions of the same speaker will have no overlapping speech added later in the sampling .
Args:
examples:
max_concurrent_spk:
use_vad: (optional) When set to True the keys that represent the alignment of the vad data are
used to compute the valid overlap region.
Returns:
"""
speaker_end = np.asarray(examples['offset']['original_source']) + np.asarray(examples['num_samples']['observation'])
if use_vad:
speaker_end = np.asarray(examples['offset']['aligned_source']) + np.asarray(
examples['num_samples']['aligned_source'])
else:
speaker_end = np.asarray(examples['offset']['original_source']) + np.asarray(
examples['num_samples']['observation'])
speaker_id = examples['speaker_id']

return get_allowed_max_overlap(speaker_end, speaker_id, max_concurrent_spk, current_source['speaker_id'])
Expand Down
5 changes: 5 additions & 0 deletions mms_msg/sampling/pattern/meeting/state_based/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from . import action_handler
from . import dataset_statistics_estimation
from . import meeting_generator
from . import sampler
from . import transition_model
from . import weighted_meeting_sampler
427 changes: 427 additions & 0 deletions mms_msg/sampling/pattern/meeting/state_based/action_handler.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import logging
import sys
import numpy as np

from operator import add
from typing import Dict, Optional, Union
from lazy_dataset import Dataset, from_dict

from mms_msg.sampling.utils.distribution_model import DistributionModel
from mms_msg.sampling.pattern.meeting.state_based.transition_model import MarkovModel, MultiSpeakerTransitionModel


logger = logging.getLogger('dataset_statistics_estimation')


class MeetingStatisticsEstimatorMarkov:
"""
Class that estimates characteristics of existing dataset. For the speaker transitions a MarkovModel is created.
The model distinguishes between 4 transition (TH: Turn hold, TS: turn switch, OV: Overlap, BC: Backchannel).
These transitions are based on the following paper:
Improving the Naturalness of Simulated Conversations for End-to-End Neural Diarization,
https://arxiv.org/abs/2204.11232
Also, distributions for silence and overlap are computed, using a histogram-like distribution model.
This class also supports processing of datasets which utilize Voice activity detection (VAD) data.
When VAD data should be processed, the dataset must have the key 'aligned_source' for each speaker,
which describes the interval in which the speaker is active.
Each sample in the processed dataset must have the following keys:
- speaker_id: List with the speaker_ids, which belong to the sources in the example
- offset: Dictionary with a key that depends on the use of vad data:
'aligned_source' when it is used, 'original_source' otherwise.
The associated item must be a list of the offsets of the sources.
- speaker_end: Dictionary with a key that depends on the use of vad data:
'aligned_source' when it is used, 'original_source' otherwise.
The associated item must be a list of the speaker endings of the sources.
Properties:
dataset: (read only) last processed dataset
model: (read only) Markov model with state transition probabilities for the current dataset
silence_distribution: (read only) distribution model of the length of the silence
overlap_distribution: (read only) distribution model of the length of the overlap
"""

def __init__(self, dataset: Optional[Union[Dataset, Dict]] = None, use_vad: bool = False):
"""
Initialization of the Markov Statistics estimator. Optionally a dataset can be given as input.
The Estimator is then used with the given dataset.
Args:
dataset: dataset: dataset that should be processed
use_vad: (optional) Set to True, when VAD data is
present in the dataset and that data should be recognized for sampling
"""
self._dataset = None
self._model = None
self._silence_distribution = None
self._overlap_distribution = None

if dataset is not None:
self.fit(dataset, use_vad)

def fit(self, dataset: [Dataset, Dict], use_vad: bool = False) -> None:
"""
Iterates over the given dataset and computes the MarkovModel and the distributions for silence and overlap.
The dataset, model and according distributions are then stored in the class.
Overrides the previously fitted dataset.
Args:
dataset: dataset that should be processed
use_vad: (optional) Set to True, when VAD data is
present in the dataset and that data should be recognized for sampling
"""

logger.info("Begin with processing the dataset")
if len(dataset) == 0:
raise AssertionError('Cannot compute statistics for an empty dataset.')

# Make sure dataset has the type Dataset
if type(dataset) is dict:
dataset = from_dict(dataset)
# Remove FilterExceptions
self._dataset = dataset.catch()

state_occurence_counter = np.array([0] * 4)
state_transition_counter = np.zeros((4, 4))

silence_durations = []
overlap_durations = []

num_speakers = 0

for n, sample in enumerate(self._dataset):
if n % 100 == 0:
logger.info(f'Processed samples: {n}')

# Depending on the usage of VAD data different keys are used
if use_vad:
offsets = sample['offset']['aligned_source']
speaker_ends = list(map(add, offsets, sample['num_samples']['aligned_source']))
else:
offsets = sample['offset']['original_source']
speaker_ends = list(map(add, offsets, sample['num_samples']['original_source']))
speaker_ids = sample['speaker_id']

num_speakers = max(num_speakers, len(set(speaker_ids)))

current_state = 0
last_foreground_end = speaker_ends[0]
last_foreground_speaker = speaker_ids[0]

for speaker_id, offset, speaker_end in list(zip(speaker_ids, offsets, speaker_ends))[1:]:
state_occurence_counter[current_state] += 1
# Turn-hold
if last_foreground_speaker == speaker_id:
new_state = 0
silence_durations.append(offset - last_foreground_end)

# Turn-switch
elif last_foreground_end < offset:
new_state = 1
silence_durations.append(offset - last_foreground_end)

# Overlap
elif last_foreground_end < speaker_end:
new_state = 2
overlap_durations.append(last_foreground_end - offset)
# Backchannel
else:
new_state = 3

# Adjust foreground information, in all states except backchannel
if new_state in (0, 1, 2):
last_foreground_end = speaker_end
last_foreground_speaker = speaker_id

state_transition_counter[current_state][new_state] += 1

current_state = new_state

# Add at least one sample to the durations, otherwise the DistributionModel can not be fitted
if len(silence_durations) == 0:
silence_durations = [0]
if len(overlap_durations) == 0:
overlap_durations = [0]

# Fit silence and overlap distributions

self._silence_distribution = DistributionModel(silence_durations)
self._overlap_distribution = DistributionModel(overlap_durations)

# Fixes matrix when some states are never reached, otherwise the resulting matrix is not a stochastic matrix,
# but this required for the markov model. (Fixed problem: division by 0, leads to infinite values)
for i in range(len(state_occurence_counter)):
if state_occurence_counter[i] == 0:
state_transition_counter[i] = np.zeros((1, len(state_transition_counter[i])))
state_transition_counter[i][i] = 1
state_occurence_counter[i] = 1

# Fit MarkovModel and create fitting SpeakerTransitionModel
self._model = MultiSpeakerTransitionModel(MarkovModel(state_transition_counter/state_occurence_counter[:, None],
state_names=["TH", "TS", "OV", "BC"]),
num_speakers=num_speakers)

logger.info("Finished processing the dataset")

@property
def model(self) -> MultiSpeakerTransitionModel:
return self._model

@property
def silence_distribution(self) -> DistributionModel:
return self._silence_distribution

@property
def overlap_distribution(self) -> DistributionModel:
return self._overlap_distribution

@property
def dataset(self) -> Dataset:
return self._dataset
142 changes: 142 additions & 0 deletions mms_msg/sampling/pattern/meeting/state_based/meeting_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from mms_msg.sampling.pattern.meeting.state_based.dataset_statistics_estimation import MeetingStatisticsEstimatorMarkov
from mms_msg.sampling.pattern.meeting.state_based.weighted_meeting_sampler import WeightedMeetingSampler
from mms_msg.sampling.pattern.meeting.state_based.action_handler import DistributionActionHandler
from mms_msg.sampling.pattern.meeting.state_based.sampler import DistributionSilenceSampler, DistributionOverlapSampler
from lazy_dataset import Dataset
from typing import Dict, Type, Optional, Any
import mms_msg


class MeetingGenerator:
"""
Class for generating meetings that aim to replicate the state transition probabilities of another dataset.
The samples that are used to generate the artificial data is from an input dataset
that can be independent of the dataset the state transitions are estimates from.
This class uses a Markov based model for the different transitions of the speakers and tries to balance
the activity of all speakers in each meeting.
Properties:
model: Transition Model
silence: Silence Distribution
overlap: Overlap Distribution
"""
def __init__(self, estimator_class: Type = MeetingStatisticsEstimatorMarkov):
"""
Initialize the Meeting Generator
Args:
estimator_class: Class that should be used to determine the statistics on the input dataset.
The constructor must accept at least two parameters: dataset, use_vad
Also must have the following properties: model, silence_distribution, overlap_distribution
"""
self.model = None
self.silence = None
self.overlap = None

self.estimator_class = estimator_class

def fit(self, source_dataset: [Dict, Dataset], use_vad: [bool] = False) -> None:
""" Estimates the speaker transitions and the overlap/silence distributions of a given dataset.
This data can then be used in the generate function. Must be called at least once before calling generate.
It is possible that given VAD is considered when estimating the distributions.
Args:
source_dataset: Dataset for which the statistics are estimated.
This data can then be used to generate new meetings.
use_vad: Should VAD data be used. When set to True the fit function uses VAD data,
when estimating the distributions of the source dataset.
Returns: None
"""
db_sampler = self.estimator_class(dataset=source_dataset, use_vad=use_vad)

self.model = db_sampler.model

self.silence = DistributionSilenceSampler(distribution=db_sampler.silence_distribution)
self.overlap = DistributionOverlapSampler(max_concurrent_spk=2, distribution=db_sampler.overlap_distribution)

def generate(self, input_dataset: [Dict, Dataset], num_speakers: int = 2, duration: int = 960000,
num_meetings: Optional[int] = None, use_vad: bool = False) -> Dataset:
"""Generate a dataset of artificial meeting, with sources from the input_dataset.
The distribution of the generated dataset follows the last fitted distribution,
so the fit method must be called at least once before calling this method.
Also, can utilize VAD data.
Args:
input_dataset: Dataset from which the sources are drawn, that are used for generation new meetings
num_speakers: Number of speakers the meetings in the generated datasets should have.
duration: Duration that the newly generated examples should roughly have, can be slightly exceeded.
num_meetings: Number of meeting that should be generated.
When not given the number of entries in the input dataset is used.
use_vad: Should VAD data be used. When set to true VAD data is used,
during generation of the new dataset and the output dataset hat also VAD information.
Returns: Output dataset, with as many entries as the input dataset has samples
or a lower amount when specified with num_meetings.
"""

if self.model is None or self.silence is None or self.overlap is None:
raise ValueError('No dataset is fitted, you have to use the fit method first.')

if self.model.num_speakers != num_speakers:
try:
self.model.change_num_speakers(num_speakers)
except TypeError:
print('Cannot change the number of speakers of the transition model.'
'It is possible that the generation fails for the desired number of speakers.')

ds = mms_msg.sampling.source_composition.get_composition_dataset(input_dataset, num_speakers=num_speakers)

if num_meetings is not None:
ds = ds[:num_meetings]

return ds.map(WeightedMeetingSampler(transition_model=self.model, duration=duration,
action_handler=DistributionActionHandler(overlap_sampler=self.overlap,
silence_sampler=self.silence),
use_vad=use_vad)({'*': input_dataset}))


class MeetingGeneratorMap:
"""Class for generating meetings that aim to replicate the state transition probabilities of another dataset.
Can be mapped to an existing dataset created with get_composition_dataset()
to generate a meeting for each example in the dataset.
This class uses a Markov based model for the different transitions of the speakers and tries to balance
the activity of all speakers in each meeting.
Properties:
meeting_sampler: Weighted meeting sampler initialized with the statistics from the source dataset
which uses samples form the input dataset.
"""

def __init__(self, source_dataset: [Dict, Dataset], input_dataset: [Dict, Dataset], duration: int = 960000,
use_vad: bool = False, estimator_class: Type = MeetingStatisticsEstimatorMarkov):
"""
Initialize the Meeting Generator Map
Args:
source_dataset: Dataset for which the statistics are estimated.
This data can then be used to generate new meetings.
input_dataset: Dataset from which the sources are drawn, that are used for generation new meetings
duration: Duration that the newly generated examples should roughly have, can be slightly exceeded.
use_vad: Should VAD data be used. When set to true VAD data is used,
during generation of the new dataset and the output dataset hat also VAD information.
estimator_class: Class that should be used to determine the statistics on the input dataset.
The constructor must accept at least two parameters: dataset, use_vad
Also must have the following properties: model, silence_distribution, overlap_distribution
"""
db_sampler = estimator_class(dataset=source_dataset, use_vad=use_vad)

model = db_sampler.model

silence = DistributionSilenceSampler(distribution=db_sampler.silence_distribution)
overlap = DistributionOverlapSampler(max_concurrent_spk=2, distribution=db_sampler.overlap_distribution)

self.meeting_sampler = WeightedMeetingSampler(transition_model=model, duration=duration,
action_handler=DistributionActionHandler(overlap_sampler=overlap,
silence_sampler=silence),
use_vad=use_vad)({'*': input_dataset})

def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]:
return self.meeting_sampler(example)
Loading

0 comments on commit 7f7b8ea

Please sign in to comment.