Skip to content

Commit

Permalink
add ridge regression model
Browse files Browse the repository at this point in the history
  • Loading branch information
ddbourgin committed Jul 3, 2019
1 parent bf40ca3 commit 87d0d21
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 79 deletions.
Binary file modified linear_models/img/plot_bayes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
187 changes: 121 additions & 66 deletions linear_models/lm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numbers
import sys
import numpy as np

sys.path.append("..")
from utils.testing import is_symmetric_positive_definite, is_number


class LinearRegression:
"""
Expand All @@ -13,13 +16,18 @@ class LinearRegression:
y - bX ~ N(0, sigma^2 * I)
y | X, b ~ N(bX, sigma^2 * I)
The loss for the model is simply the squared error between the model
predictions and the true values:
Loss = ||y - bX||^2
The MLE for the model parameters b can be computed in closed form via the
normal equation:
b = (X^T X)^{-1} X^T y
where (X^T X)^{-1} X^T is sometimes called the pseudoinverse or the
Moore-Penrose inverse.
where (X^T X)^{-1} X^T is known as the pseudoinverse / Moore-Penrose
inverse.
"""

def __init__(self, fit_intercept=True):
Expand All @@ -41,6 +49,62 @@ def predict(self, X):
return np.dot(X, self.beta)


class RidgeRegression:
"""
Ridge regression uses the same simple linear regression model but adds an
additional penalty on the L2-norm of the coefficients to the loss function.
This is sometimes known as Tikhonov regularization.
In particular, the ridge model is still simply
y = bX + e where e ~ N(0, sigma^2 * I)
except now the error for the model is calcualted as
RidgeLoss = ||y - bX||^2 + alpha * ||b||^2
The MLE for the model parameters b can be computed in closed form via the
adjusted normal equation:
b = (X^T X + alpha I)^{-1} X^T y
where (X^T X + alpha I)^{-1} X^T is the pseudoinverse / Moore-Penrose
inverse adjusted for the L2 penalty on the model coefficients.
"""

def __init__(self, alpha=1, fit_intercept=True):
"""
A ridge regression model fit via the normal equation.
Parameters
----------
alpha : float (default: 1)
L2 regularization coefficient. Higher values correspond to larger
penalty on the l2 norm of the model coefficients
fit_intercept : bool (default: True)
Whether to fit an additional intercept term in addition to the
model coefficients
"""
self.beta = None
self.alpha = alpha
self.fit_intercept = fit_intercept

def fit(self, X, y):
# convert X to a design matrix if we're fitting an intercept
if self.fit_intercept:
X = np.c_[np.ones(X.shape[0]), X]

A = self.alpha * np.eye(X.shape[1])
pseudo_inverse = np.dot(np.linalg.inv(X.T @ X + A), X.T)
self.beta = pseudo_inverse @ y

def predict(self, X):
# convert X to a design matrix if we're fitting an intercept
if self.fit_intercept:
X = np.c_[np.ones(X.shape[0]), X]
return np.dot(X, self.beta)


class LogisticRegression:
def __init__(self, penalty="l2", gamma=0, fit_intercept=True):
"""
Expand Down Expand Up @@ -87,7 +151,10 @@ def _NLL(self, X, y, y_pred):
Penalized negative log likelihood of the targets under the current
model.
NLL = -1/N ([sum_{i=0}^N y_i log(y_pred_i) + (1-y_i) log(1-y_pred_i)] - (gamma ||b||) / 2)
NLL = -1/N * (
[sum_{i=0}^N y_i log(y_pred_i) + (1-y_i) log(1-y_pred_i)] -
(gamma ||b||) / 2
)
"""
N, M = X.shape
order = 2 if self.penalty == "l2" else 1
Expand All @@ -113,10 +180,32 @@ def predict(self, X):

