diff --git a/example/2_guassian_copula_example.py b/example/2_guassian_copula_example.py index e69de29b..b1f37e8e 100644 --- a/example/2_guassian_copula_example.py +++ b/example/2_guassian_copula_example.py @@ -0,0 +1,19 @@ +# 运行该例子,可使用: +# ipython -i example/2_guassian_copula_example.py +# 并查看 sampled_data 变量 + +from sdgx.statistics.single_table.copula import GaussianCopulaSynthesizer +from sdgx.utils.io.csv_utils import * + +# 针对 csv 格式的小规模数据 +# 目前我们以 df 作为输入的数据的格式 +demo_data, discrete_cols = get_demo_single_table() +# print(demo_data) +# print(discrete_cols) + +model = GaussianCopulaSynthesizer(discrete_cols) +model.fit(demo_data) + +# sampled +sampled_data = model.sample(10) +print(sampled_data) diff --git a/sdgx/errors.py b/sdgx/errors.py index a41238cf..25cd63a3 100644 --- a/sdgx/errors.py +++ b/sdgx/errors.py @@ -3,4 +3,7 @@ class SdgxError(Exception): """Base class for exceptions in this module.""" - pass \ No newline at end of file + pass + +class NonParametricError(Exception): + """Exception to indicate that a model is not parametric.""" \ No newline at end of file diff --git a/sdgx/statistics/base.py b/sdgx/statistics/base.py index e69de29b..66463e4b 100644 --- a/sdgx/statistics/base.py +++ b/sdgx/statistics/base.py @@ -0,0 +1,86 @@ +from typing import List, Optional + +import numpy as np +import torch + + +class BaseSynthesizerModel: + random_states = None + + def __init__(self, transformer=None, sampler=None) -> None: + # 以下几个变量都需要在初始化 model 时进行更改 + self.model = None # 存放模型 + self.status = "UNFINED" + self.model_type = "MODEL_TYPE_UNDEFINED" + # self.epochs = epochs + self._device = "CPU" + + def fit(self, input_df, discrete_cols: Optional[List] = None): + raise NotImplementedError + + def set_device(self, device): + """Set the `device` to be used ('GPU' or 'CPU').""" + self._device = device + if self._generator is not None: + self._generator.to(self._device) + + def __getstate__(self): + device_backup = self._device + self.set_device(torch.device("cpu")) + state = self.__dict__.copy() + self.set_device(device_backup) + if ( + isinstance(self.random_states, tuple) + and isinstance(self.random_states[0], np.random.RandomState) + and isinstance(self.random_states[1], torch.Generator) + ): + state["_numpy_random_state"] = self.random_states[0].get_state() + state["_torch_random_state"] = self.random_states[1].get_state() + state.pop("random_states") + return state + + def __setstate__(self, state): + if "_numpy_random_state" in state and "_torch_random_state" in state: + np_state = state.pop("_numpy_random_state") + torch_state = state.pop("_torch_random_state") + current_torch_state = torch.Generator() + current_torch_state.set_state(torch_state) + current_numpy_state = np.random.RandomState() + current_numpy_state.set_state(np_state) + state["random_states"] = (current_numpy_state, current_torch_state) + self.__dict__ = state + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.set_device(device) + + def save(self, path): + device_backup = self._device + self.set_device(torch.device("cpu")) + torch.save(self, path) + self.set_device(device_backup) + + @classmethod + def load(cls, path): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = torch.load(path) + model.set_device(device) + return model + + def set_random_state(self, random_state): + if random_state is None: + self.random_states = random_state + elif isinstance(random_state, int): + self.random_states = ( + np.random.RandomState(seed=random_state), + torch.Generator().manual_seed(random_state), + ) + elif ( + isinstance(random_state, tuple) + and isinstance(random_state[0], np.random.RandomState) + and isinstance(random_state[1], torch.Generator) + ): + self.random_states = random_state + else: + raise TypeError( + f"`random_state` {random_state} expected to be an int or a tuple of " + "(`np.random.RandomState`, `torch.Generator`)" + ) diff --git a/sdgx/statistics/single_table/copula.py b/sdgx/statistics/single_table/copula.py index 6eb34e50..b9776a65 100644 --- a/sdgx/statistics/single_table/copula.py +++ b/sdgx/statistics/single_table/copula.py @@ -1,6 +1,6 @@ """ Wrappers around copulas models. - 需要修改:fit接口以适应性能优化措施 + 需要修改: fit接口以适应性能优化措施 """ import logging import warnings @@ -14,16 +14,19 @@ from copulas import multivariate from rdt.transformers import OneHotEncoder -from sdv.errors import NonParametricError -from sdv.single_table.base import BaseSingleTableSynthesizer -from sdv.single_table.utils import ( +# transformer 以及 sampler 已经拆分,挪到 transform/ 目录中 +# from sdgx.transform.sampler import DataSamplerCTGAN +from sdgx.transform.transformer import DataTransformerCTGAN +from sdgx.errors import NonParametricError +from sdgx.statistics.base import BaseSynthesizerModel +from sdgx.utils.utils import ( flatten_dict, log_numerical_distributions_error, unflatten_dict, validate_numerical_distributions) LOGGER = logging.getLogger(__name__) -class GaussianCopulaSynthesizer(BaseSingleTableSynthesizer): +class GaussianCopulaSynthesizer(BaseSynthesizerModel): """Model wrapping ``copulas.multivariate.GaussianMultivariate`` copula. Args: @@ -66,7 +69,7 @@ class GaussianCopulaSynthesizer(BaseSingleTableSynthesizer): # 这里代表的是几种可选的分布,这里通常使用 beta # 其他分布的内容我们可以以后再了解。 - # 这个应该是属于 类变量的 + # 这个应该是属于 类变量的 _DISTRIBUTIONS = { 'norm': copulas.univariate.GaussianUnivariate, 'beta': copulas.univariate.BetaUnivariate, @@ -102,16 +105,17 @@ def get_distribution_class(cls, distribution): # 初始化方法,参数的类型需要仔细分析一下 def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, locales=None, numerical_distributions=None, default_distribution=None): - super().__init__( - metadata, # fit 数据集的元数据 - enforce_min_max_values=enforce_min_max_values, # 限制最大最小value,一般都是true - enforce_rounding=enforce_rounding, # 保持相同的小数点位数, 一般都是true - locales=locales, # 和语言设置有关,这个暂时不太清楚 - ) - # 验证分布? - # 这个时候还没有输入数据的,只有 metadata + + self.metadata = metadata, # fit 数据集的元数据 + self.enforce_min_max_values = enforce_min_max_values, # 限制最大最小value,一般都是true + self.enforce_rounding = enforce_rounding, # 保持相同的小数点位数, 一般都是true + self.locales = locales, # 和语言设置有关,这个暂时不太清楚 + + # 验证分布? + # 这个时候还没有输入数据的,只有 metadata # 应该是通过 metadata 来验证分布的 - validate_numerical_distributions(numerical_distributions, self.metadata.columns) + validate_numerical_distributions( + numerical_distributions, self.metadata) # 这里的意思是,如果没有指定分布,那么就使用 beta 分布 # 下面分别是两个参数,重新赋值 self.numerical_distributions = numerical_distributions or {} @@ -119,7 +123,8 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, # 以下几行代码是使用 class method 把字符串类型的分布 # 转化为 copulas.univariate 的分布实例 - self._default_distribution = self.get_distribution_class(self.default_distribution) + self._default_distribution = self.get_distribution_class( + self.default_distribution) self._numerical_distributions = { field: self.get_distribution_class(distribution) for field, distribution in self.numerical_distributions.items() @@ -134,13 +139,15 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, # 初步的性能优化方案可以是: # 1. 增加增量学习机制,防止内存消耗过多 # 2. 增加 - def _fit(self, processed_data): - """Fit the model to the table. + def fit(self, processed_data): + + # 载入 transformer + self._transformer = DataTransformerCTGAN() + self._transformer.fit(processed_data, self.metadata[0]) + + # 使用 transformer 处理数据 + processed_data = pd.DataFrame(self._transformer.transform(processed_data)) - Args: - processed_data (pandas.DataFrame): - Data to be learned. - """ # 这个应该是打 log ,不影响实际训练 log_numerical_distributions_error( self.numerical_distributions, processed_data.columns, LOGGER) @@ -186,7 +193,8 @@ def _warn_for_update_transformers(self, column_name_to_transformer): ) # 生成数据函数,这里其实直接调用的 model 的方法 # 没有其他了 - def _sample(self, num_rows, conditions=None): + + def sample(self, num_rows, conditions=None): """Sample the indicated number of rows from the model. Args: @@ -201,7 +209,7 @@ def _sample(self, num_rows, conditions=None): pandas.DataFrame: Sampled data. """ - return self._model.sample(num_rows, conditions=conditions) + return self._transformer.inverse_transform(self._model.sample(num_rows, conditions=conditions).to_numpy()) def _get_valid_columns_from_metadata(self, columns): valid_columns = [] @@ -236,7 +244,8 @@ def get_learned_distributions(self): valid_columns = self._get_valid_columns_from_metadata(columns) for column, learned_params in zip(columns, univariates): if column in valid_columns: - distribution = self.numerical_distributions.get(column, self.default_distribution) + distribution = self.numerical_distributions.get( + column, self.default_distribution) learned_params.pop('type') learned_distributions[column] = { 'distribution': distribution, @@ -265,7 +274,8 @@ def _get_parameters(self): univariate = univariate._instance if univariate.PARAMETRIC == copulas.univariate.ParametricType.NON_PARAMETRIC: - raise NonParametricError('This GaussianCopula uses non parametric distributions') + raise NonParametricError( + 'This GaussianCopula uses non parametric distributions') params = self._model.to_dict() @@ -274,7 +284,8 @@ def _get_parameters(self): correlation.append(row[:index + 1]) params['correlation'] = correlation - params['univariates'] = dict(zip(params.pop('columns'), params['univariates'])) + params['univariates'] = dict( + zip(params.pop('columns'), params['univariates'])) params['num_rows'] = self._num_rows return flatten_dict(params) @@ -378,7 +389,8 @@ def _rebuild_gaussian_copula(self, model_parameters): for column, univariate in model_parameters['univariates'].items(): columns.append(column) univariate['type'] = self.get_distribution_class( - self._numerical_distributions.get(column, self.default_distribution) + self._numerical_distributions.get( + column, self.default_distribution) ) if 'scale' in univariate: univariate['scale'] = max(0, univariate['scale']) @@ -390,7 +402,8 @@ def _rebuild_gaussian_copula(self, model_parameters): correlation = model_parameters.get('correlation') if correlation: - model_parameters['correlation'] = self._rebuild_correlation_matrix(correlation) + model_parameters['correlation'] = self._rebuild_correlation_matrix( + correlation) else: model_parameters['correlation'] = [[1.0]] @@ -409,8 +422,10 @@ def _set_parameters(self, parameters): parameters = unflatten_dict(parameters) if 'num_rows' in parameters: num_rows = parameters.pop('num_rows') - self._num_rows = 0 if pd.isna(num_rows) else max(0, int(round(num_rows))) + self._num_rows = 0 if pd.isna( + num_rows) else max(0, int(round(num_rows))) if parameters: parameters = self._rebuild_gaussian_copula(parameters) - self._model = multivariate.GaussianMultivariate.from_dict(parameters) + self._model = multivariate.GaussianMultivariate.from_dict( + parameters) diff --git a/sdgx/utils/utils.py b/sdgx/utils/utils.py index b245993c..e1d820af 100644 --- a/sdgx/utils/utils.py +++ b/sdgx/utils/utils.py @@ -72,3 +72,172 @@ def set_random_states(random_state, set_model_random_state): np.random.set_state(original_np_state) torch.set_rng_state(original_torch_state) + + +def flatten_array(nested, prefix=''): + """Flatten an array as a dict. + + Args: + nested (list, numpy.array): + Iterable to flatten. + prefix (str): + Name to append to the array indices. Defaults to ``''``. + + Returns: + dict: + Flattened array. + """ + result = {} + for index in range(len(nested)): + prefix_key = '__'.join([prefix, str(index)]) if len(prefix) else str(index) + + value = nested[index] + if isinstance(value, (list, np.ndarray)): + result.update(flatten_array(value, prefix=prefix_key)) + + elif isinstance(value, dict): + result.update(flatten_dict(value, prefix=prefix_key)) + + else: + result[prefix_key] = value + + return result + + +IGNORED_DICT_KEYS = ['fitted', 'distribution', 'type'] + + +def flatten_dict(nested, prefix=''): + """Flatten a dictionary. + + This method returns a flatten version of a dictionary, concatenating key names with + double underscores. + + Args: + nested (dict): + Original dictionary to flatten. + prefix (str): + Prefix to append to key name. Defaults to ``''``. + + Returns: + dict: + Flattened dictionary. + """ + result = {} + + for key, value in nested.items(): + prefix_key = '__'.join([prefix, str(key)]) if len(prefix) else key + + if key in IGNORED_DICT_KEYS and not isinstance(value, (dict, list)): + continue + + elif isinstance(value, dict): + result.update(flatten_dict(value, prefix_key)) + + elif isinstance(value, (np.ndarray, list)): + result.update(flatten_array(value, prefix_key)) + + else: + result[prefix_key] = value + + return result + + +def log_numerical_distributions_error(numerical_distributions, processed_data_columns, logger): + """Log error when numerical distributions columns don't exist anymore.""" + unseen_columns = numerical_distributions.keys() - set(processed_data_columns) + for column in unseen_columns: + logger.info( + f"Requested distribution '{numerical_distributions[column]}' " + f"cannot be applied to column '{column}' because it no longer " + 'exists after preprocessing.' + ) + + +def _key_order(key_value): + parts = [] + for part in key_value[0].split('__'): + if part.isdigit(): + part = int(part) + + parts.append(part) + + return parts + + +def unflatten_dict(flat): + """Transform a flattened dict into its original form. + + Args: + flat (dict): + Flattened dict. + + Returns: + dict: + Nested dict (if corresponds) + """ + unflattened = {} + + for key, value in sorted(flat.items(), key=_key_order): + if '__' in key: + key, subkey = key.split('__', 1) + subkey, name = subkey.rsplit('__', 1) + + if name.isdigit(): + column_index = int(name) + row_index = int(subkey) + + array = unflattened.setdefault(key, []) + + if len(array) == row_index: + row = [] + array.append(row) + elif len(array) == row_index + 1: + row = array[row_index] + else: + # This should never happen + raise ValueError('There was an error unflattening the extension.') + + if len(row) == column_index: + row.append(value) + else: + # This should never happen + raise ValueError('There was an error unflattening the extension.') + + else: + subdict = unflattened.setdefault(key, {}) + if subkey.isdigit(): + subkey = int(subkey) + + inner = subdict.setdefault(subkey, {}) + inner[name] = value + + else: + unflattened[key] = value + + return unflattened + + +def validate_numerical_distributions(numerical_distributions, metadata_columns): + """Validate ``numerical_distributions``. + + Raise an error if it's not None or dict, or if its columns are not present in the metadata. + + Args: + numerical_distributions (dict): + Dictionary that maps field names from the table that is being modeled with + the distribution that needs to be used. + metadata_columns (list): + Columns present in the metadata. + """ + if numerical_distributions: + if not isinstance(numerical_distributions, dict): + raise TypeError('numerical_distributions can only be None or a dict instance.') + + invalid_columns = numerical_distributions.keys() - set(metadata_columns) + # if invalid_columns: + # raise SynthesizerInputError( + # 'Invalid column names found in the numerical_distributions dictionary ' + # f'{invalid_columns}. The column names you provide must be present ' + # 'in the metadata.' + # ) \ No newline at end of file