Skip to content

Commit

Permalink
fix heun scheduler (huggingface#1512)
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj authored Dec 1, 2022
1 parent e65b71a commit 0f1c246
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/diffusers/schedulers/scheduling_heun.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,13 @@ def step(

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
sigma_input = sigma_hat if self.state_in_first_order else sigma_next
pred_original_sample = sample - sigma_input * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
sigma_input = sigma_hat if self.state_in_first_order else sigma_next
pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
sample / (sigma_input**2 + 1)
)
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
Expand All @@ -207,7 +210,7 @@ def step(
self.sample = sample
else:
# 2. 2nd order / Heun's method
derivative = (sample - pred_original_sample) / sigma_hat
derivative = (sample - pred_original_sample) / sigma_next
derivative = (self.prev_derivative + derivative) / 2

# 3. Retrieve 1st order derivative
Expand Down

0 comments on commit 0f1c246

Please sign in to comment.