Skip to content

Commit

Permalink
resolve conflict for initializing whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
jnwnlee committed Feb 13, 2024
2 parents 941c73d + 9759272 commit 86885a8
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 209 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,7 @@ fadtk/test/fad_scores
fadtk/test/samples/embeddings
fadtk/test/samples/convert
fadtk/test/samples/stats
fadtk/test/comparison.csv
fadtk/test/comparison.csv

.DS_Store
._*
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ The code in this toolkit is licensed under the [MIT License](./LICENSE). Please
@inproceedings{fadtk,
title = {Adapting Frechet Audio Distance for Generative Music Evaluation},
author = {Azalea Gui, Hannes Gamper, Sebastian Braun, Dimitra Emmanouilidou},
booktitle = {Submitted to IEEE ICASSP 2024},
booktitle = {Proc. IEEE ICASSP 2024},
year = {2024},
url = {https://arxiv.org/abs/2311.01616},
}
Expand Down
28 changes: 23 additions & 5 deletions fadtk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from argparse import ArgumentParser

from .fad import FrechetAudioDistance, log
Expand All @@ -16,12 +17,17 @@ def main():
agupa.add_argument('model', type=str, choices=list(models.keys()), help="The embedding model to use")
agupa.add_argument('baseline', type=str, help="The baseline dataset")
agupa.add_argument('eval', type=str, help="The directory to evaluate against")
agupa.add_argument('csv', type=str, nargs='?',
help="The CSV file to append results to. "
"If this argument is not supplied, single-value results will be printed to stdout, "
"and for --indiv, the results will be saved to 'fad-individual-results.csv'")

# Add optional arguments
agupa.add_argument('-w', '--workers', type=int, default=8)
agupa.add_argument('-s', '--sox-path', type=str, default='/usr/bin/sox')
agupa.add_argument('--inf', action='store_true', help="Use FAD-inf extrapolation")
agupa.add_argument('--indiv', type=str, help="Calculate FAD for individual songs and store the results in the given file")
agupa.add_argument('--indiv', action='store_true',
help="Calculate FAD for individual songs and store the results in the given file")

args = agupa.parse_args()
model = models[args.model]
Expand All @@ -33,24 +39,36 @@ def main():
for d in [baseline, eval]:
if Path(d).is_dir():
cache_embedding_files(d, model, workers=args.workers)

# 2. Calculate FAD
fad = FrechetAudioDistance(model, audio_load_worker=args.workers, load_model=False)
if args.inf:
assert Path(eval).is_dir(), "FAD-inf requires a directory as the evaluation dataset"
score = fad.score_inf(baseline, list(Path(eval).glob('*.*')))
print("FAD-inf Information:", score)
score, inf_r2 = score.score, score.r2
elif args.indiv:
assert Path(eval).is_dir(), "Individual FAD requires a directory as the evaluation dataset"
fad.score_individual(baseline, eval, Path(args.indiv))
log.info(f"Individual FAD scores saved to {args.indiv}")
csv_path = Path(args.csv or 'fad-individual-results.csv')
fad.score_individual(baseline, eval, csv_path)
log.info(f"Individual FAD scores saved to {csv_path}")
exit(0)
else:
score = fad.score(baseline, eval)
inf_r2 = None

# 3. Print results
log.info("FAD computed.")
if args.csv:
Path(args.csv).parent.mkdir(parents=True, exist_ok=True)
if not Path(args.csv).is_file():
Path(args.csv).write_text('model,baseline,eval,score,inf_r2,time\n')
with open(args.csv, 'a') as f:
f.write(f'{model.name},{baseline},{eval},{score},{inf_r2},{time.time()}\n')
log.info(f"FAD score appended to {args.csv}")

log.info(f"The FAD {model.name} score between {baseline} and {eval} is: {score}")


if __name__ == "__main__":
main()
main()
4 changes: 2 additions & 2 deletions fadtk/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def load_wav(self, wav_file: Path):
import torchaudio
from encodec.utils import convert_audio

wav, sr = torchaudio.load(wav_file)
wav, sr = torchaudio.load(str(wav_file))
wav = convert_audio(wav, sr, self.sr, self.model.channels)

# If it's longer than 3 minutes, cut it
Expand Down Expand Up @@ -692,7 +692,7 @@ def __init__(self, size: Literal['tiny', 'base', 'small', 'medium', 'large'], au
model_identifier = f"whisper-{size}"

super().__init__(model_identifier, model_dim, 16000, audio_len=audio_len)
self.huggingface_id = f"openai/whisper-large"
self.huggingface_id = f"openai/whisper-{size}"

def load_model(self):
from transformers import AutoFeatureExtractor
Expand Down
12 changes: 12 additions & 0 deletions fadtk/test/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,20 @@
fp = Path(__file__).parent
reference = pd.read_csv(fp / 'samples_FAD_scores.csv')

# Get reference models in column names
reference_models = [c.split('_', 1)[1].replace('_fma_pop', '') for c in reference.columns if c.startswith('FAD_')]
print("Models with reference data:", reference_models)

# Compute FAD score
for model in get_all_models():
if model.name.replace('-', '_') not in reference_models:
print(f'No reference data for {model.name}, skipping')
continue

# Because of the heavy computation required to run each test, we limit the MERT models to only a few layers
if model.name.startswith('MERT') and model.name[-1] not in ['1', '4', '8', 'M']:
continue

log.info(f'Computing FAD score for {model.name}')
csv = fp / 'fad_scores' / f'{model.name}.csv'
if csv.is_file():
Expand Down
Loading

0 comments on commit 86885a8

Please sign in to comment.