-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update natural Euler sampler example
- Loading branch information
1 parent
4f7d8a6
commit 9b4aa4b
Showing
2 changed files
with
236 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import os\n", | ||
"import sys\n", | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import warnings\n", | ||
"import copy\n", | ||
"import plotly.graph_objects as go\n", | ||
"\n", | ||
"import torch.distributions as dist\n", | ||
"\n", | ||
"from rectified_flow.utils import set_seed\n", | ||
"from rectified_flow.utils import match_dim_with_data\n", | ||
"from rectified_flow.datasets.toy_gmm import TwoPointGMM\n", | ||
"\n", | ||
"from rectified_flow.rectified_flow import RectifiedFlow\n", | ||
"from rectified_flow.models.toy_mlp import MLPVelocityConditioned, MLPVelocity\n", | ||
"\n", | ||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from rectified_flow.datasets.toy_gmm import TwoPointGMM\n", | ||
"\n", | ||
"set_seed(0)\n", | ||
"n_samples = 50000\n", | ||
"pi_0 = dist.MultivariateNormal(torch.zeros(2, device=device), torch.eye(2, device=device))\n", | ||
"pi_1 = TwoPointGMM(x=15.0, y=2, std=0.3)\n", | ||
"D0 = pi_0.sample([n_samples])\n", | ||
"D1, labels = pi_1.sample_with_labels([n_samples])\n", | ||
"labels.tolist()\n", | ||
"\n", | ||
"from rectified_flow.flow_components.interpolation_solver import AffineInterp\n", | ||
"from rectified_flow.utils import visualize_2d_trajectories_plotly\n", | ||
"\n", | ||
"straight_interp = AffineInterp(\"straight\")\n", | ||
"spherical_interp = AffineInterp(\"spherical\")\n", | ||
"\n", | ||
"idx = torch.randperm(n_samples)[:1000]\n", | ||
"x_0 = D0[idx]\n", | ||
"x_1 = D1[idx]\n", | ||
"\n", | ||
"print(x_0.shape)\n", | ||
"\n", | ||
"straight_interp_list = []\n", | ||
"spherical_interp_list = []\n", | ||
"\n", | ||
"for t in np.linspace(0, 1, 50):\n", | ||
"\tx_t_straight, dot_x_t_straight = straight_interp.forward(x_0, x_1, t)\n", | ||
"\tx_t_spherical, dot_x_t_spherical = spherical_interp.forward(x_0, x_1, t)\n", | ||
"\tstraight_interp_list.append(x_t_straight)\n", | ||
"\tspherical_interp_list.append(x_t_spherical)\n", | ||
"\n", | ||
"visualize_2d_trajectories_plotly(\n", | ||
"\ttrajectories_dict={\"straight interp\": straight_interp_list, \"spherical interp\": spherical_interp_list},\n", | ||
"\tD1_gt_samples=D1[:5000],\n", | ||
"\tnum_trajectories=50,\n", | ||
"\ttitle=\"Interpolated Trajectories Visualization\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from rectified_flow.flow_components.interpolation_convertor import AffineInterpConverter\n", | ||
"\n", | ||
"def rf_trainer(rectified_flow, label = \"loss\", batch_size = 1024):\n", | ||
" model = rectified_flow.velocity_field\n", | ||
" optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n", | ||
"\n", | ||
" losses = []\n", | ||
" for step in range(5000):\n", | ||
" optimizer.zero_grad()\n", | ||
" x_0 = pi_0.sample([batch_size]).to(device)\n", | ||
" x_1 = pi_1.sample([batch_size]).to(device)\n", | ||
"\n", | ||
" loss = rectified_flow.get_loss(x_0, x_1)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
" losses.append(loss.item())\n", | ||
"\n", | ||
" if step % 1000 == 0:\n", | ||
" print(f\"Epoch {step}, Loss: {loss.item()}\")\n", | ||
"\n", | ||
" plt.plot(losses, label=label)\n", | ||
" plt.legend()\n", | ||
"\n", | ||
"from rectified_flow.models.toy_mlp import MLPVelocity\n", | ||
"\n", | ||
"set_seed(0)\n", | ||
"straight_rf = RectifiedFlow(\n", | ||
" data_shape=(2,),\n", | ||
" velocity_field=MLPVelocity(2, hidden_sizes = [128, 128, 128]).to(device),\n", | ||
" interp=straight_interp,\n", | ||
" source_distribution=pi_0,\n", | ||
" device=device,\n", | ||
")\n", | ||
"\n", | ||
"set_seed(0)\n", | ||
"rf_trainer(rectified_flow=straight_rf, label=\"straight interp\")\n", | ||
"\n", | ||
"spherical_rf = AffineInterpConverter(straight_rf, AffineInterp(\"spherical\")).transform_rectified_flow()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Both vanilla Euler sampler, huge difference in the trajectories\n", | ||
"\n", | ||
"from rectified_flow.samplers import EulerSampler\n", | ||
"\n", | ||
"num_samples = 300\n", | ||
"num_steps = 10\n", | ||
"\n", | ||
"euler_sampler_straight = EulerSampler(straight_rf, num_steps=num_steps)\n", | ||
"euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)\n", | ||
"\n", | ||
"euler_sampler_spherical = EulerSampler(spherical_rf, num_steps=num_steps)\n", | ||
"euler_sampler_spherical.sample_loop(seed=0, num_samples=num_samples)\n", | ||
"\n", | ||
"visualize_2d_trajectories_plotly(\n", | ||
" trajectories_dict={\n", | ||
" \"straight rf\": euler_sampler_straight.trajectories,\n", | ||
" \"spherical rf\": euler_sampler_spherical.trajectories,\n", | ||
"\t},\n", | ||
" D1_gt_samples=D1[:num_samples*3],\n", | ||
" num_trajectories=50,\n", | ||
" title=\"Euler Sampler, straight rf vs spherical rf\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Unmatched time grid, both natural euler samplers, nearly identical\n", | ||
"\n", | ||
"from rectified_flow.samplers import CurvedEulerSampler\n", | ||
"\n", | ||
"num_samples = 300\n", | ||
"num_steps = 10\n", | ||
"\n", | ||
"natural_euler_sampler_straight = CurvedEulerSampler(straight_rf, num_steps=num_steps)\n", | ||
"natural_euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)\n", | ||
"\n", | ||
"natural_euler_sampler_spherical = CurvedEulerSampler(spherical_rf, num_steps=num_steps)\n", | ||
"natural_euler_sampler_spherical.sample_loop(seed=0, num_samples=num_samples)\n", | ||
"\n", | ||
"visualize_2d_trajectories_plotly(\n", | ||
" trajectories_dict={\n", | ||
" \"straight rf\": natural_euler_sampler_straight.trajectories,\n", | ||
" \"spherical rf\": natural_euler_sampler_spherical.trajectories,\n", | ||
"\t},\n", | ||
" D1_gt_samples=D1[:num_samples*3],\n", | ||
" num_trajectories=50,\n", | ||
" title=\"Natural Euler Sampler, straight rf vs spherical rf, unmatched time gird\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Matched time grid, both natural euler samplers, exactly the same\n", | ||
"\n", | ||
"def convert_time(t, affine_interp):\n", | ||
" return [affine_interp.alpha(t) / (affine_interp.alpha(t) + affine_interp.beta(t)) for t in t]\n", | ||
"\n", | ||
"converted_time_grid = convert_time(natural_euler_sampler_spherical.time_grid, spherical_interp)\n", | ||
"print(converted_time_grid)\n", | ||
"\n", | ||
"natural_euler_sampler_straight = CurvedEulerSampler(straight_rf, time_grid=converted_time_grid)\n", | ||
"natural_euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)\n", | ||
"\n", | ||
"visualize_2d_trajectories_plotly(\n", | ||
" trajectories_dict={\n", | ||
" \"straight rf\": natural_euler_sampler_straight.trajectories,\n", | ||
" \"spherical rf\": natural_euler_sampler_spherical.trajectories,\n", | ||
"\t},\n", | ||
" D1_gt_samples=D1[:num_samples*3],\n", | ||
" num_trajectories=50,\n", | ||
" title=\"Natural Euler Sampler, straight rf vs spherical rf, matched time gird\",\n", | ||
")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "learning", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |