RectifiedFlow is a repository that offers a unified and minimal codebase for implementing a variety of diffusion and flow models, including Rectified Flow, DDPM, DDIM, and more. By adopting a unified Ordinary Differential Equation (ODE) perspective, RectifiedFlow provides a streamlined and convenient framework for training and inference, tailored for research purposes. This PyTorch-based library includes:
-
Unified Training and Inference: Seamlessly train and infer rectified flow (flow matching) and diffusion models from a single, coherent ODE perspective.
-
Interactive Tutorials: Engage with beginner-friendly tutorials that offer hands-on experience with rectified flows. Learn how to effortlessly transform rectified flow models into other models like DDIM and demonstrate their equivalence.
-
Comprehensive Tools: Access a robust set of tools for studying Rectified Flow models, including interpolation methods and ODE/SDE solvers. Designed with the ODE framework in mind, these tools are both easy to comprehend and use.
-
Support for State-of-the-Art Models: Utilize support for cutting-edge open-source models such as Flux.1-dev, ensuring that you can implement advanced tasks with surprising ease. Demos demonstrate how to perform sophisticated tasks like image editing with minimal effort.
Whether you are a researcher exploring the frontiers of generative modeling or a practitioner seeking to deepen your understanding through comprehensive tutorials, RectifiedFlow provides the essential resources and functionalities to advance your projects with confidence and ease.
Please run the following commands in the given order to install the dependency.
conda create -n rf python=3.10
conda activate rf
git clone https://github.com/lqiang67/rectified-flow.git
cd rectified-flow
pip install -r requirements.txt
Then install the rectified-flow
package:
pip install -e .
The RectifiedFlow
class serves as an intermediary for your training and inference processes. Each different velocity field should instantiate a separate RectifiedFlow
class.
from rectified_flow.rectified_flow import RectifiedFlow
model = YourVelocityFieldModel()
# Initialize RectifiedFlow with custom settings
rectified_flow = RectifiedFlow(
data_shape=(32, 32),
velocity_field=model,
interp="straight",
source_distribution="normal",
is_independent_coupling=True,
train_time_distribution="uniform",
train_time_weight="uniform",
criterion="mse",
device=device,
)
# Or use the default settings
rectified_flow = RectifiedFlow(
data_shape=(32, 32),
velocity_field=model,
device=device,
)
During training, you can easily compute the predefined loss by passing your target data samples x_1
. If samples from source distribution x_0
is not provided, it will be sampled by default. The RectifiedFlow
class supports various pre-specified loss functions and interpolation methods, and it calculates the loss accordingly.
loss = rectified_flow.get_loss(x_0=None, x_1=x_1, **kwargs)
For sampling, import the desired sampler class and pass the RectifiedFlow
instance to it.
from rectified_flow.samplers import SDESampler
sde_sampler = SDESampler(rectified_flow=rectified_flow)
sde_sampler.sample_loop(
num_samples=128,
num_steps=100,
seed=0,
)
traj = sde_sampler.trajectories
img = traj[-1]
-
Introduction with 2D Toy: This notebook provides an example using a 2D toy to illustrate the basic concepts of Rectified Flow. It covers the interpolation process
${X_t}$ , the rectified flow${Z_t}$ with velocity$\mathbb{E}[\dot{X}_t \mid X_t]$ , and Reflow${Z^1_t}$ . -
Samplers: This notebook explores the samplers provided in this repository using a 2D toy example. It illustrates the concepts and usage of various samplers such as
CurvedEuler
,Overshooting
, andSDESampler
. Additionally, it demonstrates how to customize your own sampler by inheriting from theSampler
base class. And discusses the implications of using stochastic samplers. -
Interpolation: This notebook first illustrates the idea that different interpolations
${X_t}$ can be converted from one another and presents a simple implementation to achieve this conversion. It also reveals the interesting fact that the very same transformation applies to${Z_t}$ , along with a few notable findings. - Flux: We provide a notebook that shows how to easily interact with the wrapped Flux model using different samplers. Additionally, another notebook demonstrates how to perform image editing task with Flux. All in a straightforward and friendly manner.
We provide Diffusers-style training scripts for UNet. and DiT in this directory. The training scripts utilizes Accelerate for multi-GPU training.
Results Using this Training Scripts:
-
UNet CIFAR10: Trained for
$500 \text{k}$ iterations withbatch_size=128
. You can download the model here.$\text{FID}_{50\text{K}}=4.308$ . -
DiT CIFAR10: Trained for
$1000 \text{k}$ iterations withbatch_size=128
. You can download the model here.$\text{FID}_{50\text{K}}=3.678$ .
Loading a Pretrained Model:
To construct a model from a pretrained checkpoint, simply run the following code:
from rectified_flow.models.dit import DiT
model = DiT.from_pretrained(save_directory="PATH_TO_MODEL", filename="dit", use_ema=True).to(device)
The velocity_field
argument in the RectifiedFlow
class expects a neural network or a callable function that takes x_t
and t
as inputs. When you want to reparameterize the model or change the direction of the generating ODE time from
model = YourPretrainedModel()
# Change ODE time direction from 1→0 to 0→1
velocity = lambda x_t, t: model(x_t, 1.0 - t)
# Reparameterization example
velocity = lambda x_t, t: t**2 * model(x_t, t)
In Rectified Flow, we assume that when
The AffineInterp
class manages the affine interpolation between the source distribution
-
Automatic Interpolation Handling: By providing
$\alpha_t$ and$\beta_t$ functions (along with optional$\dot{\alpha}_t$ and$\dot\beta_t$ ),AffineInterp
computes the interpolated state$X_t$ and its time derivative$\dot X_t$ . If the derivatives functions$\dot \alpha_t, \dot\beta_t$ are not supplied, they are calculated automatically using Pytorch automatic differentiation.from rectified_flow.flow_components import AffineInterp alpha_function = lambda t: torch.sin(a * t) / torch.sin(a) beta_function = lambda t: torch.sin(a * (1 - t)) / torch.sin(a) interp = AffineInterp(alpha=alpha_function, beta=beta_function) x_t, dot_x_t = interp.forward(x_0, x_1, t)
-
Automatic Solving of Unknown Variables: Given any two of the four variables (
$X_0,X_1,X_t,\dot X_t$ ), the class can automatically solve for the remaining unknowns using precomputed symbolic solvers. This feature is very convenient when computing certain common quantities, such as estimating$\hat X_0$ and$\hat X_1$ given$X_t$ and$v(X_t, t)$ .# Solve for x_0 and x_1 given x_t and dot_x_t interp.solve(t=t, x_t=x_t, dot_x_t=velocity) print(interp.x_0, interp.x_1)
To tailor the training process to your specific requirements, you can customize these utilities by inheriting from their base classes and overriding their methods. Once customized, simply pass the instances to the RectifiedFlow
class during initialization.
Example: Custom Weighting Scheme
from rectified_flow.flow_components import TrainTimeWeight
class CustomTimeWeight(TrainTimeWeight):
def __init__(self):
super().__init__()
def __call__(self, t: torch.Tensor) -> torch.Tensor:
wts = torch.exp(t)
return wts
# Initialize with custom exponential weighting
custom_time_weight = CustomTimeWeight()
To create custom samplers with specific integration schemes, subclass the Sampler
class and implement the step
method tailored to your needs. The step
method receives the current state x_t
, t
, and t_next
from the base class and defines the integration scheme.
Example: Euler Sampler
from rectified_flow.flow_components import Sampler
class EulerSampler(Sampler):
def __init__(
self,
rectified_flow: RectifiedFlow,
num_steps: int | None = None,
time_grid: list[float] | torch.Tensor | None = None,
record_traj_period: int = 1,
callbacks: list[Callable] | None = None,
num_samples: int | None = None,
):
super().__init__(
rectified_flow,
num_steps,
time_grid,
record_traj_period,
callbacks,
num_samples,
)
def step(self, **model_kwargs):
t, t_next, x_t = self.t, self.t_next, self.x_t
v_t = self.rectified_flow.get_velocity(x_t, t, **model_kwargs)
self.x_t = x_t + (t_next - t) * v_t
After defining your custom sampler, you can perform sampling just like with a standard sampler.
If you find this repository useful for your research, please consider citing
@misc{lq2024rectifiedflow,
author = {Qiang Liu, Runlong Liao, Bo Liu, Xixi Hu},
title = {PyTorch RectifiedFlow},
year = {2024},
url = {https://github.com/lqiang67/rectified-flow}
}
Component | License |
---|---|
Codebase | MIT License |