diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 44a33b3..5e59747 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.7, 3.8, 3.9] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 diff --git a/mms_msg/sampling/pattern/meeting/overlap_sampler.py b/mms_msg/sampling/pattern/meeting/overlap_sampler.py index 1ecc3d5..7a4fe94 100644 --- a/mms_msg/sampling/pattern/meeting/overlap_sampler.py +++ b/mms_msg/sampling/pattern/meeting/overlap_sampler.py @@ -5,7 +5,7 @@ import padertorch as pt -def _get_valid_overlap_region(examples, max_concurrent_spk, current_source): +def _get_valid_overlap_region(examples, max_concurrent_spk, current_source, use_vad=False): """ Compute maximum overlap that guarantees that no more than max_concurrent_spk are active at the same time. Note: This function underestimates the maximal overlap to ensure regions sampled as silence or @@ -13,11 +13,17 @@ def _get_valid_overlap_region(examples, max_concurrent_spk, current_source): Args: examples: max_concurrent_spk: - + use_vad: (optional) When set to True the keys that represent the alignment of the vad data are + used to compute the valid overlap region. Returns: """ - speaker_end = np.asarray(examples['offset']['original_source']) + np.asarray(examples['num_samples']['observation']) + if use_vad: + speaker_end = np.asarray(examples['offset']['aligned_source']) + np.asarray( + examples['num_samples']['aligned_source']) + else: + speaker_end = np.asarray(examples['offset']['original_source']) + np.asarray( + examples['num_samples']['observation']) speaker_id = examples['speaker_id'] return get_allowed_max_overlap(speaker_end, speaker_id, max_concurrent_spk, current_source['speaker_id']) diff --git a/mms_msg/sampling/pattern/meeting/state_based/__init__.py b/mms_msg/sampling/pattern/meeting/state_based/__init__.py index a49a382..6e131fa 100644 --- a/mms_msg/sampling/pattern/meeting/state_based/__init__.py +++ b/mms_msg/sampling/pattern/meeting/state_based/__init__.py @@ -1 +1,6 @@ +from . import action_handler +from . import dataset_statistics_estimation +from . import meeting_generator +from . import sampler from . import transition_model +from . import weighted_meeting_sampler diff --git a/mms_msg/sampling/pattern/meeting/state_based/action_handler.py b/mms_msg/sampling/pattern/meeting/state_based/action_handler.py new file mode 100644 index 0000000..a06a566 --- /dev/null +++ b/mms_msg/sampling/pattern/meeting/state_based/action_handler.py @@ -0,0 +1,427 @@ +from abc import ABC, abstractmethod +import copy +import logging +import numpy as np +from typing import Optional, Set, Dict, Union, Any, List, Tuple +from lazy_dataset import Dataset + +from mms_msg.sampling.utils import sequence_sampling +from mms_msg.sampling.utils.rng import get_rng + +from mms_msg.sampling.pattern.meeting.state_based.sampler import SilenceSampler, BackchannelStartSampler, OverlapSampler + +logger = logging.getLogger('meeting_generation') + + +class ActionHandler(ABC): + """ + Class for the state processing during the state-based generation of a meeting example. + + Each ActionHandler has tags, which consist of a list of the events that could be processed by it. + """ + + @abstractmethod + def start(self, example_id: str, scenario_ids: List[str], base_examples: List[Dict], scenario_id_index: int, + env_state: Optional[Any] = None, **kwargs) -> Tuple[bool, Dict, int, Any]: + """ + Starts the sampling for a new meeting and return the first sampled source, together with an offset. + The offset should be a non-negative integer. + + Args: + example_id: ID of the current meeting for which a start example should be sampled + scenario_ids: IDs of the scenarios/speakers which should be in the generated meeting + base_examples: List of base examples, each scenario/speaker has one according base example + scenario_id_index: Index of the starting speaker/scenario with respect to the list of base examples + env_state: (optional) State of the environment, can be used to provide additional information + to the ActionHandler + + Returns: Tuple with four entries: + - Boolean that indicates the success of the start method + (Sampling of fitting source for the given example_id was successful) + - Sampled source + - offset + - potentially changed environment state + """ + raise NotImplementedError + + @abstractmethod + def next_scenario(self, action: Any, scenario_id_index: int, examples: List[Dict], base_examples: List[Dict], + env_state: Optional[Any] = None, **kwargs) -> Tuple[bool, Optional[Dict], int, Any]: + """ + Samples a matching source and offset for the given action and speaker. + + Args: + action: Action that determines the transition between the last speaker and the current speaker + scenario_id_index: Index of the current speaker/scenario with respect to the list of base examples + examples: A list of source that the newly generated meeting example contains to this point + base_examples: List of base examples, each scenario/speaker has one according base example + env_state: (optional) State of the environment, can be used to provide additional information + to the ActionHandler + + Returns: Tuple with four entries: + - Boolean that indicates the success of the Action + - Sampled source, when available + - offset + - potentially changed environment state + """ + raise NotImplementedError + + @abstractmethod + def set_datasets(self, normalized_datasets: Dict[str, Dataset], use_vad: bool = False) -> None: + """ + Sets the datasets from which the sources for the corresponding actions are sampled. + First a specify dataset is used, otherwise the default dataset ('*' as key). + + Args: + normalized_datasets: Dictionary which maps actions to normalized datasets (keys: action, item: dataset) + Datasets can be normalized by calling mms_msg.sampling.utils.cache_and_normalize_input_dataset + use_vad: (optional) Is VAD data in the given datasets and should these data be used during + the selection of samples and offset. Default Value: False + """ + raise NotImplementedError + + @property + @abstractmethod + def tags(self) -> Set[str]: + """ + Tags of this class. The tags contain information about which types of action + the ActionHandler can process. A '*' means that this handler can process any type of action. + + Returns: Set with all tags + """ + raise NotImplementedError + + +class DistributionActionHandler(ActionHandler): + """ + Action Handler that can handle four actions: (TH: Turn Hold, TS: Turn Switch, OV: Overlap, BC: Backchannel). + For each action a source with a fitting offset is sampled. + The offset is determined by first computing some values that depends on the action + (TH -> silence, TS -> silence, OV -> overlap, BC -> backchannel offset) + For all three of these values (silence, overlap, backchannel offset) a sampler with some distribution is used, + which depends on the according sampler. + When the offset is calculated from these intermediate values then VAD data is also taken into account, + when available. + + Important! When selecting a fitting source for the OV action, the overlap is computed after a source is selected. + Due to this the resulting overlap distribution depends heavily on the length of the samples + in the given input dataset. + This also leads to the effect, that when using an input dataset with similar mean sample length, + the resulting overlap distribution is skewed toward smaller values. + When now the processes of sampling and generation is done recursively multiple times with the same input dataset + (generate dataset with action handler, use this as source dataset, repeat,...), + the resulting overlap distribution gets smaller, with each iteration. Thus, this is not recommended. + + + Properties: + overlap_sampler: Used sampler for the overlap + silence_sampler: Used sampler for the silence + backchannel_start_sampler: Used sampler for offset of the backchannel source + border_margin: Used as minimal overlap during the OV action and minimal spacing + of the backchannel source from the borders of the foreground source. + use_vad: Is VAD data present in the given datasets and should this data be used for determining sources + and offsets. + + scenario_ids: (read only) Scenario ids of the used speakers in the currently processed example + example_id: (read only) ID of the currently processed example + last_foreground_speaker: (read only) Scenario ID of the speaker that is currently active in the foreground + base_examples: (read only) Base examples for the scenarios of all used speakers of the active example + grouped_datasets: (read only) Dictionary with the datasets that are used to sample for each action. + The datasets themselves are grouped by scenario_id. The actions are the keys, + while the datasets are the values in the dictionary. '*' marks default dataset. + """ + def __init__(self, overlap_sampler: OverlapSampler, silence_sampler: SilenceSampler, + backchannel_start_sampler: BackchannelStartSampler = BackchannelStartSampler(), + bc_border_margin: int = 100): + """ + Initialization of the action handler with samplers. + Important: After the initialization you have to provide datasets via the set_dataset function + + Args: + overlap_sampler: Used sampler for the overlap + silence_sampler: Used sampler for the silence + backchannel_start_sampler: Used sampler for offset off the backchannel source + bc_border_margin: (optional) Used as minimal spacing of the backchannel source + from the borders of the foreground source. + """ + + self.overlap_sampler = overlap_sampler + self.silence_sampler = silence_sampler + self.backchannel_start_sampler = backchannel_start_sampler + self.bc_border_margin = bc_border_margin + self.use_vad = None + + self._scenario_ids = None + self._example_id = None + + self._last_foreground_scenario = None + + self._base_examples = None + self._grouped_datasets = None + + def set_datasets(self, normalized_datasets: Dict[str, Dataset], use_vad: bool = False) -> None: + self._grouped_datasets = {key: dataset.groupby(lambda x: x['scenario']) for + (key, dataset) in normalized_datasets.items()} + self.use_vad = use_vad + + if not ('*' in self._grouped_datasets.keys() + or self.tags.issubset(self._grouped_datasets.keys())): + raise AssertionError(("The Tags of the normalized datasets and the ActionHandler do not fit, some actions " + "have no fitting dataset and thus no fitting sources can be sampled" + " for these actions. You can set a dataset as default that is used" + " when no specific dataset is available by using '*' as key."), + "Missing Tags: ", self.tags.difference(self._grouped_datasets.keys())) + + def start(self, example_id: str, scenario_ids: List[str], base_examples: List[Dict], scenario_id_index: int, + env_state: Optional[Any] = None, **kwargs) -> Tuple[bool, Dict[str, Any], int, Any]: + self._scenario_ids = scenario_ids + self._example_id = example_id + + # Adding the first speaker + current_source = copy.deepcopy(base_examples[scenario_id_index]) + offset = 0 + self._last_foreground_scenario = scenario_ids[scenario_id_index] + + return True, current_source, offset, None + + def next_scenario(self, action: Any, scenario_id_index: int, examples: List[Dict], base_examples: List[Dict], + env_state: Optional[Any] = None, **kwargs) -> Tuple[bool, Optional[Dict[str, Any]], int, Any]: + + assert self._grouped_datasets is not None, \ + "set_datasets has to be called, before using the ActionHandler" + + current_scenario = self._scenario_ids[scenario_id_index] + segment_idx = len(examples) + + offset = 0 + current_source = None + + if self.use_vad: + source_key = 'aligned_source' + source_key2 = 'aligned_source' + else: + source_key = 'original_source' + source_key2 = 'observation' + + # Select the fitting dataset for the current action, when for this action no dataset is available, + # the default dataset is chosen + if action in self._grouped_datasets.keys(): + current_dataset = self._grouped_datasets[action] + else: + current_dataset = self._grouped_datasets["*"] + + # Determine fitting source and offset + try: + if action in ("TH", "TS"): + current_source, offset = self._action_th_ts(current_scenario, current_dataset, examples, segment_idx, + source_key) + elif action == "OV": + current_source, offset = self._action_ov(current_scenario, current_dataset, examples, segment_idx, + source_key) + elif action == "BC": + current_source, offset = self._action_bc(current_scenario, current_dataset, examples, segment_idx, + source_key, source_key2) + except ValueError: + # Sampling of Offset failed + return False, None, -1, None + + if not (current_source is None): + if self.use_vad: + offset = offset - current_source['offset']['aligned_source'] + + # Prevent negative offsets + offset = max(0, offset) + + return True, current_source, offset, None + else: + return False, None, -1, None + + def _sample_source(self, current_scenario: str, current_dataset: Union[Dict, Dataset], examples: List[Dict], + segment_idx: int) -> Dict[str, Any]: + """ + Internal function that samples a source from the current scenario from the current dataset using + the random round-robin method. To achieve consistency for multiple executions all previously + sampled examples and the index of the current examples are used as seed for the random number generator. + + Args: + current_scenario: Scenario from which the source should be sampled + current_dataset: Dataset from which the source should be sampled + examples: List of previously sampled sources + segment_idx: Index of the currently sampled source (used as seed for rng) + + Returns: Dictionary which represents the sampled source + """ + + current_source_id = sequence_sampling.sample_random_round_robin( + current_dataset[current_scenario].keys(), + sequence=[x['example_id'] for x in examples if x['scenario'] == current_scenario], + rng=get_rng(self._example_id, 'example', segment_idx), + ) + current_source = copy.deepcopy(current_dataset[current_scenario][current_source_id]) + return current_source + + def _action_th_ts(self, current_scenario: str, current_dataset: Union[Dict, Dataset], examples: List[Dict], + segment_idx: int, source_key: str) -> Tuple[Dict[str, Any], int]: + """ + Internal function for handling the Turn hold (TH) and Turn switch (TS) actions. + Samples a fitting source from the given dataset, computes the offset of the sampled source and updates + the current active foreground speaker. + + Args: + current_scenario: Scenario from which the source should be sampled + current_dataset: Dataset from which the source should be sampled + examples: List of previously sampled sources (used as seed for rng) + segment_idx: Index of the currently sampled source (used as seed for rng) + source_key: Key for accessing the values in the dictionary which represent single sources. + Depends on the usage of VAD data (No VAD: original_source, VAD: aligned_source) + + Returns: Tuple of the sampled source and the corresponding offset + """ + + current_source = self._sample_source(current_scenario, current_dataset, examples, segment_idx) + silence = self.silence_sampler(get_rng(self._example_id, segment_idx, 'silence')) + self._last_foreground_scenario = current_scenario + offset = max([x['speaker_end'][source_key] for x in examples]) + silence + + return current_source, offset + + def _action_ov(self, current_scenario: str, current_dataset: Union[Dict, Dataset], examples: List[Dict], + segment_idx: int, source_key: str) -> Tuple[Dict[str, Any], int]: + """ + Internal function for handling the Overlap (OV) action. + Samples a fitting source from the given dataset, computes the offset of the sampled source and updates + the current active foreground speaker. + + Args: + current_scenario: Scenario from which the source should be sampled + current_dataset: Dataset from which the source should be sampled + examples: List of previously sampled sources (used as seed for rng) + segment_idx: Index of the currently sampled source + source_key: Key for accessing the values in the dictionary which represent single sources. + Depends on the usage of VAD data (No VAD: original_source, VAD: aligned_source) + + Returns: Tuple of the sampled source and the corresponding offset + """ + + current_source = self._sample_source(current_scenario, current_dataset, examples, segment_idx) + overlap = self.overlap_sampler(examples, current_source, + rng=get_rng(self._example_id, segment_idx, 'overlap'), + use_vad=self.use_vad) + self._last_foreground_scenario = current_scenario + + offset = max([x['speaker_end'][source_key] for x in examples]) - overlap + + return current_source, offset + + def _action_bc(self, current_scenario: str, current_dataset: Union[Dict, Dataset], examples: List[Dict], + segment_idx: int, source_key: str, source_key2: str) -> Tuple[Optional[Dict[str, Any]], int]: + """ + Internal function for handling the Backchannel action (BC). + Samples a fitting source from the given dataset and computes the offset of the sampled source. + + Args: + current_scenario: Scenario from which the source should be sampled + current_dataset: Dataset from which the source should be sampled + examples: List of previously sampled sources (used as seed for rng) + segment_idx: Index of the currently sampled source + source_key: Key for accessing the values in the dictionary which represent single sources. + Depends on the usage of VAD data (No VAD: original_source, VAD: aligned_source) + source_key2: Second Key for accessing the values in the dictionary which represent single sources. + Depends on the usage of VAD data (No VAD: observation, VAD: aligned_source) + + Returns: Tuple of the sampled source and the corresponding offset + """ + + last_foreground_example = list(filter(lambda x: x['scenario'] == self._last_foreground_scenario, examples))[-1] + + backchannel_speaker_ends = [ + x['speaker_end'][source_key] + for x in examples + if x['scenario'] != self._last_foreground_scenario + ] + + foreground_length = last_foreground_example['num_samples'][source_key] + free_backchannel_length = (last_foreground_example['speaker_end'][source_key] + - max(backchannel_speaker_ends + [0])) + + max_allowed_length = min(foreground_length, free_backchannel_length) - 2 * self.bc_border_margin + + # Rejection sampling of the backchannel source + current_source = rejection_sampling(get_rng(self._example_id, 'example', segment_idx), current_scenario, + current_dataset, examples, max_length=max_allowed_length) + current_source = copy.deepcopy(current_source) + + if current_source is not None: + min_possible_start_offset = max(backchannel_speaker_ends + [0] + + [last_foreground_example['offset'][source_key]]) \ + + self.bc_border_margin + max_possible_start_offset = last_foreground_example['speaker_end'][source_key] - \ + current_source['num_samples'][source_key2] - self.bc_border_margin + + offset = self.backchannel_start_sampler(min_possible_start_offset, max_possible_start_offset, + get_rng(self._example_id, segment_idx, 'start_offset')) + return current_source, offset + else: + logger.warning("No fitting backchannel source found.") + return None, 0 + + @property + def tags(self) -> Set[str]: + return {"TS", "TH", "OV", "BC"} + + @property + def scenario_ids(self) -> List[str]: + return self._scenario_ids + + @property + def example_id(self) -> str: + return self._example_id + + @property + def last_foreground_speaker(self) -> str: + return self._last_foreground_scenario + + @property + def grouped_datasets(self) -> Dict[str, Union[Dict, Dataset]]: + return self._grouped_datasets + + +def rejection_sampling(rng: np.random.Generator, current_scenario: str, current_dataset: Union[Dict, Dataset], + examples: List[Dict], max_tries: int = 100, min_length: int = 0, + max_length: Optional[int] = None) -> Optional[Dict[str, Any]]: + """ + Uses rejection sampling to get a source that has more than min and less than max samples. + When no fitting sample can be found, None is returned. + + Args: + rng: random number generator that is used for the sampling + current_scenario: scenario from which the source should be sampled + current_dataset: dataset from which the source should be sampled + examples: list of sources that were sampled until now + max_tries: maximum amount of tries, when after these amount of tries no fitting source is found, + None is returned + min_length: minimal amount of samples that the source should have + max_length: maximal amount of samples that the source should have + + Returns: source: when its fitting, None: when no fitting source is found + """ + + sequence = [x['example_id'] for x in examples if x['scenario'] == current_scenario] + rejected_sources = [] + + for _ in range(max_tries): + current_source_id = sequence_sampling.sample_random_round_robin( + current_dataset[current_scenario].keys(), + sequence=sequence + rejected_sources, + rng=rng + ) + current_source = copy.deepcopy(current_dataset[current_scenario][current_source_id]) + + if current_source['num_samples']['observation'] >= min_length and ( + max_length is None or current_source['num_samples']['observation'] <= max_length): + return current_source + else: + rejected_sources.append(current_source['example_id']) + + # When no fitting source is found None is returned + return None diff --git a/mms_msg/sampling/pattern/meeting/state_based/dataset_statistics_estimation.py b/mms_msg/sampling/pattern/meeting/state_based/dataset_statistics_estimation.py new file mode 100644 index 0000000..3aa13b6 --- /dev/null +++ b/mms_msg/sampling/pattern/meeting/state_based/dataset_statistics_estimation.py @@ -0,0 +1,181 @@ +import logging +import sys +import numpy as np + +from operator import add +from typing import Dict, Optional, Union +from lazy_dataset import Dataset, from_dict + +from mms_msg.sampling.utils.distribution_model import DistributionModel +from mms_msg.sampling.pattern.meeting.state_based.transition_model import MarkovModel, MultiSpeakerTransitionModel + + +logger = logging.getLogger('dataset_statistics_estimation') + + +class MeetingStatisticsEstimatorMarkov: + """ + Class that estimates characteristics of existing dataset. For the speaker transitions a MarkovModel is created. + The model distinguishes between 4 transition (TH: Turn hold, TS: turn switch, OV: Overlap, BC: Backchannel). + These transitions are based on the following paper: + Improving the Naturalness of Simulated Conversations for End-to-End Neural Diarization, + https://arxiv.org/abs/2204.11232 + + Also, distributions for silence and overlap are computed, using a histogram-like distribution model. + This class also supports processing of datasets which utilize Voice activity detection (VAD) data. + When VAD data should be processed, the dataset must have the key 'aligned_source' for each speaker, + which describes the interval in which the speaker is active. + + Each sample in the processed dataset must have the following keys: + - speaker_id: List with the speaker_ids, which belong to the sources in the example + - offset: Dictionary with a key that depends on the use of vad data: + 'aligned_source' when it is used, 'original_source' otherwise. + The associated item must be a list of the offsets of the sources. + - speaker_end: Dictionary with a key that depends on the use of vad data: + 'aligned_source' when it is used, 'original_source' otherwise. + The associated item must be a list of the speaker endings of the sources. + + Properties: + dataset: (read only) last processed dataset + model: (read only) Markov model with state transition probabilities for the current dataset + silence_distribution: (read only) distribution model of the length of the silence + overlap_distribution: (read only) distribution model of the length of the overlap + """ + + def __init__(self, dataset: Optional[Union[Dataset, Dict]] = None, use_vad: bool = False): + """ + Initialization of the Markov Statistics estimator. Optionally a dataset can be given as input. + The Estimator is then used with the given dataset. + + Args: + dataset: dataset: dataset that should be processed + use_vad: (optional) Set to True, when VAD data is + present in the dataset and that data should be recognized for sampling + """ + self._dataset = None + self._model = None + self._silence_distribution = None + self._overlap_distribution = None + + if dataset is not None: + self.fit(dataset, use_vad) + + def fit(self, dataset: [Dataset, Dict], use_vad: bool = False) -> None: + """ + Iterates over the given dataset and computes the MarkovModel and the distributions for silence and overlap. + The dataset, model and according distributions are then stored in the class. + Overrides the previously fitted dataset. + + Args: + dataset: dataset that should be processed + use_vad: (optional) Set to True, when VAD data is + present in the dataset and that data should be recognized for sampling + """ + + logger.info("Begin with processing the dataset") + if len(dataset) == 0: + raise AssertionError('Cannot compute statistics for an empty dataset.') + + # Make sure dataset has the type Dataset + if type(dataset) is dict: + dataset = from_dict(dataset) + # Remove FilterExceptions + self._dataset = dataset.catch() + + state_occurence_counter = np.array([0] * 4) + state_transition_counter = np.zeros((4, 4)) + + silence_durations = [] + overlap_durations = [] + + num_speakers = 0 + + for n, sample in enumerate(self._dataset): + if n % 100 == 0: + logger.info(f'Processed samples: {n}') + + # Depending on the usage of VAD data different keys are used + if use_vad: + offsets = sample['offset']['aligned_source'] + speaker_ends = list(map(add, offsets, sample['num_samples']['aligned_source'])) + else: + offsets = sample['offset']['original_source'] + speaker_ends = list(map(add, offsets, sample['num_samples']['original_source'])) + speaker_ids = sample['speaker_id'] + + num_speakers = max(num_speakers, len(set(speaker_ids))) + + current_state = 0 + last_foreground_end = speaker_ends[0] + last_foreground_speaker = speaker_ids[0] + + for speaker_id, offset, speaker_end in list(zip(speaker_ids, offsets, speaker_ends))[1:]: + state_occurence_counter[current_state] += 1 + # Turn-hold + if last_foreground_speaker == speaker_id: + new_state = 0 + silence_durations.append(offset - last_foreground_end) + + # Turn-switch + elif last_foreground_end < offset: + new_state = 1 + silence_durations.append(offset - last_foreground_end) + + # Overlap + elif last_foreground_end < speaker_end: + new_state = 2 + overlap_durations.append(last_foreground_end - offset) + # Backchannel + else: + new_state = 3 + + # Adjust foreground information, in all states except backchannel + if new_state in (0, 1, 2): + last_foreground_end = speaker_end + last_foreground_speaker = speaker_id + + state_transition_counter[current_state][new_state] += 1 + + current_state = new_state + + # Add at least one sample to the durations, otherwise the DistributionModel can not be fitted + if len(silence_durations) == 0: + silence_durations = [0] + if len(overlap_durations) == 0: + overlap_durations = [0] + + # Fit silence and overlap distributions + + self._silence_distribution = DistributionModel(silence_durations) + self._overlap_distribution = DistributionModel(overlap_durations) + + # Fixes matrix when some states are never reached, otherwise the resulting matrix is not a stochastic matrix, + # but this required for the markov model. (Fixed problem: division by 0, leads to infinite values) + for i in range(len(state_occurence_counter)): + if state_occurence_counter[i] == 0: + state_transition_counter[i] = np.zeros((1, len(state_transition_counter[i]))) + state_transition_counter[i][i] = 1 + state_occurence_counter[i] = 1 + + # Fit MarkovModel and create fitting SpeakerTransitionModel + self._model = MultiSpeakerTransitionModel(MarkovModel(state_transition_counter/state_occurence_counter[:, None], + state_names=["TH", "TS", "OV", "BC"]), + num_speakers=num_speakers) + + logger.info("Finished processing the dataset") + + @property + def model(self) -> MultiSpeakerTransitionModel: + return self._model + + @property + def silence_distribution(self) -> DistributionModel: + return self._silence_distribution + + @property + def overlap_distribution(self) -> DistributionModel: + return self._overlap_distribution + + @property + def dataset(self) -> Dataset: + return self._dataset diff --git a/mms_msg/sampling/pattern/meeting/state_based/meeting_generator.py b/mms_msg/sampling/pattern/meeting/state_based/meeting_generator.py new file mode 100644 index 0000000..ef0224f --- /dev/null +++ b/mms_msg/sampling/pattern/meeting/state_based/meeting_generator.py @@ -0,0 +1,142 @@ +from mms_msg.sampling.pattern.meeting.state_based.dataset_statistics_estimation import MeetingStatisticsEstimatorMarkov +from mms_msg.sampling.pattern.meeting.state_based.weighted_meeting_sampler import WeightedMeetingSampler +from mms_msg.sampling.pattern.meeting.state_based.action_handler import DistributionActionHandler +from mms_msg.sampling.pattern.meeting.state_based.sampler import DistributionSilenceSampler, DistributionOverlapSampler +from lazy_dataset import Dataset +from typing import Dict, Type, Optional, Any +import mms_msg + + +class MeetingGenerator: + """ + Class for generating meetings that aim to replicate the state transition probabilities of another dataset. + The samples that are used to generate the artificial data is from an input dataset + that can be independent of the dataset the state transitions are estimates from. + + This class uses a Markov based model for the different transitions of the speakers and tries to balance + the activity of all speakers in each meeting. + + Properties: + model: Transition Model + silence: Silence Distribution + overlap: Overlap Distribution + """ + def __init__(self, estimator_class: Type = MeetingStatisticsEstimatorMarkov): + """ + Initialize the Meeting Generator + + Args: + estimator_class: Class that should be used to determine the statistics on the input dataset. + The constructor must accept at least two parameters: dataset, use_vad + Also must have the following properties: model, silence_distribution, overlap_distribution + """ + self.model = None + self.silence = None + self.overlap = None + + self.estimator_class = estimator_class + + def fit(self, source_dataset: [Dict, Dataset], use_vad: [bool] = False) -> None: + """ Estimates the speaker transitions and the overlap/silence distributions of a given dataset. + This data can then be used in the generate function. Must be called at least once before calling generate. + It is possible that given VAD is considered when estimating the distributions. + + Args: + source_dataset: Dataset for which the statistics are estimated. + This data can then be used to generate new meetings. + use_vad: Should VAD data be used. When set to True the fit function uses VAD data, + when estimating the distributions of the source dataset. + + Returns: None + """ + db_sampler = self.estimator_class(dataset=source_dataset, use_vad=use_vad) + + self.model = db_sampler.model + + self.silence = DistributionSilenceSampler(distribution=db_sampler.silence_distribution) + self.overlap = DistributionOverlapSampler(max_concurrent_spk=2, distribution=db_sampler.overlap_distribution) + + def generate(self, input_dataset: [Dict, Dataset], num_speakers: int = 2, duration: int = 960000, + num_meetings: Optional[int] = None, use_vad: bool = False) -> Dataset: + """Generate a dataset of artificial meeting, with sources from the input_dataset. + The distribution of the generated dataset follows the last fitted distribution, + so the fit method must be called at least once before calling this method. + Also, can utilize VAD data. + + Args: + input_dataset: Dataset from which the sources are drawn, that are used for generation new meetings + num_speakers: Number of speakers the meetings in the generated datasets should have. + duration: Duration that the newly generated examples should roughly have, can be slightly exceeded. + num_meetings: Number of meeting that should be generated. + When not given the number of entries in the input dataset is used. + use_vad: Should VAD data be used. When set to true VAD data is used, + during generation of the new dataset and the output dataset hat also VAD information. + + Returns: Output dataset, with as many entries as the input dataset has samples + or a lower amount when specified with num_meetings. + """ + + if self.model is None or self.silence is None or self.overlap is None: + raise ValueError('No dataset is fitted, you have to use the fit method first.') + + if self.model.num_speakers != num_speakers: + try: + self.model.change_num_speakers(num_speakers) + except TypeError: + print('Cannot change the number of speakers of the transition model.' + 'It is possible that the generation fails for the desired number of speakers.') + + ds = mms_msg.sampling.source_composition.get_composition_dataset(input_dataset, num_speakers=num_speakers) + + if num_meetings is not None: + ds = ds[:num_meetings] + + return ds.map(WeightedMeetingSampler(transition_model=self.model, duration=duration, + action_handler=DistributionActionHandler(overlap_sampler=self.overlap, + silence_sampler=self.silence), + use_vad=use_vad)({'*': input_dataset})) + + +class MeetingGeneratorMap: + """Class for generating meetings that aim to replicate the state transition probabilities of another dataset. + Can be mapped to an existing dataset created with get_composition_dataset() + to generate a meeting for each example in the dataset. + + This class uses a Markov based model for the different transitions of the speakers and tries to balance + the activity of all speakers in each meeting. + + Properties: + meeting_sampler: Weighted meeting sampler initialized with the statistics from the source dataset + which uses samples form the input dataset. + """ + + def __init__(self, source_dataset: [Dict, Dataset], input_dataset: [Dict, Dataset], duration: int = 960000, + use_vad: bool = False, estimator_class: Type = MeetingStatisticsEstimatorMarkov): + """ + Initialize the Meeting Generator Map + + Args: + source_dataset: Dataset for which the statistics are estimated. + This data can then be used to generate new meetings. + input_dataset: Dataset from which the sources are drawn, that are used for generation new meetings + duration: Duration that the newly generated examples should roughly have, can be slightly exceeded. + use_vad: Should VAD data be used. When set to true VAD data is used, + during generation of the new dataset and the output dataset hat also VAD information. + estimator_class: Class that should be used to determine the statistics on the input dataset. + The constructor must accept at least two parameters: dataset, use_vad + Also must have the following properties: model, silence_distribution, overlap_distribution + """ + db_sampler = estimator_class(dataset=source_dataset, use_vad=use_vad) + + model = db_sampler.model + + silence = DistributionSilenceSampler(distribution=db_sampler.silence_distribution) + overlap = DistributionOverlapSampler(max_concurrent_spk=2, distribution=db_sampler.overlap_distribution) + + self.meeting_sampler = WeightedMeetingSampler(transition_model=model, duration=duration, + action_handler=DistributionActionHandler(overlap_sampler=overlap, + silence_sampler=silence), + use_vad=use_vad)({'*': input_dataset}) + + def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: + return self.meeting_sampler(example) diff --git a/mms_msg/sampling/pattern/meeting/state_based/sampler.py b/mms_msg/sampling/pattern/meeting/state_based/sampler.py new file mode 100644 index 0000000..fa62cea --- /dev/null +++ b/mms_msg/sampling/pattern/meeting/state_based/sampler.py @@ -0,0 +1,273 @@ +import numpy as np +import padertorch as pt + +from dataclasses import dataclass +from typing import Optional, Dict, List, Any +from abc import ABC, abstractmethod + +from mms_msg.sampling.utils.distribution_model import DistributionModel +from mms_msg.sampling.utils import collate_fn +from mms_msg.sampling.pattern.meeting.overlap_sampler import _get_valid_overlap_region + + +class SilenceSampler(ABC): + """ + Abstract class to allow sampling of an integer value according to a certain distribution. + Optionally the sampling can be restricted by a minimum and a maximum bound. This bound is guaranteed, + when the sampler is called with (). Alternatively a value can be directly sampled with sample_silence, + but then the bounds are not enforced. When the sampling fails a ValueError should be raised. + + Properties: + hard_minimum_value: (optional) minimum value that should be sampled + hard_maximum_value: (optional) maximum value that should be sampled + """ + hard_minimum_value: int = 0 + hard_maximum_value: int = 1000000 + + def __call__(self, rng: np.random.random = np.random.default_rng(), minimum_value: Optional[int] = None, + maximum_value: Optional[int] = None) -> int: + """ + Samples an integer as silence value according to both, the class bounds and those given as parameters. + + Args: + rng: (optional) The numpy rng that should be used, the rng should generate a number in the interval [0,1) + When not set a new uniform rng is used. + minimum_value: (optional) minimum_value that should be sampled, + is overwritten by the hard limits of the class. + maximum_value: (optional) minimum_value that should be sampled, + is overwritten by the hard limits of the class. + + Returns: Integer that is guaranteed to be in the both given bounds, the class bounds and the parameter bounds. + """ + if minimum_value is None: + minimum_value = self.hard_minimum_value + else: + minimum_value = max(minimum_value, self.hard_minimum_value) + + if maximum_value is None: + maximum_value = self.hard_maximum_value + else: + maximum_value = max(maximum_value, self.hard_maximum_value) + + if minimum_value >= maximum_value: + raise ValueError('The maximum value must be greater than the minimum value. You have the change either' + ' the hard bounds of the class or the parameter bounds used when calling the sampler.') + + return self.sample_silence(rng, minimum_value, maximum_value) + + @abstractmethod + def sample_silence(self, rng: np.random.random = np.random.default_rng(), minimum_value: Optional[int] = None, + maximum_value: Optional[int] = None) -> int: + """ + Samples an integer according to the given bounds. + + Args: + rng: (optional) The numpy rng that should be used, the rng should generate a number in the interval [0,1) + When not set a new uniform rng is used. + minimum_value: (optional) minimum_value that should be sampled. + maximum_value: (optional) minimum_value that should be sampled. + + Returns: Integer that is guaranteed to be in the given bounds. + """ + raise NotImplementedError() + + +@dataclass +class UniformSilenceSampler(SilenceSampler): + """ + Generate uniform integer samples between a given min and max value. + Optionally the sampling can be restricted by a minimum and a maximum bound. This bound is guaranteed, + when the sampler is called with (). Alternatively a value can be directly sampled with sample_silence, + but then the bounds are not enforced. When the sampling fails a ValueError is raised. + + Properties: + hard_minimum_value: (optional) minimum value that should be sampled + hard_maximum_value: (optional) maximum value that should be sampled + """ + + def sample_silence(self, rng: np.random.random = np.random.default_rng(), minimum_value: Optional[int] = None, + maximum_value: Optional[int] = None) -> int: + return rng.integers(minimum_value, maximum_value) + + +@dataclass +class DistributionSilenceSampler(SilenceSampler): + """ + Generates samples using a given distribution. + Optionally the sampling can be restricted by a minimum and a maximum bound. This bound is guaranteed, + when the sampler is called with (). Alternatively a value can be directly sampled with sample_silence, + but then the bounds are not enforced. When the sampling fails a ValueError is raised. + Properties: + distribution: Distribution form which the silence values are sampled. + minimum_value: Minimum value that can be sampled, when not given then it depends on the given distribution + maximum_value: Maximum value that can be sampled, when not given then it depends on the given distribution + """ + + distribution: DistributionModel + minimum_value: Optional[int] = None + maximum_value: Optional[int] = None + + def __post_init__(self): + if self.minimum_value is None: + self.minimum_value = int(self.distribution.min_value) + if self.maximum_value is None: + self.maximum_value = int(self.distribution.max_value) + + def sample_silence(self, rng: np.random.random = np.random.default_rng(), minimum_value: Optional[int] = None, + maximum_value: Optional[int] = None) -> int: + return self.distribution.sample_value(rng, minimum_value=minimum_value, maximum_value=maximum_value) + + +class OverlapSampler(ABC): + """ Abstract class that allows to construct an Overlap sampler, which is used to sample overlap values for the + generation of a meeting. It is guaranteed that the given overlap values are valid and only a maximum number + of speakers is active simultaneously, when the sampler is called with (). + The sampling process of the values must be implemented in a subclass. + When the sampling fails a ValueError should be returned. + + Properties: + max_concurrent_spk: Maximum number of concurrent active speakers. + hard_minimum_overlap: Hard minimum value for the overlap + hard_maximum_overlap: Hard maximum value for the overlap + """ + max_concurrent_spk: int + hard_minimum_overlap: int = 0 + hard_maximum_overlap: int = 1000000 + + def __call__(self, examples: List[Dict], current_source: Dict[str, Any], + rng: np.random.random = np.random.default_rng(), use_vad: bool = False) -> int: + """ + Determines the maximum allowed overlap and that samples an overlap value through the function _sample_overlap. + + Args: + examples: List of all examples that are currently present in the meeting. + current_source: Source for which the overlap should be determined. + rng: The numpy rng that should be used, the rng should generate a number in the interval [0,1). + When not set a uniform rng is used. + use_vad: (optional) Is VAD data in the given datasets and should these data be used during + the selection of samples and offset. Default Value: False + + Returns: Sampled overlap + """ + + maximum_overlap = _get_valid_overlap_region(collate_fn(examples), self.max_concurrent_spk, current_source, + use_vad) + examples = examples[:] + + if use_vad: + examples.sort(key=lambda x: x['speaker_end']['aligned_source']) + if len(examples) > 1: + maximum_overlap = min(maximum_overlap, + examples[-1]['speaker_end']['aligned_source'] + - examples[-2]['speaker_end']['aligned_source']) + maximum_overlap = min(maximum_overlap, current_source['num_samples']['aligned_source'], + examples[-1]['num_samples']['aligned_source']) + else: + maximum_overlap = min(maximum_overlap, current_source['num_samples']['observation']) + + maximum_overlap = min(maximum_overlap, self.hard_maximum_overlap) + + overlap = self._sample_overlap(self.hard_minimum_overlap, maximum_overlap, rng, examples, current_source) + + return overlap + + @abstractmethod + def _sample_overlap(self, minimum_overlap: int, maximum_overlap: int, + rng: np.random.random = np.random.default_rng(), examples: List[Dict] = None, + current_source: Dict[str, Any] = None) -> int: + """ + Internal function that samples overlap with respect to the maximum and minimum allowed overlap. + Also, can take the previous examples and the current source as parameters + when required for sampling the overlap. + + Args: + minimum_overlap: Minimum for the overlap that is sampled + maximum_overlap: Maximum for the overlap that is sampled + rng: The numpy rng that should be used, the rng should generate a number in the interval [0,1). + When not set a uniform rng is used. + examples: (optional) List of all examples that are currently present in the meeting. + current_source: (optional) Source for which the overlap should be determined. + + Returns: Sampled overlap + """ + raise NotImplementedError + + +@dataclass +class DistributionOverlapSampler(OverlapSampler): + """ + Class which is used to sample overlap values for the generation of a meeting using a DistributionModel. + It is guaranteed that the given overlap values are valid and only a maximum number + of speakers is active simultaneously, when the sampler is called with (). + When the sampling fails a ValueError is returned. + + Properties: + max_concurrent_spk: Maximum number of concurrent active speakers. + distribution: DistributionModel from which the overlap should be sampled. + hard_minimum_overlap: Hard minimum value for the overlap + hard_maximum_overlap: Hard maximum value for the overlap + """ + + max_concurrent_spk: int + distribution: DistributionModel + hard_minimum_overlap: int = 0 + hard_maximum_overlap: int = 1000000 + + def _sample_overlap(self, minimum_overlap: int, maximum_overlap: int, + rng: np.random.random = np.random.default_rng(), examples: List[Dict] = None, + current_source: Dict[str, Any] = None) -> int: + """ + Internal function that samples overlap from the distribution with respect to the maximum + and minimum allowed overlap. examples and current_source are not used + + Args: + maximum_overlap: Maximum for the overlap that is sampled + minimum_overlap: Minimum for the overlap that is sampled + rng: The numpy rng that should be used, the rng should generate a number in the interval [0,1). + When not set a uniform rng is used. + examples: (Not used in the implementation of this function) + current_source: (Not used in the implementation of this function) + + Returns: Sampled overlap according to the distribution + """ + + return self.distribution.sample_value(rng, minimum_value=minimum_overlap, maximum_value=maximum_overlap) + + +@dataclass +class BackchannelStartSampler: + """ + Class that can be used for the sampling the starting distance of the backchannel source from the beginning of the + foreground source. + Example: Foreground offset: 2000, start_distance: 1500 => Backchannel offset: 3500 + + Important: The current implementation does not guarantee, that the sampled start distance for + the backchannel is valid, that must be ensured through the given parameters. + + Properties: + minimum_start_distance: Hard minimum value for the start distance + maximum_start_distance: Hard maximum value for the start distance + """ + minimum_start_distance: int = 0 + maximum_start_distance: int = 16000000 + + def __call__(self, minimum_possible_start: int, maximum_possible_start: int, + rng: np.random.random = np.random.default_rng()) -> int: + """ + Samples the offset of a backchannel example, while making sure that the hard minimum and maximum distances + are followed. Do not guarantee that the sampled offset is valid that has to be ensured + through the input parameters. The offset is sampled uniformly in the possible range. + + Args: + minimum_possible_start: Minimum possible offset for the backchannel source + maximum_possible_start: Maximum possible offset for the backchannel source + rng: The numpy rng that should be used, the rng should generate a number in the interval [0,1). + When not set a uniform rng is used. + + Returns: Offset of the backchannel source + """ + + maximum_start = min(minimum_possible_start+self.maximum_start_distance, maximum_possible_start) + minimum_start = min(minimum_possible_start+self.minimum_start_distance, self.maximum_start_distance) + + return rng.integers(minimum_start, maximum_start) diff --git a/mms_msg/sampling/pattern/meeting/state_based/transition_model.py b/mms_msg/sampling/pattern/meeting/state_based/transition_model.py index 286a571..2bad2ae 100644 --- a/mms_msg/sampling/pattern/meeting/state_based/transition_model.py +++ b/mms_msg/sampling/pattern/meeting/state_based/transition_model.py @@ -1,7 +1,12 @@ +from __future__ import annotations +import json import numpy as np -from typing import Optional, List, Union, Generic, TypeVar, Tuple, Any, Set +from typing import Optional, List, Union, Generic, TypeVar, Tuple, Any, Set, Dict from abc import ABC, abstractmethod from copy import deepcopy +import sys + +from mms_msg.sampling.pattern.meeting.scenario_sequence_sampler import sample_balanced class StateTransitionModel(ABC): @@ -37,6 +42,55 @@ def reset(self) -> None: """ raise NotImplementedError + @staticmethod + @abstractmethod + def to_json(obj: StateTransitionModel) -> str: + """ Static method that serializes a StateTransitionModel into a json string. + + Args: + obj: StateTransitionModel which should be serialized + + Returns: Json string which contains the data of the given object + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def from_json(json_string: str) -> StateTransitionModel: + """ Static method that creates a StateTransitionModel from a given json string. + + Args: + json_string: Json string that contains the data required for the StateTransitionModel + + Returns: StateTransitionModel constructed from the data of the json string. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def save(obj: StateTransitionModel, filepath: Optional[str] = 'state_transition_model.json') -> None: + """ Static method that saves the given StateTransitionModel to a file belonging to the given filepath. + When the file exists its contests will be overwritten. When it not exists it is created. + The used dataformat is json, so a .json file extension is recommended. + + Args: + obj: StateTransitionModel which should be saved + filepath: Path to the file where the model should be saved. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def load(filepath: Optional[str] = 'distribution_model.json') -> StateTransitionModel: + """Static method that loads a StateTransitionModel from file belonging to the given filepath. + + Args: + filepath: Path to the file where the model is saved. + + Returns: StateTransitionModel constructed from the data of the given file. + """ + raise NotImplementedError + MST = TypeVar('MST') # MarkovStateType @@ -48,7 +102,7 @@ class MarkovModel(StateTransitionModel, Generic[MST]): Properties: s0: Name or index of the starting state size: Number of states in the model - states: (read-only) Names of the states + state_names: (read-only) Names of the states current_state_index: Index of the current state last_state_index: Index of the previous state """ @@ -194,6 +248,39 @@ def __repr__(self): f"Probability matrix:\n {self.probability_matrix}" ) + @staticmethod + def to_json(obj: MarkovModel) -> str: + def json_default(o): + if isinstance(o, np.ndarray): + return o.tolist() + else: + return o.__dict__ + return json.dumps(obj, default=json_default) + + @staticmethod + def from_json(json_string: str) -> MarkovModel: + obj = MarkovModel(np.ones((1, 1))) + data = json.loads(json_string) + for k, v in data.items(): + if k == 'probability_matrix': + obj.__dict__[k] = np.asarray(v) + else: + obj.__dict__[k] = v + return obj + + @staticmethod + def save(obj: MarkovModel, filepath: Optional[str] = 'state_transition_model.json') -> None: + with open(filepath, 'w+') as file: + json_string = MarkovModel.to_json(obj) + file.write(json_string) + + @staticmethod + def load(filepath: Optional[str] = 'state_transition_model.json') -> MarkovModel: + with open(filepath, 'r') as file: + json_string = file.read() + obj = MarkovModel.from_json(json_string) + return obj + class SpeakerTransitionModel(ABC): """ @@ -231,6 +318,18 @@ def next(self, rng: np.random.random = np.random.default_rng(), last_action_succ """ raise NotImplementedError + @abstractmethod + def change_num_speakers(self, num_speakers: int = 2) -> None: + """ + Tries to change the number of speakers in the transition model. + This can be used to create meetings with a different number of speakers than in the source dataset. + When the change is not possible due to the structure of the transition model, + this function should throw an TypeError. + + Args: + num_speakers: New number of speakers that the transition model should use for its output. + """ + @abstractmethod def reset(self) -> None: """ @@ -248,28 +347,87 @@ def tags(self) -> Set[str]: """ pass + @staticmethod + @abstractmethod + def to_json(obj: SpeakerTransitionModel) -> str: + """ Static method that serializes a SpeakerTransitionModel into a json string. + + Args: + obj: SpeakerTransitionModel which should be serialized + + Returns: Json string which contains the data of the given object + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def from_json(json_string: str) -> SpeakerTransitionModel: + """ Static method that creates a SpeakerTransitionModel from a given json string. + + Args: + json_string: Json string that contains the data required for the SpeakerTransitionModel + + Returns: SpeakerTransitionModel constructed from the data of the json string. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def save(obj: SpeakerTransitionModel, filepath: Optional[str] = 'speaker_transition_model.json') -> None: + """ Static method that saves the given SpeakerTransitionModel to a file belonging to the given filepath. + When the file exists its contests will be overwritten. When it not exists it is created. + The used dataformat is json, so a .json file extension is recommended. + + Args: + obj: SpeakerTransitionModel which should be saved + filepath: Path to the file where the model should be saved. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def load(filepath: Optional[str] = 'speaker_transition_model.json') -> SpeakerTransitionModel: + """Static method that loads a SpeakerTransitionModel from file belonging to the given filepath. + + Args: + filepath: Path to the file where the model is saved. + + Returns: SpeakerTransitionModel constructed from the data of the given file. + """ + raise NotImplementedError -class TwoSpeakerTransitionModel(SpeakerTransitionModel): - """SpeakerTransitionModel which implements the transition between two speakers. - For the transition of the states a transitionModel is used. - Supports four possible actions: + +class MultiSpeakerTransitionModel(SpeakerTransitionModel): + """SpeakerTransitionModel which implements the transition between multiple speakers. + For the transition of the states a transitionModel is used. + Supports four possible actions: - TH: Turn-hold (no speaker change with silence) - TS: Turn-switch (speaker change with silence) - OV: Overlap (speaker change with overlap) - BC: Backchannel (speaker in backchannel, totally overlapped from foreground speaker) + Currently, there are two implemented modes, that determine the selection of the next speaker (random, balanced) + Corresponding paper: Improving the Naturalness of Simulated Conversations for End-to-End Neural Diarization, https://arxiv.org/abs/2204.11232 Properties: transition_model: StateTransitionModel used for the selection of the action current_active_index: Index of current active speaker + last_active_index: Index of last active speaker tries: Current number of tries for finding an action that can be successfully executed max_tries: Maximum number of tries to find a valid action, after this is surpassed a StopIteration Exception is returned tags: Set of all actions that can be returned by the model. + num_speakers: Number of speakers in the transition model. + mode: Currently, there are two implemented modes, that determine the selection of the next speaker: + - random: The next speaker is chosen at random. + - balanced: The next speaker is the speaker which has the least amount of speech time in the + currently generated section of the meeting. + """ - def __init__(self, transition_model: StateTransitionModel, max_tries: int = 50) -> None: + def __init__(self, transition_model: StateTransitionModel, max_tries: int = 50, num_speakers: int = 2, + mode: str = 'balanced') -> None: """ Initialization with a StateTransitionModel and the number of maximal tries for a successful action. When it is not possible to execute the action, the StateTransitionModel is reverted @@ -279,42 +437,107 @@ def __init__(self, transition_model: StateTransitionModel, max_tries: int = 50) Args: transition_model: Underlying transition model, that is internally used to determine the next action. max_tries: Maximum number of tries + num_speakers: Number of speakers in the transition model """ self.transition_model = transition_model self.current_active_index = 0 + self.last_active_index = 0 # Required to revert internal state after sampling of a fitting source fails. self.tries = 0 self.max_tries = max_tries + self.num_speakers = num_speakers + self.mode = mode def start(self, env_state: Optional[Any] = None, **kwargs) -> Tuple[int, Any]: self.reset() return self.current_active_index, env_state def next(self, rng: np.random.random = np.random.default_rng(), last_action_success: bool = True, - env_state: Optional[Any] = None, **kwargs) -> Tuple[str, int, Any]: + env_state: Optional[Any] = None, examples: List[Dict] = None, **kwargs) -> Tuple[str, int, Any]: + """ + Returns the next action and the next speaker according to the implemented transition model. + When the model cannot find an action that can be executed a StopIteration error is raised. + + Args: + rng: rng that can be used in the determination of the next state + last_action_success: (optional) status of the execution of the last action, + can be used when the actions can fail. + env_state: (optional) Additional information about the state of the environment + examples: (optional) List of the previously chosen samples of the current meeting + + Returns: Next action (state), the index of the next active speaker and the current environment state + """ + + if examples is None: + examples = [] + if last_action_success: self.tries = 0 elif self.tries < self.max_tries: self.transition_model.step_back() + self.current_active_index = self.last_active_index self.tries += 1 else: raise StopIteration('Number of max tries exceeded') action = self.transition_model.next(rng) + self.last_active_index = self.current_active_index + # In the case of a backchannel, the next speaker changes, but the active speaker in the foreground does not. if action == "BC": - return action, (self.current_active_index + 1) % 2, None + return action, self._next_speaker(rng=rng, examples=examples, env_state=env_state, **kwargs), None # Speaker changes in the case of a turn-switch or overlap if action in ("TS", "OV"): - self.current_active_index = (self.current_active_index + 1) % 2 + self.current_active_index = self._next_speaker(rng=rng, examples=examples, env_state=env_state, **kwargs) return action, self.current_active_index, env_state + def change_num_speakers(self, num_speakers: int = 2): + self.num_speakers = num_speakers + + def _next_speaker(self, rng: np.random.random, examples: List[Dict], **kwargs) -> int: + """ + Internal function that determines the next speaker, depending on the current mode. + + Args: + rng: rng that can be used in the determination of the next speaker. + examples: List of the previously chosen samples of the current meeting. + **kwargs: Additional keyword arguments given to the function. + + Returns: Index of the next speaker. + """ + if self.mode == 'balanced': + speakers = set() + + for source in examples: + speakers.add(source['speaker_id']) + + speakers = list(speakers) + if len(speakers) < self.num_speakers: + speakers.extend([str(i) for i in range(len(speakers), self.num_speakers)]) + + next_index = speakers.index(str(sample_balanced(scenarios=speakers, examples=examples, rng=rng))) + + # When the active speaker is selected, a random speaker is chosen, + # because the selected action requires a speaker change, + if next_index != self.current_active_index: + return next_index + else: + possible_speakers = list({i for i in range(self.num_speakers)}.difference({self.current_active_index})) + return int(rng.choice(possible_speakers, size=1)) + + elif self.mode == 'random': + possible_speakers = list({i for i in range(self.num_speakers)}.difference({self.current_active_index})) + return int(rng.choice(possible_speakers, size=1)) + else: + raise AssertionError(f'Selected mode ({self.mode}) is not supported. Supported modes: random, balanced') + def reset(self) -> None: self.transition_model.reset() self.current_active_index = 0 + self.last_active_index = 0 @property def tags(self) -> Set[str]: @@ -322,7 +545,44 @@ def tags(self) -> Set[str]: def __repr__(self): return ( - f"TwoSpeakerTransitionModel: " - f"Current state: {self.current_active_index}\n" + f"MultiSpeakerTransitionModel: " + f"Current state: {self.current_active_index} " + f"Number of speakers:{self.num_speakers}\n" f"TransitionModel:\n{self.transition_model}" ) + + @staticmethod + def to_json(obj: MultiSpeakerTransitionModel) -> str: + def json_default(o): + if isinstance(o, StateTransitionModel): + return type(o).__name__, o.__class__.to_json(o) + else: + return o.__dict__ + + return json.dumps(obj, default=json_default) + + @staticmethod + def from_json(json_string: str) -> MultiSpeakerTransitionModel: + obj = MultiSpeakerTransitionModel(MarkovModel(np.ones((1, 1)))) + data = json.loads(json_string) + for k, v in data.items(): + if k == 'transition_model': + # Restriction of possible classes to the ones in this file to prevent + # possible injection of malicious code through the json string + obj.__dict__[k] = getattr(sys.modules[__name__], v[0]).from_json(v[1]) + else: + obj.__dict__[k] = v + return obj + + @staticmethod + def save(obj: MultiSpeakerTransitionModel, filepath: Optional[str] = 'speaker_transition_model.json') -> None: + with open(filepath, 'w+') as file: + json_string = MultiSpeakerTransitionModel.to_json(obj) + file.write(json_string) + + @staticmethod + def load(filepath: Optional[str] = 'speaker_transition_model.json') -> MultiSpeakerTransitionModel: + with open(filepath, 'r') as file: + json_string = file.read() + obj = MultiSpeakerTransitionModel.from_json(json_string) + return obj diff --git a/mms_msg/sampling/pattern/meeting/state_based/weighted_meeting_sampler.py b/mms_msg/sampling/pattern/meeting/state_based/weighted_meeting_sampler.py new file mode 100644 index 0000000..82e1d42 --- /dev/null +++ b/mms_msg/sampling/pattern/meeting/state_based/weighted_meeting_sampler.py @@ -0,0 +1,376 @@ +import logging +import sys +from dataclasses import dataclass +from typing import Any, Optional, Union, Dict + +from cached_property import cached_property + +from mms_msg import keys +import padertorch as pt +import paderbox as pb +from lazy_dataset import Dataset, FilterException, from_dict +from mms_msg.sampling.utils import update_num_samples, cache_and_normalize_input_dataset, collate_fn +from mms_msg.sampling.utils.rng import get_rng_example + +from mms_msg.sampling.pattern.meeting.state_based.transition_model import SpeakerTransitionModel +from mms_msg.sampling.pattern.meeting.state_based.action_handler import ActionHandler + +logger = logging.getLogger('meeting_generation') + + +class _WeightedMeetingSampler: + """ + Main class of the WeightedMeetingSampler. + + The WeighedMeetingSampler requires two internal components to work: + - A TransitionModel which determines the sequence of the speakers and transition types between them (actions). + - An ActionHandler which processes the action by sampling a source and determines a fitting offset + + TransitionModel and ActionHandler are abstract classes. Due to this it is possible to customize the behavior + of the meeting generator to a great extent. In the files transition_model.py and action_handler.py + possible example implementations for these classes can be found, that are also used in meeting_generator.py + + During generation a new meeting the WeighedMeetingSampler runs through the following loop, + until the desired length of the meeting is reached or an error during the generation happens. + - The transition model outputs next action and speaker + - The action handler processes these two values and samples a fitting source of the speaker + together with a fitting offset + - The keys of the selected source are adjusted to the sampled offset + + When the generation of a meeting fails, a FilterException is raised. + + Properties: + input_datasets: Dictionary of datasets that are used for generation of new dialogue examples, + for each dataset an action can be specified, when this action is selected the following sample + is drawn from the corresponding dataset. When this property is set the normalized_datasets property + is also adopted to the new input datasets. + normalized_datasets: (read only) dictionary of the normalized input datasets + duration: Duration that the newly generated examples should roughly have, can be slightly exceeded + transition_model: Transition-model that determines the sequence of speaker and the transitions (actions) + between them. + action_handler: Action handler that is responsible for processing the transitions from the transition model. + Draws fitting samples from the input dataset for each action. + use_vad: Should VAD data be used, only activate when VAD exists in the input dataset + """ + + def __init__(self, input_datasets: [Dataset, Dict], transition_model: SpeakerTransitionModel, + action_handler: ActionHandler, duration: Optional[int] = 960000, use_vad: Optional[bool] = False, + force_fitting_tags: Optional[bool] = True): + """ + Initializes the WeightedMeetingSampler with a TransitionModel and an ActionHandler. + Also, the input datasets for the different actions are set. + + Args: + input_datasets: Dictionary of datasets that are used for generation of new dialogue examples, + for each dataset an action can be specified, when this action is selected the following sample + is drawn from the corresponding dataset. + transition_model: Transition-model that determines the sequence of speaker and the transitions (actions) + between them. + action_handler: Action handler that is responsible for processing the transitions from the transition model. + Draws fitting samples from the input dataset for each transition. + duration: Duration that the newly generated examples should roughly have, can be slightly exceeded + use_vad: Should VAD data be used, only activate when VAD exists in the input dataset + force_fitting_tags: When set to True, the class enforces that all actions that can be generated + from the transition model can be processed by the action handler, by comparing their tags. + """ + self._input_datasets = dict() + for k, v in input_datasets.items(): + # Transforms datasets in the shape of a dictionary to the Type Dataset + if type(v) is dict: + v = from_dict(v) + # Remove FilterExceptions + self._input_datasets[k] = v.catch() + + self.duration = duration + self.transition_model = transition_model + self.action_handler = action_handler + self.action_handler.set_datasets(self.normalized_datasets, use_vad) + self.use_vad = use_vad + + if force_fitting_tags: + # Check if arbitrary action scan be handled by the action handler (*) + # or all possible actions of the transition model can be handled (subset relation) + if not ('*' in self.action_handler.tags + or self.transition_model.tags.issubset(self.action_handler.tags)): + raise AssertionError(("The Tags of the TransitionModel and the ActionHandler do not fit, some actions " + "from the TransitionModel can't be handled by the ActionHandler. " + "To disable the enforcement of fitting tags set force_fitting_tracks to False."), + "Not supported Tags: ", + self.transition_model.tags.difference(self.action_handler.tags)) + + @property + def input_datasets(self): + return self._input_datasets + + @input_datasets.setter + def input_datasets(self, input_datasets): + # Invalidate cached property that is calculated from input_dataset + delattr(self, "normalized_datasets") + + self._input_datasets = dict() + for k, v in input_datasets.items(): + # Transforms datasets in the shape of a dictionary to the Type Dataset + if type(v) is dict: + v = from_dict(v) + # Remove FilterExceptions + self._input_datasets[k] = v.catch() + + # Update ActionHandler with new datasets + self.action_handler.set_datasets(self.normalized_datasets, self.use_vad) + + @cached_property + def normalized_datasets(self) -> Dict[str, Union[Dataset, Dict]]: + return {key: cache_and_normalize_input_dataset(dataset) for (key, dataset) in self._input_datasets.items()} + + def _log_action(self, i: int, action: [Any], current_source: Dict[str, Any]) -> None: + """ + Internal function that logs the action and corresponding source. Logged keys of the source: scenario, offset, + speaker_end. When VAD is used the aligned source values are logged, otherwise the original_source is used. + + Args: + i: index of the action in the final output + action: action that should be logged + current_source: current active source that should be logged + """ + + key = 'original_source' + if self.use_vad: + key = 'aligned_source' + + logger.info("i: %s, action: %s, scenario: %s, offset: %s, speaker_end: %s", i, action, + current_source['scenario'], + current_source['offset'][key], current_source['speaker_end'][key]) + + def _set_keys(self, source: Dict[str, Any], offset: int) -> None: + """ + Internal function that sets the offset, num_samples and speaker_end keys for a source with a certain offset. + Additional keys are set when a vad is used in the generation process. + The keys are set in-place, so the source used as input is changed. + + Args: + source: source for which the keys are set + offset: desired offset of the source, offset must be the offset + for the original_source not the aligned_source + """ + + # Aligned source: Speaker active + + for key in ['offset', 'num_samples', 'speaker_end']: + if key not in source.keys(): + source[key] = dict() + + if self.use_vad: + source['offset']['aligned_source'] = offset + source['offset']['aligned_source'] + source['speaker_end']['aligned_source'] = (source['offset']['aligned_source'] + + source['num_samples']['aligned_source']) + + source['offset']['original_source'] = offset + source['num_samples']['original_source'] = source['num_samples']['observation'] + source['speaker_end']['original_source'] = offset + source['num_samples']['original_source'] + + def __call__(self, example: Dict[str, Any]) -> Dict[str, Any]: + """ + Used for generating of a meeting, with transitions according to the given transition model + and sources and offsets according to the given action handler. + The generation is based on base examples which select the active speakers. + + Args: + example: List of base example, which determine which speakers are active in the generated meeting + + Returns: Meeting generated according to the transition model and the action handler. + """ + + example_id = example['example_id'] + + if 'ST' in self.normalized_datasets.keys(): + base_examples = [self.normalized_datasets['ST'][source_id] for source_id in example['source_id']] + else: + base_examples = [self.normalized_datasets['*'][source_id] for source_id in example['source_id']] + + logger.info(f'Generating meeting with example ID "{example_id}"') + + # Sample stuff that is constant over the meeting + scenario_ids = [x['scenario'] for x in base_examples] + + # A generated example is valid, if every chosen speaker has at least one sample + valid = False + + examples = [] + + transition_rng = get_rng_example(example, 'transition') + + # Get first speaker + scenario_id_index, state = self.transition_model.start() + success, current_source, offset, state = self.action_handler.start(example_id, scenario_ids, base_examples, + scenario_id_index, env_state=state) + if success: + self._set_keys(current_source, offset) + self._log_action(0, "ST", current_source) + examples.append(current_source) + valid = True + else: + logger.error("Start action failed.") + + # Add samples until the length overshoots the desired length, then stop + while valid and (max([example[keys.OFFSET][keys.ORIGINAL_SOURCE] + example['num_samples']['observation'] + for example in examples], default=0) < self.duration): + + try: + action, scenario_id_index, state = self.transition_model.next(rng=transition_rng, + last_action_success=success, + env_state=state, examples=examples) + + success, current_source, offset, state = self.action_handler.next_scenario(action, scenario_id_index, + examples, base_examples, + env_state=state) + + if success: + self._set_keys(current_source, offset) + self._log_action(len(examples), action, current_source) + examples.append(current_source) + else: + logger.warning(f"Can not find fitting source for Action {action}, retry.") + + except StopIteration: + # No valid action possible. + valid = False + logger.error("No valid action possible.") + + # Check if each speaker has at least one appearance + scenarios = {scenario_example['scenario'] for scenario_example in examples} + + if scenarios != set(scenario_ids): + logger.error(f'The speakers present in the meeting, do not correspond to those in the base examples.' + f' Missing scenarios: ' + str(set(scenario_ids).difference(scenarios)).replace('set()', '{}') + + f' Surplus scenarios: ' + str(scenarios.difference(set(scenario_ids))).replace('set()', '{}')) + valid = False + + # Generation not successful when the first state cannot be initialized, no valid action can be found + # or the present speakers do not fit the base examples + if not valid: + logger.error('Generation not successful for: %s', example_id) + # When the generation of an example is not successful a filter exception is raised + raise FilterException() + + # Collate the examples and copy over / replicate things that are already + # present in the base full overlap example. + # Use the same format as SMS-WSJ. + # Heuristic: Replicate nothing that is in the collated + # example. For the rest, we have a white- and blacklist of keys that should + # or should not be replicated. Keys that are not replicated are copied from + # the input example to the output example. + + replicate_whitelist = ( + 'log_weights', + 'audio_path.rir', + 'source_position', + ) + replicate_blacklist = ( + 'sensor_positions', + 'room_dimensions', + 'example_id', + 'num_speakers', + 'source_dataset', + 'sound_decay_time', + 'sensor_position', + 'snr', + ) + + # Collate + collated_examples = collate_fn(examples) + + # Handle some special keys prior to replication + collated_examples['source_id'] = collated_examples.pop('example_id') + flat_example = pb.utils.nested.flatten(example) + speaker_ids = flat_example['speaker_id'] + collated_examples['num_samples'].pop('observation') + + sources = collated_examples['audio_path'].pop('observation') + collated_examples['audio_path']['original_source'] = sources + collated_examples['source_dataset'] = collated_examples['dataset'] + update_num_samples(collated_examples) + collated_examples['dataset'] = example['dataset'] + + # Copy and replicate + flat_collated_example = pb.utils.nested.flatten(collated_examples) + for key in flat_example.keys(): + if key not in flat_collated_example: + if key in replicate_whitelist: + if key == 'source_position': + # Special case: nested lists + assert len(flat_example[key][0]) == len(speaker_ids), (flat_example[key], speaker_ids) + transposed = zip(*flat_example[key]) + m = dict(zip(speaker_ids, transposed)) + transposed = [m[s] for s in flat_collated_example['speaker_id']] + flat_collated_example[key] = list(map(list, zip(*transposed))) + else: + assert len(flat_example[key]) == len(speaker_ids), (flat_example[key], speaker_ids) + m = dict(zip(speaker_ids, flat_example[key])) + flat_collated_example[key] = [m[s] for s in flat_collated_example['speaker_id']] + else: + if key not in replicate_blacklist: + # Add keys that you need to blacklist/whitelist + raise RuntimeError( + f'Key {key} not found in replicate_whitelist or ' + f'replicate_blacklist.\n' + f'replicate whitelist={replicate_whitelist},\n' + f'replicate whitelist={replicate_blacklist},\n' + ) + flat_collated_example[key] = flat_example[key] + collated_example = pb.utils.nested.deflatten(flat_collated_example) + return collated_example + + +@dataclass(frozen=True) +class WeightedMeetingSampler(pt.Configurable): + """ + Wrapper class for _WeightedMeetingSampler that + Samples a meeting from (full-overlap) base examples. + + Properties: + transition_model: Transition-model that determines the sequence of speaker and the transitions (actions) + between them. + action_handler: Action handler that is responsible for processing the transitions from the transition model. + Draws fitting samples from the input dataset for each action. + duration: Duration that the newly generated examples should roughly have, can be slightly exceeded + use_vad: Should VAD data be used, only activate when VAD exists in the input dataset + force_fitting_tags: When set to True, the class enforces that all actions that can be generated + from the transition model can be processed by the action handler, by comparing their tags. + """ + + transition_model: SpeakerTransitionModel + action_handler: ActionHandler + duration: int = 120 * 8000 + use_vad: bool = False + force_fitting_tags: bool = True + + def __call__(self, input_datasets: Dict[str, Union[Dict, Dataset]]) -> _WeightedMeetingSampler: + """ + Initialises the _WeightedMeetingSampler with input datasets and uses + the values of the dataclass for the other parameters. + + Args: + input_datasets: Dictionary of datasets that are used for generation of new dialogue examples, + for each dataset an action can be specified, when this action is selected the following sample + is drawn from the corresponding dataset. + + The examples in the datasets must contain the following keys: + - scenario (string): Only utterances with the same scenario are put + into the same meeting for the same speaker. Example: In LibriSpeech, + the environment changes heavily for different chapters, even if + the speaker stays the same. So, we want the chapter to stay the same + during one meeting. + - num_samples (int): length of audio signal + - (optional) vad (ArrayInterval or numpy array): VAD information, with sample + resolution + + Returns: WeightedMeetingSampler that is initialised with the given input datasets + """ + return _WeightedMeetingSampler( + input_datasets=input_datasets, + duration=self.duration, + transition_model=self.transition_model, + action_handler=self.action_handler, + use_vad=self.use_vad, + force_fitting_tags=self.force_fitting_tags + ) diff --git a/mms_msg/sampling/utils/distribution_model.py b/mms_msg/sampling/utils/distribution_model.py index b4dedd1..c55e4f9 100644 --- a/mms_msg/sampling/utils/distribution_model.py +++ b/mms_msg/sampling/utils/distribution_model.py @@ -1,3 +1,6 @@ +from __future__ import annotations +import json +import sys import numpy as np from typing import Optional, List, Union, Tuple from collections import Counter @@ -22,11 +25,13 @@ class DistributionModel: def __init__(self, samples: Optional[List[Union[int, float]]] = None, bin_size: Union[int, float] = 100, allow_negative_samples: bool = False): """ - :param bin_size: size of the histogram bins - :param samples: (optional) list of samples that should be added - :param allow_negative_samples: (optional) Allowing negative values to be added to the model. - Disabled by default. + Args: + samples: (optional) list of samples that should be added + bin_size: (optional) size of the histogram bins + allow_negative_samples: (optional) Allowing negative values to be added to the model. + Disabled by default. """ + self.n = 0 self._distribution_prob = None self._bin_size = bin_size @@ -69,9 +74,10 @@ def standard_deviation(self) -> float: return self._standard_deviation def clear(self) -> None: - """ Removes all samples from the model and resets the related statistical values - :return: None """ + Removes all samples from the model and resets the related statistical values + """ + self._distribution_prob = None self._min_value = None self._max_value = None @@ -80,9 +86,11 @@ def clear(self) -> None: self._standard_deviation = None def fit(self, samples: Union[List[Union[int, float]]]) -> None: - """ Fits the distribution model to a number of samples. Previously estimated values will be overwritten. - :param samples: Samples to which the model is fitted. The samples can be given as list or as set. - :return: None + """ + Fits the distribution model to a number of samples. Previously estimated values will be overwritten. + + Args: + samples: Samples to which the model is fitted. The samples can be given as list or as set. """ if len(samples) == 0: @@ -127,17 +135,21 @@ def sample_value(self, rng: Optional[np.random.random] = None, random_state: Opt """ Sample a value according to the currently estimated distribution saved in the distribution model. It is also possible to restrict the area to an interval from which a sample is drawn. - In this case, the distribution inside the interval is normalized to the probability 1 and then used for sampling. - - :param rng: (optional) The numpy rng that should be used, the rng should generate a number in the interval [0,1) - If not set a new uniform rng is used. - :param random_state: (optional) Seed for the default random number generator. - If not set, no seed is used for the rng, so the samples are no reproducible. - :param sample_integer: (optional) When set to true, the sampled value is an integer, otherwise it is a float. - Default: True. - :param minimum_value: (optional) minimal value that should be sampled (including minimum_value) - :param maximum_value: (optional) maximum value that should be sampled (excluding maximum_value) - :return: sample according to the distribution Integer, when sample_integer is True. + In this case, the distribution inside the interval is normalized to the probability 1 and then used for sampling + + Args: + rng: (optional) The numpy rng that should be used, the rng should generate a number in the interval [0,1) + If not set a new uniform rng is used. + random_state: (optional) Seed for the default random number generator. + If not set, no seed is used for the rng, so the samples are no reproducible. + sample_integer: (optional) When set to true, the sampled value is an integer, otherwise it is a float. + Default: True. + minimum_value: (optional) minimal value that should be sampled (including minimum_value) + maximum_value: (optional) maximum value that should be sampled (excluding maximum_value) + + Returns: Sample according to the distribution Integer. Returns an integer when sample_integer is set to True, + otherwise returns a float. + """ if rng is None: @@ -146,6 +158,9 @@ def sample_value(self, rng: Optional[np.random.random] = None, random_state: Opt if self.n == 0: raise AssertionError("No samples has been added to the model. Sampling not possible.") + if minimum_value is not None and maximum_value is not None and minimum_value >= maximum_value: + raise ValueError('When given the maximum value must be greater than the minimum value.') + if minimum_value is None: p_min = 0 else: @@ -156,6 +171,11 @@ def sample_value(self, rng: Optional[np.random.random] = None, random_state: Opt else: p_s = self.get_cdf_value(maximum_value)-p_min + # Check if it is possible to sample a value, 1e-06 is used instead of 0, due to floating point precision + if p_s <= 1e-06: + raise ValueError('The probability that an element is in the given boundaries is 0 according to the' + ' underlying model.') + temp = p_min + rng.random()*p_s for (val, prob) in self.distribution_prob: @@ -175,8 +195,10 @@ def get_cdf_value(self, value: Union[int, float]) -> float: Returns the value of the cumulative distribution function (cdf) for the given value. In other words returns the probability that a random sample is smaller than value. - :param value: Value for which the cdf should be evaluated - :return: Output of the cdf function at the given value. + Args: + value: Value for which the cdf should be evaluated + + Returns: Output of the cdf function at the given value. """ if value < self.min_value: @@ -198,34 +220,34 @@ def __repr__(self): ret += " Minimum value:" + str(self.min_value) ret += " Maximum value:" + str(self.max_value) ret += " Expected value:" + str(self.expected_value) - ret += " Standard derivation:" + str(self.standard_deviation) + ret += " Standard deviation:" + str(self.standard_deviation) ret += " Variance:" + str(self.variance) return ret - def plot(self, show = False, fig = None, ax = None): + def plot(self, show: bool = False, ax=None): """ Creates a plot of the distribution model using matplotlib and - returns a figure and axes with the corresponding plot. - @:param show: (optional) When set to True the figure is directly shown - @:param fig: (optional) Figure on which a new axes with the plot is created. - Will be overwritten when ax is given. - When not given and also ax is not provided the function creates a new figure - with one axes and uses this for the plot. - @:param ax: (optional) axes on which the plot is created, when not provided - the function creates a new axes on the figure, when also the figure is not provided - then the function creates a new figure with one axes and uses this for the plot. - :return: Figure and axes with the plot of the distribution. - When an axis but no figure is given as input then the tuple (None,ax) is returned. + returns a figure and axes with the corresponding plot. + + Args: + show: (optional) When set to True the figure is directly shown + ax: (optional) axes on which the plot is created, when not provided + the function creates a new axes on the figure, when also the figure is not provided + then the function creates a new figure with one axes and uses this for the plot. + + Returns: Figure and axes with the plot of the distribution. + When an axis is given as input then the tuple (None,ax) is returned. """ + import matplotlib.pyplot as plt if self.n == 0: raise AssertionError("No samples has been added to the model. Plot is empty.") - if fig is None and ax is None: + fig = None + + if ax is None: fig, ax = plt.subplots() - elif fig is not None: - ax = fig.add_axes() ax.hist(list(map(lambda x: x[0], self.distribution_prob)), bins=int((self.max_value - self.min_value)/self.bin_size), @@ -239,16 +261,74 @@ def plot(self, show = False, fig = None, ax = None): plt.show() return fig, ax + @staticmethod + def to_json(obj: DistributionModel) -> str: + """ Static method that serializes a DistributionModel into a json string. + + Args: + obj: DistributionModel which should be serialized + + Returns: Json string which contains the data of the given object + """ + return json.dumps(obj, default=lambda o: o.__dict__) + + @staticmethod + def from_json(json_string: str) -> DistributionModel: + """ Static method that creates a DistributionModel from a given json string. + + Args: + json_string: Json string that contains the data required for the DistributionModel + + Returns: DistributionModel constructed from the data of the json string. + """ + obj = DistributionModel() + data = json.loads(json_string) + for k, v in data.items(): + obj.__dict__[k] = v + return obj + + @staticmethod + def save(obj: DistributionModel, filepath: Optional[str] = 'distribution_model.json') -> None: + """ Static method that saves the given DistributionModel to a file belonging to the given filepath. + When the file exists its contests will be overwritten. When it not exists it is created. + The used dataformat is json, so a .json file extension is recommended. + + Args: + obj: DistributionModel which should be saved + filepath: Path to the file where the model should be saved. + """ + with open(filepath, 'w+') as file: + json_string = DistributionModel.to_json(obj) + file.write(json_string) + + @staticmethod + def load(filepath: Optional[str] = 'distribution_model.json') -> DistributionModel: + """Static method that loads a DistributionModel from file belonging to the given filepath. + + Args: + filepath: Path to the file where the model is saved. + + Returns: DistributionModel constructed from the data of the given file. + """ + with open(filepath, 'r') as file: + json_string = file.read() + obj = DistributionModel.from_json(json_string) + return obj + def statistical_distance(d1: DistributionModel, d2: DistributionModel) -> float: """ Calculates the statistical distance (total variation distance, https://en.wikipedia.org/wiki/Total_variation_distance_of_probability_measures) of two distribution models (d1 and d2). - :param d1: DistributionModel for comparison - :param d2: DistributionModel for comparison - :return: statistical distance + + Args: + d1: DistributionModel for comparison + d2: DistributionModel for comparison + + Returns: statistical distance """ + if d1.n == 0: raise AssertionError("No samples has been added to the first model. No comparison possible.") elif d2.n == 0: