From d888c2f24e935c5efaa0462fd516f7bf119c6a8c Mon Sep 17 00:00:00 2001 From: Ezra Erives <30280328+eje24@users.noreply.github.com> Date: Tue, 11 Feb 2025 23:01:57 -0500 Subject: [PATCH] Add lab two solutions --- solutions/lab_two_complete.ipynb | 2592 ++++++++++++++++++++++++++++++ 1 file changed, 2592 insertions(+) create mode 100644 solutions/lab_two_complete.ipynb diff --git a/solutions/lab_two_complete.ipynb b/solutions/lab_two_complete.ipynb new file mode 100644 index 0000000..edf355a --- /dev/null +++ b/solutions/lab_two_complete.ipynb @@ -0,0 +1,2592 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f4ca5863-5112-49cd-800c-61199d955cb6", + "metadata": {}, + "source": [ + "# Lab Two: Flow Matching and Score Matching" + ] + }, + { + "cell_type": "markdown", + "id": "ebb2a99a-6d00-4f94-aa23-a2146897321f", + "metadata": {}, + "source": [ + "Welcome to lab two! In this lab, we will provide an intuitive and hands-on walk-through of *flow matching* and *score matching*.\n", + "\n", + "**Instructions for registered students**:\n", + "1. Complete this lab.\n", + "2. Export this notebook to a PDF.\n", + "3. Submit the PDF to Gradescope via Canvas.\n", + "There are a total of *16 points* in this lab. Questions can be found by searching for the phrase \"Your job...\". If you have any questions or concerns, please come to office hours or fill out the following [feedback/question form here](https://forms.gle/iixgq4E2wkwudEb19). Thanks!" + ] + }, + { + "cell_type": "markdown", + "id": "82a79a8b-a061-4c93-aedb-3d2269011f36", + "metadata": {}, + "source": [ + "### Part 0: Miscellaneous Imports and Utility Functions\n", + "No questions here, but free to read through to familiarize yourself with these helper functions. Most of this is what you already completed in lab one!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e32fa50e-30d9-4048-9c8b-f0661aedeffe", + "metadata": {}, + "outputs": [], + "source": [ + "from abc import ABC, abstractmethod\n", + "from typing import Optional, List, Type, Tuple, Dict\n", + "import math\n", + "\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "import matplotlib.cm as cm\n", + "from matplotlib.axes._axes import Axes\n", + "import torch\n", + "import torch.distributions as D\n", + "from torch.func import vmap, jacrev\n", + "from tqdm import tqdm\n", + "import seaborn as sns\n", + "from sklearn.datasets import make_moons, make_circles\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef0e04ac-f9a1-4be0-9e8c-162f85901207", + "metadata": {}, + "outputs": [], + "source": [ + "class Sampleable(ABC):\n", + " \"\"\"\n", + " Distribution which can be sampled from\n", + " \"\"\"\n", + " @property\n", + " @abstractmethod\n", + " def dim(self) -> int:\n", + " \"\"\"\n", + " Returns:\n", + " - Dimensionality of the distribution\n", + " \"\"\"\n", + " pass\n", + " \n", + " @abstractmethod\n", + " def sample(self, num_samples: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " - num_samples: the desired number of samples\n", + " Returns:\n", + " - samples: shape (batch_size, dim)\n", + " \"\"\"\n", + " pass\n", + "\n", + "class Density(ABC):\n", + " \"\"\"\n", + " Distribution with tractable density\n", + " \"\"\"\n", + " @abstractmethod\n", + " def log_density(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Returns the log density at x.\n", + " Args:\n", + " - x: shape (batch_size, dim)\n", + " Returns:\n", + " - log_density: shape (batch_size, 1)\n", + " \"\"\"\n", + " pass\n", + "\n", + "class Gaussian(torch.nn.Module, Sampleable, Density):\n", + " \"\"\"\n", + " Multivariate Gaussian distribution\n", + " \"\"\"\n", + " def __init__(self, mean: torch.Tensor, cov: torch.Tensor):\n", + " \"\"\"\n", + " mean: shape (dim,)\n", + " cov: shape (dim,dim)\n", + " \"\"\"\n", + " super().__init__()\n", + " self.register_buffer(\"mean\", mean)\n", + " self.register_buffer(\"cov\", cov)\n", + "\n", + " @property\n", + " def dim(self) -> int:\n", + " return self.mean.shape[0]\n", + "\n", + " @property\n", + " def distribution(self):\n", + " return D.MultivariateNormal(self.mean, self.cov, validate_args=False)\n", + "\n", + " def sample(self, num_samples) -> torch.Tensor:\n", + " return self.distribution.sample((num_samples,))\n", + " \n", + " def log_density(self, x: torch.Tensor):\n", + " return self.distribution.log_prob(x).view(-1, 1)\n", + "\n", + " @classmethod\n", + " def isotropic(cls, dim: int, std: float) -> \"Gaussian\":\n", + " mean = torch.zeros(dim)\n", + " cov = torch.eye(dim) * std ** 2\n", + " return cls(mean, cov)\n", + "\n", + "class GaussianMixture(torch.nn.Module, Sampleable, Density):\n", + " \"\"\"\n", + " Two-dimensional Gaussian mixture model, and is a Density and a Sampleable. Wrapper around torch.distributions.MixtureSameFamily.\n", + " \"\"\"\n", + " def __init__(\n", + " self,\n", + " means: torch.Tensor, # nmodes x data_dim\n", + " covs: torch.Tensor, # nmodes x data_dim x data_dim\n", + " weights: torch.Tensor, # nmodes\n", + " ):\n", + " \"\"\"\n", + " means: shape (nmodes, 2)\n", + " covs: shape (nmodes, 2, 2)\n", + " weights: shape (nmodes, 1)\n", + " \"\"\"\n", + " super().__init__()\n", + " self.nmodes = means.shape[0]\n", + " self.register_buffer(\"means\", means)\n", + " self.register_buffer(\"covs\", covs)\n", + " self.register_buffer(\"weights\", weights)\n", + "\n", + " @property\n", + " def dim(self) -> int:\n", + " return self.means.shape[1]\n", + "\n", + " @property\n", + " def distribution(self):\n", + " return D.MixtureSameFamily(\n", + " mixture_distribution=D.Categorical(probs=self.weights, validate_args=False),\n", + " component_distribution=D.MultivariateNormal(\n", + " loc=self.means,\n", + " covariance_matrix=self.covs,\n", + " validate_args=False,\n", + " ),\n", + " validate_args=False,\n", + " )\n", + "\n", + " def log_density(self, x: torch.Tensor) -> torch.Tensor:\n", + " return self.distribution.log_prob(x).view(-1, 1)\n", + "\n", + " def sample(self, num_samples: int) -> torch.Tensor:\n", + " return self.distribution.sample(torch.Size((num_samples,)))\n", + "\n", + " @classmethod\n", + " def random_2D(\n", + " cls, nmodes: int, std: float, scale: float = 10.0, x_offset: float = 0.0, seed = 0.0\n", + " ) -> \"GaussianMixture\":\n", + " torch.manual_seed(seed)\n", + " means = (torch.rand(nmodes, 2) - 0.5) * scale + x_offset * torch.Tensor([1.0, 0.0])\n", + " covs = torch.diag_embed(torch.ones(nmodes, 2)) * std ** 2\n", + " weights = torch.ones(nmodes)\n", + " return cls(means, covs, weights)\n", + "\n", + " @classmethod\n", + " def symmetric_2D(\n", + " cls, nmodes: int, std: float, scale: float = 10.0, x_offset: float = 0.0\n", + " ) -> \"GaussianMixture\":\n", + " angles = torch.linspace(0, 2 * np.pi, nmodes + 1)[:nmodes]\n", + " means = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1) * scale + torch.Tensor([1.0, 0.0]) * x_offset\n", + " covs = torch.diag_embed(torch.ones(nmodes, 2) * std ** 2)\n", + " weights = torch.ones(nmodes) / nmodes\n", + " return cls(means, covs, weights)\n", + "\n", + "class MoonsSampleable(Sampleable):\n", + " \"\"\"\n", + " Implementation of the Moons distribution using sklearn's make_moons\n", + " \"\"\"\n", + " def __init__(self, device: torch.device, noise: float = 0.05, scale: float = 5.0, offset: Optional[torch.Tensor] = None):\n", + " \"\"\"\n", + " Args:\n", + " noise: Standard deviation of Gaussian noise added to the data\n", + " scale: How much to scale the data\n", + " offset: How much to shift the samples from the original distribution (2,)\n", + " \"\"\"\n", + " self.noise = noise\n", + " self.scale = scale\n", + " self.device = device\n", + " if offset is None:\n", + " offset = torch.zeros(2)\n", + " self.offset = offset.to(device)\n", + "\n", + " @property\n", + " def dim(self) -> int:\n", + " return 2\n", + "\n", + " def sample(self, num_samples: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " num_samples: Number of samples to generate\n", + " Returns:\n", + " torch.Tensor: Generated samples with shape (num_samples, 3)\n", + " \"\"\"\n", + " samples, _ = make_moons(\n", + " n_samples=num_samples,\n", + " noise=self.noise,\n", + " random_state=None # Allow for random generation each time\n", + " )\n", + " return self.scale * torch.from_numpy(samples.astype(np.float32)).to(self.device) + self.offset\n", + "\n", + "class CirclesSampleable(Sampleable):\n", + " \"\"\"\n", + " Implementation of concentric circle distribution using sklearn's make_circles\n", + " \"\"\"\n", + " def __init__(self, device: torch.device, noise: float = 0.05, scale=5.0, offset: Optional[torch.Tensor] = None):\n", + " \"\"\"\n", + " Args:\n", + " noise: standard deviation of Gaussian noise added to the data\n", + " \"\"\"\n", + " self.noise = noise\n", + " self.scale = scale\n", + " self.device = device\n", + " if offset is None:\n", + " offset = torch.zeros(2)\n", + " self.offset = offset.to(device)\n", + "\n", + " @property\n", + " def dim(self) -> int:\n", + " return 2\n", + "\n", + " def sample(self, num_samples: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " num_samples: number of samples to generate\n", + " Returns:\n", + " torch.Tensor: shape (num_samples, 3)\n", + " \"\"\"\n", + " samples, _ = make_circles(\n", + " n_samples=num_samples,\n", + " noise=self.noise,\n", + " factor=0.5,\n", + " random_state=None\n", + " )\n", + " return self.scale * torch.from_numpy(samples.astype(np.float32)).to(self.device) + self.offset\n", + "\n", + "class CheckerboardSampleable(Sampleable):\n", + " \"\"\"\n", + " Checkboard-esque distribution\n", + " \"\"\"\n", + " def __init__(self, device: torch.device, grid_size: int = 3, scale=5.0):\n", + " \"\"\"\n", + " Args:\n", + " noise: standard deviation of Gaussian noise added to the data\n", + " \"\"\"\n", + " self.grid_size = grid_size\n", + " self.scale = scale\n", + " self.device = device\n", + "\n", + " @property\n", + " def dim(self) -> int:\n", + " return 2\n", + "\n", + " def sample(self, num_samples: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " num_samples: number of samples to generate\n", + " Returns:\n", + " torch.Tensor: shape (num_samples, 3)\n", + " \"\"\"\n", + " grid_length = 2 * self.scale / self.grid_size\n", + " samples = torch.zeros(0,2).to(device)\n", + " while samples.shape[0] < num_samples:\n", + " # Sample num_samples\n", + " new_samples = (torch.rand(num_samples,2).to(self.device) - 0.5) * 2 * self.scale\n", + " x_mask = torch.floor((new_samples[:,0] + self.scale) / grid_length) % 2 == 0 # (bs,)\n", + " y_mask = torch.floor((new_samples[:,1] + self.scale) / grid_length) % 2 == 0 # (bs,)\n", + " accept_mask = torch.logical_xor(~x_mask, y_mask)\n", + " samples = torch.cat([samples, new_samples[accept_mask]], dim=0)\n", + " return samples[:num_samples]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "037723a8-e212-46f5-ae75-997230282515", + "metadata": {}, + "outputs": [], + "source": [ + "# Several plotting utility functions\n", + "def hist2d_samples(samples, ax: Optional[Axes] = None, bins: int = 200, scale: float = 5.0, percentile: int = 99, **kwargs):\n", + " H, xedges, yedges = np.histogram2d(samples[:, 0], samples[:, 1], bins=bins, range=[[-scale, scale], [-scale, scale]])\n", + " \n", + " # Determine color normalization based on the 99th percentile\n", + " cmax = np.percentile(H, percentile)\n", + " cmin = 0.0\n", + " norm = cm.colors.Normalize(vmax=cmax, vmin=cmin)\n", + " \n", + " # Plot using imshow for more control\n", + " extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]\n", + " ax.imshow(H.T, extent=extent, origin='lower', norm=norm, **kwargs)\n", + "\n", + "def hist2d_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, bins=200, scale: float = 5.0, percentile: int = 99, **kwargs):\n", + " assert sampleable.dim == 2\n", + " if ax is None:\n", + " ax = plt.gca()\n", + " samples = sampleable.sample(num_samples).detach().cpu() # (ns, 2)\n", + " hist2d_samples(samples, ax, bins, scale, percentile, **kwargs)\n", + "\n", + "def scatter_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):\n", + " assert sampleable.dim == 2\n", + " if ax is None:\n", + " ax = plt.gca()\n", + " samples = sampleable.sample(num_samples) # (ns, 2)\n", + " ax.scatter(samples[:,0].cpu(), samples[:,1].cpu(), **kwargs)\n", + "\n", + "def kdeplot_sampleable(sampleable: Sampleable, num_samples: int, ax: Optional[Axes] = None, **kwargs):\n", + " assert sampleable.dim == 2\n", + " if ax is None:\n", + " ax = plt.gca()\n", + " samples = sampleable.sample(num_samples) # (ns, 2)\n", + " sns.kdeplot(x=samples[:,0].cpu(), y=samples[:,1].cpu(), ax=ax, **kwargs)\n", + "\n", + "def imshow_density(density: Density, x_bounds: Tuple[float, float], y_bounds: Tuple[float, float], bins: int, ax: Optional[Axes] = None, x_offset: float = 0.0, **kwargs):\n", + " if ax is None:\n", + " ax = plt.gca()\n", + " x_min, x_max = x_bounds\n", + " y_min, y_max = y_bounds\n", + " x = torch.linspace(x_min, x_max, bins).to(device) + x_offset\n", + " y = torch.linspace(y_min, y_max, bins).to(device)\n", + " X, Y = torch.meshgrid(x, y)\n", + " xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)\n", + " density = density.log_density(xy).reshape(bins, bins).T\n", + " im = ax.imshow(density.cpu(), extent=[x_min, x_max, y_min, y_max], origin='lower', **kwargs)\n", + "\n", + "def contour_density(density: Density, bins: int, scale: float, ax: Optional[Axes] = None, x_offset:float = 0.0, **kwargs):\n", + " if ax is None:\n", + " ax = plt.gca()\n", + " x = torch.linspace(-scale + x_offset, scale + x_offset, bins).to(device)\n", + " y = torch.linspace(-scale, scale, bins).to(device)\n", + " X, Y = torch.meshgrid(x, y)\n", + " xy = torch.stack([X.reshape(-1), Y.reshape(-1)], dim=-1)\n", + " density = density.log_density(xy).reshape(bins, bins).T\n", + " im = ax.contour(density.cpu(), origin='lower', **kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f47118d5-30d0-4374-81de-f47e3b96f6e3", + "metadata": {}, + "outputs": [], + "source": [ + "class ODE(ABC):\n", + " @abstractmethod\n", + " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Returns the drift coefficient of the ODE.\n", + " Args:\n", + " - xt: state at time t, shape (bs, dim)\n", + " - t: time, shape (batch_size, 1)\n", + " Returns:\n", + " - drift_coefficient: shape (batch_size, dim)\n", + " \"\"\"\n", + " pass\n", + "\n", + "class SDE(ABC):\n", + " @abstractmethod\n", + " def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Returns the drift coefficient of the ODE.\n", + " Args:\n", + " - xt: state at time t, shape (batch_size, dim)\n", + " - t: time, shape (batch_size, 1)\n", + " Returns:\n", + " - drift_coefficient: shape (batch_size, dim)\n", + " \"\"\"\n", + " pass\n", + "\n", + " @abstractmethod\n", + " def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Returns the diffusion coefficient of the ODE.\n", + " Args:\n", + " - xt: state at time t, shape (batch_size, dim)\n", + " - t: time, shape (batch_size, 1)\n", + " Returns:\n", + " - diffusion_coefficient: shape (batch_size, dim)\n", + " \"\"\"\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f98a6a4-67a7-4740-837f-156aac725c2b", + "metadata": {}, + "outputs": [], + "source": [ + "class Simulator(ABC):\n", + " @abstractmethod\n", + " def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor):\n", + " \"\"\"\n", + " Takes one simulation step\n", + " Args:\n", + " - xt: state at time t, shape (bs, dim)\n", + " - t: time, shape (bs,1)\n", + " - dt: time, shape (bs,1)\n", + " Returns:\n", + " - nxt: state at time t + dt (bs, dim)\n", + " \"\"\"\n", + " pass\n", + "\n", + " @torch.no_grad()\n", + " def simulate(self, x: torch.Tensor, ts: torch.Tensor):\n", + " \"\"\"\n", + " Simulates using the discretization gives by ts\n", + " Args:\n", + " - x_init: initial state at time ts[0], shape (batch_size, dim)\n", + " - ts: timesteps, shape (bs, num_timesteps,1)\n", + " Returns:\n", + " - x_final: final state at time ts[-1], shape (batch_size, dim)\n", + " \"\"\"\n", + " for t_idx in range(len(ts) - 1):\n", + " t = ts[:, t_idx]\n", + " h = ts[:, t_idx + 1] - ts[:, t_idx]\n", + " x = self.step(x, t, h)\n", + " return x\n", + "\n", + " @torch.no_grad()\n", + " def simulate_with_trajectory(self, x: torch.Tensor, ts: torch.Tensor):\n", + " \"\"\"\n", + " Simulates using the discretization gives by ts\n", + " Args:\n", + " - x_init: initial state at time ts[0], shape (bs, dim)\n", + " - ts: timesteps, shape (bs, num_timesteps, 1)\n", + " Returns:\n", + " - xs: trajectory of xts over ts, shape (batch_size, num\n", + " _timesteps, dim)\n", + " \"\"\"\n", + " xs = [x.clone()]\n", + " nts = ts.shape[1]\n", + " for t_idx in tqdm(range(nts - 1)):\n", + " t = ts[:,t_idx]\n", + " h = ts[:, t_idx + 1] - ts[:, t_idx]\n", + " x = self.step(x, t, h)\n", + " xs.append(x.clone())\n", + " return torch.stack(xs, dim=1)\n", + "\n", + "class EulerSimulator(Simulator):\n", + " def __init__(self, ode: ODE):\n", + " self.ode = ode\n", + " \n", + " def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):\n", + " return xt + self.ode.drift_coefficient(xt,t) * h\n", + "\n", + "class EulerMaruyamaSimulator(Simulator):\n", + " def __init__(self, sde: SDE):\n", + " self.sde = sde\n", + " \n", + " def step(self, xt: torch.Tensor, t: torch.Tensor, h: torch.Tensor):\n", + " return xt + self.sde.drift_coefficient(xt,t) * h + self.sde.diffusion_coefficient(xt,t) * torch.sqrt(h) * torch.randn_like(xt)\n", + "\n", + "def record_every(num_timesteps: int, record_every: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Compute the indices to record in the trajectory given a record_every parameter\n", + " \"\"\"\n", + " if record_every == 1:\n", + " return torch.arange(num_timesteps)\n", + " return torch.cat(\n", + " [\n", + " torch.arange(0, num_timesteps - 1, record_every),\n", + " torch.tensor([num_timesteps - 1]),\n", + " ]\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "9b91789e-d50d-4aef-9d98-4188aa20ceed", + "metadata": {}, + "source": [ + "### Part 1: Implementing Conditional Probability Paths\n", + "Recall from lecture and the class notes the basic premise of conditional flow matching: describe a *conditional probability path* $p_t(x|z)$, so that $p_1(x|z) = \\delta_z(x)$, and $p_0(z) = p_{\\text{simple}}$ (e.g., a Gaussian), and $p_t(x|z)$ interpolates continuously (we are not being rigorous here) between $p_0(x|z)$ and $p_1(x|z)$. Such a conditional path can be seen as corresponding to some corruption process which (in reverse time) drives the point $z$ at $t=1$ to be distribution as $p_0(x|z)$ at time $t=0$. Such a corruption process is given by the ODE\n", + "$$dX_t = u_t^{\\text{ref}}(X_t|z)\\,dt,\\quad \\quad X_0 \\sim p_{\\text{simple}}.$$\n", + "The drift $u_t^{\\text{ref}}(X_t|z)$ is referred to as the *conditional vector field*. By averaging $u_t^{\\text{ref}}(x|z)$ over all such choices of $z$, we obtain the *marginal* vector field $u_t^{\\text{ref}}(x)$. Flow matching proposes to exploit the fact that the *marginal probability path* $p_t(x)$ generated by the marginal vector field $u_t^{\\text{ref}}(x)$, bridges $p_{\\text{simple}}$ to $p_{\\text{data}}$. Since the conditional vector field $u_t^{\\text{ref}}(x|z)$ is often analytically available, we may implicitly regress against the unknown marginal vector field $u_t^{\\text{ref}}(x)$ by explicitly regressing against the conditional vector field $u_t^{\\text{ref}}(x|z)$." + ] + }, + { + "cell_type": "markdown", + "id": "8ca98cf8-4eae-4d3a-8f88-6424b80b06e5", + "metadata": {}, + "source": [ + "The central object in this construction is a *conditional probability path*, whose interface is implemented below in the class `ConditionalProbabilityPath`. In this lab, you will implement two subclasses: `GaussianConditionalProbabilityPath`, and `LinearConditionalProbabilityPath` corresponding to probability paths of the same names from the lectures and notes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4706df1f-33d0-4484-b702-99282fff23cf", + "metadata": {}, + "outputs": [], + "source": [ + "class ConditionalProbabilityPath(torch.nn.Module, ABC):\n", + " \"\"\"\n", + " Abstract base class for conditional probability paths\n", + " \"\"\"\n", + " def __init__(self, p_simple: Sampleable, p_data: Sampleable):\n", + " super().__init__()\n", + " self.p_simple = p_simple\n", + " self.p_data = p_data\n", + "\n", + " def sample_marginal_path(self, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Samples from the marginal distribution p_t(x) = p_t(x|z) p(z)\n", + " Args:\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - x: samples from p_t(x), (num_samples, dim)\n", + " \"\"\"\n", + " num_samples = t.shape[0]\n", + " # Sample conditioning variable z ~ p(z)\n", + " z = self.sample_conditioning_variable(num_samples) # (num_samples, dim)\n", + " # Sample conditional probability path x ~ p_t(x|z)\n", + " x = self.sample_conditional_path(z, t) # (num_samples, dim)\n", + " return x\n", + "\n", + " @abstractmethod\n", + " def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Samples the conditioning variable z\n", + " Args:\n", + " - num_samples: the number of samples\n", + " Returns:\n", + " - z: samples from p(z), (num_samples, dim)\n", + " \"\"\"\n", + " pass\n", + " \n", + " @abstractmethod\n", + " def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Samples from the conditional distribution p_t(x|z)\n", + " Args:\n", + " - z: conditioning variable (num_samples, dim)\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - x: samples from p_t(x|z), (num_samples, dim)\n", + " \"\"\"\n", + " pass\n", + " \n", + " @abstractmethod\n", + " def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates the conditional vector field u_t(x|z)\n", + " Args:\n", + " - x: position variable (num_samples, dim)\n", + " - z: conditioning variable (num_samples, dim)\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - conditional_vector_field: conditional vector field (num_samples, dim)\n", + " \"\"\" \n", + " pass\n", + "\n", + " @abstractmethod\n", + " def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates the conditional score of p_t(x|z)\n", + " Args:\n", + " - x: position variable (num_samples, dim)\n", + " - z: conditioning variable (num_samples, dim)\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - conditional_score: conditional score (num_samples, dim)\n", + " \"\"\" \n", + " pass" + ] + }, + { + "cell_type": "markdown", + "id": "7351c076-f555-4068-9337-d3d0311bc6de", + "metadata": {}, + "source": [ + "# Part 2: Gaussian Conditional Probability Paths\n", + "In this section, we'll implement a **Gaussian conditional probability path** via the class `GaussianConditionalProbabilityPath`. We will then use it to transform a simple source $p_{\\text{simple}} = N(0, I_d)$ into a Gaussian mixture $p_{\\text{data}}$. Later, we'll experiment with more exciting distributions. Recall that a Gaussian conditional probability path is given by\n", + "$$p_t(x|z) = N(x;\\alpha_t z,\\beta_t^2 I_d),\\quad\\quad\\quad p_{\\text{simple}}=N(0,I_d),$$\n", + "where $\\alpha_t: [0,1] \\to \\mathbb{R}$ and $\\beta_t: [0,1] \\to \\mathbb{R}$ are monotonic, continuously differentiable functions satisfying $\\alpha_1 = \\beta_0 = 1$ and $\\alpha_0 = \\beta_1 = 0$. In other words, this implies that $p_1(x|z) = \\delta_z$ and $p_0(x|z) = N(0, I_d)$ is a unit Gaussian. Before we dive into things, let's take a look at $p_{\\text{simple}}$ and $p_{\\text{data}}$. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4d99507-2f5d-44df-8aca-cbcd9940df34", + "metadata": {}, + "outputs": [], + "source": [ + "# Constants for the duration of our use of Gaussian conditional probability paths, to avoid polluting the namespace...\n", + "PARAMS = {\n", + " \"scale\": 15.0,\n", + " \"target_scale\": 10.0,\n", + " \"target_std\": 1.0,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56ab35ea-7c63-4bb7-ac35-48d92cc56517", + "metadata": {}, + "outputs": [], + "source": [ + "p_simple = Gaussian.isotropic(dim=2, std = 1.0).to(device)\n", + "p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device)\n", + "\n", + "fig, axes = plt.subplots(1,3, figsize=(24,8))\n", + "bins = 200\n", + "\n", + "scale = PARAMS[\"scale\"]\n", + "x_bounds = [-scale,scale]\n", + "y_bounds = [-scale,scale]\n", + "\n", + "axes[0].set_title('Heatmap of p_simple')\n", + "axes[0].set_xticks([])\n", + "axes[0].set_yticks([])\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=axes[0], vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "\n", + "\n", + "axes[1].set_title('Heatmap of p_data')\n", + "axes[1].set_xticks([])\n", + "axes[1].set_yticks([])\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=axes[1], vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "axes[2].set_title('Heatmap of p_simple and p_data')\n", + "axes[2].set_xticks([])\n", + "axes[2].set_yticks([])\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))" + ] + }, + { + "cell_type": "markdown", + "id": "574e24a6-59f7-4ea5-b7d8-efc1106a412a", + "metadata": {}, + "source": [ + "### Problem 2.1: Implementing $\\alpha_t$ and $\\beta_t$ " + ] + }, + { + "cell_type": "markdown", + "id": "2690fc06-d90e-4505-aff4-01d5e8279465", + "metadata": {}, + "source": [ + "Let's get started by implementing $\\alpha_t$ and $\\beta_t$. We can think of these simply as callable objects which fulfill the simple contract $\\alpha_1 = \\beta_0 = 1$ and $\\alpha_0 = \\beta_1 = 0$, and which can compute their time derivatives $\\dot{\\alpha}_t$ and $\\dot{\\beta}_t$. We implement them below via the classes `Alpha` and `Beta`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f28999de-9f02-4439-a462-39965d932eb3", + "metadata": {}, + "outputs": [], + "source": [ + "class Alpha(ABC):\n", + " def __init__(self):\n", + " # Check alpha_t(0) = 0\n", + " assert torch.allclose(\n", + " self(torch.zeros(1,1)), torch.zeros(1,1)\n", + " )\n", + " # Check alpha_1 = 1\n", + " assert torch.allclose(\n", + " self(torch.ones(1,1)), torch.ones(1,1)\n", + " )\n", + " \n", + " @abstractmethod\n", + " def __call__(self, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates alpha_t. Should satisfy: self(0.0) = 0.0, self(1.0) = 1.0.\n", + " Args:\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - alpha_t (num_samples, 1)\n", + " \"\"\" \n", + " pass\n", + "\n", + " def dt(self, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates d/dt alpha_t.\n", + " Args:\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - d/dt alpha_t (num_samples, 1)\n", + " \"\"\" \n", + " t = t.unsqueeze(1) # (num_samples, 1, 1)\n", + " dt = vmap(jacrev(self))(t) # (num_samples, 1, 1, 1, 1)\n", + " return dt.view(-1, 1)\n", + " \n", + "class Beta(ABC):\n", + " def __init__(self):\n", + " # Check beta_0 = 1\n", + " assert torch.allclose(\n", + " self(torch.zeros(1,1)), torch.ones(1,1)\n", + " )\n", + " # Check beta_1 = 0\n", + " assert torch.allclose(\n", + " self(torch.ones(1,1)), torch.zeros(1,1)\n", + " )\n", + " \n", + " @abstractmethod\n", + " def __call__(self, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates alpha_t. Should satisfy: self(0.0) = 1.0, self(1.0) = 0.0.\n", + " Args:\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - beta_t (num_samples, 1)\n", + " \"\"\" \n", + " pass \n", + "\n", + " def dt(self, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates d/dt beta_t.\n", + " Args:\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - d/dt beta_t (num_samples, 1)\n", + " \"\"\" \n", + " t = t.unsqueeze(1) # (num_samples, 1, 1)\n", + " dt = vmap(jacrev(self))(t) # (num_samples, 1, 1, 1, 1)\n", + " return dt.view(-1, 1)" + ] + }, + { + "cell_type": "markdown", + "id": "205ab204-8f54-4011-88f7-d6765c1ae4e4", + "metadata": {}, + "source": [ + "In this section, we'll be using $$\\alpha_t = t \\quad \\quad \\text{and} \\quad \\quad \\beta_t = \\sqrt{1-t}.$$ It is not hard to check that both functions are continuously differentiable on $[0,1)$, and monotonic, that $\\alpha_1 = \\beta_0 = 1$, and that $\\alpha_0 = \\beta_1 = 0$.\n", + "\n", + "**Your job (2 points)**: Implement the `__call__` methods of the classes `LinearAlpha` and `SquareRootBeta` below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39faaaef-d2af-4617-b172-c1191761e129", + "metadata": {}, + "outputs": [], + "source": [ + "class LinearAlpha(Alpha):\n", + " \"\"\"\n", + " Implements alpha_t = t\n", + " \"\"\"\n", + " \n", + " def __call__(self, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - alpha_t (num_samples, 1)\n", + " \"\"\" \n", + " return t\n", + "\n", + " def dt(self, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates d/dt alpha_t.\n", + " Args:\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - d/dt alpha_t (num_samples, 1)\n", + " \"\"\" \n", + " return torch.ones_like(t)\n", + "\n", + "class SquareRootBeta(Beta):\n", + " \"\"\"\n", + " Implements beta_t = rt(1-t)\n", + " \"\"\"\n", + " def __call__(self, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - beta_t (num_samples, 1)\n", + " \"\"\" \n", + " return torch.sqrt(1-t)\n", + "\n", + " def dt(self, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates d/dt alpha_t.\n", + " Args:\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - d/dt alpha_t (num_samples, 1)\n", + " \"\"\" \n", + " return - 0.5 / (torch.sqrt(1 - t) + 1e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "b9c2bfe7-2bab-4a69-99fa-1553ddfcd93c", + "metadata": {}, + "source": [ + "Let us know turn towards the task of implementing the `GaussianConditionalProbabilityPath` path. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cc2084d-68ab-420e-adcd-ca8dfff34f74", + "metadata": {}, + "outputs": [], + "source": [ + "class GaussianConditionalProbabilityPath(ConditionalProbabilityPath):\n", + " def __init__(self, p_data: Sampleable, alpha: Alpha, beta: Beta):\n", + " p_simple = Gaussian.isotropic(p_data.dim, 1.0)\n", + " super().__init__(p_simple, p_data)\n", + " self.alpha = alpha\n", + " self.beta = beta\n", + "\n", + " def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Samples the conditioning variable z ~ p_data(x)\n", + " Args:\n", + " - num_samples: the number of samples\n", + " Returns:\n", + " - z: samples from p(z), (num_samples, dim)\n", + " \"\"\"\n", + " return p_data.sample(num_samples)\n", + " \n", + " def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Samples from the conditional distribution p_t(x|z) = N(alpha_t * z, beta_t**2 * I_d)\n", + " Args:\n", + " - z: conditioning variable (num_samples, dim)\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - x: samples from p_t(x|z), (num_samples, dim)\n", + " \"\"\"\n", + " return self.alpha(t) * z + self.beta(t) * torch.randn_like(z)\n", + " \n", + " def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates the conditional vector field u_t(x|z)\n", + " Note: Only defined on t in [0,1)\n", + " Args:\n", + " - x: position variable (num_samples, dim)\n", + " - z: conditioning variable (num_samples, dim)\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - conditional_vector_field: conditional vector field (num_samples, dim)\n", + " \"\"\" \n", + " alpha_t = self.alpha(t) # (num_samples, 1)\n", + " beta_t = self.beta(t) # (num_samples, 1)\n", + " dt_alpha_t = self.alpha.dt(t) # (num_samples, 1)\n", + " dt_beta_t = self.beta.dt(t) # (num_samples, 1)\n", + "\n", + " return (dt_alpha_t - dt_beta_t / beta_t * alpha_t) * z + dt_beta_t / beta_t * x\n", + "\n", + " def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates the conditional score of p_t(x|z) = N(alpha_t * z, beta_t**2 * I_d)\n", + " Note: Only defined on t in [0,1)\n", + " Args:\n", + " - x: position variable (num_samples, dim)\n", + " - z: conditioning variable (num_samples, dim)\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - conditional_score: conditional score (num_samples, dim)\n", + " \"\"\" \n", + " alpha_t = self.alpha(t)\n", + " beta_t = self.beta(t)\n", + " return (z * alpha_t - x) / beta_t ** 2" + ] + }, + { + "cell_type": "markdown", + "id": "6ae68e0c-b534-41ca-8f05-86d3994ea34d", + "metadata": {}, + "source": [ + "### Problem 2.2: Gaussian Conditional Probability Path" + ] + }, + { + "cell_type": "markdown", + "id": "7f362c33-10fb-440d-aede-48486004c4b8", + "metadata": {}, + "source": [ + "**Your work (2 points)**: Implement the class method `sample_conditional_path` to sample from the conditional distribution $p_t(x|z) = N(x;\\alpha_t z,\\beta_t^2 I_d)$. You can check the correctness of your implementation by running the next two cells to generate an image of the conditional probability path and comparing these to the corresponding plot from Figure 6 in the lecture notes (the one labeled \"Ground-Truth Conditional Probability Path\").\n", + "\n", + "**Hint**: You may use the fact that the reandom variable $X \\sim N(\\mu, \\sigma^2 I_d)$ is obtained via $X = \\mu + \\sigma Z$, where $Z \\sim N(0, I_d)$." + ] + }, + { + "cell_type": "markdown", + "id": "3b3744cf-1ba6-438c-9e53-25e823ce696f", + "metadata": {}, + "source": [ + "We can now sample from, and thus visualize, the *conditional* probaability path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8d49b40-083d-4318-b028-a1f32fe839bb", + "metadata": {}, + "outputs": [], + "source": [ + "# Construct conditional probability path\n", + "path = GaussianConditionalProbabilityPath(\n", + " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", + " alpha = LinearAlpha(),\n", + " beta = SquareRootBeta()\n", + ").to(device)\n", + "\n", + "scale = PARAMS[\"scale\"]\n", + "x_bounds = [-scale,scale]\n", + "y_bounds = [-scale,scale]\n", + "\n", + "plt.figure(figsize=(10,10))\n", + "plt.xlim(*x_bounds)\n", + "plt.ylim(*y_bounds)\n", + "plt.title('Gaussian Conditional Probability Path')\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "# Sample conditioning variable z\n", + "z = path.sample_conditioning_variable(1) # (1,2)\n", + "ts = torch.linspace(0.0, 1.0, 7).to(device)\n", + "\n", + "# Plot z\n", + "plt.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=75, label='z')\n", + "plt.xticks([])\n", + "plt.yticks([])\n", + "\n", + "# Plot conditional probability path at each intermediate t\n", + "num_samples = 1000\n", + "for t in ts:\n", + " zz = z.expand(num_samples, 2)\n", + " tt = t.unsqueeze(0).expand(num_samples, 1) # (samples, 1)\n", + " samples = path.sample_conditional_path(zz, tt) # (samples, 2)\n", + " plt.scatter(samples[:,0].cpu(), samples[:,1].cpu(), alpha=0.25, s=8, label=f't={t.item():.1f}')\n", + "\n", + "plt.legend(prop={'size': 18}, markerscale=3)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "ec77b226-8faf-4fd6-80f9-f2d151d529d0", + "metadata": {}, + "source": [ + "### Problem 2.3: Conditional Vector Field\n", + "From lecture and the notes, we know that the conditional vector field $u_t(x|z)$ is given by\n", + "$$u_t(x|z) = \\left(\\dot{\\alpha}_t-\\frac{\\dot{\\beta}_t}{\\beta_t}\\alpha_t\\right)z+\\frac{\\dot{\\beta}_t}{\\beta_t}x.$$" + ] + }, + { + "cell_type": "markdown", + "id": "76df0399-be66-4bcb-9095-03db72a6298e", + "metadata": {}, + "source": [ + "**Your work (2 points)**: Implement the class method `conditional_vector_field` to compute the conditional vector field $u_t(x|z)$.\n", + "\n", + "**Hint**: You can compute $\\dot{\\alpha}_t$ with `self.alpha.dt(t)`, which has been implemented for you. You may compute $\\dot{\\beta}_t$ similarly." + ] + }, + { + "cell_type": "markdown", + "id": "65dbeee3-36a7-4004-adf3-828bcfa66032", + "metadata": {}, + "source": [ + "We may now visualize the conditional trajectories corresponding to the ODE $$d X_t = u_t(X_t|z)dt, \\quad \\quad X_0 = x_0 \\sim p_{\\text{simple}}.$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5aee7e7-a730-4270-85d8-f860ca162830", + "metadata": {}, + "outputs": [], + "source": [ + "class ConditionalVectorFieldODE(ODE):\n", + " def __init__(self, path: ConditionalProbabilityPath, z: torch.Tensor):\n", + " \"\"\"\n", + " Args:\n", + " - path: the ConditionalProbabilityPath object to which this vector field corresponds\n", + " - z: the conditioning variable, (1, dim)\n", + " \"\"\"\n", + " super().__init__()\n", + " self.path = path\n", + " self.z = z\n", + "\n", + " def drift_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Returns the conditional vector field u_t(x|z)\n", + " Args:\n", + " - x: state at time t, shape (bs, dim)\n", + " - t: time, shape (bs,.)\n", + " Returns:\n", + " - u_t(x|z): shape (batch_size, dim)\n", + " \"\"\"\n", + " bs = x.shape[0]\n", + " z = self.z.expand(bs, *self.z.shape[1:])\n", + " return self.path.conditional_vector_field(x,z,t)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f7b9b84-eb6d-4323-80ab-dc41b6281cb4", + "metadata": {}, + "outputs": [], + "source": [ + "# Run me for Problem 2.3!\n", + "\n", + "#######################\n", + "# Change these values #\n", + "#######################\n", + "num_samples = 1000\n", + "num_timesteps = 1000\n", + "num_marginals = 3\n", + "\n", + "########################\n", + "# Setup path and plot #\n", + "########################\n", + "\n", + "path = GaussianConditionalProbabilityPath(\n", + " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", + " alpha = LinearAlpha(),\n", + " beta = SquareRootBeta()\n", + ").to(device)\n", + "\n", + "\n", + "# Setup figure\n", + "fig, axes = plt.subplots(1,3, figsize=(36, 12))\n", + "scale = PARAMS[\"scale\"]\n", + "legend_size = 24\n", + "markerscale = 1.8\n", + "x_bounds = [-scale,scale]\n", + "y_bounds = [-scale,scale]\n", + "\n", + "# Sample conditioning variable z\n", + "torch.cuda.manual_seed(1)\n", + "z = path.sample_conditioning_variable(1) # (1,2)\n", + "\n", + "######################################\n", + "# Graph samples from conditional ODE #\n", + "######################################\n", + "ax = axes[1]\n", + "\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "ax.set_title('Samples from Conditional ODE', fontsize=20)\n", + "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "\n", + "# Construct integrator and plot trajectories\n", + "sigma = 0.5 # Can't make this too high or integration is numerically unstable!\n", + "ode = ConditionalVectorFieldODE(path, z)\n", + "simulator = EulerSimulator(ode)\n", + "x0 = path.p_simple.sample(num_samples) # (num_samples, 2)\n", + "ts = torch.linspace(0.0, 1.0, num_timesteps).view(1,-1,1).expand(num_samples,-1,1).to(device) # (num_samples, nts, 1)\n", + "xts = simulator.simulate_with_trajectory(x0, ts) # (bs, nts, dim)\n", + "\n", + "# Extract every n-th integration step to plot\n", + "every_n = record_every(num_timesteps=num_timesteps, record_every=num_timesteps // num_marginals)\n", + "xts_every_n = xts[:,every_n,:] # (bs, nts // n, dim)\n", + "ts_every_n = ts[0,every_n] # (nts // n,)\n", + "for plot_idx in range(xts_every_n.shape[1]):\n", + " tt = ts_every_n[plot_idx].item()\n", + " ax.scatter(xts_every_n[:,plot_idx,0].detach().cpu(), xts_every_n[:,plot_idx,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt:.2f}')\n", + "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", + "\n", + "\n", + "#########################################\n", + "# Graph Trajectories of Conditional ODE #\n", + "#########################################\n", + "ax = axes[2]\n", + "\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "ax.set_title('Trajectories of Conditional ODE', fontsize=20)\n", + "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", + "\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "for traj_idx in range(15):\n", + " ax.plot(xts[traj_idx,:,0].detach().cpu(), xts[traj_idx,:,1].detach().cpu(), alpha=0.5, color='black')\n", + "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", + "\n", + "\n", + "###################################################\n", + "# Graph Ground-Truth Conditional Probability Path #\n", + "###################################################\n", + "ax = axes[0]\n", + "\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "ax.set_title('Ground-Truth Conditional Probability Path', fontsize=20)\n", + "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", + "\n", + "\n", + "for plot_idx in range(xts_every_n.shape[1]):\n", + " tt = ts_every_n[plot_idx].unsqueeze(0).expand(num_samples, 1)\n", + " zz = z.expand(num_samples, 2)\n", + " marginal_samples = path.sample_conditional_path(zz, tt)\n", + " ax.scatter(marginal_samples[:,0].detach().cpu(), marginal_samples[:,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt[0,0].item():.2f}')\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "3f9a579c-a0e7-48c3-b1b5-dae685d90b20", + "metadata": {}, + "source": [ + "**Note**: You may have noticed that since for Gaussian probability paths, $z \\sim p_{\\text{data}}(x)$, the method `GaussianConditionalProbabilityPath.sample_conditioning_variable` is effectively sampling from the data distribution. But wait - aren't we trying to learn to sample from $p_{\\text{data}}$ in the first place? This is a subtlety that we have glossed over thus far. The answer is that *in practice*, `sample_conditioning_variable` would return points from a finite *training set*, which is formally assumed to have been sampled IID from the true distribution $z \\sim p_{\\text{data}}$." + ] + }, + { + "cell_type": "markdown", + "id": "75225acb-5522-467e-82df-9fba32c0b708", + "metadata": {}, + "source": [ + "### Problem 2.4: The Conditional Score" + ] + }, + { + "cell_type": "markdown", + "id": "705cac87-6d1e-4df6-868c-e3d8dbd1c55d", + "metadata": {}, + "source": [ + "As in lecture may now visualize the conditional trajectories corresponding to the SDE $$d X_t = \\left[u_t(X_t|z) + \\frac{1}{2}\\sigma^2 \\nabla_x \\log p_t(X_t|z) \\right]dt + \\sigma\\, dW_t, \\quad \\quad X_0 = x_0 \\sim p_{\\text{simple}},$$\n", + "obtained by adding *Langevin dynamics* to the original ODE." + ] + }, + { + "cell_type": "markdown", + "id": "bf5de941-a496-4f79-8814-dd5d47f0118e", + "metadata": {}, + "source": [ + "**Your work (2 points)**: Implement the class method `conditional_score` to compute the conditional distribution $\\nabla_x \\log p_t(x|z)$, which we compute to be\n", + "$$\\nabla_x \\log p_t(x|z) = \\nabla_x N(x;\\alpha_t z,\\beta_t^2 I_d) = \\frac{\\alpha_t z - x}{\\beta_t^2}.$$\n", + "To check for correctness, use the next two cells to verify that samples from the conditional SDE match the samples drawn analytically from the conditional probability path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c0c2adf-0c0b-4b30-9de0-8c0dd8db47a8", + "metadata": {}, + "outputs": [], + "source": [ + "class ConditionalVectorFieldSDE(SDE):\n", + " def __init__(self, path: ConditionalProbabilityPath, z: torch.Tensor, sigma: float):\n", + " \"\"\"\n", + " Args:\n", + " - path: the ConditionalProbabilityPath object to which this vector field corresponds\n", + " - z: the conditioning variable, (1, ...)\n", + " \"\"\"\n", + " super().__init__()\n", + " self.path = path\n", + " self.z = z\n", + " self.sigma = sigma\n", + "\n", + " def drift_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Returns the conditional vector field u_t(x|z)\n", + " Args:\n", + " - x: state at time t, shape (bs, dim)\n", + " - t: time, shape (bs,.)\n", + " Returns:\n", + " - u_t(x|z): shape (batch_size, dim)\n", + " \"\"\"\n", + " bs = x.shape[0]\n", + " z = self.z.expand(bs, *self.z.shape[1:])\n", + " return self.path.conditional_vector_field(x,z,t) + 0.5 * self.sigma**2 * self.path.conditional_score(x,z,t)\n", + "\n", + " def diffusion_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " - x: state at time t, shape (bs, dim)\n", + " - t: time, shape (bs,.)\n", + " Returns:\n", + " - u_t(x|z): shape (batch_size, dim)\n", + " \"\"\"\n", + " return self.sigma * torch.randn_like(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ac806db-9f0f-49cc-b367-84739bda7d1e", + "metadata": {}, + "outputs": [], + "source": [ + "# Run me for Problem 2.3!\n", + "\n", + "#######################\n", + "# Change these values #\n", + "#######################\n", + "num_samples = 1000\n", + "num_timesteps = 1000\n", + "num_marginals = 3\n", + "sigma = 2.5\n", + "\n", + "########################\n", + "# Setup path and plot #\n", + "########################\n", + "\n", + "path = GaussianConditionalProbabilityPath(\n", + " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", + " alpha = LinearAlpha(),\n", + " beta = SquareRootBeta()\n", + ").to(device)\n", + "\n", + "\n", + "# Setup figure\n", + "fig, axes = plt.subplots(1,3, figsize=(36, 12))\n", + "scale = PARAMS[\"scale\"]\n", + "x_bounds = [-scale,scale]\n", + "y_bounds = [-scale,scale]\n", + "legend_size = 24\n", + "markerscale = 1.8\n", + "\n", + "# Sample conditioning variable z\n", + "torch.cuda.manual_seed(1)\n", + "z = path.sample_conditioning_variable(1) # (1,2)\n", + "\n", + "######################################\n", + "# Graph Samples from Conditional SDE #\n", + "######################################\n", + "ax = axes[1]\n", + "\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "ax.set_title('Samples from Conditional SDE', fontsize=20)\n", + "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "\n", + "# Construct integrator and plot trajectories\n", + "sde = ConditionalVectorFieldSDE(path, z, sigma)\n", + "simulator = EulerMaruyamaSimulator(sde)\n", + "x0 = path.p_simple.sample(num_samples) # (num_samples, 2)\n", + "ts = torch.linspace(0.0, 1.0, num_timesteps).view(1,-1,1).expand(num_samples,-1,1).to(device) # (num_samples, nts, 1)\n", + "xts = simulator.simulate_with_trajectory(x0, ts) # (bs, nts, dim)\n", + "\n", + "# Extract every n-th integration step to plot\n", + "every_n = record_every(num_timesteps=num_timesteps, record_every=num_timesteps // num_marginals)\n", + "xts_every_n = xts[:,every_n,:] # (bs, nts // n, dim)\n", + "ts_every_n = ts[0,every_n] # (nts // n,)\n", + "for plot_idx in range(xts_every_n.shape[1]):\n", + " tt = ts_every_n[plot_idx].item()\n", + " ax.scatter(xts_every_n[:,plot_idx,0].detach().cpu(), xts_every_n[:,plot_idx,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt:.2f}')\n", + "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", + "\n", + "\n", + "##########################################\n", + "# Graph Trajectories of Conditional SDE #\n", + "##########################################\n", + "ax = axes[2]\n", + "\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "ax.set_title('Trajectories of Conditional SDE', fontsize=20)\n", + "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", + "\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "for traj_idx in range(5):\n", + " ax.plot(xts[traj_idx,:,0].detach().cpu(), xts[traj_idx,:,1].detach().cpu(), alpha=0.5, color='black')\n", + "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", + "\n", + "\n", + "###################################################\n", + "# Graph Ground-Truth Conditional Probability Path #\n", + "###################################################\n", + "ax = axes[0]\n", + "\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "ax.set_title('Ground-Truth Conditional Probability Path', fontsize=20)\n", + "ax.scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20) # Plot z\n", + "\n", + "\n", + "for plot_idx in range(xts_every_n.shape[1]):\n", + " tt = ts_every_n[plot_idx].unsqueeze(0).expand(num_samples, 1)\n", + " zz = z.expand(num_samples, 2)\n", + " marginal_samples = path.sample_conditional_path(zz, tt)\n", + " ax.scatter(marginal_samples[:,0].detach().cpu(), marginal_samples[:,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt[0,0].item():.2f}')\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "4609e0a6-cc7f-429e-9f81-2423164b5c8d", + "metadata": {}, + "source": [ + "# Part 3: Flow Matching and Score Matching with Gaussian Conditional Probability Paths\n" + ] + }, + { + "cell_type": "markdown", + "id": "565e1f4f-f02e-4663-ae99-b62d9fee95b3", + "metadata": {}, + "source": [ + "### Problem 3.1 Flow Matching with Gaussian Conditional Probability Paths" + ] + }, + { + "cell_type": "markdown", + "id": "760936a6-fc17-452e-82a1-abb6e1e43c53", + "metadata": {}, + "source": [ + "Recall now that from lecture that our goal is to learn the *marginal vector field* $u_t(x)$ given by $$u_t^{\\text{ref}}(x) = \\mathbb{E}_{z \\sim p_t(z|x)}\\left[u_t^{\\text{ref}}(x|z)\\right].$$\n", + "Unfortunately, we don't actually know what $u_t^{\\text{ref}}(x)$ is! We will thus approximate $u_t^{\\text{ref}}(x)$ as a neural network $u_t^{\\theta}(x)$, and exploit the identity $$ u_t^{\\text{ref}}(x) = \\text{argmin}_{u_t(x)} \\,\\,\\mathbb{E}_{z \\sim p_t(z|x)} \\lVert u_t(x) - u_t^{\\text{ref}}(x|z)\\rVert^2$$ to obtain the **conditional flow matching objective**\n", + "$$ \\mathcal{L}_{\\text{CFM}}(\\theta) = \\,\\,\\mathbb{E}_{z \\sim p(z), x \\sim p_t(x|z)} \\lVert u_t^{\\theta}(x) - u_t^{\\text{ref}}(x|z)\\rVert^2.$$\n", + "To model $u_t^{\\theta}(x)$, we'll use a simple MLP. This network will take in both $x$ and $t$, and will return the learned vector field $u_t^{\\theta}(x)$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e57626e7-765e-4e39-aa46-ae403f7960ef", + "metadata": {}, + "outputs": [], + "source": [ + "def build_mlp(dims: List[int], activation: Type[torch.nn.Module] = torch.nn.SiLU):\n", + " mlp = []\n", + " for idx in range(len(dims) - 1):\n", + " mlp.append(torch.nn.Linear(dims[idx], dims[idx + 1]))\n", + " if idx < len(dims) - 2:\n", + " mlp.append(activation())\n", + " return torch.nn.Sequential(*mlp)\n", + "\n", + "class MLPVectorField(torch.nn.Module):\n", + " \"\"\"\n", + " MLP-parameterization of the learned vector field u_t^theta(x)\n", + " \"\"\"\n", + " def __init__(self, dim: int, hiddens: List[int]):\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.net = build_mlp([dim + 1] + hiddens + [dim])\n", + "\n", + " def forward(self, x: torch.Tensor, t: torch.Tensor):\n", + " \"\"\"\n", + " Args:\n", + " - x: (bs, dim)\n", + " Returns:\n", + " - u_t^theta(x): (bs, dim)\n", + " \"\"\"\n", + " xt = torch.cat([x,t], dim=-1)\n", + " return self.net(xt) " + ] + }, + { + "cell_type": "markdown", + "id": "73bdbaaa-fecb-4753-a9a2-043b1dee585a", + "metadata": {}, + "source": [ + "Let's first define a general-purpose class `Trainer` to keep things tidy as we start training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94f618c8-5033-4f16-8826-f29d4f1772f8", + "metadata": {}, + "outputs": [], + "source": [ + "class Trainer(ABC):\n", + " def __init__(self, model: torch.nn.Module):\n", + " super().__init__()\n", + " self.model = model\n", + "\n", + " @abstractmethod\n", + " def get_train_loss(self, **kwargs) -> torch.Tensor:\n", + " pass\n", + "\n", + " def get_optimizer(self, lr: float):\n", + " return torch.optim.Adam(self.model.parameters(), lr=lr)\n", + "\n", + " def train(self, num_epochs: int, device: torch.device, lr: float = 1e-3, **kwargs) -> torch.Tensor:\n", + " # Start\n", + " self.model.to(device)\n", + " opt = self.get_optimizer(lr)\n", + " self.model.train()\n", + "\n", + " # Train loop\n", + " pbar = tqdm(enumerate(range(num_epochs)))\n", + " for idx, epoch in pbar:\n", + " opt.zero_grad()\n", + " loss = self.get_train_loss(**kwargs)\n", + " loss.backward()\n", + " opt.step()\n", + " pbar.set_description(f'Epoch {idx}, loss: {loss.item()}')\n", + "\n", + " # Finish\n", + " self.model.eval()" + ] + }, + { + "cell_type": "markdown", + "id": "4f270b97-ce2c-4b6b-b924-e4d976813e0e", + "metadata": {}, + "source": [ + "**Your work (2 points)**: Fill in `ConditionalFlowMatchingTrainer.get_train_loss` below. This function should implement the conditional flow matching objective $$\\mathcal{L}_{\\text{CFM}}(\\theta) = \\,\\,\\mathbb{E}_{\\textcolor{blue}{t \\in \\mathcal{U}[0,1), z \\sim p(z), x \\sim p_t(x|z)}} \\textcolor{green}{\\lVert u_t^{\\theta}(x) - u_t^{\\text{ref}}(x|z)\\rVert^2}$$\n", + "using a Monte-Carlo estimate of the form\n", + "$$\\frac{1}{N}\\sum_{i=1}^N \\textcolor{green}{\\lVert u_{t_i}^{\\theta}(x_i) - u_{t_i}^{\\text{ref}}(x_i|z_i)\\rVert^2}, \\quad \\quad \\quad \\forall i\\in[1, \\dots, N]: \\textcolor{blue}{\\,z_i \\sim p_{\\text{data}},\\, t_i \\sim \\mathcal{U}[0,1),\\, x_i \\sim p_t(\\cdot | z_i)}.$$\n", + "Here, $N$ is our *batch size*.\n", + "\n", + "\n", + "**Hint 1**: For sampling:\n", + "- You can sample `batch_size` points $z$ from $p_{\\text{data}}$ using `self.path.p_data.sample(batch_size)`.\n", + "- You can sample `batch_size` values of `t` using `torch.rand(batch_size, 1)`.\n", + "- You can sample `batch_size` points from `p_t(x|z)` using `self.path.p_simple.sample_conditional_path(z,t)`.\n", + "\n", + "**Hint 2**: For the loss function:\n", + "- You can access $u_t^{\\theta}(x)$ using `self.model(x,t)`.\n", + "- You can access $u_t^{\\text{ref}}(x|z)$ using `self.path.conditional_vector_field(x,z,t)`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f01a592-40da-4e71-a7ce-b53363100c64", + "metadata": {}, + "outputs": [], + "source": [ + "class ConditionalFlowMatchingTrainer(Trainer):\n", + " def __init__(self, path: ConditionalProbabilityPath, model: MLPVectorField, **kwargs):\n", + " super().__init__(model, **kwargs)\n", + " self.path = path\n", + "\n", + " def get_train_loss(self, batch_size: int) -> torch.Tensor:\n", + " z = self.path.p_data.sample(batch_size) # (bs, dim)\n", + " t = torch.rand(batch_size,1).to(z) # (bs, 1)\n", + " x = self.path.sample_conditional_path(z,t) # (bs, dim)\n", + "\n", + " ut_theta = self.model(x,t) # (bs, dim)\n", + " ut_ref = self.path.conditional_vector_field(x,z,t) # (bs, dim)\n", + " error = torch.sum(torch.square(ut_theta - ut_ref), dim=-1) # (bs,)\n", + " return torch.mean(error)" + ] + }, + { + "cell_type": "markdown", + "id": "75fc5768-d14c-4477-88d1-7b27dfc61538", + "metadata": {}, + "source": [ + "Now let's train! This may take about a minute... **Remember, the loss should converge, but not to zero!**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a2d13db-7fd3-429b-ab54-86d937b4a9ff", + "metadata": {}, + "outputs": [], + "source": [ + "# Construct conditional probability path\n", + "path = GaussianConditionalProbabilityPath(\n", + " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", + " alpha = LinearAlpha(),\n", + " beta = SquareRootBeta()\n", + ").to(device)\n", + "\n", + "# Construct learnable vector field\n", + "flow_model = MLPVectorField(dim=2, hiddens=[64,64,64,64])\n", + "\n", + "# Construct trainer\n", + "trainer = ConditionalFlowMatchingTrainer(path, flow_model)\n", + "losses = trainer.train(num_epochs=5000, device=device, lr=1e-3, batch_size=1000)" + ] + }, + { + "cell_type": "markdown", + "id": "b3e40ead-7bfc-4083-9d9c-bf0abd83a5c7", + "metadata": {}, + "source": [ + "Is our model any good? Let's visualize? First, we need to wrap our learned vector field in an subclass of `ODE` so that we can simulate it using our `Simulator` class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a25f8b48-cc63-4b52-b59e-5bc57813e338", + "metadata": {}, + "outputs": [], + "source": [ + "class LearnedVectorFieldODE(ODE):\n", + " def __init__(self, net: MLPVectorField):\n", + " self.net = net\n", + "\n", + " def drift_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " - x: (bs, dim)\n", + " - t: (bs, dim)\n", + " Returns:\n", + " - u_t: (bs, dim)\n", + " \"\"\"\n", + " return self.net(x, t)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56152a16-abf6-47cb-aa0f-f8f69b8ab690", + "metadata": {}, + "outputs": [], + "source": [ + "#######################\n", + "# Change these values #\n", + "#######################\n", + "num_samples = 1000\n", + "num_timesteps = 1000\n", + "num_marginals = 3\n", + "\n", + "\n", + "##############\n", + "# Setup Plot #\n", + "##############\n", + "\n", + "scale = PARAMS[\"scale\"]\n", + "x_bounds = [-scale,scale]\n", + "y_bounds = [-scale,scale]\n", + "legend_size=24\n", + "markerscale=1.8\n", + "\n", + "# Setup figure\n", + "fig, axes = plt.subplots(1,3, figsize=(36, 12))\n", + "\n", + "###########################################\n", + "# Graph Samples from Learned Marginal ODE #\n", + "###########################################\n", + "ax = axes[1]\n", + "\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "ax.set_title(\"Samples from Learned Marginal ODE\", fontsize=20)\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "\n", + "# Construct integrator and plot trajectories\n", + "ode = LearnedVectorFieldODE(flow_model)\n", + "simulator = EulerSimulator(ode)\n", + "x0 = path.p_simple.sample(num_samples) # (num_samples, 2)\n", + "ts = torch.linspace(0.0, 1.0, num_timesteps).view(1,-1,1).expand(num_samples,-1,1).to(device) # (num_samples, nts, 1)\n", + "xts = simulator.simulate_with_trajectory(x0, ts) # (bs, nts, dim)\n", + "\n", + "# Extract every n-th integration step to plot\n", + "every_n = record_every(num_timesteps=num_timesteps, record_every=num_timesteps // num_marginals)\n", + "xts_every_n = xts[:,every_n,:] # (bs, nts // n, dim)\n", + "ts_every_n = ts[0,every_n] # (nts // n,)\n", + "for plot_idx in range(xts_every_n.shape[1]):\n", + " tt = ts_every_n[plot_idx].item()\n", + " ax.scatter(xts_every_n[:,plot_idx,0].detach().cpu(), xts_every_n[:,plot_idx,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt:.2f}')\n", + "\n", + "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", + "\n", + "##############################################\n", + "# Graph Trajectories of Learned Marginal ODE #\n", + "##############################################\n", + "ax = axes[2]\n", + "ax.set_title(\"Trajectories of Learned Marginal ODE\", fontsize=20)\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "for traj_idx in range(num_samples // 10):\n", + " ax.plot(xts[traj_idx,:,0].detach().cpu(), xts[traj_idx,:,1].detach().cpu(), alpha=0.5, color='black')\n", + "\n", + "################################################\n", + "# Graph Ground-Truth Marginal Probability Path #\n", + "################################################\n", + "ax = axes[0]\n", + "ax.set_title(\"Ground-Truth Marginal Probability Path\", fontsize=20)\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "\n", + "for plot_idx in range(xts_every_n.shape[1]):\n", + " tt = ts_every_n[plot_idx].unsqueeze(0).expand(num_samples, 1)\n", + " marginal_samples = path.sample_marginal_path(tt)\n", + " ax.scatter(marginal_samples[:,0].detach().cpu(), marginal_samples[:,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt[0,0].item():.2f}')\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", + " \n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "24b0491a-9d82-413d-b6c9-95199e0f23e4", + "metadata": {}, + "source": [ + "### Problem 3.2: Score Matching with Gaussian Conditional Probability Paths" + ] + }, + { + "cell_type": "markdown", + "id": "48fe3eb5-51ee-4137-9626-6c1d1af910fe", + "metadata": {}, + "source": [ + "We have thus far used flow matching to train a model $u_t^{\\theta}(x) \\approx u_t^{\\text{ref}}$ so that $$d X_t = u_t^{\\theta}(X_t) dt $$ approximately passes through the desired marginal probability path $p_t(x)$. Now recall from lecture that we may augment the reference marginal vector field $u_t^{\\text{ref}}(x)$ with *Langevin dynamics* to add stochasticity while preserving the marginals, viz., $$dX_t = \\left[u_t^{\\text{ref}}(x) + \\frac{1}{2}\\sigma^2 \\nabla \\log p_t(x)\\right] dt + \\sigma d W_t.$$\n", + "Substituting our learned approximation $u_t^{\\theta}(x) \\approx u_t^{\\text{ref}}$ therefore yields \n", + "$$dX_t = \\left[u_t^{\\theta}(x) + \\frac{1}{2}\\sigma^2 \\nabla \\log p_t(x)\\right] dt + \\sigma d W_t.$$\n", + "There's just one issue, what's the marginal score $\\nabla \\log p_t(x)$? In Question 2.3, we computed the conditional score $\\nabla \\log p_t(x|z)$ of the Gaussian probability path. In the same way that we learned an approximation $u_t^{\\theta}(x) \\approx u_t^{\\text{ref}}$, we'd like to be able to learn a similar approximation $s_t^{\\theta}(x) \\approx \\nabla \\log p_t(x)$. Recall from lecture the identity $$\\nabla \\log p_t(x) = \\mathbb{E}_{z \\sim p_t(z|x)}\\left[\\nabla \\log p_t(x|z) \\right].$$ It then immediately follows that\n", + "$$\\nabla \\log p_t(x) = \\text{argmin}_{s_t(x)} \\,\\,\\mathbb{E}_{z \\sim p(z), x \\sim p_t(x|z)} \\lVert s_t(x) - \\nabla \\log p_x(x|z)\\rVert^2.$$\n", + "We thus obtain the **conditional score matching** loss\n", + "$$\\mathcal{L}_{\\text{CSM}}(\\theta) \\triangleq \\mathbb{E}_{t \\sim \\mathcal{U}[0,1), z \\sim p(z), x \\sim p_t(x|z)} \\lVert s_t^{\\theta}(x) - \\nabla \\log p_x(x|z)\\rVert^2.$$\n", + "Here, we will parameterize $s_t^{\\theta}(x): \\mathbb{R}^2 \\to \\mathbb{R}^2$ as a simple MLP, just like $u_t^{\\theta}(x)$." + ] + }, + { + "cell_type": "markdown", + "id": "7efcae65-2c86-4018-9f7c-ae9175607e06", + "metadata": {}, + "source": [ + "**Your job (2 points)**: Fill in method `ConditionalScoreMatchingTrainer.get_train_loss` to implement the conditional score matching loss $\\mathcal{L}_{\\text{CSM}}(\\theta)$.\n", + "\n", + "**Hint:** Remember to re-use your implementation of `GaussianConditionalProbabilityPath.conditional_score`!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96d654b2-13d4-47bf-a5c3-98606c580771", + "metadata": {}, + "outputs": [], + "source": [ + "class MLPScore(torch.nn.Module):\n", + " \"\"\"\n", + " MLP-parameterization of the learned score field\n", + " \"\"\"\n", + " def __init__(self, dim: int, hiddens: List[int]):\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.net = build_mlp([dim + 1] + hiddens + [dim])\n", + "\n", + " def forward(self, x: torch.Tensor, t: torch.Tensor):\n", + " \"\"\"\n", + " Args:\n", + " - x: (bs, dim)\n", + " Returns:\n", + " - s_t^theta(x): (bs, dim)\n", + " \"\"\"\n", + " xt = torch.cat([x,t], dim=-1)\n", + " return self.net(xt) \n", + " \n", + "class ConditionalScoreMatchingTrainer(Trainer):\n", + " def __init__(self, path: ConditionalProbabilityPath, model: MLPScore, **kwargs):\n", + " super().__init__(model, **kwargs)\n", + " self.path = path\n", + "\n", + " def get_train_loss(self, batch_size: int) -> torch.Tensor:\n", + " z = self.path.p_data.sample(batch_size) # (bs, dim)\n", + " t = torch.rand(batch_size,1).to(z) # (bs, 1)\n", + " x = self.path.sample_conditional_path(z,t) # (bs, dim)\n", + "\n", + " s_theta = self.model(x,t) # (bs, dim)\n", + " s_ref = self.path.conditional_score(x,z,t) # (bs, dim)\n", + " mse = torch.sum(torch.square(s_theta - s_ref), dim=-1) # (bs,)\n", + " return torch.mean(mse)" + ] + }, + { + "cell_type": "markdown", + "id": "4f35e47c-1836-41b4-8234-37acc589ee70", + "metadata": {}, + "source": [ + "Now let's train! **Remember, the loss should converge, but not to zero!**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfff0733-cd1d-4e90-970a-3f9948f0ea65", + "metadata": {}, + "outputs": [], + "source": [ + "# Construct conditional probability path\n", + "path = GaussianConditionalProbabilityPath(\n", + " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", + " alpha = LinearAlpha(),\n", + " beta = SquareRootBeta()\n", + ").to(device)\n", + "\n", + "# Construct learnable vector field\n", + "score_model = MLPScore(dim=2, hiddens=[64,64,64,64])\n", + "\n", + "# Construct trainer\n", + "trainer = ConditionalScoreMatchingTrainer(path, score_model)\n", + "losses = trainer.train(num_epochs=1000, device=device, lr=1e-3, batch_size=1000)" + ] + }, + { + "cell_type": "markdown", + "id": "a946925a-1230-48bc-bc6c-19b72bd26099", + "metadata": {}, + "source": [ + "Now let's visualize our work! Before we do however, we'll need to wrap our learned our flow model and score model in an instance of `SDE` so that we can integrate it using our `EulerMaruyamaIntegrator` class. This new class, `LangevinFlowSDE` will correspond to the dynamics $$dX_t = \\left[u_t^{\\theta}(x) + \\frac{1}{2}\\sigma^2 s_t^{\\theta}(x)\\right] dt + \\sigma d W_t.$$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afc4fae7-9dbb-4e5e-9839-8088f491549e", + "metadata": {}, + "outputs": [], + "source": [ + "class LangevinFlowSDE(SDE):\n", + " def __init__(self, flow_model: MLPVectorField, score_model: MLPScore, sigma: float):\n", + " \"\"\"\n", + " Args:\n", + " - path: the ConditionalProbabilityPath object to which this vector field corresponds\n", + " - z: the conditioning variable, (1, dim)\n", + " \"\"\"\n", + " super().__init__()\n", + " self.flow_model = flow_model\n", + " self.score_model = score_model\n", + " self.sigma = sigma\n", + "\n", + " def drift_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " - x: state at time t, shape (bs, dim)\n", + " - t: time, shape (bs,.)\n", + " Returns:\n", + " - u_t(x|z): shape (batch_size, dim)\n", + " \"\"\"\n", + " return self.flow_model(x,t) + 0.5 * sigma ** 2 * self.score_model(x, t)\n", + "\n", + " def diffusion_coefficient(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " - x: state at time t, shape (bs, dim)\n", + " - t: time, shape (bs,.)\n", + " Returns:\n", + " - u_t(x|z): shape (batch_size, dim)\n", + " \"\"\"\n", + " return self.sigma * torch.randn_like(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e40fa803-25d7-4bbd-8230-e3650bdb79d4", + "metadata": {}, + "outputs": [], + "source": [ + "#######################\n", + "# Change these values #\n", + "#######################\n", + "num_samples = 1000\n", + "num_timesteps = 300\n", + "num_marginals = 3\n", + "sigma = 2.0 # Don't set sigma too large or you'll get numerical issues!\n", + "\n", + "\n", + "##############\n", + "# Setup Plot #\n", + "##############\n", + "\n", + "scale = PARAMS[\"scale\"]\n", + "x_bounds = [-scale,scale]\n", + "y_bounds = [-scale,scale]\n", + "legend_size = 24\n", + "markerscale = 1.8\n", + "\n", + "# Setup figure\n", + "fig, axes = plt.subplots(1,3, figsize=(36, 12))\n", + "\n", + "###########################################\n", + "# Graph Samples from Learned Marginal SDE #\n", + "###########################################\n", + "ax = axes[1]\n", + "ax.set_title(\"Samples from Learned Marginal SDE\", fontsize=20)\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=path.p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=path.p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "\n", + "# Construct integrator and plot trajectories\n", + "sde = LangevinFlowSDE(flow_model, score_model, sigma)\n", + "simulator = EulerMaruyamaSimulator(sde)\n", + "x0 = path.p_simple.sample(num_samples) # (num_samples, 2)\n", + "ts = torch.linspace(0.0, 1.0, num_timesteps).view(1,-1,1).expand(num_samples,-1,1).to(device) # (num_samples, nts, 1)\n", + "xts = simulator.simulate_with_trajectory(x0, ts) # (bs, nts, dim)\n", + "\n", + "# Extract every n-th integration step to plot\n", + "every_n = record_every(num_timesteps=num_timesteps, record_every=num_timesteps // num_marginals)\n", + "xts_every_n = xts[:,every_n,:] # (bs, nts // n, dim)\n", + "ts_every_n = ts[0,every_n] # (nts // n,)\n", + "for plot_idx in range(xts_every_n.shape[1]):\n", + " tt = ts_every_n[plot_idx].item()\n", + " ax.scatter(xts_every_n[:,plot_idx,0].detach().cpu(), xts_every_n[:,plot_idx,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt:.2f}')\n", + "\n", + "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", + "\n", + "###############################################\n", + "# Graph Trajectories of Learned Marginal SDE #\n", + "###############################################\n", + "ax = axes[2]\n", + "ax.set_title(\"Trajectories of Learned Marginal SDE\", fontsize=20)\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=path.p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=path.p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "for traj_idx in range(num_samples // 10):\n", + " ax.plot(xts[traj_idx,:,0].detach().cpu(), xts[traj_idx,:,1].detach().cpu(), alpha=0.5, color='black')\n", + "\n", + "################################################\n", + "# Graph Ground-Truth Marginal Probability Path #\n", + "################################################\n", + "ax = axes[0]\n", + "ax.set_title(\"Ground-Truth Marginal Probability Path\", fontsize=20)\n", + "ax.set_xlim(*x_bounds)\n", + "ax.set_ylim(*y_bounds)\n", + "ax.set_xticks([])\n", + "ax.set_yticks([])\n", + "\n", + "for plot_idx in range(xts_every_n.shape[1]):\n", + " tt = ts_every_n[plot_idx].unsqueeze(0).expand(num_samples, 1)\n", + " marginal_samples = path.sample_marginal_path(tt)\n", + " ax.scatter(marginal_samples[:,0].detach().cpu(), marginal_samples[:,1].detach().cpu(), marker='o', alpha=0.5, label=f't={tt[0,0].item():.2f}')\n", + "\n", + "# Plot source and target\n", + "imshow_density(density=path.p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + "imshow_density(density=path.p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + "\n", + "ax.legend(prop={'size': legend_size}, loc='upper right', markerscale=markerscale)\n", + " \n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "57aaf7f6-d7d0-4fe8-b415-8a069ac02f93", + "metadata": {}, + "source": [ + "### Question 3.3: Deriving the Marginal Score from the Marginal Flow\n", + "Recall from the notes and the lecture that for Gaussian probability paths $$u_t^{\\text{ref}}(x) = a_tx + b_t\\nabla \\log p_t^{\\text{ref}}(x).$$\n", + "\n", + "where $(a_t, b_t) = \\left(\\frac{\\dot{\\alpha}_t}{\\alpha_t}, \\beta_t^2 \\frac{\\dot{\\alpha}_t}{\\alpha_t} - \\dot{\\beta}_t \\beta_t\\right)$. Rearranging yields $$\\nabla \\log p_t^{\\text{ref}}(x) = \\frac{u_t^{\\text{ref}}(x) - a_tx}{b_t}.$$\n", + "\n", + "Therefore, we may instead exploit the fact that we have already trained $u_t^{\\theta}(x)$, to parameterize $s_t^{\\theta}(x)$ via\n", + "$$\\tilde{s}_t^{\\theta}(x) = \\frac{u_t^{\\theta}(x) - a_tx}{b_t} = \\frac{\\alpha_t u_t^{\\theta}(x) - \\dot{\\alpha}_t x}{\\beta_t^2 \\dot{\\alpha}_t - \\alpha_t \\dot{\\beta}_t \\beta_t},$$\n", + "so long as $\\beta_t^2 \\dot{\\alpha}_t - \\alpha_t \\dot{\\beta}_t \\beta_t \\neq 0$ (which is true for $t \\in [0,1)$ by monotonicity). Here, we differentiate $\\tilde{s}_t^{\\theta}(x)$ paramterized via $u_t^{\\theta}(x)$ from $s_t^{\\theta}(x)$ learned indepedently using score matching. Plugging in $\\alpha_t = t$ and $\\beta_t = \\sqrt{1-t}$, we find that $$\\beta_t^2 \\dot{\\alpha}_t - \\alpha_t \\dot{\\beta}_t \\beta_t = \\begin{cases} 1 - \\frac{t}{2} & \\text{if}\\,\\,t\\in [0,1)\\\\0 & \\text{if}\\,\\,{t=1}. \\end{cases}.$$ In the following visualization, we'll circumvent the issue at $t=1.0$ by taking $t=1 - \\varepsilon$ in place of $t=1$, for small $\\varepsilon \\approx 0$." + ] + }, + { + "cell_type": "markdown", + "id": "3e1f0dd1-dab6-44bd-aeb1-e472a6294af4", + "metadata": {}, + "source": [ + "**Your job (1 point)**: Implement $\\tilde{s}_t^{\\theta}(x)$ by filling in the body of `ScoreFromVectorField.forward` below. The next several cells generate a visualization comparing the flow-parameterized score $\\tilde{s}_t^{\\theta}(x)$ to our independently learned score $s_t^{\\theta}(x)$. You can check that your implementation is correct by making sure that the visualizations match." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91efc10a-7ef7-4ec6-975f-8e52e0cac2a1", + "metadata": {}, + "outputs": [], + "source": [ + "class ScoreFromVectorField(torch.nn.Module):\n", + " \"\"\"\n", + " MLP-parameterization of the learned score field\n", + " \"\"\"\n", + " def __init__(self, vector_field: MLPVectorField, alpha: Alpha, beta: Beta):\n", + " super().__init__()\n", + " self.vector_field = vector_field\n", + " self.alpha = alpha\n", + " self.beta = beta\n", + "\n", + " def forward(self, x: torch.Tensor, t: torch.Tensor):\n", + " \"\"\"\n", + " Args:\n", + " - x: (bs, dim)\n", + " Returns:\n", + " - \\tilde{s}_t^theta(x): (bs, dim)\n", + " \"\"\"\n", + " alpha_t = self.alpha(t)\n", + " beta_t = self.beta(t)\n", + " dt_alpha_t = self.alpha.dt(t)\n", + " dt_beta_t = self.beta.dt(t)\n", + "\n", + " num = alpha_t * self.vector_field(x,t) - dt_alpha_t * x\n", + " den = beta_t ** 2 * dt_alpha_t - alpha_t * dt_beta_t * beta_t\n", + "\n", + " return num / den " + ] + }, + { + "cell_type": "markdown", + "id": "0aefa1ba-794f-47bb-8fa1-2f6a53b486b1", + "metadata": {}, + "source": [ + "Now, let's compare our learned marginal score $s_t^{\\theta}(x)$ (an instance of `MLPScore`) to our flow-parameterized score (an instance of `ScoreFromVectorField`). We'll do so by plotting the vector fields across time and space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2b99180-43d1-4ef9-bcf9-5e53ee0819b1", + "metadata": {}, + "outputs": [], + "source": [ + "#######################\n", + "# Change these values #\n", + "#######################\n", + "num_bins = 30\n", + "num_marginals = 4\n", + "\n", + "##############################\n", + "# Construct probability path #\n", + "##############################\n", + "path = GaussianConditionalProbabilityPath(\n", + " p_data = GaussianMixture.symmetric_2D(nmodes=5, std=PARAMS[\"target_std\"], scale=PARAMS[\"target_scale\"]).to(device), \n", + " alpha = LinearAlpha(),\n", + " beta = SquareRootBeta()\n", + ").to(device)\n", + "\n", + "#########################\n", + "# Define score networks #\n", + "#########################\n", + "learned_score_model = score_model\n", + "flow_score_model = ScoreFromVectorField(flow_model, path.alpha, path.beta)\n", + "\n", + "\n", + "###############################\n", + "# Plot score fields over time #\n", + "###############################\n", + "fig, axes = plt.subplots(2, num_marginals, figsize=(6 * num_marginals, 12))\n", + "axes = axes.reshape((2, num_marginals))\n", + "\n", + "scale = PARAMS[\"scale\"]\n", + "ts = torch.linspace(0.0, 0.9999, num_marginals).to(device)\n", + "xs = torch.linspace(-scale, scale, num_bins).to(device)\n", + "ys = torch.linspace(-scale, scale, num_bins).to(device)\n", + "xx, yy = torch.meshgrid(xs, ys)\n", + "xx = xx.reshape(-1,1)\n", + "yy = yy.reshape(-1,1)\n", + "xy = torch.cat([xx,yy], dim=-1)\n", + "\n", + "axes[0,0].set_ylabel(\"Learned with Score Matching\", fontsize=12)\n", + "axes[1,0].set_ylabel(\"Computed from $u_t^{{\\\\theta}}(x)$\", fontsize=12)\n", + "for idx in range(num_marginals):\n", + " t = ts[idx]\n", + " bs = num_bins ** 2\n", + " tt = t.view(1,1).expand(bs, 1)\n", + " \n", + " # Learned scores\n", + " learned_scores = learned_score_model(xy, tt)\n", + " learned_scores_x = learned_scores[:,0]\n", + " learned_scores_y = learned_scores[:,1]\n", + "\n", + " ax = axes[0, idx]\n", + " ax.quiver(xx.detach().cpu(), yy.detach().cpu(), learned_scores_x.detach().cpu(), learned_scores_y.detach().cpu(), scale=125, alpha=0.5)\n", + " imshow_density(density=path.p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + " imshow_density(density=path.p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + " ax.set_title(f'$s_{{t}}^{{\\\\theta}}$ at t={t.item():.2f}')\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + " \n", + "\n", + " # Flow score model\n", + " ax = axes\n", + " flow_scores = flow_score_model(xy,tt)\n", + " flow_scores_x = flow_scores[:,0]\n", + " flow_scores_y = flow_scores[:,1]\n", + "\n", + " ax = axes[1, idx]\n", + " ax.quiver(xx.detach().cpu(), yy.detach().cpu(), flow_scores_x.detach().cpu(), flow_scores_y.detach().cpu(), scale=125, alpha=0.5)\n", + " imshow_density(density=path.p_simple, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Reds'))\n", + " imshow_density(density=path.p_data, x_bounds=x_bounds, y_bounds=y_bounds, bins=200, ax=ax, vmin=-10, alpha=0.25, cmap=plt.get_cmap('Blues'))\n", + " ax.set_title(f'$\\\\tilde{{s}}_{{t}}^{{\\\\theta}}$ at t={t.item():.2f}')\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])" + ] + }, + { + "cell_type": "markdown", + "id": "b58fd169-6891-45ab-b917-14cc3912bec8", + "metadata": {}, + "source": [ + "# Part 4: Flow Matching Between Arbitrary Distributions with a Linear Probability Path\n", + "In this section, we will consider an alterntive conditional probability path - the **linear conditional probability path** - which can be constructed as follows. Given a source distribution $p_{\\text{simple}}$ and a data distribution $p_{\\text{data}}$, for a fixed $z$ we may consider the *interpolant* $$X_t = (1-t) X_0 + tz$$\n", + "where $X_0 \\sim p_{\\text{simple}}$ is a random variable. We may then define $p_t(x|z)$ so that $X_t \\sim p_t(x|z)$. Then it is apparent that $p_0(x|z) = p_{\\text{simple}}(x)$ and $p_1(x| z)= \\delta_z(x)$. It is also not difficult to show that the conditional vector field is given by $u_t^{\\text{ref}}(x) = \\frac{z - x}{1-t}$ for $t \\in [0,1)$. We make two observations about the linear conditional probability path: First, unlike in the Gaussian probability path, we do not have a closed form for the conditional score $\\nabla \\log p_t(x|z)$. Second, there is no constraint that $p_{\\text{simple}}$ be a Gaussian, which we will exploit in Problem 4.3 to construct flows between arbitrary choices of $p_{\\text{simple}}$ and $p_{\\text{data}}$. First, let's examine some more complicated choices of $p_{\\text{data}}$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c24d596d-e307-4344-b683-aa7b55adf0c0", + "metadata": {}, + "outputs": [], + "source": [ + "class CirclesSampleable(Sampleable):\n", + " \"\"\"\n", + " Implementation of concentric circle distribution using sklearn's make_circles\n", + " \"\"\"\n", + " def __init__(self, device: torch.device, noise: float = 0.05, scale=5.0, offset: Optional[torch.Tensor] = None):\n", + " \"\"\"\n", + " Args:\n", + " noise: standard deviation of Gaussian noise added to the data\n", + " \"\"\"\n", + " self.noise = noise\n", + " self.scale = scale\n", + " self.device = device\n", + " if offset is None:\n", + " offset = torch.zeros(2)\n", + " self.offset = offset.to(device)\n", + "\n", + " @property\n", + " def dim(self) -> int:\n", + " return 2\n", + "\n", + " def sample(self, num_samples: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " num_samples: number of samples to generate\n", + " Returns:\n", + " torch.Tensor: shape (num_samples, 3)\n", + " \"\"\"\n", + " samples, _ = make_circles(\n", + " n_samples=num_samples,\n", + " noise=self.noise,\n", + " factor=0.5,\n", + " random_state=None\n", + " )\n", + " return self.scale * torch.from_numpy(samples.astype(np.float32)).to(self.device) + self.offset\n", + "\n", + "class CheckerboardSampleable(Sampleable):\n", + " \"\"\"\n", + " Checkboard-esque distribution\n", + " \"\"\"\n", + " def __init__(self, device: torch.device, grid_size: int = 3, scale=5.0):\n", + " \"\"\"\n", + " Args:\n", + " noise: standard deviation of Gaussian noise added to the data\n", + " \"\"\"\n", + " self.grid_size = grid_size\n", + " self.scale = scale\n", + " self.device = device\n", + "\n", + " @property\n", + " def dim(self) -> int:\n", + " return 2\n", + "\n", + " def sample(self, num_samples: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " num_samples: number of samples to generate\n", + " Returns:\n", + " torch.Tensor: shape (num_samples, 3)\n", + " \"\"\"\n", + " grid_length = 2 * self.scale / self.grid_size\n", + " samples = torch.zeros(0,2).to(device)\n", + " while samples.shape[0] < num_samples:\n", + " # Sample num_samples\n", + " new_samples = (torch.rand(num_samples,2).to(self.device) - 0.5) * 2 * self.scale\n", + " x_mask = torch.floor((new_samples[:,0] + self.scale) / grid_length) % 2 == 0 # (bs,)\n", + " y_mask = torch.floor((new_samples[:,1] + self.scale) / grid_length) % 2 == 0 # (bs,)\n", + " accept_mask = torch.logical_xor(~x_mask, y_mask)\n", + " samples = torch.cat([samples, new_samples[accept_mask]], dim=0)\n", + " return samples[:num_samples]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3c44f89-ddbc-45c4-99fd-d1c8ea5027a3", + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize alternative choices of p_data\n", + "targets = {\n", + " \"circles\": CirclesSampleable(device),\n", + " \"moons\": MoonsSampleable(device, scale=3.5),\n", + " \"checkerboard\": CheckerboardSampleable(device, grid_size=4)\n", + "}\n", + "\n", + "###################################\n", + "# Graph Various Choices of p_data #\n", + "###################################\n", + "\n", + "fig, axes = plt.subplots(1, len(targets), figsize=(6 * len(targets), 6))\n", + "\n", + "num_samples = 20000\n", + "num_bins = 100\n", + "for idx, (target_name, target) in enumerate(targets.items()):\n", + " ax = axes[idx]\n", + " hist2d_sampleable(target, num_samples, bins=bins, scale=7.5, ax=ax)\n", + " ax.set_aspect('equal')\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + " ax.set_title(f'Histogram of {target_name}')\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "5e14a9a6-a5c7-40df-9acc-39af296b22e2", + "metadata": {}, + "source": [ + "### Problem 4.1: Linear Probability Paths\n", + "Below we define the `LinearConditionalProbabilityPath`. We purposely omit the implementation of `conditional_score` because, as mentioned earlier, there is no nice form for it!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d0aa7a5-49b2-48af-a2c3-552c00c3322d", + "metadata": {}, + "outputs": [], + "source": [ + "class LinearConditionalProbabilityPath(ConditionalProbabilityPath):\n", + " def __init__(self, p_simple: Sampleable, p_data: Sampleable):\n", + " super().__init__(p_simple, p_data)\n", + "\n", + " def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:\n", + " \"\"\"\n", + " Samples the conditioning variable z ~ p_data(x)\n", + " Args:\n", + " - num_samples: the number of samples\n", + " Returns:\n", + " - z: samples from p(z), (num_samples, ...)\n", + " \"\"\"\n", + " return self.p_data.sample(num_samples)\n", + " \n", + " def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Samples the random variable X_t = (1-t) X_0 + tz\n", + " Args:\n", + " - z: conditioning variable (num_samples, dim)\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - x: samples from p_t(x|z), (num_samples, dim)\n", + " \"\"\"\n", + " x0 = self.p_simple.sample(z.shape[0])\n", + " return (1 - t) * x0 + t * z\n", + " \n", + " def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Evaluates the conditional vector field u_t(x|z) = (z - x) / (1 - t)\n", + " Note: Only defined on t in [0,1)\n", + " Args:\n", + " - x: position variable (num_samples, dim)\n", + " - z: conditioning variable (num_samples, dim)\n", + " - t: time (num_samples, 1)\n", + " Returns:\n", + " - conditional_vector_field: conditional vector field (num_samples, dim)\n", + " \"\"\" \n", + " return (z - x) / (1 - t)\n", + "\n", + " def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Not known for Linear Conditional Probability Paths\n", + " \"\"\" \n", + " raise Exception(\"You should not be calling this function!\")" + ] + }, + { + "cell_type": "markdown", + "id": "9fea4cf9-4365-4b45-a3c3-be267c162ea9", + "metadata": {}, + "source": [ + "**Your work (2 points)**: Implement `LinearConditionalProbabilityPath.sample_conditional_path` and `LinearConditionalProbabilityPath.conditional_vector_field`." + ] + }, + { + "cell_type": "markdown", + "id": "9f357d6e-577e-47d9-96ea-6d13f3bfa795", + "metadata": {}, + "source": [ + "You can sanity check that the implementations are correct by ensuring that they are consistent with one another. The following visualization provides three sequences of graphs:\n", + "1. The first row shows the conditional probability path, as produced by your implemententation of `sample_conditional_path`.\n", + "2. The second row shows the conditional probability path, as produced by your implemententation of `conditional_vector_field`.\n", + "3. The third row shows the marginal probability path, as produced by `sample_marginal_path`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9b3d914-8508-41c9-8b1f-1a789e919711", + "metadata": {}, + "outputs": [], + "source": [ + "##########################\n", + "# Play around with these #\n", + "##########################\n", + "num_samples = 100000\n", + "num_timesteps = 500\n", + "num_marginals = 5\n", + "assert num_timesteps % (num_marginals - 1) == 0\n", + "\n", + "##########################################\n", + "# Construct conditional probability path #\n", + "##########################################\n", + "path = LinearConditionalProbabilityPath(\n", + " p_simple = Gaussian.isotropic(dim=2, std=1.0),\n", + " p_data = CheckerboardSampleable(device, grid_size=4)\n", + ").to(device)\n", + "z = path.p_data.sample(1) # (1,2)\n", + "\n", + "##############\n", + "# Setup plots #\n", + "##############\n", + "\n", + "fig, axes = plt.subplots(3, num_marginals, figsize=(6 * num_marginals, 6 * 3))\n", + "axes = axes.reshape(3, num_marginals)\n", + "scale = 6.0\n", + "\n", + "\n", + "#####################################################################\n", + "# Graph conditional probability paths using sample_conditional_path #\n", + "#####################################################################\n", + "ts = torch.linspace(0.0, 1.0, num_marginals).to(device)\n", + "for idx, t in enumerate(ts):\n", + " zz = z.expand(num_samples, -1)\n", + " tt = t.view(1,1).expand(num_samples,1)\n", + " xts = path.sample_conditional_path(zz, tt)\n", + " percentile = min(99 + 2 * torch.sin(t).item(), 100)\n", + " hist2d_samples(samples=xts.cpu(), ax=axes[0, idx], bins=300, scale=scale, percentile=percentile, alpha=1.0)\n", + " axes[0, idx].set_xlim(-scale, scale)\n", + " axes[0, idx].set_ylim(-scale, scale)\n", + " axes[0, idx].set_xticks([])\n", + " axes[0, idx].set_yticks([])\n", + " axes[0, idx].set_title(f'$t={t.item():.2f}$', fontsize=15)\n", + "axes[0, 0].set_ylabel(\"Conditional\", fontsize=20)\n", + "\n", + "# Plot z\n", + "axes[0,-1].scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20)\n", + "axes[0,-1].legend()\n", + "\n", + "######################################################################\n", + "# Graph conditional probability paths using conditional_vector_field #\n", + "######################################################################\n", + "ode = ConditionalVectorFieldODE(path, z)\n", + "simulator = EulerSimulator(ode)\n", + "ts = torch.linspace(0,1,num_timesteps).to(device)\n", + "record_every_idxs = record_every(len(ts), len(ts) // (num_marginals - 1))\n", + "x0 = path.p_simple.sample(num_samples)\n", + "xts = simulator.simulate_with_trajectory(x0, ts.view(1,-1,1).expand(num_samples,-1,1))\n", + "xts = xts[:,record_every_idxs,:]\n", + "for idx in range(xts.shape[1]):\n", + " xx = xts[:,idx,:]\n", + " tt = ts[record_every_idxs[idx]]\n", + " percentile = min(99 + 2 * torch.sin(tt).item(), 100)\n", + " hist2d_samples(samples=xx.cpu(), ax=axes[1, idx], bins=300, scale=scale, percentile=percentile, alpha=1.0)\n", + " axes[1, idx].set_xlim(-scale, scale)\n", + " axes[1, idx].set_ylim(-scale, scale)\n", + " axes[1, idx].set_xticks([])\n", + " axes[1, idx].set_yticks([])\n", + " axes[1, idx].set_title(f'$t={tt.item():.2f}$', fontsize=15)\n", + "axes[1, 0].set_ylabel(\"Conditional\", fontsize=20)\n", + "\n", + "# Plot z\n", + "axes[1,-1].scatter(z[:,0].cpu(), z[:,1].cpu(), marker='*', color='red', s=200, label='z',zorder=20)\n", + "axes[1,-1].legend()\n", + "\n", + "##################################################################\n", + "# Graph conditional probability paths using sample_marginal_path #\n", + "##################################################################\n", + "ts = torch.linspace(0.0, 1.0, num_marginals).to(device)\n", + "for idx, t in enumerate(ts):\n", + " zz = z.expand(num_samples, -1)\n", + " tt = t.view(1,1).expand(num_samples,1)\n", + " xts = path.sample_marginal_path(tt)\n", + " hist2d_samples(samples=xts.cpu(), ax=axes[2, idx], bins=300, scale=scale, percentile=99, alpha=1.0)\n", + " axes[2, idx].set_xlim(-scale, scale)\n", + " axes[2, idx].set_ylim(-scale, scale)\n", + " axes[2, idx].set_xticks([])\n", + " axes[2, idx].set_yticks([])\n", + " axes[2, idx].set_title(f'$t={t.item():.2f}$', fontsize=15)\n", + "axes[2, 0].set_ylabel(\"Marginal\", fontsize=20)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "aaf5f81d-bd38-4222-a14e-553a31afa964", + "metadata": {}, + "source": [ + "### Part 4.2: Flow Matching with Linear Probability Paths\n", + "Now, let's train a flow matching model using the linear conditional probability path! **Remember, the loss should converge, but not necessarily to zero!**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc9dfb02-8c65-436d-ad70-a6fee48ba035", + "metadata": {}, + "outputs": [], + "source": [ + "# Construct conditional probability path\n", + "path = LinearConditionalProbabilityPath(\n", + " p_simple = Gaussian.isotropic(dim=2, std=1.0),\n", + " p_data = CheckerboardSampleable(device, grid_size=4)\n", + ").to(device)\n", + "\n", + "# Construct learnable vector field\n", + "linear_flow_model = MLPVectorField(dim=2, hiddens=[64,64,64,64])\n", + "\n", + "# Construct trainer\n", + "trainer = ConditionalFlowMatchingTrainer(path, linear_flow_model)\n", + "losses = trainer.train(num_epochs=10000, device=device, lr=1e-3, batch_size=2000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ebeb708-9829-4276-afe9-db7a930e741a", + "metadata": {}, + "outputs": [], + "source": [ + "##########################\n", + "# Play around With These #\n", + "##########################\n", + "num_samples = 50000\n", + "num_marginals = 5\n", + "\n", + "##############\n", + "# Setup Plots #\n", + "##############\n", + "\n", + "fig, axes = plt.subplots(2, num_marginals, figsize=(6 * num_marginals, 6 * 2))\n", + "axes = axes.reshape(2, num_marginals)\n", + "scale = 6.0\n", + "\n", + "###########################\n", + "# Graph Ground-Truth Marginals #\n", + "###########################\n", + "ts = torch.linspace(0.0, 1.0, num_marginals).to(device)\n", + "for idx, t in enumerate(ts):\n", + " tt = t.view(1,1).expand(num_samples,1)\n", + " xts = path.sample_marginal_path(tt)\n", + " hist2d_samples(samples=xts.cpu(), ax=axes[0, idx], bins=200, scale=scale, percentile=99, alpha=1.0)\n", + " axes[0, idx].set_xlim(-scale, scale)\n", + " axes[0, idx].set_ylim(-scale, scale)\n", + " axes[0, idx].set_xticks([])\n", + " axes[0, idx].set_yticks([])\n", + " axes[0, idx].set_title(f'$t={t.item():.2f}$', fontsize=15)\n", + "axes[0, 0].set_ylabel(\"Ground Truth\", fontsize=20)\n", + "\n", + "###############################################\n", + "# Graph Marginals of Learned Vector Field #\n", + "###############################################\n", + "ode = LearnedVectorFieldODE(linear_flow_model)\n", + "simulator = EulerSimulator(ode)\n", + "ts = torch.linspace(0,1,100).to(device)\n", + "record_every_idxs = record_every(len(ts), len(ts) // (num_marginals - 1))\n", + "x0 = path.p_simple.sample(num_samples)\n", + "xts = simulator.simulate_with_trajectory(x0, ts.view(1,-1,1).expand(num_samples,-1,1))\n", + "xts = xts[:,record_every_idxs,:]\n", + "for idx in range(xts.shape[1]):\n", + " xx = xts[:,idx,:]\n", + " hist2d_samples(samples=xx.cpu(), ax=axes[1, idx], bins=200, scale=scale, percentile=99, alpha=1.0)\n", + " axes[1, idx].set_xlim(-scale, scale)\n", + " axes[1, idx].set_ylim(-scale, scale)\n", + " axes[1, idx].set_xticks([])\n", + " axes[1, idx].set_yticks([])\n", + " tt = ts[record_every_idxs[idx]]\n", + " axes[1, idx].set_title(f'$t={tt.item():.2f}$', fontsize=15)\n", + "axes[1, 0].set_ylabel(\"Learned\", fontsize=20) \n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "98e4271f-495d-4164-8183-8619c08af3d7", + "metadata": {}, + "source": [ + "### Problem 4.3: Bridging Between Arbitrary Source and Target\n", + "Notice that in our construction of the linear probability path, there is no need for $p_{\\text{simple}}$ to be a Gaussian. Let's try setting it to another distribution!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccec7698-4263-4c2c-a69e-e4a0b7a38724", + "metadata": {}, + "outputs": [], + "source": [ + "# Construct conditional probability path\n", + "path = LinearConditionalProbabilityPath(\n", + " p_simple = CirclesSampleable(device),\n", + " p_data = CheckerboardSampleable(device, grid_size=4)\n", + ").to(device)\n", + "\n", + "# Construct learnable vector field\n", + "bridging_flow_model = MLPVectorField(dim=2, hiddens=[100,100,100,100])\n", + "\n", + "# Construct trainer\n", + "trainer = ConditionalFlowMatchingTrainer(path, bridging_flow_model)\n", + "losses = trainer.train(num_epochs=20000, device=device, lr=1e-3, batch_size=2000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e7b2d3f-f647-42a8-a8a1-d7c183efb90e", + "metadata": {}, + "outputs": [], + "source": [ + "##########################\n", + "# Play around With These #\n", + "##########################\n", + "num_samples = 30000\n", + "num_marginals = 5\n", + "\n", + "##########################################\n", + "# Construct Conditional Probability Path #\n", + "##########################################\n", + "\n", + "path = LinearConditionalProbabilityPath(\n", + " p_simple = CirclesSampleable(device),\n", + " p_data = CheckerboardSampleable(device, grid_size=4)\n", + ").to(device)\n", + "\n", + "\n", + "##############\n", + "# Setup Plots #\n", + "##############\n", + "\n", + "fig, axes = plt.subplots(2, num_marginals, figsize=(6 * num_marginals, 6 * 2))\n", + "axes = axes.reshape(2, num_marginals)\n", + "scale = 6.0\n", + "\n", + "\n", + "###########################\n", + "# Graph Ground-Truth Marginals #\n", + "###########################\n", + "ts = torch.linspace(0.0, 1.0, num_marginals).to(device)\n", + "for idx, t in enumerate(ts):\n", + " tt = t.view(1,1).expand(num_samples,1)\n", + " xts = path.sample_marginal_path(tt)\n", + " hist2d_samples(samples=xts.cpu(), ax=axes[0, idx], bins=200, scale=scale, percentile=99, alpha=1.0)\n", + " axes[0, idx].set_xlim(-scale, scale)\n", + " axes[0, idx].set_ylim(-scale, scale)\n", + " axes[0, idx].set_xticks([])\n", + " axes[0, idx].set_yticks([])\n", + " axes[0, idx].set_title(f'$t={t.item():.2f}$', fontsize=15)\n", + "axes[0, 0].set_ylabel(\"Ground Truth\", fontsize=20)\n", + "\n", + "###############################################\n", + "# Graph Learned Marginals #\n", + "###############################################\n", + "ode = LearnedVectorFieldODE(bridging_flow_model)\n", + "simulator = EulerSimulator(ode)\n", + "ts = torch.linspace(0,1,200).to(device)\n", + "record_every_idxs = record_every(len(ts), len(ts) // (num_marginals - 1))\n", + "x0 = path.p_simple.sample(num_samples)\n", + "xts = simulator.simulate_with_trajectory(x0, ts.view(1,-1,1).expand(num_samples,-1,1))\n", + "xts = xts[:,record_every_idxs,:]\n", + "for idx in range(xts.shape[1]):\n", + " xx = xts[:,idx,:]\n", + " hist2d_samples(samples=xx.cpu(), ax=axes[1, idx], bins=200, scale=scale, percentile=99, alpha=1.0)\n", + " axes[1, idx].set_xlim(-scale, scale)\n", + " axes[1, idx].set_ylim(-scale, scale)\n", + " axes[1, idx].set_xticks([])\n", + " axes[1, idx].set_yticks([])\n", + " tt = ts[record_every_idxs[idx]]\n", + " axes[1, idx].set_title(f'$t={tt.item():.2f}$', fontsize=15)\n", + "axes[1, 0].set_ylabel(\"Learned\", fontsize=20)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "3d7bb3ca-939b-4a62-8870-ba9b2e103ad1", + "metadata": {}, + "source": [ + "**Your job (1 point)**: Play around with the choice of $p_{\\text{simple}}$ and $p_{\\text{data}}$. Any observations?" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mtds", + "language": "python", + "name": "mtds" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.20" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}