diff --git a/backend/whisper_admin.py b/backend/whisper_admin.py index 19413e7..3c60d20 100644 --- a/backend/whisper_admin.py +++ b/backend/whisper_admin.py @@ -3,6 +3,9 @@ from __future__ import annotations import importlib import importlib.util import os +import shutil +import subprocess +import tempfile import threading from pathlib import Path from typing import Any, Dict, Optional @@ -10,6 +13,20 @@ from typing import Any, Dict, Optional DEFAULT_WHISPER_MODEL = "base" _WHISPER_DOWNLOAD_LOCK = threading.Lock() +_WHISPER_MODEL_LOCK = threading.Lock() +_WHISPER_MODELS: Dict[tuple[str, str], Any] = {} +_AUDIO_SUFFIXES = { + "audio/aac": ".aac", + "audio/flac": ".flac", + "audio/mp4": ".m4a", + "audio/mpeg": ".mp3", + "audio/mpga": ".mp3", + "audio/ogg": ".ogg", + "audio/wav": ".wav", + "audio/wave": ".wav", + "audio/webm": ".webm", + "audio/x-wav": ".wav", +} def _default_download_root() -> Path: @@ -24,6 +41,13 @@ def _load_whisper_module(): return None +def _load_torch_module(): + try: + return importlib.import_module("torch") + except Exception: + return None + + def whisper_runtime_error() -> Optional[str]: if importlib.util.find_spec("whisper") is None: return ( @@ -123,3 +147,126 @@ def ensure_whisper_model_downloaded(model_name: str = DEFAULT_WHISPER_MODEL) -> "path": str(target), "error": "", } + + +def _resolve_whisper_device() -> str: + try: + torch_mod = _load_torch_module() + if torch_mod is not None and getattr(torch_mod.cuda, "is_available", lambda: False)(): + return "cuda" + + backends = getattr(torch_mod, "backends", None) + mps_backend = getattr(backends, "mps", None) + if mps_backend is not None and getattr(mps_backend, "is_available", lambda: False)(): + return "mps" + except Exception: + pass + + return "cpu" + + +def _resolve_ffmpeg_binary() -> Optional[str]: + candidate = os.getenv("HEIMGEIST_FFMPEG_PATH") or shutil.which("ffmpeg") or "/usr/bin/ffmpeg" + if not candidate: + return None + if Path(candidate).exists() or shutil.which(candidate): + return str(candidate) + return None + + +def _audio_suffix_for_mime_type(mime_type: str) -> str: + base_mime = str(mime_type or "").split(";", 1)[0].strip().lower() + return _AUDIO_SUFFIXES.get(base_mime, ".webm") + + +def _load_transcription_model(model_name: str = DEFAULT_WHISPER_MODEL) -> tuple[Any, str]: + error = whisper_runtime_error() + if error: + raise RuntimeError(error) + + whisper_mod = _load_whisper_module() + if whisper_mod is None: + raise RuntimeError("Failed to import the Whisper runtime.") + + ensure_whisper_model_downloaded(model_name) + device = _resolve_whisper_device() + cache_key = (model_name, device) + + with _WHISPER_MODEL_LOCK: + model = _WHISPER_MODELS.get(cache_key) + if model is None: + try: + model = whisper_mod.load_model(model_name, device=device) + except TypeError: + model = whisper_mod.load_model(model_name) + _WHISPER_MODELS[cache_key] = model + + return model, device + + +def _convert_audio_to_wav(input_path: Path, output_path: Path) -> None: + ffmpeg_bin = _resolve_ffmpeg_binary() + if not ffmpeg_bin: + raise RuntimeError( + "Audio transcription requires ffmpeg. Heimgeist could not resolve it from HEIMGEIST_FFMPEG_PATH or PATH." + ) + + process = subprocess.run( + [ + ffmpeg_bin, + "-y", + "-hide_banner", + "-loglevel", + "error", + "-i", + str(input_path), + "-ac", + "1", + "-ar", + "16000", + "-f", + "wav", + str(output_path), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + if process.returncode != 0: + detail = (process.stderr or process.stdout or "ffmpeg audio conversion failed").strip() + raise RuntimeError(detail) + + +def transcribe_audio_bytes( + audio_bytes: bytes, + mime_type: str, + model_name: str = DEFAULT_WHISPER_MODEL, +) -> Dict[str, Any]: + if not audio_bytes: + raise RuntimeError("Recorded audio was empty.") + + model, device = _load_transcription_model(model_name) + input_path = Path(tempfile.mkstemp(suffix=_audio_suffix_for_mime_type(mime_type))[1]) + wav_path = Path(tempfile.mkstemp(suffix=".wav")[1]) + + try: + input_path.write_bytes(audio_bytes) + source_path = input_path + if input_path.suffix.lower() != ".wav": + _convert_audio_to_wav(input_path, wav_path) + source_path = wav_path + + result = model.transcribe(str(source_path), task="transcribe", fp16=device == "cuda") + return { + "model": model_name, + "device": device, + "language": str(result.get("language") or "").strip(), + "text": str(result.get("text") or "").strip(), + } + finally: + for path in (input_path, wav_path): + try: + if path.exists(): + path.unlink() + except Exception: + pass