forked from feizc/FluxMusic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
__main__.py
183 lines (156 loc) · 4.47 KB
/
__main__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!D:\GitDownload\SupThirdParty\audioldm2\venv\Scripts\python.exe
import os
import torch
import logging
from audioldm2 import text_to_audio, build_model, save_wave, get_time, read_list
import argparse
os.environ["TOKENIZERS_PARALLELISM"] = "true"
matplotlib_logger = logging.getLogger('matplotlib')
matplotlib_logger.setLevel(logging.WARNING)
parser = argparse.ArgumentParser()
parser.add_argument(
"-t",
"--text",
type=str,
required=False,
default="",
help="Text prompt to the model for audio generation",
)
parser.add_argument(
"--transcription",
type=str,
required=False,
default="",
help="Transcription for Text-to-Speech",
)
parser.add_argument(
"-tl",
"--text_list",
type=str,
required=False,
default="",
help="A file that contains text prompt to the model for audio generation",
)
parser.add_argument(
"-s",
"--save_path",
type=str,
required=False,
help="The path to save model output",
default="./output",
)
parser.add_argument(
"--model_name",
type=str,
required=False,
help="The checkpoint you gonna use",
default="audioldm_48k",
choices=["audioldm_48k", "audioldm_16k_crossattn_t5", "audioldm2-full", "audioldm2-music-665k",
"audioldm2-full-large-1150k", "audioldm2-speech-ljspeech", "audioldm2-speech-gigaspeech"]
)
parser.add_argument(
"-d",
"--device",
type=str,
required=False,
help="The device for computation. If not specified, the script will automatically choose the device based on your environment.",
default="auto",
)
parser.add_argument(
"-b",
"--batchsize",
type=int,
required=False,
default=1,
help="Generate how many samples at the same time",
)
parser.add_argument(
"--ddim_steps",
type=int,
required=False,
default=200,
help="The sampling step for DDIM",
)
parser.add_argument(
"-gs",
"--guidance_scale",
type=float,
required=False,
default=3.5,
help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
)
parser.add_argument(
"-dur",
"--duration",
type=float,
required=False,
default=10.0,
help="The duration of the samples",
)
parser.add_argument(
"-n",
"--n_candidate_gen_per_text",
type=int,
required=False,
default=3,
help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
)
parser.add_argument(
"--seed",
type=int,
required=False,
default=0,
help="Change this value (any integer number) will lead to a different generation result.",
)
args = parser.parse_args()
torch.set_float32_matmul_precision("high")
save_path = os.path.join(args.save_path, get_time())
text = args.text
random_seed = args.seed
duration = args.duration
sample_rate = 16000
if ("audioldm2" in args.model_name):
print(
"Warning: For AudioLDM2 we currently only support 10s of generation. Please use audioldm_48k or audioldm_16k_crossattn_t5 if you want a different duration.")
duration = 10
if ("48k" in args.model_name):
sample_rate = 48000
guidance_scale = args.guidance_scale
n_candidate_gen_per_text = args.n_candidate_gen_per_text
transcription = args.transcription
if (transcription):
if "speech" not in args.model_name:
print(
"Warning: You choose to perform Text-to-Speech by providing the transcription.However you do not choose the correct model name (audioldm2-speech-gigaspeech or audioldm2-speech-ljspeech).")
print("Warning: We will use audioldm2-speech-gigaspeech by default")
args.model_name = "audioldm2-speech-gigaspeech"
if (not text):
print(
"Warning: You should provide text as a input to describe the speaker. Use default (A male reporter is speaking)")
text = "A female reporter is speaking full of emotion"
os.makedirs(save_path, exist_ok=True)
audioldm2 = build_model(model_name=args.model_name, device=args.device)
if (args.text_list):
print("Generate audio based on the text prompts in %s" % args.text_list)
prompt_todo = read_list(args.text_list)
else:
prompt_todo = [text]
for text in prompt_todo:
if ("|" in text):
text, name = text.split("|")
else:
name = text[:128]
if (transcription):
name += "-TTS-%s" % transcription
waveform = text_to_audio(
audioldm2,
text,
transcription=transcription, # To avoid the model to ignore the last vocab
seed=random_seed,
duration=duration,
guidance_scale=guidance_scale,
ddim_steps=args.ddim_steps,
n_candidate_gen_per_text=n_candidate_gen_per_text,
batchsize=args.batchsize,
)
save_wave(waveform, save_path, name=name, samplerate=sample_rate)