Skip to content

Commit

Permalink
Add external csv evaluation script
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson authored and reuben committed Apr 5, 2022
1 parent 404d1d6 commit 77c7352
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 1 deletion.
114 changes: 114 additions & 0 deletions training/coqui_stt_training/evaluate_from_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import argparse

import pandas as pd
from tqdm import tqdm

from clearml import Task
from coqui_stt_training.util.evaluate_tools import calculate_and_print_report


def evaluate_from_csv(
transcriptions_csv,
ground_truth_csv,
audio_path_trans_column="wav_filename",
audio_path_gt_column="wav_filename",
text_trans_column="transcript",
transcription_gt_column="transcript",
):
# load csvs
df_gt = pd.read_csv(ground_truth_csv, sep=",")
df_transcriptions = pd.read_csv(transcriptions_csv, sep=",")

# guarantee that text column is different
df_gt.rename(columns={transcription_gt_column: "transcription_gt"}, inplace=True)
# guarantee that audio column is equal
df_gt.rename(columns={audio_path_gt_column: audio_path_trans_column}, inplace=True)

# the model batch can generate duplicates lines so, dropout all duplicates
df_gt.drop_duplicates(audio_path_trans_column, inplace=True)
df_transcriptions.drop_duplicates(audio_path_trans_column, inplace=True)

# sort to guarantee the same order
df_gt = df_gt.sort_values(by=[audio_path_trans_column])
df_transcriptions = df_transcriptions.sort_values(by=[audio_path_trans_column])

# check if have all files in df_transcriptions
if len(df_transcriptions.values.tolist()) != len(df_gt.values.tolist()):
return "ERROR: The following audios are missing in your CSV file: " + str(
set(df_gt[audio_path_trans_column].values.tolist())
- set(df_transcriptions[audio_path_trans_column].values.tolist())
)

# dropall except the audio and text key for transcription df
df_transcriptions = df_transcriptions.filter(
[audio_path_trans_column, text_trans_column]
)

# merge dataframes
df_merged = pd.merge(df_gt, df_transcriptions, on=audio_path_trans_column)

wav_filenames = []
ground_truths = []
predictions = []
losses = []
for index, line in tqdm(df_merged.iterrows()):
# if pred text is None replace for nothing
if pd.isna(line[text_trans_column]):
line[text_trans_column] = ""
# if GT text is None just ignore the sample
if pd.isna(line["transcription_gt"]):
continue

prediction = line[text_trans_column]
ground_truth = line["transcription_gt"]
wav_filename = line[audio_path_trans_column]

wav_filenames.append(wav_filename)
ground_truths.append(ground_truth)
predictions.append(prediction)
losses.append(0.0)

# Print test summary
samples = calculate_and_print_report(
wav_filenames, ground_truths, predictions, losses, transcriptions_csv
)

return samples


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Evaluation report using transcription CSV"
)
parser.add_argument(
"--transcriptions_csv",
type=str,
help="Path to the CSV transcriptions file.",
)
parser.add_argument(
"--ground_truth_csv",
type=str,
help="Path to the CSV source file.",
)
parser.add_argument(
"--clearml_task",
type=str,
help="The Experiment name (Task Name) for the ClearML.",
)

parser.add_argument(
"--clearml_project",
type=str,
default="STT-Evaluation",
help="Project Name for the ClearML. Default: STT-Evaluation",
)

args = parser.parse_args()

# init ClearML
run = Task.init(project_name=args.clearml_project, task_name=args.clearml_task)

evaluate_from_csv(args.transcriptions_csv, args.ground_truth_csv)
5 changes: 4 additions & 1 deletion training/coqui_stt_training/util/evaluate_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
from .text import levenshtein
from tempfile import gettempdir
import tensorflow.compat.v1 as tfv1

tfv1.logging.set_verbosity(tfv1.logging.ERROR)


def tfv1_plot_scalars(scalar_dict):
with tfv1.summary.FileWriter(os.path.join(gettempdir(), "logs")) as writer:
with tfv1.Graph().as_default(), tfv1.Session(config=tfv1.ConfigProto(log_device_placement=False)) as sess:
with tfv1.Graph().as_default(), tfv1.Session(
config=tfv1.ConfigProto(log_device_placement=False)
) as sess:
for key in scalar_dict.keys():
summary = tfv1.summary.scalar(name=key, tensor=scalar_dict[key])
writer.add_summary(sess.run(summary))
Expand Down

0 comments on commit 77c7352

Please sign in to comment.