Skip to content

Commit

Permalink
Add Colab links
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkMnDragon committed Dec 27, 2024
1 parent 38ff912 commit 096f3ae
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 20 deletions.
17 changes: 17 additions & 0 deletions examples/editing_flux_dev.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/editing_flux_dev.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git clone https://github.com/lqiang67/rectified-flow.git\n",
"%cd rectified-flow/"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down
17 changes: 17 additions & 0 deletions examples/inference_flux_dev.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/inference_flux_dev.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git clone https://github.com/lqiang67/rectified-flow.git\n",
"%cd rectified-flow/"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
27 changes: 22 additions & 5 deletions examples/interpolation_conversion.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/interpolation_conversion.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git clone https://github.com/lqiang67/rectified-flow.git\n",
"%cd rectified-flow/"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -295,7 +312,7 @@
"\t\t\"straight to spherical interp\": straight_to_spherical_interp_list,\n",
"\t\t\"spherical to straight interp\": spherical_to_straight_interp_list,\n",
"\t},\n",
"\tD1_gt_samples=D1[:5000],\n",
"\tD1_gt_samples=D1[:2000],\n",
"\tnum_trajectories=100,\n",
"\ttitle=\"Interpolated Trajectories Visualization\",\n",
")"
Expand Down Expand Up @@ -462,7 +479,7 @@
" \"1rf straight\": euler_sampler_straight.trajectories,\n",
" \"1rf spherical\": euler_sampler_spherical.trajectories,\n",
"\t},\n",
" D1_gt_samples=D1[:5000],\n",
" D1_gt_samples=D1[:2000],\n",
" num_trajectories=100,\n",
" title=\"Euler Sampler Visualization\",\n",
")"
Expand Down Expand Up @@ -547,7 +564,7 @@
" \"straight rf\": euler_sampler_straight.trajectories, \n",
" \"straight to spherical rf\": euler_sampler_converted_spherical.trajectories\n",
" },\n",
"\tD1_gt_samples=D1[:5000],\n",
"\tD1_gt_samples=D1[:2000],\n",
"\tnum_trajectories=100,\n",
"\ttitle=f\"Straight Converted to Spherical v.s. Original Straight RF\",\n",
")"
Expand Down Expand Up @@ -588,7 +605,7 @@
" \"spherical rf\": euler_sampler_spherical.trajectories, \n",
" \"converted spherical rf\": euler_sampler_converted_spherical.trajectories\n",
" },\n",
"\tD1_gt_samples=D1[:5000],\n",
"\tD1_gt_samples=D1[:2000],\n",
"\tnum_trajectories=100,\n",
"\ttitle=f\"Converte Straight to Spherical RF v.s. Spherical RF\",\n",
")"
Expand Down Expand Up @@ -674,7 +691,7 @@
" \"reparam spherical rf\": euler_sampler_spherical.trajectories, \n",
" \"straight to spherical rf\": euler_sampler_converted_spherical.trajectories\n",
" },\n",
"\tD1_gt_samples=D1[:5000],\n",
"\tD1_gt_samples=D1[:2000],\n",
"\tnum_trajectories=100,\n",
"\ttitle=f\"Re-parametrized Straight RF v.s. Converted Straight to Spherical RF\",\n",
")"
Expand Down
23 changes: 20 additions & 3 deletions examples/natural_euler_sampler.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/natural_euler_sampler.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git clone https://github.com/lqiang67/rectified-flow.git\n",
"%cd rectified-flow/"
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -124,7 +141,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Both vanilla Euler sampler, huge difference in the trajectories\n",
"# Both vanilla Euler sampler, noticeable difference in final generated samples\n",
"\n",
"from rectified_flow.samplers import EulerSampler\n",
"\n",
Expand Down Expand Up @@ -154,7 +171,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Unmatched time grid, both natural euler samplers, nearly identical\n",
"# Unmatched time grid, both natural euler samplers, nearly identical final generated samples\n",
"\n",
"from rectified_flow.samplers import CurvedEulerSampler\n",
"\n",
Expand Down Expand Up @@ -184,7 +201,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Matched time grid, both natural euler samplers, exactly the same\n",
"# Matched time grid, both natural euler samplers, exactly the same final generated samples\n",
"\n",
"def convert_time(t, affine_interp):\n",
" return [affine_interp.alpha(t) / (affine_interp.alpha(t) + affine_interp.beta(t)) for t in t]\n",
Expand Down
23 changes: 20 additions & 3 deletions examples/samplers_2d_toys.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/samplers_2d_toys.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git clone https://github.com/lqiang67/rectified-flow.git\n",
"%cd rectified-flow/"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -380,7 +397,7 @@
"\\text{Effective Noise at time } t = \\text{step\\_size} \\times \\text{noise\\_scale} \\times \\beta_t^{\\text{noise\\_decay\\_rate}}\n",
"$$\n",
"\n",
"Check out this [blog post](https://rectifiedflow.github.io/blog/2024/samplers/) for more details."
"Check out this [blog post](https://rectifiedflow.github.io/blog/2024/diffusion/) for more details."
]
},
{
Expand Down Expand Up @@ -463,7 +480,7 @@
"sde_sampler = OverShootingSampler(\n",
" rectified_flow=straight_rf,\n",
" num_steps=10,\n",
" num_samples=500,\n",
" num_samples=1000,\n",
" c=15.0,\n",
" overshooting_method=\"t+dt\"\n",
")\n",
Expand All @@ -475,7 +492,7 @@
"# Plot CurvedEulerSampler results\n",
"visualize_2d_trajectories_plotly(\n",
" {\"overshooting\": sde_sampler.trajectories},\n",
" D1[:5000], # D1 defined previously\n",
" D1[:1000], # D1 defined previously\n",
" num_trajectories=100,\n",
" show_legend=True\n",
")\n",
Expand Down
35 changes: 26 additions & 9 deletions examples/train_2d_toys.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/lqiang67/rectified-flow/blob/main/examples/train_2d_toys.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git clone https://github.com/lqiang67/rectified-flow.git\n",
"%cd rectified-flow/"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -64,7 +81,7 @@
"D1, labels = pi_1.sample_with_labels([n_samples])\n",
"labels.tolist()\n",
"\n",
"plt.figure(figsize=(5, 5))\n",
"plt.figure(figsize=(3, 3))\n",
"plt.title(r'Samples from $\\pi_0$ and $\\pi_1$')\n",
"plt.scatter(D0[:, 0].cpu(), D0[:, 1].cpu(), alpha=0.5, label=r'$\\pi_0$')\n",
"plt.scatter(D1[:, 0].cpu(), D1[:, 1].cpu(), alpha=0.5, label=r'$\\pi_1$')\n",
Expand Down Expand Up @@ -132,8 +149,8 @@
" \n",
"visualize_2d_trajectories_plotly(\n",
" trajectories_dict={\n",
" \"upper interpolation\": interp_upper,\n",
"\t\t\"lower interpolation\": interp_lower\n",
" \"upper\": interp_upper,\n",
"\t\t\"lower\": interp_lower\n",
" },\n",
" D1_gt_samples=torch.cat([x_1_upper, x_1_lower], dim=0),\n",
" num_trajectories=100,\n",
Expand Down Expand Up @@ -167,7 +184,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -221,7 +238,7 @@
},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
"batch_size = 1024\n",
"\n",
"losses = []\n",
Expand Down Expand Up @@ -349,7 +366,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -392,7 +409,7 @@
"\toptimizer.step()\n",
"\tlosses.append(loss.item())\n",
"\n",
"\tif step % 200 == 0:\n",
"\tif step % 1000 == 0:\n",
"\t\tprint(f\"Epoch {step}, Loss: {loss.item()}\")\n",
" \n",
"plt.plot(losses)"
Expand Down Expand Up @@ -449,7 +466,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -483,7 +500,7 @@
"\toptimizer.step()\n",
"\tlosses.append(loss.item())\n",
"\n",
"\tif step % 200 == 0:\n",
"\tif step % 1000 == 0:\n",
"\t\tprint(f\"Epoch {step}, Loss: {loss.item()}\")\n",
" \n",
"plt.plot(losses)"
Expand Down

0 comments on commit 096f3ae

Please sign in to comment.