Skip to content

Commit

Permalink
Merge branch 'master' of github.com:jeongyoonlee/Kaggler
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongyoonlee committed Jun 10, 2021
2 parents ca4dd68 + 768f865 commit f4ac8cf
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 35 deletions.
2 changes: 1 addition & 1 deletion kaggler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.9.9'
__version__ = '0.9.10'
__all__ = ['const',
'data_io',
'ensemble',
Expand Down
106 changes: 73 additions & 33 deletions kaggler/preprocessing/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,13 @@ def mask_inputs():
class DAELayer(Layer):
"""A DAE layer with one pair of the encoder and decoder."""

def __init__(self, encoding_dim=128, noise_std=.0, swap_prob=.2, mask_prob=.0, seed=42, **kwargs):
def __init__(self, encoding_dim=128, n_encoder=1, noise_std=.0, swap_prob=.2, mask_prob=.0, seed=42,
**kwargs):
"""Initialize a DAE (Denoising AutoEncoder) layer.
Args:
encoding_dim (int): the numbers of hidden units in encoding/decoding layers.
n_encoder (int): the numbers of hidden encoding layers.
noise_std (float): standard deviation of gaussian noise to be added to features.
swap_prob (float): probability to add swap noise to features.
mask_prob (float): probability to add zero masking to features.
Expand All @@ -108,12 +110,14 @@ def __init__(self, encoding_dim=128, noise_std=.0, swap_prob=.2, mask_prob=.0, s
super().__init__(**kwargs)

self.encoding_dim = encoding_dim
self.n_encoder = n_encoder
self.noise_std = noise_std
self.swap_prob = swap_prob
self.mask_prob = mask_prob
self.seed = seed

self.encoder = Dense(encoding_dim, activation='relu', name=f'{self.name}_encoder')
self.encoders = [Dense(encoding_dim, activation='relu', name=f'{self.name}_encoder_{i}')
for i in range(self.n_encoder)]

def build(self, input_shape):
self.input_dim = input_shape[-1]
Expand All @@ -136,7 +140,13 @@ def call(self, inputs, training):
masked_inputs = ZeroNoiseMasker(probs=[self.mask_prob] * self.input_dim,
seed=[self.seed] * 2)(masked_inputs)

encoded = self.encoder(masked_inputs)
x = masked_inputs
encoded_list = []
for encoder in self.encoders:
x = encoder(x)
encoded_list.append(x)

encoded = Concatenate()(encoded_list) if len(encoded_list) > 1 else encoded_list[0]
decoded = self.decoder(encoded)

rec_loss = K.mean(mean_squared_error(inputs, decoded))
Expand All @@ -147,6 +157,7 @@ def call(self, inputs, training):
def get_config(self):
config = super().get_config().copy()
config.update({'encoding_dim': self.encoding_dim,
'n_encoder': self.n_encoder,
'noise_std': self.noise_std,
'swap_prob': self.swap_prob,
'mask_prob': self.mask_prob,
Expand All @@ -156,9 +167,9 @@ def get_config(self):
class DAE(base.BaseEstimator):
"""Denoising AutoEncoder feature transformer."""

def __init__(self, cat_cols=[], num_cols=[], n_emb=[], encoding_dim=128, n_layer=1, noise_std=.0,
swap_prob=.2, mask_prob=.0, dropout=.2, min_obs=10, n_epoch=100, batch_size=1024,
random_state=42):
def __init__(self, cat_cols=[], num_cols=[], n_emb=[], encoding_dim=128, n_layer=1, n_encoder=1,
noise_std=.0, swap_prob=.2, mask_prob=.0, dropout=.2, min_obs=1, n_epoch=10, batch_size=1024,
random_state=42, label_encoding=False):
"""Initialize a DAE (Denoising AutoEncoder) class object.
Args:
Expand All @@ -167,6 +178,7 @@ def __init__(self, cat_cols=[], num_cols=[], n_emb=[], encoding_dim=128, n_layer
n_emb (int or list of int): the numbers of embedding features used for columns.
encoding_dim (int): the numbers of hidden units in encoding/decoding layers.
n_layer (int): the numbers of the encoding/decoding layer pairs
n_encoder (int): the numbers of encoding layers in each of the encoding/decoding pairs
noise_std (float): standard deviation of gaussian noise to be added to features.
swap_prob (float): probability to add swap noise to features.
mask_prob (float): probability to add zero masking to features.
Expand All @@ -175,6 +187,7 @@ def __init__(self, cat_cols=[], num_cols=[], n_emb=[], encoding_dim=128, n_layer
n_epoch (int): the number of epochs to train a neural network with embedding layer
batch_size (int): the size of mini-batches in model training
random_state (int or np.RandomState): random seed.
label_encoding (bool): to label-encode categorical columns (True) or not (False)
"""
assert cat_cols or num_cols
self.cat_cols = cat_cols
Expand All @@ -191,9 +204,10 @@ def __init__(self, cat_cols=[], num_cols=[], n_emb=[], encoding_dim=128, n_layer
else:
raise ValueError('n_emb should be int or list')

assert (encoding_dim > 0) and (n_layer > 0)
assert (encoding_dim > 0) and (n_layer > 0) and (n_encoder > 0)
self.encoding_dim = encoding_dim
self.n_layer = n_layer
self.n_encoder = n_encoder

assert (0. <= noise_std) and (0. <= swap_prob < 1.) and (0. <= mask_prob < 1.) and (0. <= dropout < 1.)
self.noise_std = noise_std
Expand All @@ -215,7 +229,9 @@ def __init__(self, cat_cols=[], num_cols=[], n_emb=[], encoding_dim=128, n_layer

# Get an integer seed from np.random.RandomState to use it for tensorflow
self.seed = self.random_state_.get_state()[1][0]
self.lbe = LabelEncoder(min_obs=min_obs)
self.label_encoding = label_encoding
if self.label_encoding:
self.lbe = LabelEncoder(min_obs=min_obs)

def build_model(self, X, y=None):
inputs = []
Expand All @@ -242,48 +258,71 @@ def build_model(self, X, y=None):

inputs = inputs + [num_inputs]
else:
merged_inputs = Concatenate()(embeddings)
merged_inputs = Concatenate()(embeddings) if len(embeddings) > 1 else embeddings[0]

dae_layers = []
for i in range(self.n_layer):
dae_layers.append(DAELayer(encoding_dim=self.encoding_dim, noise_std=self.noise_std,
swap_prob=self.swap_prob, mask_prob=self.mask_prob,
dae_layers.append(DAELayer(encoding_dim=self.encoding_dim, n_encoder=self.n_encoder,
noise_std=self.noise_std, swap_prob=self.swap_prob, mask_prob=self.mask_prob,
seed=self.seed, name=f'dae_layer_{i}'))

encoded, decoded = dae_layers[i](merged_inputs)
_, merged_inputs = dae_layers[i](merged_inputs, training=False)

self.encoder = Model(inputs=inputs, outputs=encoded, name='encoder_head')
self.dae = Model(inputs=inputs, outputs=decoded, name='decoder_head')
self.encoder = Model(inputs=inputs, outputs=encoded, name='encoder_model')
self.dae = Model(inputs=inputs, outputs=decoded, name='decoder_model')
self.dae.compile(optimizer='adam')

def fit(self, X, y=None):
def fit(self, X, y=None, validation_data=None):
"""Train DAE
Args:
X (pandas.DataFrame): features to encode
y (pandas.Series, optional): not used
validation_data (list of pandas.DataFrame and pandas.Series): validation features and target
Returns:
A trained DAE object.
"""
if self.cat_cols:
if validation_data is not None:
if y is None:
X_val = validation_data[0]
y_val = None
else:
X_val, y_val = validation_data

if self.cat_cols and self.label_encoding:
X[self.cat_cols] = self.lbe.fit_transform(X[self.cat_cols])

if validation_data is not None:
X_val[self.cat_cols] = self.lbe.fit_transform(X_val[self.cat_cols])

self.build_model(X, y)

features = [X[col].values for col in self.cat_cols]
if self.num_cols:
features += [X[self.num_cols].values]

if validation_data is not None:
features_val = [X_val[col].values for col in self.cat_cols]
if self.num_cols:
features_val += [X_val[self.num_cols].values]

es = EarlyStopping(monitor='val_loss', min_delta=.0, patience=5, verbose=1, mode='min',
baseline=None, restore_best_weights=True)
rlr = ReduceLROnPlateau(monitor='val_loss', factor=.5, patience=3, min_lr=1e-6, mode='min')
self.dae.fit(x=features, y=y,
epochs=self.n_epoch,
validation_split=.2,
batch_size=self.batch_size,
callbacks=[es, rlr])
if validation_data is None:
self.dae.fit(x=features, y=y,
epochs=self.n_epoch,
validation_split=.2,
batch_size=self.batch_size,
callbacks=[es, rlr])
else:
self.dae.fit(x=features, y=y,
epochs=self.n_epoch,
validation_data=(features_val, y_val),
batch_size=self.batch_size,
callbacks=[es, rlr])

def transform(self, X):
"""Encode features using the DAE trained
Expand All @@ -295,7 +334,7 @@ def transform(self, X):
Encoding matrix for features
"""
X = X.copy()
if self.cat_cols:
if self.cat_cols and self.label_encoding:
X[self.cat_cols] = self.lbe.transform(X[self.cat_cols])

features = [X[col].values for col in self.cat_cols]
Expand All @@ -304,17 +343,18 @@ def transform(self, X):

return self.encoder.predict(features)

def fit_transform(self, X, y=None):
def fit_transform(self, X, y=None, validation_data=None):
"""Train DAE and encode features using the DAE trained
Args:
X (pandas.DataFrame): features to encode
y (pandas.Series, optional): not used
validation_data (list of pandas.DataFrame and pandas.Series): validation features and target
Returns:
Encoding matrix for features
"""
self.fit(X, y)
self.fit(X, y, validation_data)
return self.transform(X)


Expand Down Expand Up @@ -346,18 +386,19 @@ def build_model(self, X, y=None):

inputs = inputs + [num_inputs]
else:
merged_inputs = Concatenate()(embeddings)
merged_inputs = Concatenate()(embeddings) if len(embeddings) > 1 else embeddings[0]

dae_layers = []
for i in range(self.n_layer):
dae_layers.append(DAELayer(encoding_dim=self.encoding_dim, noise_std=self.noise_std,
swap_prob=self.swap_prob, mask_prob=self.mask_prob,
dae_layers.append(DAELayer(encoding_dim=self.encoding_dim, n_encoder=self.n_encoder,
noise_std=self.noise_std, swap_prob=self.swap_prob, mask_prob=self.mask_prob,
seed=self.seed, name=f'dae_layer_{i}'))

encoded, decoded = dae_layers[i](merged_inputs)
_, merged_inputs = dae_layers[i](merged_inputs, training=False)

self.encoder = Model(inputs=inputs, outputs=encoded, name='encoder_head')
self.encoder = Model(inputs=inputs, outputs=encoded, name='encoder_model')
self.decoder = Model(inputs=inputs, outputs=decoded, name='decoder_model')

if y.dtype in [np.int32, np.int64]:
n_uniq = len(np.unique(y))
Expand All @@ -375,24 +416,23 @@ def build_model(self, X, y=None):
self.output_loss = 'mean_squared_error'

# supervised head
supervised_inputs = Input((self.encoding_dim,), name='supervised_inputs')
x = Dense(1024, 'relu')(supervised_inputs)
x = Dense(1024, 'relu')(encoded)
x = Dropout(.3)(x)
supervised_outputs = Dense(self.n_class, activation=self.output_activation)(x)
self.supervised = Model(inputs=supervised_inputs, outputs=supervised_outputs, name='supervised_head')

self.dae = Model(inputs=inputs, outputs=self.supervised(self.encoder(inputs)), name='decoder_head')
self.dae = Model(inputs=inputs, outputs=supervised_outputs, name='supervised_model')
self.dae.compile(optimizer='adam', loss=self.output_loss)

def fit(self, X, y):
def fit(self, X, y, validation_data=None):
"""Train supervised DAE
Args:
X (pandas.DataFrame): features to encode
y (pandas.Series): target variable
validation_data (list of pandas.DataFrame and pandas.Series): validation features and target
Returns:
A trained SDAE object.
"""
assert y is not None, 'SDAE needs y (target variable) for fit()'
super().fit(X, y)
super().fit(X, y, validation_data)
42 changes: 41 additions & 1 deletion tests/test_encoders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from kaggler.preprocessing import DAE, SDAE, TargetEncoder, EmbeddingEncoder, FrequencyEncoder
from sklearn.model_selection import KFold
from sklearn.model_selection import KFold, train_test_split

from .const import RANDOM_SEED, TARGET_COL

Expand All @@ -20,6 +20,35 @@ def test_DAE(generate_data):
assert X.shape[1] == encoding_dim


def test_DAE_with_validation_data(generate_data):
encoding_dim = 10

df = generate_data()
feature_cols = [x for x in df.columns if x != TARGET_COL]
cat_cols = [x for x in feature_cols if df[x].nunique() < 100]
num_cols = [x for x in feature_cols if x not in cat_cols]

dae = DAE(cat_cols=cat_cols, num_cols=num_cols, encoding_dim=encoding_dim, random_state=RANDOM_SEED)
trn, val = train_test_split(df, test_size=.2, shuffle=True, random_state=RANDOM_SEED)
X = dae.fit_transform(trn[feature_cols], validation_data=[val[feature_cols]])
assert X.shape[1] == encoding_dim


def test_DAE_with_multiple_encoders(generate_data):
encoding_dim = 10
n_encoder = 3

df = generate_data()
feature_cols = [x for x in df.columns if x != TARGET_COL]
cat_cols = [x for x in feature_cols if df[x].nunique() < 100]
num_cols = [x for x in feature_cols if x not in cat_cols]

dae = DAE(cat_cols=cat_cols, num_cols=num_cols, encoding_dim=encoding_dim, n_encoder=n_encoder,
random_state=RANDOM_SEED)
X = dae.fit_transform(df[feature_cols])
assert X.shape[1] == encoding_dim * n_encoder


def test_SDAE(generate_data):
encoding_dim = 10

Expand All @@ -32,6 +61,17 @@ def test_SDAE(generate_data):
X = dae.fit_transform(df[feature_cols], df[TARGET_COL])
assert X.shape[1] == encoding_dim

trn, val = train_test_split(df, test_size=.2, shuffle=True, random_state=RANDOM_SEED)
X = dae.fit_transform(trn[feature_cols], trn[TARGET_COL],
validation_data=(val[feature_cols], val[TARGET_COL]))
assert X.shape[1] == encoding_dim

n_encoder = 3
dae = SDAE(cat_cols=cat_cols, num_cols=num_cols, encoding_dim=encoding_dim, n_encoder=n_encoder,
random_state=RANDOM_SEED)
X = dae.fit_transform(df[feature_cols], df[TARGET_COL])
assert X.shape[1] == encoding_dim * n_encoder


def test_TargetEncoder(generate_data):
df = generate_data()
Expand Down

0 comments on commit f4ac8cf

Please sign in to comment.