-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsvgp.py
153 lines (130 loc) · 5.02 KB
/
svgp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import matplotlib.pyplot as plt
import torch
import numpy as np
import pyro
import pyro.contrib.gp as gp
import pyro.distributions as dist
from matplotlib.animation import FuncAnimation
from mpl_toolkits.axes_grid1 import make_axes_locatable
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
smoke_test = "CI" in os.environ
assert pyro.__version__.startswith('1.8.4')
pyro.set_rng_seed(0)
def plot(
X_inputs, y_inputs,
y_index=None,
plot_observed_data=False,
plot_predictions=False,
n_prior_samples=0,
model=None,
kernel=None,
n_test=500,
ax=None,
):
if y_index is not None:
y_inputs = y_inputs[y_index]
if ax is None:
fig, ax = plt.subplots(figsize=(12, 6))
if plot_observed_data:
ax.plot(X_inputs.numpy(), y_inputs.numpy(), "kx")
if plot_predictions:
Xtest = X_inputs # test inputs
# compute predictive mean and variance
with torch.no_grad():
mean, cov = model(Xtest, full_cov=False)
if y_index is not None:
mean = mean[y_index]
cov = cov[y_index]
sd = cov**0.5 # standard deviation at each input point x
ax.plot(Xtest.numpy(), mean.numpy(), "r", lw=2) # plot the mean
ax.fill_between(
Xtest.numpy(), # plot the two-sigma uncertainty about the mean
(mean - 2.0 * sd).numpy(),
(mean + 2.0 * sd).numpy(),
color="C0",
alpha=0.3,
)
if n_prior_samples > 0: # plot samples from the GP prior
Xtest = torch.linspace(-0.5, 5.5, n_test) # test inputs
noise = (
model.noise
if type(model) != gp.models.VariationalSparseGP
else model.likelihood.variance
)
cov = kernel.forward(Xtest) + noise.expand(n_test).diag()
samples = dist.MultivariateNormal(
torch.zeros(n_test), covariance_matrix=cov
).sample(sample_shape=(n_prior_samples,))
ax.plot(Xtest.numpy(), samples.numpy().T, lw=2, alpha=0.4)
#ax.set_xlim(-0.5, 5.5)
def plot_inducing_points(Xu, ax=None):
for xu in Xu:
g = ax.axvline(xu, color="red", linestyle="-.", alpha=0.5)
ax.legend(
handles=[g],
labels=["Inducing Point Locations"],
bbox_to_anchor=(0.5, 1.15),
loc="upper center",
)
def plot_loss(loss):
plt.plot(loss)
plt.xlabel("Iterations")
_ = plt.ylabel("Loss") # supress output text
def main():
N = 1000
X = dist.Uniform(0.0, 5.0).sample(sample_shape=(N,))
y = torch.cat(
[torch.cat([0.5 * torch.sin(3 * X) + dist.Normal(0.0, 0.2).sample(sample_shape=(1, N)),
0.5 * torch.sin(4 * X) + dist.Normal(0.0, 0.2).sample(sample_shape=(1, N)),
0.5 * torch.sin(5 * X) + dist.Normal(0.0, 0.2).sample(sample_shape=(1, N)),
0.5 * torch.sin(6 * X) + dist.Normal(0.0, 0.2).sample(sample_shape=(1, N))], 0).unsqueeze(0),
torch.cat([0.5 * torch.sin(7 * X) + dist.Normal(0.0, 0.2).sample(sample_shape=(1, N)),
0.5 * torch.sin(8 * X) + dist.Normal(0.0, 0.2).sample(sample_shape=(1, N)),
0.5 * torch.sin(9 * X) + dist.Normal(0.0, 0.2).sample(sample_shape=(1, N)),
0.5 * torch.sin(10 * X) + dist.Normal(0.0, 0.2).sample(sample_shape=(1, N))], 0).unsqueeze(0)],
0)
Xu = torch.arange(20.0) / 4.0
for i in range(2):
for j in range(4):
plot(X, y, y_index=(i, j), plot_observed_data=True)
# initialize the inducing inputs
plot_inducing_points(Xu, plt.gca())
plt.show()
# initialize the kernel and model
pyro.clear_param_store()
kernel = gp.kernels.RBF(input_dim=1)
likelihood = gp.likelihoods.Gaussian()
# we increase the jitter for better numerical stability
vsgp = gp.models.VariationalSparseGP(
X, y, kernel, Xu=Xu, likelihood=likelihood, jitter=1.0e-5, whiten=True
)
# the way we setup inference is similar to above
optimizer = torch.optim.Adam(vsgp.parameters(), lr=0.005)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = []
locations = []
variances = []
lengthscales = []
# noises = []
num_steps = 2000 if not smoke_test else 2
for i in range(num_steps):
optimizer.zero_grad()
loss = loss_fn(vsgp.model, vsgp.guide)
locations.append(vsgp.Xu.data.numpy().copy())
variances.append(vsgp.kernel.variance.item())
# noises.append(vsgp.noise.item())
lengthscales.append(vsgp.kernel.lengthscale.item())
loss.backward()
optimizer.step()
losses.append(loss.item())
plot_loss(losses)
plt.show()
for i in range(2):
for j in range(4):
plot(X_inputs=X, y_inputs=y, y_index=(i, j), model=vsgp, plot_observed_data=True, plot_predictions=True)
plot_inducing_points(vsgp.Xu.data.numpy(), plt.gca())
plt.show()
if __name__ == "__main__":
main()