-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from sibange/main
Addition of the MeetingGenerator for generation of meeting data based on a given source dataset. Also includes small fix for the TransitionModel.
- Loading branch information
Showing
10 changed files
with
1,807 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,6 @@ | ||
from . import action_handler | ||
from . import dataset_statistics_estimation | ||
from . import meeting_generator | ||
from . import sampler | ||
from . import transition_model | ||
from . import weighted_meeting_sampler |
427 changes: 427 additions & 0 deletions
427
mms_msg/sampling/pattern/meeting/state_based/action_handler.py
Large diffs are not rendered by default.
Oops, something went wrong.
181 changes: 181 additions & 0 deletions
181
mms_msg/sampling/pattern/meeting/state_based/dataset_statistics_estimation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import logging | ||
import sys | ||
import numpy as np | ||
|
||
from operator import add | ||
from typing import Dict, Optional, Union | ||
from lazy_dataset import Dataset, from_dict | ||
|
||
from mms_msg.sampling.utils.distribution_model import DistributionModel | ||
from mms_msg.sampling.pattern.meeting.state_based.transition_model import MarkovModel, MultiSpeakerTransitionModel | ||
|
||
|
||
logger = logging.getLogger('dataset_statistics_estimation') | ||
|
||
|
||
class MeetingStatisticsEstimatorMarkov: | ||
""" | ||
Class that estimates characteristics of existing dataset. For the speaker transitions a MarkovModel is created. | ||
The model distinguishes between 4 transition (TH: Turn hold, TS: turn switch, OV: Overlap, BC: Backchannel). | ||
These transitions are based on the following paper: | ||
Improving the Naturalness of Simulated Conversations for End-to-End Neural Diarization, | ||
https://arxiv.org/abs/2204.11232 | ||
Also, distributions for silence and overlap are computed, using a histogram-like distribution model. | ||
This class also supports processing of datasets which utilize Voice activity detection (VAD) data. | ||
When VAD data should be processed, the dataset must have the key 'aligned_source' for each speaker, | ||
which describes the interval in which the speaker is active. | ||
Each sample in the processed dataset must have the following keys: | ||
- speaker_id: List with the speaker_ids, which belong to the sources in the example | ||
- offset: Dictionary with a key that depends on the use of vad data: | ||
'aligned_source' when it is used, 'original_source' otherwise. | ||
The associated item must be a list of the offsets of the sources. | ||
- speaker_end: Dictionary with a key that depends on the use of vad data: | ||
'aligned_source' when it is used, 'original_source' otherwise. | ||
The associated item must be a list of the speaker endings of the sources. | ||
Properties: | ||
dataset: (read only) last processed dataset | ||
model: (read only) Markov model with state transition probabilities for the current dataset | ||
silence_distribution: (read only) distribution model of the length of the silence | ||
overlap_distribution: (read only) distribution model of the length of the overlap | ||
""" | ||
|
||
def __init__(self, dataset: Optional[Union[Dataset, Dict]] = None, use_vad: bool = False): | ||
""" | ||
Initialization of the Markov Statistics estimator. Optionally a dataset can be given as input. | ||
The Estimator is then used with the given dataset. | ||
Args: | ||
dataset: dataset: dataset that should be processed | ||
use_vad: (optional) Set to True, when VAD data is | ||
present in the dataset and that data should be recognized for sampling | ||
""" | ||
self._dataset = None | ||
self._model = None | ||
self._silence_distribution = None | ||
self._overlap_distribution = None | ||
|
||
if dataset is not None: | ||
self.fit(dataset, use_vad) | ||
|
||
def fit(self, dataset: [Dataset, Dict], use_vad: bool = False) -> None: | ||
""" | ||
Iterates over the given dataset and computes the MarkovModel and the distributions for silence and overlap. | ||
The dataset, model and according distributions are then stored in the class. | ||
Overrides the previously fitted dataset. | ||
Args: | ||
dataset: dataset that should be processed | ||
use_vad: (optional) Set to True, when VAD data is | ||
present in the dataset and that data should be recognized for sampling | ||
""" | ||
|
||
logger.info("Begin with processing the dataset") | ||
if len(dataset) == 0: | ||
raise AssertionError('Cannot compute statistics for an empty dataset.') | ||
|
||
# Make sure dataset has the type Dataset | ||
if type(dataset) is dict: | ||
dataset = from_dict(dataset) | ||
# Remove FilterExceptions | ||
self._dataset = dataset.catch() | ||
|
||
state_occurence_counter = np.array([0] * 4) | ||
state_transition_counter = np.zeros((4, 4)) | ||
|
||
silence_durations = [] | ||
overlap_durations = [] | ||
|
||
num_speakers = 0 | ||
|
||
for n, sample in enumerate(self._dataset): | ||
if n % 100 == 0: | ||
logger.info(f'Processed samples: {n}') | ||
|
||
# Depending on the usage of VAD data different keys are used | ||
if use_vad: | ||
offsets = sample['offset']['aligned_source'] | ||
speaker_ends = list(map(add, offsets, sample['num_samples']['aligned_source'])) | ||
else: | ||
offsets = sample['offset']['original_source'] | ||
speaker_ends = list(map(add, offsets, sample['num_samples']['original_source'])) | ||
speaker_ids = sample['speaker_id'] | ||
|
||
num_speakers = max(num_speakers, len(set(speaker_ids))) | ||
|
||
current_state = 0 | ||
last_foreground_end = speaker_ends[0] | ||
last_foreground_speaker = speaker_ids[0] | ||
|
||
for speaker_id, offset, speaker_end in list(zip(speaker_ids, offsets, speaker_ends))[1:]: | ||
state_occurence_counter[current_state] += 1 | ||
# Turn-hold | ||
if last_foreground_speaker == speaker_id: | ||
new_state = 0 | ||
silence_durations.append(offset - last_foreground_end) | ||
|
||
# Turn-switch | ||
elif last_foreground_end < offset: | ||
new_state = 1 | ||
silence_durations.append(offset - last_foreground_end) | ||
|
||
# Overlap | ||
elif last_foreground_end < speaker_end: | ||
new_state = 2 | ||
overlap_durations.append(last_foreground_end - offset) | ||
# Backchannel | ||
else: | ||
new_state = 3 | ||
|
||
# Adjust foreground information, in all states except backchannel | ||
if new_state in (0, 1, 2): | ||
last_foreground_end = speaker_end | ||
last_foreground_speaker = speaker_id | ||
|
||
state_transition_counter[current_state][new_state] += 1 | ||
|
||
current_state = new_state | ||
|
||
# Add at least one sample to the durations, otherwise the DistributionModel can not be fitted | ||
if len(silence_durations) == 0: | ||
silence_durations = [0] | ||
if len(overlap_durations) == 0: | ||
overlap_durations = [0] | ||
|
||
# Fit silence and overlap distributions | ||
|
||
self._silence_distribution = DistributionModel(silence_durations) | ||
self._overlap_distribution = DistributionModel(overlap_durations) | ||
|
||
# Fixes matrix when some states are never reached, otherwise the resulting matrix is not a stochastic matrix, | ||
# but this required for the markov model. (Fixed problem: division by 0, leads to infinite values) | ||
for i in range(len(state_occurence_counter)): | ||
if state_occurence_counter[i] == 0: | ||
state_transition_counter[i] = np.zeros((1, len(state_transition_counter[i]))) | ||
state_transition_counter[i][i] = 1 | ||
state_occurence_counter[i] = 1 | ||
|
||
# Fit MarkovModel and create fitting SpeakerTransitionModel | ||
self._model = MultiSpeakerTransitionModel(MarkovModel(state_transition_counter/state_occurence_counter[:, None], | ||
state_names=["TH", "TS", "OV", "BC"]), | ||
num_speakers=num_speakers) | ||
|
||
logger.info("Finished processing the dataset") | ||
|
||
@property | ||
def model(self) -> MultiSpeakerTransitionModel: | ||
return self._model | ||
|
||
@property | ||
def silence_distribution(self) -> DistributionModel: | ||
return self._silence_distribution | ||
|
||
@property | ||
def overlap_distribution(self) -> DistributionModel: | ||
return self._overlap_distribution | ||
|
||
@property | ||
def dataset(self) -> Dataset: | ||
return self._dataset |
142 changes: 142 additions & 0 deletions
142
mms_msg/sampling/pattern/meeting/state_based/meeting_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from mms_msg.sampling.pattern.meeting.state_based.dataset_statistics_estimation import MeetingStatisticsEstimatorMarkov | ||
from mms_msg.sampling.pattern.meeting.state_based.weighted_meeting_sampler import WeightedMeetingSampler | ||
from mms_msg.sampling.pattern.meeting.state_based.action_handler import DistributionActionHandler | ||
from mms_msg.sampling.pattern.meeting.state_based.sampler import DistributionSilenceSampler, DistributionOverlapSampler | ||
from lazy_dataset import Dataset | ||
from typing import Dict, Type, Optional, Any | ||
import mms_msg | ||
|
||
|
||
class MeetingGenerator: | ||
""" | ||
Class for generating meetings that aim to replicate the state transition probabilities of another dataset. | ||
The samples that are used to generate the artificial data is from an input dataset | ||
that can be independent of the dataset the state transitions are estimates from. | ||
This class uses a Markov based model for the different transitions of the speakers and tries to balance | ||
the activity of all speakers in each meeting. | ||
Properties: | ||
model: Transition Model | ||
silence: Silence Distribution | ||
overlap: Overlap Distribution | ||
""" | ||
def __init__(self, estimator_class: Type = MeetingStatisticsEstimatorMarkov): | ||
""" | ||
Initialize the Meeting Generator | ||
Args: | ||
estimator_class: Class that should be used to determine the statistics on the input dataset. | ||
The constructor must accept at least two parameters: dataset, use_vad | ||
Also must have the following properties: model, silence_distribution, overlap_distribution | ||
""" | ||
self.model = None | ||
self.silence = None | ||
self.overlap = None | ||
|
||
self.estimator_class = estimator_class | ||
|
||
def fit(self, source_dataset: [Dict, Dataset], use_vad: [bool] = False) -> None: | ||
""" Estimates the speaker transitions and the overlap/silence distributions of a given dataset. | ||
This data can then be used in the generate function. Must be called at least once before calling generate. | ||
It is possible that given VAD is considered when estimating the distributions. | ||
Args: | ||
source_dataset: Dataset for which the statistics are estimated. | ||
This data can then be used to generate new meetings. | ||
use_vad: Should VAD data be used. When set to True the fit function uses VAD data, | ||
when estimating the distributions of the source dataset. | ||
Returns: None | ||
""" | ||
db_sampler = self.estimator_class(dataset=source_dataset, use_vad=use_vad) | ||
|
||
self.model = db_sampler.model | ||
|
||
self.silence = DistributionSilenceSampler(distribution=db_sampler.silence_distribution) | ||
self.overlap = DistributionOverlapSampler(max_concurrent_spk=2, distribution=db_sampler.overlap_distribution) | ||
|
||
def generate(self, input_dataset: [Dict, Dataset], num_speakers: int = 2, duration: int = 960000, | ||
num_meetings: Optional[int] = None, use_vad: bool = False) -> Dataset: | ||
"""Generate a dataset of artificial meeting, with sources from the input_dataset. | ||
The distribution of the generated dataset follows the last fitted distribution, | ||
so the fit method must be called at least once before calling this method. | ||
Also, can utilize VAD data. | ||
Args: | ||
input_dataset: Dataset from which the sources are drawn, that are used for generation new meetings | ||
num_speakers: Number of speakers the meetings in the generated datasets should have. | ||
duration: Duration that the newly generated examples should roughly have, can be slightly exceeded. | ||
num_meetings: Number of meeting that should be generated. | ||
When not given the number of entries in the input dataset is used. | ||
use_vad: Should VAD data be used. When set to true VAD data is used, | ||
during generation of the new dataset and the output dataset hat also VAD information. | ||
Returns: Output dataset, with as many entries as the input dataset has samples | ||
or a lower amount when specified with num_meetings. | ||
""" | ||
|
||
if self.model is None or self.silence is None or self.overlap is None: | ||
raise ValueError('No dataset is fitted, you have to use the fit method first.') | ||
|
||
if self.model.num_speakers != num_speakers: | ||
try: | ||
self.model.change_num_speakers(num_speakers) | ||
except TypeError: | ||
print('Cannot change the number of speakers of the transition model.' | ||
'It is possible that the generation fails for the desired number of speakers.') | ||
|
||
ds = mms_msg.sampling.source_composition.get_composition_dataset(input_dataset, num_speakers=num_speakers) | ||
|
||
if num_meetings is not None: | ||
ds = ds[:num_meetings] | ||
|
||
return ds.map(WeightedMeetingSampler(transition_model=self.model, duration=duration, | ||
action_handler=DistributionActionHandler(overlap_sampler=self.overlap, | ||
silence_sampler=self.silence), | ||
use_vad=use_vad)({'*': input_dataset})) | ||
|
||
|
||
class MeetingGeneratorMap: | ||
"""Class for generating meetings that aim to replicate the state transition probabilities of another dataset. | ||
Can be mapped to an existing dataset created with get_composition_dataset() | ||
to generate a meeting for each example in the dataset. | ||
This class uses a Markov based model for the different transitions of the speakers and tries to balance | ||
the activity of all speakers in each meeting. | ||
Properties: | ||
meeting_sampler: Weighted meeting sampler initialized with the statistics from the source dataset | ||
which uses samples form the input dataset. | ||
""" | ||
|
||
def __init__(self, source_dataset: [Dict, Dataset], input_dataset: [Dict, Dataset], duration: int = 960000, | ||
use_vad: bool = False, estimator_class: Type = MeetingStatisticsEstimatorMarkov): | ||
""" | ||
Initialize the Meeting Generator Map | ||
Args: | ||
source_dataset: Dataset for which the statistics are estimated. | ||
This data can then be used to generate new meetings. | ||
input_dataset: Dataset from which the sources are drawn, that are used for generation new meetings | ||
duration: Duration that the newly generated examples should roughly have, can be slightly exceeded. | ||
use_vad: Should VAD data be used. When set to true VAD data is used, | ||
during generation of the new dataset and the output dataset hat also VAD information. | ||
estimator_class: Class that should be used to determine the statistics on the input dataset. | ||
The constructor must accept at least two parameters: dataset, use_vad | ||
Also must have the following properties: model, silence_distribution, overlap_distribution | ||
""" | ||
db_sampler = estimator_class(dataset=source_dataset, use_vad=use_vad) | ||
|
||
model = db_sampler.model | ||
|
||
silence = DistributionSilenceSampler(distribution=db_sampler.silence_distribution) | ||
overlap = DistributionOverlapSampler(max_concurrent_spk=2, distribution=db_sampler.overlap_distribution) | ||
|
||
self.meeting_sampler = WeightedMeetingSampler(transition_model=model, duration=duration, | ||
action_handler=DistributionActionHandler(overlap_sampler=overlap, | ||
silence_sampler=silence), | ||
use_vad=use_vad)({'*': input_dataset}) | ||
|
||
def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: | ||
return self.meeting_sampler(example) |
Oops, something went wrong.