Skip to content

Commit

Permalink
Support customizable affine interpolations
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkMnDragon committed Dec 27, 2024
1 parent b25fffc commit 7c884ef
Showing 1 changed file with 53 additions and 35 deletions.
88 changes: 53 additions & 35 deletions rectified_flow/flow_components/interpolation_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,49 +199,67 @@ class AffineInterp(nn.Module):

def __init__(
self,
name: str = "straight",
name: str | None = None,
alpha: Callable | None = None,
beta: Callable | None = None,
dot_alpha: Callable | None = None,
dot_beta: Callable | None = None,
):

super().__init__()

if name.lower() in ["straight", "lerp"]:
# Special case for "straight" interpolation
alpha = lambda t: t
beta = lambda t: 1 - t
dot_alpha = lambda t: torch.ones_like(t)
dot_beta = lambda t: -torch.ones_like(t)
name = "straight"
elif name.lower() in ["slerp", "spherical"]:
# Special case of "spherical" interpolation
alpha = lambda t: torch.sin(t * torch.pi / 2.0)
beta = lambda t: torch.cos(t * torch.pi / 2.0)
dot_alpha = lambda t: torch.cos(t * torch.pi / 2.0) * torch.pi / 2.0
dot_beta = lambda t: -torch.sin(t * torch.pi / 2.0) * torch.pi / 2.0
name = "spherical"
elif name.lower() in ["ddim", "ddpm"]:
# DDIM/DDPM scheme; see Eq 7 in https://arxiv.org/pdf/2209.03003
# Note in VP-ODE, dot_beta_t will explode at t = 1
a = 19.9
b = 0.1
alpha = lambda t: torch.exp(-a * (1 - t) ** 2 / 4.0 - b * (1 - t) / 2.0)
beta = lambda t: torch.sqrt(1 - alpha(t) ** 2)
name = "DDIM"
elif isinstance(name, str):
raise ValueError(f"Unknown interpolation scheme: {name}. Provide custom interpolation functions for alpha and beta.")
elif (
alpha is None
or beta is None
):
# Custom interpolation functions
raise NotImplementedError(
"Custom interpolation functions must be provided for alpha and beta."
)
if alpha is not None or beta is not None:
if name and name.lower() in ["straight", "lerp", "slerp", "spherical", "ddim", "ddpm"]:
raise ValueError(
f"You provided a predefined interpolation name '{name}' and also custom alpha/beta. "
"Only one option is allowed."
)
if alpha is None or beta is None:
raise ValueError("Custom interpolation requires both alpha and beta functions.")

self.name = name if name is not None else "custom"
self.alpha = alpha
self.beta = beta
self.dot_alpha = dot_alpha
self.dot_beta = dot_beta

else:
if name is None:
raise ValueError(
"No interpolation scheme name provided, and no custom alpha/beta supplied."
)

lower_name = name.lower()

if lower_name in ["straight", "lerp"]:
# Straight line interpolation
self.name = "straight"
self.alpha = lambda t: t
self.beta = lambda t: 1 - t
self.dot_alpha = lambda t: torch.ones_like(t)
self.dot_beta = lambda t: -torch.ones_like(t)

elif lower_name in ["slerp", "spherical"]:
# Spherical interpolation
self.name = "spherical"
self.alpha = lambda t: torch.sin(t * torch.pi / 2.0)
self.beta = lambda t: torch.cos(t * torch.pi / 2.0)
self.dot_alpha = lambda t: (torch.pi / 2.0) * torch.cos(t * torch.pi / 2.0)
self.dot_beta = lambda t: -(torch.pi / 2.0) * torch.sin(t * torch.pi / 2.0)

elif lower_name in ["ddim", "ddpm"]:
# DDIM/DDPM scheme
self.name = "DDIM"
a = 19.9
b = 0.1
self.alpha = lambda t: torch.exp(-a * (1 - t) ** 2 / 4.0 - b * (1 - t) / 2.0)
self.beta = lambda t: torch.sqrt(1 - self.alpha(t) ** 2)

else:
raise ValueError(
f"Unknown interpolation scheme name '{name}'. Provide a known scheme name "
"or supply custom alpha/beta functions."
)

self.name = name
self.alpha = lambda t: alpha(self.ensure_tensor(t))
self.beta = lambda t: beta(self.ensure_tensor(t))
if dot_alpha is not None:
Expand Down

0 comments on commit 7c884ef

Please sign in to comment.