Skip to content

Commit

Permalink
🍂🦎
Browse files Browse the repository at this point in the history
  • Loading branch information
jcopo committed Oct 16, 2024
1 parent b46ff14 commit 184415a
Show file tree
Hide file tree
Showing 25 changed files with 60 additions and 225 deletions.
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ __pycache__/
# C extensions
*.so

*.png
*.gif

sandbox.ipynb

Expand Down
49 changes: 43 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,53 @@
Structure of the repository:
Code repository for the paper [Bayesian Experimental Design via Contrastive Diffusions](https://arxiv.org/abs/2410.11826v1)

# Structure of the repository:

- `diffuse/`: contains the source code of the Diffuse tool with the following files:
- `mixtures.py`: contains the implementation of the mixture models for the test of toy diffusion
- `sde.py`: Implementation of the Lin stochastic differential equations and its reverse-time counterpart used for sampling
- `image.py`: tools for image processing and masking for MNIST
- `score_matching.py`: implementation of the score matching loss used to train Diffusion Model
- `unet.py`: implementation of the U-Net architecture used for the Diffusion Model
- `mnist_train.py`: script to train the Diffusion Model on MNIST
- `mnist_train.py`: script to train the Diffusion Model on MNIST
- `conditional.py`: implementation of the conditional sampling procedure

- `vraie_vie`: contains the code for the medical image application in anomaly detection on WMH dataset 🧠
- `mni_coregistration.py`: Coregister the original WMH dataset into the MNI space 👩‍🔬
- `create_data.py`: Creates the dataloader that iterates on the slices of each subjects 🚀

- `test`: contains the test scripts for the Diffuse tool with a test on a toy example with Gaussian Mixtures
- `examples/`: contains the examples of use of the Diffuse tool with a test on a toy example with Gaussian Mixtures
- `design_mnist.py`: script to run the design optimization on handwritten digits retrieval
- `mixture_evolution.py`: plot evolution of the mixture models with noising and denoising process
- `test/`: contains the tests for the Diffuse tool with a test on a toy example with Gaussian Mixtures

# Visualization of the design optimization procedure:
<p align="center">
<img src="img/mnist_444/samples_0.png" width="100%" />
<img src="img/mnist_444/samples_1.png" width="100%" />
<img src="img/mnist_444/samples_2.png" width="100%" />
<img src="img/mnist_444/samples_3.png" width="100%" />
<img src="img/mnist_444/samples_4.png" width="100%" />
<img src="img/mnist_444/samples_5.png" width="100%" />
</p>

**Figure:** Image reconstruction. First 6 experiments (rows): image ground truth, measurement at experiment $k$, samples from current prior $p(\theta|\mathcal{D}_{k-1}) $, with best (upper) and worst (lower) weights in each sub-row. The samples incorporate past measurement information as the procedure advances.

# Comparison with random measurements:
<p align="center">
<img src="img/comparison_1.png" width="100%" />
<img src="img/comparison_2.png" width="100%" />
<img src="img/comparison_3.png" width="100%" />
<img src="img/comparison_4.png" width="100%" />
</p>

**Figure:** Optimized vs. random designs: measured outcome $y$ (2nd vs. 3rd column) and parameter $\theta$ estimates (reconstruction) with highest weights (upper vs. lower sub-row).



Forward / Reverse process diffusion process on mixtures:
<p align="center">
<img src="img/backward_process.gif" width="380" />
<img src="img/forward_process.gif" width="380" />
</p>

For tests and plots of the diffusion on mixture of Gaussians:
```bash
pytest --plot
```
15 changes: 4 additions & 11 deletions diffuse/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def tree_unflatten(cls, aux_data, children):

@dataclass
class CondSDE(SDE):
"""
cond_sde.mask.restore act as the matrix A^T
"""
mask: SquareMask
tf: float
score: Callable[[Array, float], Array]
Expand Down Expand Up @@ -85,9 +88,6 @@ def reverse_diffusion(state):
x, t = state
return cond_reverse_diffusion(CondState(x, y, xi, t), self)

# jax.debug.print("meas{}\n", measure(xi, img, self.mask))
# jax.debug.print("y{}\n", y.shape)
# jax.debug.print("diff{}\n", measure(xi, img, self.mask) - y )

x, _ = euler_maryama_step(
SDEState(x, t), dt, key, revese_drift, reverse_diffusion
Expand All @@ -97,26 +97,19 @@ def reverse_diffusion(state):


def cond_reverse_drift(state: CondState, cond_sde: CondSDE) -> Array:
# stack together x and y and apply reverse drift
x, y, xi, t = state
# img = restore(xi, x, cond_sde.mask, y)
# return cond_sde.reverse_drift(SDEState(img, t))
drift_x = cond_sde.reverse_drift(SDEState(x, t))
beta_t = cond_sde.beta(cond_sde.tf - t)
meas_x = cond_sde.mask.measure(xi, x)
alpha_t = jnp.exp(cond_sde.beta.integrate(0.0, t))
# here if needed we average over y

drift_y = (
beta_t * cond_sde.mask.restore(xi, jnp.zeros_like(x), y - meas_x) / alpha_t
)
# f = lambda y: beta_t * (y - meas_x) / alpha_t
# drifts = jax.vmap(f)(y)
# drift_y = drifts.mean(axis=0)
return drift_x + drift_y


def cond_reverse_diffusion(state: CondState, cond_sde: CondSDE) -> Array:
# stack together x and y and apply reverse diffusion
x, y, xi, t = state
img = cond_sde.mask.restore(xi, x, y)
return cond_sde.reverse_diffusion(SDEState(img, t))
127 changes: 0 additions & 127 deletions diffuse/filter.py

This file was deleted.

2 changes: 0 additions & 2 deletions diffuse/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def make(self, xi: Array) -> Array:
mask = jax.nn.sigmoid(
(-jnp.maximum(y_dist, x_dist) + mask_half_size) / softness
)
# return jnp.where(mask > 0.5, 1.0, 0.0)[..., None]
return mask[..., None]

def measure_from_mask(self, hist_mask: Array, img: Array):
Expand All @@ -58,7 +57,6 @@ def restore(self, xi: Array, img: Array, measured: Array):
xs = einops.rearrange(xs, "b h w -> b h w 1")

x = xs[0]
# x = jax.random.normal(jax.random.PRNGKey(0), x.shape)

mask = SquareMask(10, x.shape)
xi = jnp.array([15.0, 15.0])
Expand Down
9 changes: 0 additions & 9 deletions diffuse/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,9 @@ def calculate_drift_expt_post(
\beta(t) \sum_{n=1}^N \nu_n \nabla_\thetab \log p(\yb_n^t|\thetab'_t, \xib)
to add for conditional diffusioon
"""
# pdb.set_trace()
drifts = jax.vmap(calculate_drift_y, in_axes=(None, None, None, 0))(
cond_sde, sde_state, design, y
)
# drifts = calculate_drift_y(cond_sde, t, xi, x, y)
drift_y = drifts.mean(axis=0)
return drift_y

Expand Down Expand Up @@ -126,10 +124,8 @@ def particle_step(
n_particles = sde_state.position.shape[0]
idx = stratified(rng_key, weights, n_particles)

#return sde_state.position, weights
return jax.lax.cond(
(ess_val < 0.6 * n_particles) & (ess_val > 0.2 * n_particles),
#(ess_val > 0.2 * n_particles),
lambda x: (x[idx], weights[idx]),
lambda x: (x, weights),
sde_state.position,
Expand All @@ -152,13 +148,9 @@ def logpdf_change_y(
alpha = jnp.sqrt(jnp.exp(cond_sde.beta.integrate(0.0, t)))
cov = cond_sde.reverse_diffusion(x_sde_state) * jnp.sqrt(dt) + alpha

# mean = cond_sde.mask.measure(design, x + drift_x * dt)
mean = cond_sde.mask.measure_from_mask(design, x + drift_x * dt)
logsprobs = jax.scipy.stats.norm.logpdf(y_next, mean, cov)
logsprobs = cond_sde.mask.measure_from_mask(design, logsprobs)
#jax.experimental.io_callback(plot_lines, None, logsprobs)
#jax.experimental.io_callback(sigle_plot, None, y_next)
#logsprobs = jax.vmap(cond_sde.mask.measure, in_axes=(None, 0))(design, logsprobs)
logsprobs = einops.reduce(logsprobs, "t ... -> t ", "sum")
return logsprobs

Expand Down Expand Up @@ -234,7 +226,6 @@ def step(state, itr):
key, u, u_next = itr
keys = jax.random.split(key, n_ts)
n_state = update_joint(state_xt, u, u_next, key)
# n_state = jax.vmap(update_joint, in_axes=(SDEState(None, 0), 0, 0, 0))(state_xt, u, u_next, keys)
return n_state, n_state

end_state, hist = jax.lax.scan(step, (state_x, weights), (keys, u_0Tm, u_1T))
Expand Down
6 changes: 0 additions & 6 deletions diffuse/mnsit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,11 @@
xs = jax.random.permutation(key, xs, axis=0)
data = einops.rearrange(xs, "b h w -> b h w 1")
shape_sample = data.shape[1:]
# plt.imshow(data[0], cmap='gray')
# plt.show()
# dt = jnp.linspace(0, 2.0, n_t)
# dt = jnp.array([2.0 / n_t] * batch_size)

beta = LinearSchedule(b_min=0.02, b_max=5.0, t0=0.0, T=2.0)
sde = SDE(beta)

nn_unet = UNet(dt, 64, upsampling="pixel_shuffle")
# init_params = nn_unet.init(key, data[:batch_size], dt)
init_params = nn_unet.init(
key, jnp.ones((batch_size, *shape_sample)), jnp.ones((batch_size,))
)
Expand Down Expand Up @@ -71,7 +66,6 @@ def step(key, params, opt_state, ema_state, data):

for epoch in range(n_epochs):
subkey, key = jax.random.split(key)
# data = jax.random.permutation(subkey, data, axis=0)
idx = jax.random.choice(
subkey, data.shape[0], (nsteps_per_epoch, batch_size), replace=False
)
Expand Down
17 changes: 0 additions & 17 deletions diffuse/neural_networks.py

This file was deleted.

Loading

0 comments on commit 184415a

Please sign in to comment.