Skip to content

Commit

Permalink
Merge pull request huggingface#138 from eustlb/fix-concatenate-datasets
Browse files Browse the repository at this point in the history
[pseudo-labelling] fix concatenate datasets
  • Loading branch information
eustlb authored Jun 12, 2024
2 parents e7138f2 + 270a146 commit a5ed489
Showing 1 changed file with 50 additions and 54 deletions.
104 changes: 50 additions & 54 deletions training/run_pseudo_labelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from huggingface_hub import HfFolder, create_repo, get_full_repo_name, snapshot_download, upload_folder
from torch.utils.data import DataLoader
from tqdm import tqdm
from soundfile import LibsndfileError
from datasets.arrow_dataset import table_iter
from transformers import (
HfArgumentParser,
Seq2SeqTrainingArguments,
Expand Down Expand Up @@ -628,51 +630,44 @@ def main():
raw_datasets = raw_datasets.sort(speaker_id_column_name)

def concatenate_dataset(batch):
audio = [sample["array"] for sample in batch[audio_column_name]]
input_lengths = [len(sample) for sample in audio]

text = batch[text_column_name]
speaker_id = batch[speaker_id_column_name] if speaker_id_column_name else len(text) * [None]

concatenated_audio = []
concatenated_text = []
concatenated_speaker = []
condition_on_prev = []
audio_sample = audio[0]
text_sample = text[0]

for idx in range(1, len(audio)):
prev_speaker = speaker_id[idx - 1]
speaker = speaker_id[idx]

if len(audio_sample) + input_lengths[idx] < max_input_length:
if speaker == prev_speaker:
# we have no information about whether the segments follow on sequentially
# so we just ensure the same speaker as we concatenate across files
audio_sample = np.append(audio_sample, audio[idx])
# extra spaces in the text transcription don't matter, since we only use it for the WER computation
text_sample += " " + text[idx]
else:
# speakers do not follow sequentially, save the audio and start looping again
concatenated_audio.append(audio_sample)
concatenated_text.append(text_sample)
concatenated_speaker.append(speaker)
condition_on_prev.append(0)
audio_sample = audio[idx]
text_sample = text[idx]

audio_arrays, texts, speaker_ids = [], [], []

# skip corrupted samples
for row in table_iter(batch.pa_table, batch_size=1):
row = batch.formatter.format_row(row)
try:
sample_audio = row[audio_column_name]['array']
sample_text = row[text_column_name]
sample_speaker_id = row[speaker_id_column_name] if speaker_id_column_name else None
except LibsndfileError:
logger.warning(f"{row[id_column_name]} is corrupted! Skipping sample.")
continue
audio_arrays.append(sample_audio)
texts.append(sample_text)
speaker_ids.append(sample_speaker_id)

# initialize concatenations
concat_audio = [audio_arrays[0]]
concat_text = [texts[0]]
concat_speaker_id = [speaker_ids[0]]
condition_on_prev = [0]

for audio_array, text, speaker_id in zip(audio_arrays[1:], texts[1:], speaker_ids[1:]):
is_same_speaker = speaker_id == concat_speaker_id[-1]
is_concatenable = len(audio_array) + len(concat_audio[-1]) <= max_input_length
if is_same_speaker and is_concatenable:
# inplace concatenation
concat_audio[-1] = np.append(concat_audio[-1], audio_array)
concat_text[-1] = concat_text[-1] + " " + text
else:
# concatenated audio exceeds max length, save the audio and start looping again
concatenated_audio.append(audio_sample)
concatenated_text.append(text_sample)
concatenated_speaker.append(speaker)
condition_on_prev.append(1)
audio_sample = audio[idx]
text_sample = text[idx]

batch[audio_column_name] = [{"array": array, "sampling_rate": sampling_rate} for array in concatenated_audio]
batch[text_column_name] = concatenated_text
batch[id_column_name] = concatenated_speaker
concat_audio.append(audio_array)
concat_text.append(text)
concat_speaker_id.append(speaker_id)
condition_on_prev.append(1 if is_same_speaker else 0)

batch[audio_column_name] = [{"array": array, "sampling_rate": sampling_rate} for array in concat_audio]
batch[text_column_name] = concat_text
batch[id_column_name] = concat_speaker_id
batch["condition_on_prev"] = condition_on_prev

return batch
Expand Down Expand Up @@ -987,16 +982,17 @@ def add_concatenated_text(eval_preds, condition_on_prev):
concatenated_prev.append(prompt_ids)
return {"condition_on_prev": concatenated_prev}

with accelerator.main_process_first():
raw_datasets[split] = raw_datasets[split].map(
add_concatenated_text,
input_columns=["eval_preds", "condition_on_prev"],
remove_columns=["eval_preds"],
desc="Setting condition on prev...",
batched=True,
batch_size=preprocessing_batch_size,
num_proc=num_workers,
)
if data_args.concatenate_audio:
with accelerator.main_process_first():
raw_datasets[split] = raw_datasets[split].map(
add_concatenated_text,
input_columns=["eval_preds", "condition_on_prev"],
remove_columns=["eval_preds"],
desc="Setting condition on prev...",
batched=True,
batch_size=preprocessing_batch_size,
num_proc=num_workers,
)

logger.info("***** Running Labelling *****")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
Expand Down

0 comments on commit a5ed489

Please sign in to comment.