Skip to content

Commit

Permalink
Add mutual information metric (hitsz-ids#101)
Browse files Browse the repository at this point in the history
* test

* test_v2

* no-test

* pair_v1

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove_old_mi_sim

* modify single&multi_table MISim

* modify single_mi_sim by using pair_sim instance

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify multi_mi_sim by using pair_sim instance

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change_class_name_err

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify_paircolumn

* mi only needs dataframe

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify based on review

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* complete test_mi_sim

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify test file

* change_var_name

* Update sdgx/metrics/multi_table/multitable_mi_sim.py

Co-authored-by: MoooCat <[email protected]>

* add MULTI_TABLE_DEMO_DATA

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify comments

* JSD->MISIM

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* modify base of pair_column

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add cls

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change self into cls instance

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change cls

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* series2array

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test

* test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add label_encoder for category in mi_sim

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use series.array

* change le_fit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change transform type to np.array instead of list

* add astype

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* series2array

* foo

* change test_suit

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* all right?

* all right

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Z712023 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Z712023 <[email protected]>
  • Loading branch information
4 people authored Jan 16, 2024
1 parent d29a2a0 commit dae869e
Show file tree
Hide file tree
Showing 9 changed files with 398 additions and 6 deletions.
2 changes: 0 additions & 2 deletions sdgx/metrics/column/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,8 @@ def calculate(
cls, real_data: pd.Series | pd.DataFrame, synthetic_data: pd.Series | pd.DataFrame
):
"""Calculate the metric value between columns between real table and synthetic table.
Args:
real_data(pd.DataFrame or pd.Series): the real (original) data table / column.
synthetic_data(pd.DataFrame or pd.Series): the synthetic (generated) data table / column.
"""
# This method should first check the input
Expand Down
2 changes: 1 addition & 1 deletion sdgx/metrics/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def check_output(raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
raw_metric_value (float): the calculated raw value of the Mutual Information Similarity.
"""
raise NotImplementedError()

Expand Down
71 changes: 71 additions & 0 deletions sdgx/metrics/multi_table/multitable_mi_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import pandas as pd
from scipy.stats import entropy
from sklearn.metrics.cluster import normalized_mutual_info_score

from sdgx.metrics.multi_table.base import MultiTableMetric
from sdgx.metrics.pair_column.mi_sim import MISim


class MISim(MultiTableMetric):
"""MISim : Mutual Information Similarity
This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data.
Currently, we support discrete and continuous(need to be discretized) columns as inputs.
"""

def __init__(self) -> None:
super().__init__()
self.lower_bound = 0
self.upper_bound = 1
self.metric_name = "mutual_information_similarity"
self.numerical_bins = 50

@classmethod
def calculate(
real_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata: dict
) -> pd.DataFrame:
"""
Calculate the Mutual Information Similarity between a real column and a synthetic column.
Args:
real_data (pd.DataFrame): The real data.
synthetic_data (pd.DataFrame): The synthetic data.
metadata(dict): The metadata that describes the data type of each column
Returns:
MI_similarity (float): The metric value.
"""

# 传入概率分布数组

columns = synthetic_data.columns
n = len(columns)
mi_sim_instance = MISim()
nMI_sim = np.zeros((n, n))

for i in range(len(columns)):
for j in range(len(columns)):
syn_data = pd.concat(
[synthetic_data[columns[i]], synthetic_data[columns[j]]], axis=1
)
real_data = pd.concat([real_data[columns[i]], real_data[columns[j]]], axis=1)

nMI_sim[i][j] = mi_sim_instance.calculate(real_data, syn_data, metadata)

MI_sim = np.sum(nMI_sim) / n / n
# test
MISim.check_output(MI_sim)

return MI_sim

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the Mutual Information Similarity.
"""
instance = cls()
if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound:
raise ValueError
75 changes: 75 additions & 0 deletions sdgx/metrics/pair_column/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pandas as pd

from sdgx.log import logger


class PairMetric(object):
"""PairMetric
Metrics used to evaluate the quality of synthetic data columns.
"""

upper_bound = None
lower_bound = None
metric_name = "Correlation"

def __init__(self) -> None:
pass

@classmethod
def check_input(cls, src_col: pd.Series, tar_col: pd.Series, metadata: dict):
"""Input check for table input.
Args:
src_data(pd.Series ): the source data column.
tar_data(pd.Series): the target data column .
metadata(dict): The metadata that describes the data type of each column
"""
# Input parameter must not contain None value
if real_data is None or synthetic_data is None:
raise TypeError("Input contains None.")
# check column_names
tar_name = tar_col.name
src_name = src_col.name

# check column_types
if metadata[tar_name] != metadata[src_name]:
raise TypeError("Type of Pair is Conflicting.")

# if type is pd.Series, return directly
if isinstance(real_data, pd.Series):
return src_col, tar_col

# if type is not pd.Series or pd.DataFrame tranfer it to Series
try:
src_col = pd.Series(src_col)
tar_col = pd.Series(tar_col)
return src_col, tar_col
except Exception as e:
logger.error(f"An error occurred while converting to pd.Series: {e}")

return None, None

@classmethod
def calculate(cls, src_col: pd.Series, tar_col: pd.Series, metadata):
"""Calculate the metric value between pair-columns between real table and synthetic table.
Args:
src_data(pd.Series ): the source data column.
tar_data(pd.Series): the target data column .
metadata(dict): The metadata that describes the data type of each column
"""
# This method should first check the input
# such as:
real_data, synthetic_data = PairMetric.check_input(src_col, tar_col)

raise NotImplementedError()

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the Mutual Information.
"""
raise NotImplementedError()

pass
97 changes: 97 additions & 0 deletions sdgx/metrics/pair_column/mi_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np
import pandas as pd
from scipy.stats import entropy
from sklearn.metrics.cluster import normalized_mutual_info_score
from sklearn.preprocessing import LabelEncoder

from sdgx.metrics.pair_column.base import PairMetric
from sdgx.utils import time2int


class MISim(PairMetric):
"""MISim : Mutual Information Similarity
This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data.
Currently, we support discrete and continuous(need to be discretized) columns as inputs.
"""

def __init__(instance) -> None:
super().__init__()
instance.lower_bound = 0
instance.upper_bound = 1
instance.metric_name = "mutual_information_similarity"
instance.numerical_bins = 50

@classmethod
def calculate(
cls,
src_col: pd.Series,
tar_col: pd.Series,
metadata: dict,
) -> float:
"""
Calculate the MI similarity for the source data colum and the target data column.
Args:
src_data(pd.Series ): the source data column.
tar_data(pd.Series): the target data column .
metadata(dict): The metadata that describes the data type of each columns
Returns:
MI_similarity (float): The metric value.
"""

# 传入概率分布数组
instance = cls()

col_name = src_col.name
data_type = metadata[col_name]

if data_type == "numerical":
x = np.array(src_col.array)
src_col = pd.cut(
x,
instance.numerical_bins,
labels=range(instance.numerical_bins),
)
x = np.array(tar_col.array)
tar_col = pd.cut(
x,
instance.numerical_bins,
labels=range(instance.numerical_bins),
)
src_col = src_col.to_numpy()
tar_col = tar_col.to_numpy()

elif data_type == "category":
le = LabelEncoder()
src_list = list(set(src_col.array))
tar_list = list(set(tar_col.array))
fit_list = tar_list + src_list
le.fit(fit_list)

src_col = le.transform(np.array(src_col.array))
tar_col = le.transform(np.array(tar_col.array))

elif data_type == "datetime":
src_col = src_col.apply(time2int)
tar_col = tar_col.apply(time2int)
src_col = pd.cut(
src_col, bins=instance.numerical_bins, labels=range(instance.numerical_bins)
)
tar_col = pd.cut(
tar_col, bins=instance.numerical_bins, labels=range(instance.numerical_bins)
)
src_col = src_col.to_numpy()
tar_col = tar_col.to_numpy()

MI_sim = normalized_mutual_info_score(src_col, tar_col)
return MI_sim

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the MI similarity.
"""
pass
4 changes: 2 additions & 2 deletions sdgx/metrics/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def check_input(cls, real_data: pd.DataFrame, synthetic_data: pd.DataFrame):

return None, None

def calculate(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame):
def calculate(cls, real_data: pd.DataFrame, synthetic_data: pd.DataFrame):
"""Calculate the metric value between a real table and a synthetic table.
Args:
Expand All @@ -71,7 +71,7 @@ def check_output(raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the JSD metric.
raw_metric_value (float): the calculated raw value of the Mutual Information Similarity.
"""
raise NotImplementedError()

Expand Down
67 changes: 67 additions & 0 deletions sdgx/metrics/single_table/single_mi_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import pandas as pd
from scipy.stats import entropy
from sklearn.metrics.cluster import normalized_mutual_info_score

from sdgx.metrics.pair_column.mi_sim import MISim
from sdgx.metrics.single_table.base import SingleTableMetric


class SinTabMISim(SingleTableMetric):
"""MISim : Mutual Information Similarity
This class is used to calculate the Mutual Information Similarity between the target columns of real data and synthetic data.
Currently, we support discrete and continuous(need to be discretized) columns as inputs.
"""

def __init__(self) -> None:
super().__init__()
self.lower_bound = 0
self.upper_bound = 1
self.metric_name = "mutual_information_similarity"
self.numerical_bins = 50

@classmethod
def calculate(real_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata) -> pd.DataFrame:
"""
Calculate the Mutual Information Similarity between a real column and a synthetic column.
Args:
real_data (pd.DataFrame): The real data.
synthetic_data (pd.DataFrame): The synthetic data.
metadata(dict): The metadata that describes the data type of each column
Returns:
MI_similarity (float): The metric value.
"""

# 传入概率分布数组

columns = synthetic_data.columns
n = len(columns)
mi_sim_instance = MISim()
nMI_sim = np.zeros((n, n))

for i in range(len(columns)):
for j in range(len(columns)):
syn_data = pd.concat(
[synthetic_data[columns[i]], synthetic_data[columns[j]]], axis=1
)
real_data = pd.concat([real_data[columns[i]], real_data[columns[j]]], axis=1)

nMI_sim[i][j] = mi_sim_instance.calculate(real_data, syn_data, metadata)

MI_sim = np.sum(nMI_sim) / n / n
MISim.check_output(MI_sim)

return MI_sim

@classmethod
def check_output(cls, raw_metric_value: float):
"""Check the output value.
Args:
raw_metric_value (float): the calculated raw value of the Mutual Information Similarity.
"""
instance = cls()
if raw_metric_value < instance.lower_bound or raw_metric_value > instance.upper_bound:
raise ValueError
9 changes: 8 additions & 1 deletion sdgx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import socket
import threading
import time
import urllib.request
import warnings
from contextlib import closing
Expand All @@ -26,8 +27,8 @@
"find_free_port",
"download_multi_table_demo_data",
"get_demo_single_table",
"time2int",
]

MULTI_TABLE_DEMO_DATA = {
"rossman": {
"parent_table": "store",
Expand Down Expand Up @@ -99,6 +100,12 @@ def get_demo_single_table(data_dir: str | Path = "./dataset"):
return pd_obj, discrete_cols


def time2int(datetime, form):
time_array = time.strptime(datetime, form)
time_stamp = int(time.mktime(time_array))
return time_stamp


class Singleton(type):
"""
metaclass for singleton, thread-safe.
Expand Down
Loading

0 comments on commit dae869e

Please sign in to comment.