RectifiedFlow is a simple, unified PyTorch codebase for diffusion and flow models. It offers an easy-to-use platform for training and inference, focusing on simplicity, flexibility, and quick prototyping. The library includes:
-
Companion Resources: Includes a , , and beginner-friendly covering concepts from basics to advanced implementations.
-
Unified ODE Framework: Train and infer rectified flow (RF) and diffusion models using a unified ODE approach, including 1-rectified flow from data (flow matching), reflow for speedup, diffusion as RF+Langevin, post-training conversion of affine interpolation schemes, analytic models, etc.
-
Symbolic Algorithm Derivation: We use a symbolic solver for affine interpolation to automate the derivation of algorithms and formulas, enabling easy model conversion between various forms like score functions, velocity fields, and noise predictions. This eliminates the need for manual derivation in both existing and new algorithms (e.g., the DDIM/DDPM coefficients).
-
Easy Integration with SOTA Models: Easily integrate state-of-the-art models, including the Flux series, for greater flexibility and compatibility.
You can install the rectified-flow
package using pip
:
pip install rectified-flow
Alternatively, you can install the package from source. Please run the following commands in the given order to install the dependency.
conda create -n rf python=3.10
conda activate rf
git clone https://github.com/lqiang67/rectified-flow.git
cd rectified-flow
pip install -r requirements.txt
Then install the rectified-flow
package:
pip install -e .
Consider the task of learning an ODE model
where
After training the model
Although ultimately unnecessary in theory (see Chapter 3 of the lecture notes), the codebase supports a more general affine interpolation
The RectifiedFlow
class serves as an intermediary for your training and inference processes. Each different velocity field should instantiate a separate RectifiedFlow
class.
from rectified_flow.rectified_flow import RectifiedFlow
model = YourVelocityFieldModel()
# Initialize RectifiedFlow with custom settings
rectified_flow = RectifiedFlow(
data_shape=(32, 32),
velocity_field=model,
interp="straight",
source_distribution="normal",
is_independent_coupling=True,
train_time_distribution="uniform",
train_time_weight="uniform",
criterion="mse",
device=device,
)
# Or use the default settings
rectified_flow = RectifiedFlow(
data_shape=(32, 32),
velocity_field=model,
device=device,
)
During training, you can easily compute the predefined loss by passing your target data samples x_1
. If samples from source distribution x_0
is not provided, it will be sampled by default from standard Gaussian. The RectifiedFlow
class supports various pre-specified loss functions and interpolation methods, and it calculates the loss accordingly.
loss = rectified_flow.get_loss(x_0=None, x_1=x_1, **kwargs)
This is implemented by:
# Step 1: Interpolation
x_t, dot_x_t = self.get_interpolation(x_0, x_1, t)
# Step 2: Velocity Calculation
v_t = self.get_velocity(x_t, t, **kwargs)
# Step 3: Time Weights
time_weights = self.train_time_weight(t)
# Step 4: Loss Computation
return self.criterion(
v_t=v_t,
dot_x_t=dot_x_t,
x_t=x_t,
t=t,
time_weights=time_weights,
)
After training, converting a pretrained rectified flow to another interpolation scheme (as long as alpha and beta are specified) can be done easily and automatically by:
from rectified_flow.flow_components.interpolation_convertor import AffineInterpConverter
# Converting pretrained rf into spherical one
target_interp = AffineInterp("spherical")
converted_spherical_rf = AffineInterpConverter(rf, target_interp).transform_rectified_flow()
For sampling, import the desired sampler class and pass the RectifiedFlow
instance to it.
from rectified_flow.samplers import SDESampler
sde_sampler = SDESampler(rectified_flow=rectified_flow)
sde_sampler.sample_loop(
num_samples=128,
num_steps=100,
seed=0,
)
traj = sde_sampler.trajectories
img = traj[-1]
- Introduction with 2D Toy: This notebook provides a 2D toy example to illustrate the fundamental concepts of Rectified Flow. It covers the basics of interpolation processes, the training and inference of rectified flow, and reflow, which straightens flow dynamics to achieve speedup.
-
Samplers: This notebook explores the samplers available in this repository using a 2D toy example. It illustrates the concepts and usage of both deterministic and stochastic samplers. Additionally, it demonstrates how to customize a sampler by inheriting from the
$\texttt{Sampler}$ base class and discusses the effects of using stochastic samplers. - Interpolation: This notebook introduces the idea that different affine interpolations can be converted from one form to another and provides a simple implementation to achieve this transformation. It also highlights the interesting observation that the same transformation applies to rectified flows — and, in fact, to their discretized trajectories produced by natural Euler samplers.
- Flux: We provide a notebook that shows how to easily interact with the wrapped Flux model using different samplers. Additionally, another notebook demonstrates how to perform image editing task with Flux. All implementations in a clear and accesible manner.
We provide Diffusers-style training scripts for UNet. and DiT in this directory. The training scripts utilizes Accelerate for multi-GPU training.
Results Using this Training Scripts:
-
UNet CIFAR10: Trained for
$500 \text{k}$ iterations withbatch_size=128
. You can download the model here.$\text{FID}_{50\text{K}}=4.308$ . -
DiT CIFAR10: Trained for
$1000 \text{k}$ iterations withbatch_size=128
. You can download the model here.$\text{FID}_{50\text{K}}=3.678$ .
Loading a Pretrained Model:
To construct a model from a pretrained checkpoint, simply run the following code:
from rectified_flow.models.dit import DiT
model = DiT.from_pretrained(save_directory="PATH_TO_MODEL", filename="dit", use_ema=True).to(device)
The AffineInterp
class manages the affine interpolation between the source distribution
-
Automatic Interpolation Handling: Given an affine interpolation
$X_t=\alpha_tX_1 + \beta_t X_0$ , providing$\alpha_t$ and$\beta_t$ functions (optionally along with their time-derivative functions$\dot \alpha_t$ and$\dot \beta_t$ ),AffineInterp
computes the interpolated state$X_t$ and its time derivative$\dot X_t$ . If the derivatives functions$\dot \alpha_t, \dot\beta_t$ are not supplied, they are calculated automatically using Pytorch automatic differentiation.from rectified_flow.flow_components import AffineInterp alpha_function = lambda t: torch.sin(a * t) / torch.sin(a) beta_function = lambda t: torch.sin(a * (1 - t)) / torch.sin(a) interp = AffineInterp(alpha=alpha_function, beta=beta_function) x_t, dot_x_t = interp.forward(x_0, x_1, t)
-
Automatic Solving of Unknown Variables: Given any two of the four variables (
$X_0,X_1,X_t,\dot X_t$ ), the class can automatically solve for the remaining unknowns using precomputed symbolic solvers from$X_t = \alpha_t X_1 + \beta_t X_0$ , and$\dot{X}_t = \boldsymbol{\dot{\alpha}}_t X_1 + \boldsymbol{\dot{\beta}}_t X_0$ . This feature is convenient to avoid the hand derivation of the coefficients in DDIM like algorithms, and conversion between important quantities, such as the RF velocity, score fuction, and predicte noise and targets.# Solve for x_0 and x_1 given x_t and dot_x_t interp.solve(t=t, x_t=x_t, dot_x_t=velocity) print(interp.x_0, interp.x_1)
# The inference step of DDIM as curved Euler sampler walking along the interopoliation curves def step(self): t, t_next, x_t = self.t, self.t_next, self.x_t v_t = self.rectified_flow.get_velocity(x_t, t) # find the expected noise x_0_pred and data x_1_pred from the interpolation interp = self.interp_inference.solve(t, x_t=x_t, dot_x_t=v_t) x_1_pred = interp.x_1 x_0_pred = interp.x_0 # Get x_{t_next} on the interpolated curve self.x_t = self.interp_inference.solve(t_next, x_0=x_0_pred, x_1=x_1_pred).x_t
The velocity_field
argument in the RectifiedFlow
class accepts a neural network or any callable function that takes
- Reversing the Time Direction
In Rectified Flow, we use the convention of transforming the noise (or source) distribution
# Reverse ODE time direction
velocity = lambda x_t, t: -model(x_t, 1.0 - t)
-
Reparameterizing for Noise Prediction
Some works parameterize the model to predict noise instead of velocity. Using anAffineInterpSolver
, you can automatically convert noise predictions into velocity predictions, bypassing the complexity of handling DDIM coefficients.# Convert noise prediction to velocity prediction # Assume model is trained by minimizing ((x0 - model(x_t, t))**2).mean(), with x_t = a_t * x_1 + b_t*x_0, where x_0 is noise, and x_1 is data. velocity = lambda x_t, t: self.interp.solve(t=t, x_t=x_t, x_0=model(x_t, t)).dot_x_t
To tailor the training process to your specific requirements, you can customize these utilities by inheriting from their base classes and overriding their methods. Once customized, simply pass the instances to the RectifiedFlow
class during initialization.
Example: Custom Weighting Scheme
from rectified_flow.flow_components import TrainTimeWeight
class CustomTimeWeight(TrainTimeWeight):
def __init__(self):
super().__init__()
def __call__(self, t: torch.Tensor) -> torch.Tensor:
wts = torch.exp(t)
return wts
# Initialize with custom exponential weighting
custom_time_weight = CustomTimeWeight()
To create custom samplers with specific integration schemes, subclass the Sampler
class and implement the step
method tailored to your needs. The step
method receives the current state x_t
, t
, and t_next
from the base class and defines the integration scheme.
Example: Euler Sampler
from rectified_flow.flow_components import Sampler
class EulerSampler(Sampler):
def __init__(
self,
rectified_flow: RectifiedFlow,
num_steps: int | None = None,
time_grid: list[float] | torch.Tensor | None = None,
record_traj_period: int = 1,
callbacks: list[Callable] | None = None,
num_samples: int | None = None,
):
super().__init__(
rectified_flow,
num_steps,
time_grid,
record_traj_period,
callbacks,
num_samples,
)
def step(self, **model_kwargs):
t, t_next, x_t = self.t, self.t_next, self.x_t
v_t = self.rectified_flow.get_velocity(x_t, t, **model_kwargs)
self.x_t = x_t + (t_next - t) * v_t
After defining your custom sampler, you can perform sampling just like with a standard sampler.
The following is a quick implementation of stochastic sampler which covers the DDPM sampling algorithm:
class MyStochasticSampler(Sampler):
def __init__(self, rectified_flow: RectifiedFlow, noise_replacement_rate: Callable | str = lambda t, t_next: 0.5, **kwargs):
super().__init__(rectified_flow=rectified_flow, **kwargs)
if not (self.rectified_flow.independent_coupling and self.rectified_flow.is_pi_0_zero_mean_gaussian):
import warnings
warnings.warn("It is only theoretically correct to use this sampler when pi0 is a zero mean Gaussian and the coupling (X0, X1) is independent. Proceed at your own risk.")
self.noise_replacement_rate = noise_replacement_rate
def step(self, **model_kwargs):
"""Perform a single step of the sampling process."""
t, t_next, x_t = self.t, self.t_next, self.x_t
v_t = self.rectified_flow.get_velocity(x_t, t, **model_kwargs)
# Given x_t and dot_x_t = vt, find the corresponding endpoints x_0 and x_1
interp = self.rectified_flow.interp.solve(t, x_t=x_t, dot_x_t=v_t)
x_1_pred = interp.x_1
x_0_pred = interp.x_0
# Randomize x_0_pred by replacing part of it with new noise
noise = self.rectified_flow.sample_source_distribution(self.num_samples)
noise_replacement_factor = self.noise_replacement_rate(t, t_next)
x_0_pred_refreshed = (
(1 - noise_replacement_factor) * x_0_pred +
(1 - (1 - noise_replacement_factor) ** 2) ** 0.5 * noise
)
# Interpolate to find x_t at t_next
self.x_t = self.rectified_flow.interp.solve(t_next, x_0=x_0_pred_refreshed, x_1=x_1_pred).x_t
If you find this repository useful for your research, please consider citing
@misc{lq2024rectifiedflow,
author = {Qiang Liu, Runlong Liao, Bo Liu, Xixi Hu},
title = {PyTorch RectifiedFlow},
year = {2024},
url = {https://github.com/lqiang67/rectified-flow}
}
Component | License |
---|---|
Codebase | MIT License |