forked from google-research/google-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_vi.py
226 lines (195 loc) · 8.94 KB
/
run_vi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run Variational Inference."""
import os
import numpy as onp
from jax import numpy as jnp
import jax
import tensorflow.compat.v2 as tf
import argparse
from bnn_hmc.utils import checkpoint_utils
from bnn_hmc.utils import cmd_args_utils
from bnn_hmc.utils import logging_utils
from bnn_hmc.utils import train_utils
from bnn_hmc.utils import optim_utils
from bnn_hmc.utils import script_utils
from bnn_hmc.core import vi
parser = argparse.ArgumentParser(description="Run MFVI training")
cmd_args_utils.add_common_flags(parser)
cmd_args_utils.add_sgd_flags(parser)
parser.add_argument(
"--optimizer",
type=str,
default="Adam",
choices=["SGD", "Adam"],
help="Choice of optimizer; (SGD or Adam; default: SGD)")
parser.add_argument(
"--vi_sigma_init",
type=float,
default=1e-3,
help="Initial value of the standard deviation over the "
"weights in MFVI (default: 1e-3)")
parser.add_argument(
"--vi_ensemble_size",
type=int,
default=20,
help="Size of the ensemble sampled in the VI evaluation "
"(default: 20)")
parser.add_argument(
"--mean_init_checkpoint",
type=str,
default=None,
help="SGD checkpoint to use for initialization of the "
"mean of the MFVI approximation")
args = parser.parse_args()
train_utils.set_up_jax(args.tpu_ip, args.use_float64)
def get_optimizer(lr_schedule, args):
if args.optimizer == "SGD":
optimizer = optim_utils.make_sgd_optimizer(
lr_schedule, momentum_decay=args.momentum_decay)
elif args.optimizer == "Adam":
optimizer = optim_utils.make_adam_optimizer(lr_schedule)
return optimizer
def get_dirname_tfwriter(args):
method_name = "mfvi_initsigma_{}".format(args.vi_sigma_init)
if args.mean_init_checkpoint:
method_name += "_meaninit"
if args.optimizer == "SGD":
optimizer_name = "opt_sgd_{}".format(args.momentum_decay)
elif args.optimizer == "Adam":
optimizer_name = "opt_adam"
lr_schedule_name = "lr_sch_i_{}".format(args.init_step_size)
hypers_name = "_epochs_{}_wd_{}_batchsize_{}_temp_{}".format(
args.num_epochs, args.weight_decay, args.batch_size, args.temperature)
subdirname = "{}__{}__{}__{}__seed_{}".format(method_name, optimizer_name,
lr_schedule_name, hypers_name,
args.seed)
dirname, tf_writer = script_utils.prepare_logging(subdirname, args)
return dirname, tf_writer
def make_vi_ensemble_predict_fn(predict_fn, ensemble_upd_fn, args):
def vi_ensemble_predict_fn(net_apply, params, net_state, ds):
net_state, all_preds = jax.lax.scan(
lambda state, _: predict_fn(net_apply, params, state, ds),
init=net_state,
xs=jnp.arange(args.vi_ensemble_size))
ensemble_predictions = None
num_ensembled = 0
for pred in all_preds:
ensemble_predictions = ensemble_upd_fn(ensemble_predictions,
num_ensembled, pred)
num_ensembled += 1
return net_state, ensemble_predictions
return vi_ensemble_predict_fn
def train_model():
# Initialize training directory
dirname, tf_writer = get_dirname_tfwriter(args)
# Initialize data, model, losses and metrics
(train_set, test_set, net_apply, params, net_state, key, log_likelihood_fn, _,
_, predict_fn, ensemble_upd_fn, metrics_fns,
tabulate_metrics) = script_utils.get_data_model_fns(args)
# Convert the model to MFVI parameterization
net_apply, mean_apply, _, params, net_state = vi.get_mfvi_model_fn(
net_apply, params, net_state, seed=0, sigma_init=args.vi_sigma_init)
prior_kl = vi.make_kl_with_gaussian_prior(args.weight_decay, args.temperature)
vi_ensemble_predict_fn = make_vi_ensemble_predict_fn(predict_fn,
ensemble_upd_fn, args)
# Initialize step-size schedule and optimizer
num_batches, total_steps = script_utils.get_num_batches_total_steps(
args, train_set)
num_devices = len(jax.devices())
lr_schedule = optim_utils.make_cosine_lr_schedule(args.init_step_size,
total_steps)
optimizer = get_optimizer(lr_schedule, args)
# Initialize variables
opt_state = optimizer.init(params)
net_state = jax.pmap(lambda _: net_state)(jnp.arange(num_devices))
key = jax.random.split(key, num_devices)
init_dict = checkpoint_utils.make_sgd_checkpoint_dict(-1, params, net_state,
opt_state, key)
init_dict = script_utils.get_initialization_dict(dirname, args, init_dict)
start_iteration, params, net_state, opt_state, key = (
checkpoint_utils.parse_sgd_checkpoint_dict(init_dict))
start_iteration += 1
# Loading mean checkpoint
if args.mean_init_checkpoint:
print("Initializing VI mean from the provided checkpoint")
ckpt_dict = checkpoint_utils.load_checkpoint(args.mean_init_checkpoint)
mean_params = checkpoint_utils.parse_sgd_checkpoint_dict(ckpt_dict)[1]
params["mean"] = mean_params
# Define train epoch
sgd_train_epoch = script_utils.time_fn(
train_utils.make_sgd_train_epoch(net_apply, log_likelihood_fn, prior_kl,
optimizer, num_batches))
# Train
for iteration in range(start_iteration, args.num_epochs):
(params, net_state, opt_state, elbo_avg, key), iteration_time = (
sgd_train_epoch(params, net_state, opt_state, train_set, key))
# Evaluate the model
train_stats = {"ELBO": elbo_avg, "KL": prior_kl(params)}
test_stats, ensemble_stats = {}, {}
if (iteration % args.eval_freq == 0) or (iteration == args.num_epochs - 1):
# Evaluate the mean
_, test_predictions, train_predictions, test_stats, train_stats_ = (
script_utils.evaluate(mean_apply, params, net_state, train_set,
test_set, predict_fn, metrics_fns, prior_kl))
train_stats.update(train_stats_)
del train_stats["prior"]
# Evaluate the ensemble
net_state, ensemble_predictions = onp.asarray(
vi_ensemble_predict_fn(net_apply, params, net_state, test_set))
ensemble_stats = train_utils.evaluate_metrics(ensemble_predictions,
test_set[1], metrics_fns)
# Save checkpoint
if iteration % args.save_freq == 0 or iteration == args.num_epochs - 1:
checkpoint_name = checkpoint_utils.make_checkpoint_name(iteration)
checkpoint_path = os.path.join(dirname, checkpoint_name)
checkpoint_dict = checkpoint_utils.make_sgd_checkpoint_dict(
iteration, params, net_state, opt_state, key)
checkpoint_utils.save_checkpoint(checkpoint_path, checkpoint_dict)
# Log results
other_logs = script_utils.get_common_logs(iteration, iteration_time, args)
other_logs["hypers/step_size"] = lr_schedule(opt_state[-1].count)
other_logs["hypers/momentum"] = args.momentum_decay
logging_dict = logging_utils.make_logging_dict(train_stats, test_stats,
ensemble_stats)
logging_dict.update(other_logs)
script_utils.write_to_tensorboard(tf_writer, logging_dict, iteration)
# Add a histogram of MFVI stds
with tf_writer.as_default():
stds = jax.tree_map(jax.nn.softplus, params["inv_softplus_std"])
stds = jnp.concatenate([std.reshape(-1) for std in jax.tree_leaves(stds)])
tf.summary.histogram("MFVI/param_stds", stds, step=iteration)
tabulate_dict = script_utils.get_tabulate_dict(tabulate_metrics,
logging_dict)
tabulate_dict["lr"] = lr_schedule(opt_state[-1].count)
table = logging_utils.make_table(tabulate_dict, iteration - start_iteration,
args.tabulate_freq)
print(table)
if __name__ == "__main__":
script_utils.print_visible_devices()
train_model()