Skip to content

Commit

Permalink
[Playground] Tune parameters
Browse files Browse the repository at this point in the history
Topic: tune_params_1
  • Loading branch information
hmartiro committed Jan 17, 2023
1 parent 66b0f03 commit a823173
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 10 deletions.
22 changes: 17 additions & 5 deletions riffusion/streamlit/pages/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def render_audio_to_audio() -> None:
num_inference_steps = T.cast(
int,
st.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
"Steps per sample", value=25, help="Number of denoising steps per model run"
),
)

Expand Down Expand Up @@ -115,17 +115,24 @@ def render_audio_to_audio() -> None:

counter = streamlit_util.StreamlitCounter()

denoising_default = 0.55
with st.form("audio to audio form"):
if interpolate:
left, right = st.columns(2)

with left:
st.write("##### Prompt A")
prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a"))
prompt_input_a = PromptInput(
guidance=guidance,
**get_prompt_inputs(key="a", denoising_default=denoising_default),
)

with right:
st.write("##### Prompt B")
prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b"))
prompt_input_b = PromptInput(
guidance=guidance,
**get_prompt_inputs(key="b", denoising_default=denoising_default),
)
elif use_magic_mix:
prompt = st.text_input("Prompt", key="prompt_a")

Expand All @@ -150,7 +157,12 @@ def render_audio_to_audio() -> None:
else:
prompt_input_a = PromptInput(
guidance=guidance,
**get_prompt_inputs(key="a", include_negative_prompt=True, cols=True),
**get_prompt_inputs(
key="a",
include_negative_prompt=True,
cols=True,
denoising_default=denoising_default,
),
)

st.form_submit_button("Riff", type="primary", on_click=counter.increment)
Expand Down Expand Up @@ -319,7 +331,7 @@ def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]:
p["duration_s"] = cols[1].number_input(
"Duration [s]",
min_value=0.0,
value=15.0,
value=20.0,
)

if advanced:
Expand Down
11 changes: 8 additions & 3 deletions riffusion/streamlit/pages/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,15 @@ def render_interpolation() -> None:

with left:
st.write("##### Prompt A")
prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a"))
prompt_input_a = PromptInput(
guidance=guidance, **get_prompt_inputs(key="a", denoising_default=0.75)
)

with right:
st.write("##### Prompt B")
prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b"))
prompt_input_b = PromptInput(
guidance=guidance, **get_prompt_inputs(key="b", denoising_default=0.75)
)

st.form_submit_button("Generate", type="primary")

Expand Down Expand Up @@ -201,6 +205,7 @@ def get_prompt_inputs(
key: str,
include_negative_prompt: bool = False,
cols: bool = False,
denoising_default: float = 0.5,
) -> T.Dict[str, T.Any]:
"""
Compute prompt inputs from widgets.
Expand Down Expand Up @@ -228,7 +233,7 @@ def get_prompt_inputs(

p["denoising"] = right.number_input(
"Denoising",
value=0.75,
value=denoising_default,
key=f"denoising_{key}",
help="How much to modify the seed image",
)
Expand Down
2 changes: 1 addition & 1 deletion riffusion/streamlit/pages/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def render_text_to_audio() -> None:
st.form_submit_button("Riff", type="primary")

with st.sidebar:
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50))
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=25))
width = T.cast(int, st.number_input("Width", value=512))
guidance = st.number_input(
"Guidance", value=7.0, help="How much the model listens to the text prompt"
Expand Down
2 changes: 1 addition & 1 deletion riffusion/streamlit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]

SCHEDULER_OPTIONS = [
"DPMSolverMultistepScheduler",
"PNDMScheduler",
"DDIMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]


Expand Down

0 comments on commit a823173

Please sign in to comment.