Skip to content

Commit

Permalink
Add stable timings (chidiwilliams#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored Nov 12, 2022
1 parent a1b9097 commit 392f9cb
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 86 deletions.
11 changes: 11 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[run]
omit =
whisper_cpp.py
*_test.py
stable_ts/*

[html]
directory = coverage/html

[report]
fail_under = 75
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __pycache__/
build/
.pytest_cache/
.coverage*
!.coveragerc
.env
htmlcov/
libwhisper.*
Expand Down
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[submodule "whisper.cpp"]
path = whisper.cpp
url = https://github.com/chidiwilliams/whisper.cpp
[submodule "stable_ts"]
path = stable_ts
url = https://github.com/chidiwilliams/stable-ts
branch = main
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ clean:
rm -rf dist/* || true

test: whisper_cpp.py
pytest --cov --cov-fail-under=69 --cov-report html
pytest --cov

dist/Buzz: whisper_cpp.py
pyinstaller --noconfirm Buzz.spec
Expand Down
33 changes: 24 additions & 9 deletions gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
QUrl, pyqtSignal)
from PyQt6.QtGui import (QAction, QCloseEvent, QDesktopServices, QIcon,
QKeySequence, QPixmap, QTextCursor)
from PyQt6.QtWidgets import (QApplication, QComboBox, QDialog, QFileDialog,
QGridLayout, QLabel, QMainWindow, QMessageBox,
QPlainTextEdit, QProgressDialog, QPushButton,
QVBoxLayout, QWidget)
from PyQt6.QtWidgets import (QApplication, QCheckBox, QComboBox, QDialog,
QFileDialog, QGridLayout, QLabel, QMainWindow,
QMessageBox, QPlainTextEdit, QProgressDialog,
QPushButton, QVBoxLayout, QWidget)
from requests import get
from whisper import tokenizer

Expand Down Expand Up @@ -270,16 +270,18 @@ class FileTranscriberObject(QObject):
transcriber: FileTranscriber

def __init__(
self, model_name: str, use_whisper_cpp: bool, language: Optional[str],
task: Task, file_path: str, output_file_path: str,
output_format: OutputFormat, parent: Optional['QObject'], *args) -> None:
self, model_name: str, use_whisper_cpp: bool, language: Optional[str],
task: Task, file_path: str, output_file_path: str,
output_format: OutputFormat, word_level_timings: bool,
parent: Optional['QObject'], *args) -> None:
super().__init__(parent, *args)
self.transcriber = FileTranscriber(
model_name=model_name, use_whisper_cpp=use_whisper_cpp,
on_download_model_chunk=self.on_download_model_progress,
language=language, task=task, file_path=file_path,
output_file_path=output_file_path, output_format=output_format,
event_callback=self.on_file_transcriber_event)
event_callback=self.on_file_transcriber_event,
word_level_timings=word_level_timings)

def on_download_model_progress(self, current: int, total: int):
self.download_model_progress.emit((current, total))
Expand Down Expand Up @@ -380,6 +382,7 @@ class FileTranscriberWidget(QWidget):
selected_language: Optional[str] = None
selected_task = Task.TRANSCRIBE
selected_output_format = OutputFormat.TXT
enabled_word_level_timings = False
model_download_progress_dialog: Optional[DownloadModelProgressDialog] = None
transcriber_progress_dialog: Optional[TranscriberProgressDialog] = None
file_transcriber: Optional[FileTranscriberObject] = None
Expand Down Expand Up @@ -418,6 +421,11 @@ def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
output_formats_combo_box.output_format_changed.connect(
self.on_output_format_changed)

self.word_level_timings_checkbox = QCheckBox('Word-level timings')
self.word_level_timings_checkbox.stateChanged.connect(
self.on_word_level_timings_changed)
self.word_level_timings_checkbox.setDisabled(True)

grid = (
((0, 5, FormLabel('Task:', parent=self)), (5, 7, self.tasks_combo_box)),
((0, 5, FormLabel('Language:', parent=self)),
Expand All @@ -426,6 +434,7 @@ def __init__(self, file_path: str, parent: Optional[QWidget]) -> None:
(5, 7, self.quality_combo_box)),
((0, 5, FormLabel('Export As:', self)),
(5, 7, output_formats_combo_box)),
((5, 7, self.word_level_timings_checkbox),),
((9, 3, self.run_button),)
)

Expand All @@ -447,6 +456,8 @@ def on_task_changed(self, task: Task):

def on_output_format_changed(self, output_format: OutputFormat):
self.selected_output_format = output_format
self.word_level_timings_checkbox.setDisabled(
output_format == OutputFormat.TXT)

def on_click_run(self):
default_path = FileTranscriber.get_default_output_file_path(
Expand All @@ -469,6 +480,7 @@ def on_click_run(self):
file_path=self.file_path,
language=self.selected_language, task=self.selected_task,
output_file_path=output_file, output_format=self.selected_output_format,
word_level_timings=self.enabled_word_level_timings,
parent=self)
self.file_transcriber.download_model_progress.connect(
self.on_download_model_progress)
Expand Down Expand Up @@ -530,6 +542,9 @@ def reset_model_download(self):
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog = None

def on_word_level_timings_changed(self, value: int):
self.enabled_word_level_timings = value == Qt.CheckState.Checked.value


class Settings(QSettings):
ENABLE_GGML_INFERENCE = 'enable_ggml_inference'
Expand Down Expand Up @@ -829,7 +844,7 @@ class FileTranscriberMainWindow(MainWindow):

def __init__(self, file_path: str, parent: Optional[QWidget], *args) -> None:
super().__init__(title=get_short_file_path(
file_path), w=400, h=210, parent=parent, *args)
file_path), w=400, h=240, parent=parent, *args)

self.central_widget = FileTranscriberWidget(file_path, self)
self.central_widget.setContentsMargins(10, 10, 10, 10)
Expand Down
18 changes: 18 additions & 0 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.9.13,<3.11"
sounddevice = "^0.4.5"
whisper = {git = "https://github.com/openai/whisper.git"}
whisper = { git = "https://github.com/openai/whisper.git" }
torch = "1.12.1"
numpy = "^1.23.3"
transformers = "^4.22.1"
Expand Down
1 change: 1 addition & 0 deletions stable_ts
Submodule stable_ts added at 5c1966
28 changes: 21 additions & 7 deletions transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sounddevice import PortAudioError

from conn import pipe_stderr, pipe_stdout
from stable_ts.stable_whisper import group_word_timestamps, modify_model
from whispr import (ModelLoader, Segment, Stopped, Task, WhisperCpp,
read_progress, whisper_cpp_params)

Expand Down Expand Up @@ -108,7 +109,7 @@ def process_queue(self):
audio=samples,
params=whisper_cpp_params(
language=self.language if self.language is not None else 'en',
task=self.task.value))
task=self.task.value, word_level_timings=False))

next_text: str = result.get('text')

Expand Down Expand Up @@ -236,6 +237,7 @@ def __init__(
model_name: str, use_whisper_cpp: bool,
language: Optional[str], task: Task, file_path: str,
output_file_path: str, output_format: OutputFormat,
word_level_timings: bool,
event_callback: Callable[[Event], None] = lambda *_: None,
on_download_model_chunk: Callable[[
int, int], None] = lambda *_: None,
Expand All @@ -246,6 +248,7 @@ def __init__(
self.task = task
self.open_file_on_complete = open_file_on_complete
self.output_format = output_format
self.word_level_timings = word_level_timings

self.model_name = model_name
self.use_whisper_cpp = use_whisper_cpp
Expand Down Expand Up @@ -290,6 +293,7 @@ def transcribe(self):
self.output_format,
self.language if self.language is not None else 'en',
self.task, True, True,
self.word_level_timings
))
else:
self.current_process = multiprocessing.Process(
Expand All @@ -298,6 +302,7 @@ def transcribe(self):
send_pipe, model_path, self.file_path,
self.language, self.task, self.output_file_path,
self.open_file_on_complete, self.output_format,
self.word_level_timings
))

self.current_process.start()
Expand Down Expand Up @@ -352,18 +357,26 @@ def get_default_output_file_path(cls, task: Task, input_file_path: str, output_f
def transcribe_whisper(
stderr_conn: Connection, model_path: str, file_path: str,
language: Optional[str], task: Task, output_file_path: str,
open_file_on_complete: bool, output_format: OutputFormat):
open_file_on_complete: bool, output_format: OutputFormat,
word_level_timings: bool):
with pipe_stderr(stderr_conn):
model = whisper.load_model(model_path)
result = whisper.transcribe(
model=model, audio=file_path, language=language, task=task.value, verbose=False)

if word_level_timings:
modify_model(model)

result = model.transcribe(
audio=file_path, language=language, task=task.value, verbose=False)

whisper_segments = group_word_timestamps(
result) if word_level_timings else result.get('segments')

segments = map(
lambda segment: Segment(
start=segment.get('start')*1000, # s to ms
end=segment.get('end')*1000, # s to ms
text=segment.get('text')),
result.get('segments'))
whisper_segments)

write_output(output_file_path, list(
segments), open_file_on_complete, output_format)
Expand All @@ -372,13 +385,14 @@ def transcribe_whisper(
def transcribe_whisper_cpp(
stderr_conn: Connection, model_path: str, audio: typing.Union[np.ndarray, str],
output_file_path: str, open_file_on_complete: bool, output_format: OutputFormat,
language: str, task: Task, print_realtime: bool, print_progress: bool):
language: str, task: Task, print_realtime: bool, print_progress: bool,
word_level_timings: bool):
# TODO: capturing output does not work because ctypes functions
# See: https://stackoverflow.com/questions/9488560/capturing-print-output-from-shared-library-called-from-python-with-ctypes-module
with pipe_stdout(stderr_conn), pipe_stderr(stderr_conn):
model = WhisperCpp(model_path)
params = whisper_cpp_params(
language, task, print_realtime, print_progress)
language, task, word_level_timings, print_realtime, print_progress)
result = model.transcribe(audio=audio, params=params)
segments: List[Segment] = result.get('segments')
write_output(
Expand Down
Loading

0 comments on commit 392f9cb

Please sign in to comment.