Skip to content

Commit

Permalink
Minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkMnDragon committed Dec 29, 2024
1 parent 6feb0ac commit 719a040
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions rectified_flow/flow_components/interpolation_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ def __init__(
if alpha is None or beta is None:
raise ValueError("Custom interpolation requires both alpha and beta functions.")

self.name = name if name is not None else "custom"
self.alpha = alpha
self.beta = beta
self.dot_alpha = dot_alpha
self.dot_beta = dot_beta
name = name if name is not None else "custom"
alpha = alpha
beta = beta
dot_alpha = dot_alpha
dot_beta = dot_beta

else:
if name is None:
Expand All @@ -232,36 +232,42 @@ def __init__(

if lower_name in ["straight", "lerp"]:
# Straight line interpolation
self.name = "straight"
self.alpha = lambda t: t
self.beta = lambda t: 1 - t
self.dot_alpha = lambda t: torch.ones_like(t)
self.dot_beta = lambda t: -torch.ones_like(t)
name = "straight"
alpha = lambda t: t
beta = lambda t: 1 - t
dot_alpha = lambda t: torch.ones_like(t)
dot_beta = lambda t: -torch.ones_like(t)

elif lower_name in ["slerp", "spherical"]:
# Spherical interpolation
self.name = "spherical"
self.alpha = lambda t: torch.sin(t * torch.pi / 2.0)
self.beta = lambda t: torch.cos(t * torch.pi / 2.0)
self.dot_alpha = lambda t: (torch.pi / 2.0) * torch.cos(t * torch.pi / 2.0)
self.dot_beta = lambda t: -(torch.pi / 2.0) * torch.sin(t * torch.pi / 2.0)
name = "spherical"
alpha = lambda t: torch.sin(t * torch.pi / 2.0)
beta = lambda t: torch.cos(t * torch.pi / 2.0)
dot_alpha = lambda t: (torch.pi / 2.0) * torch.cos(t * torch.pi / 2.0)
dot_beta = lambda t: -(torch.pi / 2.0) * torch.sin(t * torch.pi / 2.0)

elif lower_name in ["ddim", "ddpm"]:
# DDIM/DDPM scheme
self.name = "DDIM"
name = "DDIM"
a = 19.9
b = 0.1
self.alpha = lambda t: torch.exp(-a * (1 - t) ** 2 / 4.0 - b * (1 - t) / 2.0)
self.beta = lambda t: torch.sqrt(1 - self.alpha(t) ** 2)
self.alpha = None
self.beta = None
alpha = lambda t: torch.exp(-a * (1 - t) ** 2 / 4.0 - b * (1 - t) / 2.0)
beta = lambda t: torch.sqrt(1 - self.alpha(t) ** 2)
dot_alpha = None
dot_beta = None

else:
raise ValueError(
f"Unknown interpolation scheme name '{name}'. Provide a known scheme name "
"or supply custom alpha/beta functions."
)


self.name = name
self.alpha = lambda t: alpha(self.ensure_tensor(t))
self.beta = lambda t: beta(self.ensure_tensor(t))
self.dot_alpha = None if dot_alpha is None else lambda t: dot_alpha(self.ensure_tensor(t))
self.dot_beta = None if dot_beta is None else lambda t: dot_beta(self.ensure_tensor(t))

self.solver = AffineInterpSolver()
self.a_t = None
self.b_t = None
Expand Down

0 comments on commit 719a040

Please sign in to comment.