Extend whisper_admin.py with audio processing and transcription features

This commit is contained in:
2026-04-16 22:06:57 +02:00
parent d5cdb85629
commit 86936dcb96

View File

@@ -3,6 +3,9 @@ from __future__ import annotations
import importlib
import importlib.util
import os
import shutil
import subprocess
import tempfile
import threading
from pathlib import Path
from typing import Any, Dict, Optional
@@ -10,6 +13,20 @@ from typing import Any, Dict, Optional
DEFAULT_WHISPER_MODEL = "base"
_WHISPER_DOWNLOAD_LOCK = threading.Lock()
_WHISPER_MODEL_LOCK = threading.Lock()
_WHISPER_MODELS: Dict[tuple[str, str], Any] = {}
_AUDIO_SUFFIXES = {
"audio/aac": ".aac",
"audio/flac": ".flac",
"audio/mp4": ".m4a",
"audio/mpeg": ".mp3",
"audio/mpga": ".mp3",
"audio/ogg": ".ogg",
"audio/wav": ".wav",
"audio/wave": ".wav",
"audio/webm": ".webm",
"audio/x-wav": ".wav",
}
def _default_download_root() -> Path:
@@ -24,6 +41,13 @@ def _load_whisper_module():
return None
def _load_torch_module():
try:
return importlib.import_module("torch")
except Exception:
return None
def whisper_runtime_error() -> Optional[str]:
if importlib.util.find_spec("whisper") is None:
return (
@@ -123,3 +147,126 @@ def ensure_whisper_model_downloaded(model_name: str = DEFAULT_WHISPER_MODEL) ->
"path": str(target),
"error": "",
}
def _resolve_whisper_device() -> str:
try:
torch_mod = _load_torch_module()
if torch_mod is not None and getattr(torch_mod.cuda, "is_available", lambda: False)():
return "cuda"
backends = getattr(torch_mod, "backends", None)
mps_backend = getattr(backends, "mps", None)
if mps_backend is not None and getattr(mps_backend, "is_available", lambda: False)():
return "mps"
except Exception:
pass
return "cpu"
def _resolve_ffmpeg_binary() -> Optional[str]:
candidate = os.getenv("HEIMGEIST_FFMPEG_PATH") or shutil.which("ffmpeg") or "/usr/bin/ffmpeg"
if not candidate:
return None
if Path(candidate).exists() or shutil.which(candidate):
return str(candidate)
return None
def _audio_suffix_for_mime_type(mime_type: str) -> str:
base_mime = str(mime_type or "").split(";", 1)[0].strip().lower()
return _AUDIO_SUFFIXES.get(base_mime, ".webm")
def _load_transcription_model(model_name: str = DEFAULT_WHISPER_MODEL) -> tuple[Any, str]:
error = whisper_runtime_error()
if error:
raise RuntimeError(error)
whisper_mod = _load_whisper_module()
if whisper_mod is None:
raise RuntimeError("Failed to import the Whisper runtime.")
ensure_whisper_model_downloaded(model_name)
device = _resolve_whisper_device()
cache_key = (model_name, device)
with _WHISPER_MODEL_LOCK:
model = _WHISPER_MODELS.get(cache_key)
if model is None:
try:
model = whisper_mod.load_model(model_name, device=device)
except TypeError:
model = whisper_mod.load_model(model_name)
_WHISPER_MODELS[cache_key] = model
return model, device
def _convert_audio_to_wav(input_path: Path, output_path: Path) -> None:
ffmpeg_bin = _resolve_ffmpeg_binary()
if not ffmpeg_bin:
raise RuntimeError(
"Audio transcription requires ffmpeg. Heimgeist could not resolve it from HEIMGEIST_FFMPEG_PATH or PATH."
)
process = subprocess.run(
[
ffmpeg_bin,
"-y",
"-hide_banner",
"-loglevel",
"error",
"-i",
str(input_path),
"-ac",
"1",
"-ar",
"16000",
"-f",
"wav",
str(output_path),
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if process.returncode != 0:
detail = (process.stderr or process.stdout or "ffmpeg audio conversion failed").strip()
raise RuntimeError(detail)
def transcribe_audio_bytes(
audio_bytes: bytes,
mime_type: str,
model_name: str = DEFAULT_WHISPER_MODEL,
) -> Dict[str, Any]:
if not audio_bytes:
raise RuntimeError("Recorded audio was empty.")
model, device = _load_transcription_model(model_name)
input_path = Path(tempfile.mkstemp(suffix=_audio_suffix_for_mime_type(mime_type))[1])
wav_path = Path(tempfile.mkstemp(suffix=".wav")[1])
try:
input_path.write_bytes(audio_bytes)
source_path = input_path
if input_path.suffix.lower() != ".wav":
_convert_audio_to_wav(input_path, wav_path)
source_path = wav_path
result = model.transcribe(str(source_path), task="transcribe", fp16=device == "cuda")
return {
"model": model_name,
"device": device,
"language": str(result.get("language") or "").strip(),
"text": str(result.get("text") or "").strip(),
}
finally:
for path in (input_path, wav_path):
try:
if path.exists():
path.unlink()
except Exception:
pass