Skip to content

Commit

Permalink
replace None timestamp in WhipserHF.transcribe()
Browse files Browse the repository at this point in the history
-added logic to replace `None` timestamps returned by Hugging Face Whisper models with adjacent timestamps to prevent cases such as jianfch#306 (comment)
  • Loading branch information
jianfch committed Feb 5, 2024
1 parent 3fafd04 commit 8bbe0c5
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions stable_whisper/whisper_word_level/hf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,64 @@ def _inner_transcribe(
)['chunks']
if verbose is not None:
print(f'Transcription completed.')

def replace_none_ts(parts):
total_dur = round(audio.shape[-1] / self.sampling_rate, 3) if isinstance(audio, np.ndarray) else None
_medium_dur = _ts_nonzero_mask = None

def ts_nonzero_mask() -> np.ndarray:
nonlocal _ts_nonzero_mask
if _ts_nonzero_mask is None:
_ts_nonzero_mask = np.array([(p['end'] or p['start']) is not None for p in parts])
return _ts_nonzero_mask

def medium_dur() -> float:
nonlocal _medium_dur
if _medium_dur is None:
nonzero_dus = [p['end'] - p['start'] for p in parts if None not in (p['end'], p['start'])]
nonzero_durs = np.array(nonzero_dus)
_medium_dur = np.median(nonzero_durs) * 2 if len(nonzero_durs) else 2.0
return _medium_dur

def _curr_max_end(start: float, next_idx: float) -> float:
max_end = total_dur
if next_idx != len(parts):
mask = np.flatnonzero(ts_nonzero_mask()[next_idx:])
if len(mask):
_part = parts[mask[0]+next_idx]
max_end = _part['start'] or _part['end']

new_end = round(start + medium_dur(), 3)
if max_end is None:
return new_end
if new_end > max_end:
return max_end
return new_end

for i, part in enumerate(parts, 1):
if part['start'] is None:
is_first = i == 1
if is_first:
new_start = round((part['end'] or 0) - medium_dur(), 3)
part['start'] = max(new_start, 0.0)
else:
part['start'] = parts[i - 2]['end']
if part['end'] is None:
no_next_start = i == len(parts) or parts[i]['start'] is None
part['end'] = _curr_max_end(part['start'], i) if no_next_start else parts[i]['start']

if word_timestamps:
words = [
dict(start=word['timestamp'][0], end=word['timestamp'][1], word=word['text'])
for word in result
]
replace_none_ts(words)
return [words]
segs = [
dict(start=seg['timestamp'][0], end=seg['timestamp'][1], text=seg['text'])
for seg in result
]
replace_none_ts(segs)
return segs

def transcribe(
Expand Down

0 comments on commit 8bbe0c5

Please sign in to comment.