Skip to content

Commit

Permalink
Initialize model to session state
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryu1845 committed Feb 9, 2023
1 parent 3d9fa4f commit 11f904d
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def timeit(desc=""):
help="Text to speak.",
value="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.",
)

voices = os.listdir("tortoise/voices") + ["random"]
voices.remove("cond_latent_example")
voice = st.selectbox(
Expand All @@ -39,7 +40,7 @@ def timeit(desc=""):
"Use the & character to join two voices together. Use a comma to perform inference on multiple voices.",
index=len(voices) - 1,
)
preset = st.radio(
preset = st.selectbox(
"Preset",
(
"single_sample",
Expand All @@ -50,7 +51,7 @@ def timeit(desc=""):
"high_quality",
),
help="Which voice preset to use.",
index=3,
index=1,
)
with st.expander("Advanced"):
col1, col2 = st.columns(2)
Expand All @@ -75,7 +76,6 @@ def timeit(desc=""):
)
if seed == -1:
seed = None

"""#### Directories"""
output_path = st.text_input(
"Output Path", help="Where to store outputs.", value="results/"
Expand All @@ -88,16 +88,17 @@ def timeit(desc=""):
)



with col2:
"""#### Optimizations"""
high_vram = not st.checkbox(
"Low VRAM",
help="re-enable default offloading behaviour of tortoise",
help="Re-enable default offloading behaviour of tortoise",
value=True,
)
half = st.checkbox(
"Half-Precision",
help="enable autocast to half precision for autoregressive model",
help="Enable autocast to half precision for autoregressive model",
value=False,
)
kv_cache = st.checkbox(
Expand All @@ -121,11 +122,12 @@ def timeit(desc=""):
value=True,
)


if 'tts' not in st.session_state:
st.session_state.tts = TextToSpeech(models_dir=model_dir, high_vram=high_vram, kv_cache=kv_cache)
tts = st.session_state.tts
if st.button("Start"):
with st.spinner(f"Generating {candidates} candidates for voice {voice} (seed={seed}). You can see progress in the terminal"):
os.makedirs(output_path, exist_ok=True)
tts = TextToSpeech(models_dir=model_dir, high_vram=high_vram, kv_cache=kv_cache)

selected_voices = voice.split(",")
for k, selected_voice in enumerate(selected_voices):
Expand Down Expand Up @@ -168,7 +170,7 @@ def timeit(desc=""):
)
audio_buffer = BytesIO()
torchaudio.save(
audio_stream,
audio_buffer,
g.squeeze(0).cpu(),
24000,
)
Expand Down

0 comments on commit 11f904d

Please sign in to comment.