class BayesianLinearRegressionUnknownVariance:
"""
Bayesian linear regression extends the simple linear regression model by
introducing priors on model parameters b and/or sigma.
Bayesian Linear Regression
--------------------------
In its general form, Bayesian linear regression extends the simple linear
regression model by introducing priors on model parameters b and/or the
error variance sigma^2.
The introduction of a prior allows us to quantify the uncertainty in our
parameter estimates for b by replacing the MLE point estimate in simple
linear regression with an entire posterior *distribution*, p(b | X, y,
sigma), simply by applying Bayes rule:
p(b | X, y) = [ p(y | X, b) * p(b | sigma) ] / p(y | X)
We can also quantify the uncertainty in our predictions y* for some new
data X* with the posterior predictive distribution:
p(y* | X*, X, Y) = \int_{b} p(y* | X*, b) p(b | X, y) db
If both b and error variance sigma^2 are unknown, the conjugate prior
Depending on the choice of prior it may be impossible to compute an
analytic form for the posterior / posterior predictive distribution. In
these cases, it is common to use approximations, either via MCMC or
variational inference.
Bayesian Regression w/ unknown variance
---------------------------------------
If *both* b and the error variance sigma^2 are unknown, the conjugate prior
for the Gaussian likelihood is the Normal-Gamma distribution (univariate
likelihood) or the Normal-Inverse-Wishart distribution (multivariate
likelihood).
Expand All @@ -127,35 +216,19 @@ class BayesianLinearRegressionUnknownVariance:
sigma^2 ~ InverseGamma(alpha, beta)
b | sigma^2 ~ N(b_mean, sigma^2 * b_V)
where alpha, beta, b_V, and mu are hyperparameters of the prior.
where alpha, beta, b_V, and b_mean are parameters of the prior.
Multivariate:
b, Sigma ~ NIW(b_mean, lambda, Psi, rho)
Sigma ~ N(b_mean, 1/lambda * Sigma)
b | Sigma ~ W^{-1}(Psi, rho)
where mu, lambda, Psi, and rho are hyperparameters of the prior.
where b_mean, lambda, Psi, and rho are parameters of the prior.
The introduction of a prior allows us to quantify the uncertainty in our
parameter estimates for b by replacing the MLE point estimate in simple
linear regression with an entire posterior *distribution*, p(b | X, y,
sigma), simply by applying Bayes rule:
p(b | X, y) = [ p(y | X, b) * p(b | sigma) ] / p(y | X)
We can also quantify the uncertainty in our predictions y* for some new
data X* with the posterior predictive distribution:
p(y* | X*, X, Y) = \int_{b} p(y* | X*, b) p(b | X, y) db
Depending on the choice of prior it may be impossible to compute an
analytic form for the posterior / posterior predictive distribution. In
these cases, it is common to use approximations, either via MCMC or
variational inference.
Thankfully, however, for the above prior we *can* compute the posterior
distributions for the model parameters in closed form:
Due to the conjugacy of the above priors with the Gaussian likelihood of
the linear regression model we can compute the posterior distributions for
the model parameters in closed form:
B = (y - X b_mean)
shape = N + alpha
Expand All @@ -169,7 +242,7 @@ class BayesianLinearRegressionUnknownVariance:
b | X, y, sigma^2 ~ N(mu_b, cov_b)
which allows us a closed form for the posterior predictive distribution as
This allows us a closed form for the posterior predictive distribution as
well:
y* | X*, X, Y ~ N(X* mu_b, X* cov_b X*^T + I)
Expand Down Expand Up @@ -286,17 +359,11 @@ def predict(self, X):

class BayesianLinearRegressionKnownVariance:
"""
Bayesian linear regression extends the simple linear regression model by
introducing priors on model parameters b and/or sigma.
If we happen to already know the error variance sigma^2, the conjugate
prior on b is Gaussian. A common parameterization is:
b | sigma, b_V ~ N(b_mean, sigma^2 * b_V)
where b_mean, sigma and b_V are hyperparameters. Ridge regression is a
special case of this model where b_mean = 0, sigma = 1 and b_V = I (ie.,
the prior on b is a zero-mean, unit covariance Gaussian).
Bayesian Linear Regression
--------------------------
In its general form, Bayesian linear regression extends the simple linear
regression model by introducing priors on model parameters b and/or the
error variance sigma^2.
The introduction of a prior allows us to quantify the uncertainty in our
parameter estimates for b by replacing the MLE point estimate in simple
Expand All @@ -315,8 +382,20 @@ class BayesianLinearRegressionKnownVariance:
these cases, it is common to use approximations, either via MCMC or
variational inference.
Thankfully, however, for the above prior we *can* compute the posterior
distribution over the model parameters in closed form:
Bayesian linear regression with known variance
----------------------------------------------
If we happen to already know the error variance sigma^2, the conjugate
prior on b is Gaussian. A common parameterization is:
b | sigma, b_V ~ N(b_mean, sigma^2 * b_V)
where b_mean, sigma and b_V are hyperparameters. Ridge regression is a
special case of this model where b_mean = 0, sigma = 1 and b_V = I (ie.,
the prior on b is a zero-mean, unit covariance Gaussian).
Due to the conjugacy of the above prior with the Gaussian likelihood in the
linear regression model, we can compute the posterior distribution over the
model parameters in closed form:
A = (b_V^{-1} + X^T X)^{-1}
mu_b = A b_V^{-1} b_mean + A X^T y
Expand Down Expand Up @@ -429,27 +508,3 @@ def predict(self, X):

def sigmoid(x):
return 1 / (1 + np.exp(-x))


def is_symmetric(X):
return np.allclose(X, X.T)


def is_number(x):
return isinstance(x, numbers.Number)


