Skip to content

Commit

Permalink
Update natural Euler sampler example
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkMnDragon committed Dec 18, 2024
1 parent 4f7d8a6 commit 9b4aa4b
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 5 deletions.
10 changes: 5 additions & 5 deletions examples/interpolation_conversion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 1,
"metadata": {
"id": "o4HSmy1HV9_G"
},
Expand Down Expand Up @@ -490,7 +490,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 7,
"metadata": {
"id": "XFT8hdWZKKsI"
},
Expand Down Expand Up @@ -530,7 +530,7 @@
"source": [
"# Try different num_steps, e.g. [5, 10, 50, 100, 500]\n",
"num_samples = 500\n",
"num_steps = 200\n",
"num_steps = 10\n",
"\n",
"euler_sampler_straight = EulerSampler(straight_rf, num_steps=num_steps)\n",
"euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)\n",
Expand All @@ -549,7 +549,7 @@
" },\n",
"\tD1_gt_samples=D1[:5000],\n",
"\tnum_trajectories=100,\n",
"\ttitle=f\"Convert Straight to Spherical RF v.s. Original Straight RF\",\n",
"\ttitle=f\"Straight Converted to Spherical v.s. Original Straight RF\",\n",
")"
]
},
Expand Down Expand Up @@ -676,7 +676,7 @@
" },\n",
"\tD1_gt_samples=D1[:5000],\n",
"\tnum_trajectories=100,\n",
"\ttitle=f\"Re-parametrized Straight RF v.s. Converted Straight to Spherical RF\",\n",
"\ttitle=f\"Re-parametrized Straight RF v.s. Converted Straight to Spherical RF\",\n",
")"
]
}
Expand Down
231 changes: 231 additions & 0 deletions examples/natural_euler_sampler.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import os\n",
"import sys\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import warnings\n",
"import copy\n",
"import plotly.graph_objects as go\n",
"\n",
"import torch.distributions as dist\n",
"\n",
"from rectified_flow.utils import set_seed\n",
"from rectified_flow.utils import match_dim_with_data\n",
"from rectified_flow.datasets.toy_gmm import TwoPointGMM\n",
"\n",
"from rectified_flow.rectified_flow import RectifiedFlow\n",
"from rectified_flow.models.toy_mlp import MLPVelocityConditioned, MLPVelocity\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rectified_flow.datasets.toy_gmm import TwoPointGMM\n",
"\n",
"set_seed(0)\n",
"n_samples = 50000\n",
"pi_0 = dist.MultivariateNormal(torch.zeros(2, device=device), torch.eye(2, device=device))\n",
"pi_1 = TwoPointGMM(x=15.0, y=2, std=0.3)\n",
"D0 = pi_0.sample([n_samples])\n",
"D1, labels = pi_1.sample_with_labels([n_samples])\n",
"labels.tolist()\n",
"\n",
"from rectified_flow.flow_components.interpolation_solver import AffineInterp\n",
"from rectified_flow.utils import visualize_2d_trajectories_plotly\n",
"\n",
"straight_interp = AffineInterp(\"straight\")\n",
"spherical_interp = AffineInterp(\"spherical\")\n",
"\n",
"idx = torch.randperm(n_samples)[:1000]\n",
"x_0 = D0[idx]\n",
"x_1 = D1[idx]\n",
"\n",
"print(x_0.shape)\n",
"\n",
"straight_interp_list = []\n",
"spherical_interp_list = []\n",
"\n",
"for t in np.linspace(0, 1, 50):\n",
"\tx_t_straight, dot_x_t_straight = straight_interp.forward(x_0, x_1, t)\n",
"\tx_t_spherical, dot_x_t_spherical = spherical_interp.forward(x_0, x_1, t)\n",
"\tstraight_interp_list.append(x_t_straight)\n",
"\tspherical_interp_list.append(x_t_spherical)\n",
"\n",
"visualize_2d_trajectories_plotly(\n",
"\ttrajectories_dict={\"straight interp\": straight_interp_list, \"spherical interp\": spherical_interp_list},\n",
"\tD1_gt_samples=D1[:5000],\n",
"\tnum_trajectories=50,\n",
"\ttitle=\"Interpolated Trajectories Visualization\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from rectified_flow.flow_components.interpolation_convertor import AffineInterpConverter\n",
"\n",
"def rf_trainer(rectified_flow, label = \"loss\", batch_size = 1024):\n",
" model = rectified_flow.velocity_field\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
"\n",
" losses = []\n",
" for step in range(5000):\n",
" optimizer.zero_grad()\n",
" x_0 = pi_0.sample([batch_size]).to(device)\n",
" x_1 = pi_1.sample([batch_size]).to(device)\n",
"\n",
" loss = rectified_flow.get_loss(x_0, x_1)\n",
" loss.backward()\n",
" optimizer.step()\n",
" losses.append(loss.item())\n",
"\n",
" if step % 1000 == 0:\n",
" print(f\"Epoch {step}, Loss: {loss.item()}\")\n",
"\n",
" plt.plot(losses, label=label)\n",
" plt.legend()\n",
"\n",
"from rectified_flow.models.toy_mlp import MLPVelocity\n",
"\n",
"set_seed(0)\n",
"straight_rf = RectifiedFlow(\n",
" data_shape=(2,),\n",
" velocity_field=MLPVelocity(2, hidden_sizes = [128, 128, 128]).to(device),\n",
" interp=straight_interp,\n",
" source_distribution=pi_0,\n",
" device=device,\n",
")\n",
"\n",
"set_seed(0)\n",
"rf_trainer(rectified_flow=straight_rf, label=\"straight interp\")\n",
"\n",
"spherical_rf = AffineInterpConverter(straight_rf, AffineInterp(\"spherical\")).transform_rectified_flow()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Both vanilla Euler sampler, huge difference in the trajectories\n",
"\n",
"from rectified_flow.samplers import EulerSampler\n",
"\n",
"num_samples = 300\n",
"num_steps = 10\n",
"\n",
"euler_sampler_straight = EulerSampler(straight_rf, num_steps=num_steps)\n",
"euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)\n",
"\n",
"euler_sampler_spherical = EulerSampler(spherical_rf, num_steps=num_steps)\n",
"euler_sampler_spherical.sample_loop(seed=0, num_samples=num_samples)\n",
"\n",
"visualize_2d_trajectories_plotly(\n",
" trajectories_dict={\n",
" \"straight rf\": euler_sampler_straight.trajectories,\n",
" \"spherical rf\": euler_sampler_spherical.trajectories,\n",
"\t},\n",
" D1_gt_samples=D1[:num_samples*3],\n",
" num_trajectories=50,\n",
" title=\"Euler Sampler, straight rf vs spherical rf\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Unmatched time grid, both natural euler samplers, nearly identical\n",
"\n",
"from rectified_flow.samplers import CurvedEulerSampler\n",
"\n",
"num_samples = 300\n",
"num_steps = 10\n",
"\n",
"natural_euler_sampler_straight = CurvedEulerSampler(straight_rf, num_steps=num_steps)\n",
"natural_euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)\n",
"\n",
"natural_euler_sampler_spherical = CurvedEulerSampler(spherical_rf, num_steps=num_steps)\n",
"natural_euler_sampler_spherical.sample_loop(seed=0, num_samples=num_samples)\n",
"\n",
"visualize_2d_trajectories_plotly(\n",
" trajectories_dict={\n",
" \"straight rf\": natural_euler_sampler_straight.trajectories,\n",
" \"spherical rf\": natural_euler_sampler_spherical.trajectories,\n",
"\t},\n",
" D1_gt_samples=D1[:num_samples*3],\n",
" num_trajectories=50,\n",
" title=\"Natural Euler Sampler, straight rf vs spherical rf, unmatched time gird\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Matched time grid, both natural euler samplers, exactly the same\n",
"\n",
"def convert_time(t, affine_interp):\n",
" return [affine_interp.alpha(t) / (affine_interp.alpha(t) + affine_interp.beta(t)) for t in t]\n",
"\n",
"converted_time_grid = convert_time(natural_euler_sampler_spherical.time_grid, spherical_interp)\n",
"print(converted_time_grid)\n",
"\n",
"natural_euler_sampler_straight = CurvedEulerSampler(straight_rf, time_grid=converted_time_grid)\n",
"natural_euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)\n",
"\n",
"visualize_2d_trajectories_plotly(\n",
" trajectories_dict={\n",
" \"straight rf\": natural_euler_sampler_straight.trajectories,\n",
" \"spherical rf\": natural_euler_sampler_spherical.trajectories,\n",
"\t},\n",
" D1_gt_samples=D1[:num_samples*3],\n",
" num_trajectories=50,\n",
" title=\"Natural Euler Sampler, straight rf vs spherical rf, matched time gird\",\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "learning",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 9b4aa4b

Please sign in to comment.