def is_symmetric_positive_definite(X):
"""
Check that X is a symmetric, positive-definite matrix
"""
if is_symmetric(X):
try:
# if matrix is symmetric, check whether the Cholesky decomposition
# (defined only for symmetric/Hermitian positive definite matrices)
# exists
np.linalg.cholesky(X)
return True
except np.linalg.LinAlgError:
return False
return False
43 changes: 30 additions & 13 deletions linear_models/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


from lm import (
RidgeRegression,
LinearRegression,
BayesianLinearRegressionKnownVariance,
BayesianLinearRegressionUnknownVariance,
Expand Down Expand Up @@ -133,6 +134,11 @@ def plot_bayes():
y_pred = LR.predict(X_test)
loss = np.mean((y_test - y_pred) ** 2)

ridge = RidgeRegression(alpha=1, fit_intercept=True)
ridge.fit(X_train, y_train)
y_pred = ridge.predict(X_test)
loss_ridge = np.mean((y_test - y_pred) ** 2)

LR_var = BayesianLinearRegressionKnownVariance(
b_mean=np.c_[intercept, coefs][0],
b_sigma=np.sqrt(std),
Expand All @@ -154,11 +160,13 @@ def plot_bayes():
xmax = max(X_test) + 0.1 * (max(X_test) - min(X_test))
X_plot = np.linspace(xmin, xmax, 100)
y_plot = LR.predict(X_plot)
y_plot_ridge = ridge.predict(X_plot)
y_plot_var = LR_var.predict(X_plot)
y_plot_novar = LR_novar.predict(X_plot)

y_true = [np.dot(x, coefs) + intercept for x in X_plot]
fig, axes = plt.subplots(1, 3)
fig, axes = plt.subplots(1, 4)

axes = axes.flatten()
axes[0].scatter(X_test, y_test)
axes[0].plot(X_plot, y_plot, label="MLE")
Expand All @@ -167,22 +175,17 @@ def plot_bayes():
axes[0].legend()
# axes[0].fill_between(X_plot, y_plot - error, y_plot + error)

axes[1].plot(X_plot, y_plot_var, label="MAP")
mu, cov = LR_var.posterior["b"]["mu"], LR_var.posterior["b"]["cov"]
for k in range(200):
b_samp = np.random.multivariate_normal(mu, cov)
y_samp = [np.dot(x, b_samp[1]) + b_samp[0] for x in X_plot]
axes[1].plot(X_plot, y_samp, c="green", alpha=0.05)
axes[1].scatter(X_test, y_test)
axes[1].plot(X_plot, y_plot_ridge, label="MLE")
axes[1].plot(X_plot, y_true, label="True fn")
axes[1].legend()
axes[1].set_title(
"Bayesian Regression (known variance)\nMAP Test MSE: {:.2f}".format(loss_var)
"Ridge Regression (alpha=1)\nMLE Test MSE: {:.2f}".format(loss_ridge)
)
axes[1].legend()
print("plotted ridge.. {:.2f} MSE".format(loss_ridge))

axes[2].plot(X_plot, y_plot_novar, label="MAP")
mu = LR_novar.posterior["b | sigma**2"]["mu"]
cov = LR_novar.posterior["b | sigma**2"]["cov"]
axes[2].plot(X_plot, y_plot_var, label="MAP")
mu, cov = LR_var.posterior["b"]["mu"], LR_var.posterior["b"]["cov"]
for k in range(200):
b_samp = np.random.multivariate_normal(mu, cov)
y_samp = [np.dot(x, b_samp[1]) + b_samp[0] for x in X_plot]
Expand All @@ -191,6 +194,20 @@ def plot_bayes():
axes[2].plot(X_plot, y_true, label="True fn")
axes[2].legend()
axes[2].set_title(
"Bayesian Regression (known variance)\nMAP Test MSE: {:.2f}".format(loss_var)
)

axes[3].plot(X_plot, y_plot_novar, label="MAP")
mu = LR_novar.posterior["b | sigma**2"]["mu"]
cov = LR_novar.posterior["b | sigma**2"]["cov"]
for k in range(200):
b_samp = np.random.multivariate_normal(mu, cov)
y_samp = [np.dot(x, b_samp[1]) + b_samp[0] for x in X_plot]
axes[3].plot(X_plot, y_samp, c="green", alpha=0.05)
axes[3].scatter(X_test, y_test)
axes[3].plot(X_plot, y_true, label="True fn")
axes[3].legend()
axes[3].set_title(
"Bayesian Regression (unknown variance)\nMAP Test MSE: {:.2f}".format(
loss_novar
)
Expand All @@ -201,7 +218,7 @@ def plot_bayes():
ax.yaxis.set_ticklabels([])

# plt.tight_layout()
fig.set_size_inches(7, 2.5)
fig.set_size_inches(10, 2.5)
plt.savefig("plot_bayes.png", dpi=300)
plt.close("all")

Expand Down

0 comments on commit 87d0d21

Please sign in to comment.