commit 44e65299e5d35ce3293701dfd561fdcad84de773 Author: Victor Giers Date: Thu Aug 14 09:18:01 2025 +0200 initial commit diff --git a/icon.png b/icon.png new file mode 100644 index 0000000..3361ec7 Binary files /dev/null and b/icon.png differ diff --git a/srt-translate.py b/srt-translate.py new file mode 100755 index 0000000..205b4f0 --- /dev/null +++ b/srt-translate.py @@ -0,0 +1,2163 @@ +#!/usr/bin/env python3 +""" +SRT Translator — NLLB, SeamlessM4T, or Ollama +GUI + CLI, Multithreaded, Cancel, Context-aware (Line/Cue/Smart) + +New in this build: +- FIX: Leading commas sometimes survived. Now: + * _fix_leading_punct() recognizes Unicode commas (,、) and runs until stable. + * _remove_marker_and_tagn_tokens() strips any remaining leading commas per line as a last-resort. + * NBSP / narrow NBSP / ideographic spaces normalized to regular spaces. +- FIX: Ollama 'groups' robustness. + * If the model doesn't return a valid "groups" array, we transparently fall back to per-group "translate_cues" + requests and coerce counts to the expected shape. + * No more RuntimeError spam; threads finish gracefully. +- FIX: Marker remover accidentally included TAG (would nuke 'STAGE' etc). Now only BR|CUE; Tag removal stays via TAG0..TAG9. +- (from earlier) Restore tags bug fixed; per-line cleanup removes BR/CUE tokens and TAG0..TAG9; spaces collapsed; punctuation spacing normalized. +- NEW: Ollama untranslated-line Repair-Pass: + * Language-name prompts (e.g., "English" statt "eng_Latn") to avoid false "already in target language". + * Automatic per-line detection of unchanged output + strict re-try with temperature schedule and optional model fallback. +- NEW: Cancel reliability + * Entferntes Pause-Feature (UI & Logik). + * Kooperative Queue-Abfrage mit kurzen Timeouts und sofortigem Drain bei Cancel. + * Mikrobatching in allen Workern (Transformers & Ollama), Event-Checks zwischen Mikrobatches. +""" + +import argparse +import os +import re +import sys +import time +import json +import threading +import queue +import http.client +import subprocess +import html as _html +from typing import List, Dict, Tuple, Optional + +# ---- Optional CLI progress (tqdm) ---- +try: + from tqdm.auto import tqdm +except Exception: + tqdm = None + +# ---- Optional auto-detect ---- +try: + from langdetect import detect as _ld_detect +except Exception: + _ld_detect = None + +CONFIG_PATH = os.path.join(os.path.expanduser("~"), ".srt_translator_config.json") +SENTINEL_BR = "⟦BR⟧" # within cue: between original lines +SENTINEL_CUE = "⟦CUE⟧" # within group: between original cues +SENTINEL_TAG = "⟦TAG⟧" + +# ===================== Language code mapping ===================== +LANG2CODE = { + "english": "eng_Latn", "german": "deu_Latn", "french": "fra_Latn", "spanish": "spa_Latn", + "portuguese": "por_Latn", "italian": "ita_Latn", "dutch": "nld_Latn", "russian": "rus_Cyrl", + "japanese": "jpn_Jpan", "korean": "kor_Hang", "chinese": "zho_Hans", "chinese (simplified)": "zho_Hans", + "chinese (traditional)": "zho_Hant", "arabic": "arb_Arab", "hebrew": "heb_Hebr", "turkish": "tur_Latn", + "polish": "pol_Latn", "czech": "ces_Latn", "greek": "ell_Grek", "swedish": "swe_Latn", + "danish": "dan_Latn", "norwegian": "nob_Latn", "finnish": "fin_Latn", "romanian": "ron_Latn", + "hungarian": "hun_Latn", "ukrainian": "ukr_Cyrl", "bulgarian": "bul_Cyrl", "serbian": "srp_Latn", + "croatian": "hrv_Latn", "slovak": "slk_Latn", "slovenian": "slv_Latn", "lithuanian": "lit_Latn", + "estonian": "est_Latn", "vietnamese": "vie_Latn", "thai": "tha_Thai", "hindi": "hin_Deva", + "bengali": "ben_Beng", "indonesian": "ind_Latn", "malay": "zsm_Latn", "filipino": "tgl_Latn", + # Aliases + "jp": "jpn_Jpan", "de": "deu_Latn", "en": "eng_Latn", "fr": "fra_Latn", "es": "spa_Latn", + "pt": "por_Latn", "ru": "rus_Cyrl", "zh": "zho_Hans", "ko": "kor_Hang", "ar": "arb_Arab", +} + +# ===================== SRT parsing ===================== +CUE_RE = re.compile( + r"(?P\d+)\r?\n" + r"(?P\d{2}:\d{2}:\d{2},\d{3}\s-->\s\d{2}:\d{2}:\d{2},\d{3})(?:.*)?\r?\n" + r"(?P(?:.+(?:\r?\n)?)+?)" + r"\r?\n(?=\d+\r?\n|\Z)", + re.UNICODE +) + +def parse_srt(srt_text: str) -> List[Dict[str, str]]: + cues = [] + for m in CUE_RE.finditer(srt_text): + text = m.group("text") + cues.append({"num": m.group("num"), "ts": m.group("ts"), "text": text.rstrip("\n")}) + return cues + +def cues_to_srt(cues: List[Dict[str, str]]) -> str: + parts = [] + for c in cues: + parts.append(f"{c['num']}\n{c['ts']}\n{c['text']}\n") + return "\n".join(parts).strip() + "\n" + +# ===================== Tag protection ===================== +TAG_RE = re.compile(r"\n\r]+?>") # , etc. +ASS_TAG_RE = re.compile(r"\{\\[^}]*\}") # {\an8}, etc. + +def protect_tags(text: str) -> Tuple[str, Dict[str, str]]: + """ + Replace inline tags with placeholders so the translator doesn't mutate them. + We DO NOT mask SENTINEL_BR/CUE so we can split later. + """ + placeholders: Dict[str, str] = {} + idx = 0 + def _sub(reobj, s): + nonlocal idx + def repl(m): + nonlocal idx + token = f"⟦TAG{idx}⟧" + placeholders[token] = m.group(0) + idx += 1 + return token + return reobj.sub(repl, s) + masked = _sub(TAG_RE, text) + masked = _sub(ASS_TAG_RE, masked) + return masked, placeholders + +_TAGNUM_RE = re.compile(r"TAG(\d+)") + +def restore_tags(text: str, placeholders: Dict[str, str]) -> str: + """ + Robustly restore placeholders: + - ⟦TAG3⟧ (ideal) + - TAG3 (if model stripped brackets) + - ⟦ TAG3 ⟧ (extra spaces) + - quoted variants like "TAG3" + """ + out = text + for token, original in placeholders.items(): + m = _TAGNUM_RE.search(token) + if not m: + continue + k = m.group(1) + variants = [ + token, + f"TAG{k}", + f"⟦TAG{k}⟧", + f"⟦ TAG{k} ⟧", + f"'TAG{k}'", f"\"TAG{k}\"", + f"‘TAG{k}’", f"“TAG{k}”", + ] + for v in variants: + out = out.replace(v, original) + pattern = rf"(?:⟦\s*)?TAG{k}(?:\s*⟧)?" + out = re.sub(pattern, lambda _m, _orig=original: _orig, out) + return out + +# ===================== Language utils ===================== +def norm_lang_to_code(name_or_code: str, default: Optional[str] = None) -> Optional[str]: + if name_or_code is None: + return default + s = name_or_code.strip() + if re.match(r"^[a-z]{3}_[A-Za-z]{4}$", s): + return s + return LANG2CODE.get(s.lower(), default) + +def autodetect_code(text_sample: str) -> str: + if not _ld_detect: + return "eng_Latn" + try: + lang = _ld_detect(text_sample) + return LANG2CODE.get(lang.lower(), { + "en": "eng_Latn", "de": "deu_Latn", "fr": "fra_Latn", "es": "spa_Latn", + "pt": "por_Latn", "it": "ita_Latn", "nl": "nld_Latn", "ru": "rus_Cyrl", + "ja": "jpn_Jpan", "ko": "kor_Hang", "zh-cn": "zho_Hans", "zh-tw": "zho_Hant", "ar": "arb_Arab" + }.get(lang.lower(), "eng_Latn")) + except Exception: + return "eng_Latn" + +# ---------- Language names & detection helpers for Ollama (NEW) ---------- +def _lang_name_for_prompt(name_or_code: Optional[str]) -> str: + if not name_or_code: + return "Unknown" + s = name_or_code.strip() + if re.match(r"^[a-z]{3}_[A-Za-z]{4}$", s): + for k, v in LANG2CODE.items(): + if len(k) <= 2: + continue + if v == s: + return k[:1].upper() + k[1:] + return s + return s[:1].upper() + s[1:] + +def _build_code2iso2_map() -> Dict[str, str]: + m: Dict[str, str] = {} + for alias, code in LANG2CODE.items(): + if len(alias) <= 2: + m[code] = alias + return m + +CODE2ISO2 = _build_code2iso2_map() + +_NONWORD_RE = re.compile(r"[^\w\u00C0-\uFFFF]+", flags=re.UNICODE) +_TAG_SENTINEL_RE = re.compile(r"⟦TAG\d+⟧|⟦BR⟧|⟦CUE⟧") + +def _normalize_for_compare(s: str) -> str: + s = _TAG_SENTINEL_RE.sub("", s or "") + s = _html.unescape(s) + s = s.lower() + s = _NONWORD_RE.sub("", s) + return s + +def _looks_untranslated(src_masked: str, out_raw: str, src_code: Optional[str], tgt_code: Optional[str]) -> bool: + s_norm = _normalize_for_compare(src_masked) + t_norm = _normalize_for_compare(out_raw) + if len(s_norm) >= 4 and s_norm == t_norm: + return True + if _ld_detect: + try: + probe = _TAG_SENTINEL_RE.sub("", out_raw or "") + if len(probe) >= 12: + ld = _ld_detect(probe) + src_iso = CODE2ISO2.get(src_code or "", "") + tgt_iso = CODE2ISO2.get(tgt_code or "", "") + if ld and src_iso and ld == src_iso and (not tgt_iso or ld != tgt_iso): + return True + except Exception: + pass + return False + +# ===================== Sentence boundary heuristics ===================== +_END_PUNCT_RE = re.compile(r'[\.!\?…。!?][»”"”\)\]\}]*\s*$') +_CONTINUATION_HINT_RE = re.compile(r'([,;:—\-…]|--)\s*$') + +def _strip_tags_placeholders(s: str) -> str: + s = re.sub(r'⟦TAG\d+⟧', '', s) + s = re.sub(r'\n\r]+?>', '', s) + return s.strip() + +def looks_like_sentence_end(text: str) -> bool: + t = _strip_tags_placeholders(text) + if not t: + return True + if _END_PUNCT_RE.search(t): + return True + if _CONTINUATION_HINT_RE.search(t): + return False + return False + +# ===================== Safe splitting & cleanups ===================== +def _nearest_boundary(text: str, approx_idx: int) -> int: + if not text: + return 0 + n = len(text) + approx_idx = max(1, min(n-1, approx_idx)) + left = approx_idx + right = approx_idx + while left > 0 or right < n: + if left > 0 and (text[left-1].isspace() or not text[left-1].isalnum() or not text[left].isalnum()): + return left + if right < n and (text[right-1].isspace() or not text[right-1].isalnum() or not text[right].isalnum()): + return right + left -= 1 + right += 1 + return approx_idx + +def _split_safely(text: str, parts: int) -> List[str]: + text = text.strip() + if parts <= 1: + return [text] + total = len(text) + out = [] + start = 0 + for i in range(parts - 1): + approx = int((total / parts) * (i + 1)) + idx = _nearest_boundary(text, approx_idx=approx) + piece = text[start:idx].strip() + out.append(piece) + start = idx + out.append(text[start:].strip()) + cleaned = [] + for p in out: + if p == "" and cleaned: + cleaned[-1] = (cleaned[-1] + " ").strip() + else: + cleaned.append(p) + return cleaned if cleaned else [""] + +def _fix_midword_hyphen(s: str) -> str: + s = s.replace("\u00AD", "") + s = re.sub(r"-\s+\b", "", s) + return s + +# ----- Marker normalization & splitting (UPPERCASE ONLY) ----- +_WRAPPERS = r"""["'“”„»«\[\]\{\}\(\)<>\s]*""" +BR_UP_RE = re.compile(rf"(? str: + s = re.sub(r"<\s*unk[^>]*>?", "", s, flags=re.IGNORECASE) + return s + +def _normalize_markers(s: str) -> str: + s = _strip_unk_tokens(s) + s = BR_UP_RE.sub(f" {SENTINEL_BR} ", s) + s = CUE_UP_RE.sub(f" {SENTINEL_CUE} ", s) + return s + +def _split_on_br_normalized(s: str) -> List[str]: + parts = [p.strip() for p in s.split(SENTINEL_BR)] + return [p for p in parts if p != ""] + +def _split_on_cue_normalized(s: str) -> List[str]: + parts = [p.strip() for p in s.split(SENTINEL_CUE)] + return [p for p in parts if p != ""] + +# ----- Leading punctuation fix (iterate until stable; includes Unicode commas) ----- +_LEADING_PUNCT_RE = re.compile( + r"""^\s*([,.;:!?…,、;:]+|[’'”")\]\}》〉」』】〕])]+)\s*""", re.UNICODE +) + +def _fix_leading_punct(lines: List[str]) -> List[str]: + if not lines: + return lines + out = lines[:] + changed = True + while changed: + changed = False + for i in range(1, len(out)): + m = _LEADING_PUNCT_RE.match(out[i]) + if m: + token = m.group(1) + out[i-1] = (out[i-1].rstrip() + token).rstrip() + out[i] = out[i][m.end():].lstrip() + changed = True + return out + +# ----- Spurious brackets removal & html unescape ----- +def _unescape_and_strip_artifacts(text: str) -> str: + s = _html.unescape(text) + s = s.replace("⟦", "").replace("⟧", "") + s = s.replace("\u00A0", " ").replace("\u202F", " ").replace("\u3000", " ") + s = re.sub(r"[ \t]+", " ", s) + return s.strip() + +def _strip_spurious_pairs(orig: str, trans: str) -> str: + out = trans + if "[]" in out and "[]" not in orig: + out = out.replace("[]", "") + if "{}" in out and "{}" not in orig: + out = out.replace("{}", "") + return out.strip() + +# ----- Single-word repetition squash (SeamlessM4T only) ----- +_WORD_RE = re.compile(r"\b\w+\b", flags=re.UNICODE) +_REPEAT_LINE_RE = re.compile(r"^\s*(\w+)(?:\W+\1\b){2,}\s*$", flags=re.IGNORECASE | re.UNICODE) + +def _is_one_word_line(s: str) -> bool: + s2 = _strip_tags_placeholders(s) + return len(_WORD_RE.findall(s2)) == 1 + +def _squelch_single_word_repeat_if_needed(orig_line: str, translated_line: str) -> str: + if not _is_one_word_line(orig_line): + return translated_line + m = _REPEAT_LINE_RE.match(translated_line) + if m: + return m.group(1) + words = _WORD_RE.findall(translated_line) + if words: + lowered = [w.lower() for w in words] + if len(set(lowered)) == 1 and len(lowered) >= 5: + return words[0] + return translated_line + +# ===================== NEW: marker/token removal at per-line finalization ===================== +_MARKER_TOKEN_RE = re.compile( + r'(?:(?<=^)|(?<=\s))' + r'(?:["\'“”‘’])?' + r'\S*?(?:BR|CUE)\S*?' + r'(?:["\'“”‘’])?' + r'(?:\s*,)?' + r'(?=(?:\s|$))' +) + +_TAGN_TOKEN_RE = re.compile( + r'(?:(?<=^)|(?<=\s))' + r'(?:["\'“”‘’])?' + r'\S*?TAG[0-9]\S*?' + r'(?:["\'“”‘’])?' + r'(?:\s*,)?' + r'(?=(?:\s|$))', + flags=re.IGNORECASE +) + +def _remove_marker_and_tagn_tokens(line: str) -> str: + out = line + for _ in range(3): + new = _MARKER_TOKEN_RE.sub("", out) + new = _TAGN_TOKEN_RE.sub("", new) + if new == out: + break + out = new + out = re.sub(r"\s+,", ",", out) + out = re.sub(r"\s+([.;:!?…,、;:])", r"\1", out) + out = re.sub(r"[ \t]{2,}", " ", out) + out = re.sub(r"^\s*[,,、]+", "", out) + return out.strip() + +# ===================== Device selection (NLLB & Seamless) ===================== +def pick_device_for_workers(workers: int, device_mode: str): + workers = max(1, workers) + device_mode = (device_mode or "auto").lower() + + has_mps = has_cuda = has_hip = False + try: + import torch + has_mps = getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() + has_cuda = torch.cuda.is_available() + has_hip = getattr(torch.version, "hip", None) is not None + except Exception: + pass + + is_mac = (sys.platform == "darwin") + + def _cpu(): + return {"device": -1}, workers + + def _mps(): + return {"device": "mps"}, 1 + + def _cuda_or_rocm(): + return {"device": 0}, 1 + + if device_mode == "cpu": + return _cpu() + + if device_mode == "gpu": + if is_mac and has_mps: + return _mps() + if has_cuda or has_hip: + return _cuda_or_rocm() + return _cpu() + + if is_mac: + if has_mps: + return _mps() + if has_cuda or has_hip: + return _cuda_or_rocm() + return _cpu() + else: + if has_cuda or has_hip: + return _cuda_or_rocm() + if has_mps: + return _mps() + return _cpu() + +# ===================== NLLB via Transformers ===================== +def get_nllb_translator(model_name: str, src_code: str, tgt_code: str, device_kwargs: dict): + try: + from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline as _pipeline + import torch # noqa + except Exception: + from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + from transformers.pipelines import pipeline as _pipeline + import torch # noqa + + tok = AutoTokenizer.from_pretrained(model_name) + mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name) + mdl.eval() + return _pipeline("translation", model=mdl, tokenizer=tok, + src_lang=src_code, tgt_lang=tgt_code, **device_kwargs) + +def ensure_model_downloaded(repo_id: str, revision: str = "main", tqdm_class=None): + from huggingface_hub import snapshot_download + snapshot_download(repo_id=repo_id, revision=revision, tqdm_class=tqdm_class) + +# ===================== SeamlessM4T (text-to-text) ===================== +def _to_seamless_lang(nllb_code: str) -> str: + if not nllb_code: + return "eng" + return nllb_code.split("_", 1)[0] + +def get_seamless_translator(model_name: str, src_code: str, tgt_code: str, device_kwargs: dict): + import torch + from transformers import AutoProcessor, SeamlessM4Tv2ForTextToText + + processor = AutoProcessor.from_pretrained(model_name) + model = SeamlessM4Tv2ForTextToText.from_pretrained(model_name) + model.eval() + dev = device_kwargs.get("device", -1) + + if dev == "mps": + device = torch.device("mps") + elif isinstance(dev, int) and dev >= 0 and torch.cuda.is_available(): + device = torch.device(f"cuda:{dev}") + else: + device = torch.device("cpu") + model.to(device) + + src3 = _to_seamless_lang(src_code) + tgt3 = _to_seamless_lang(tgt_code) + + def _translate(batch_texts: List[str]) -> List[str]: + texts = [" " if (t is None or str(t).strip() == "") else str(t) for t in batch_texts] + inputs = processor(text=texts, src_lang=src3, return_tensors="pt", padding=True) + inputs = {k: v.to(device) for k, v in inputs.items()} + import torch as _torch + with _torch.no_grad(): + out = model.generate(**inputs, tgt_lang=tgt3) + seqs = getattr(out, "sequences", out) + decoded = processor.batch_decode(seqs, skip_special_tokens=True) + return decoded + + return _translate + +def ensure_seamless_downloaded(repo_id: str = "facebook/seamless-m4t-v2-large", revision: str = "main", tqdm_class=None): + from huggingface_hub import snapshot_download + snapshot_download(repo_id=repo_id, revision=revision, tqdm_class=tqdm_class) + +# ===================== Ollama JSON chat (robust) ===================== +def _ollama_system_prompt(): + return ( + "You are a professional subtitle translator.\n" + "Return STRICT JSON ONLY. No Markdown, no code fences, no commentary.\n" + "CRITICAL RULES:\n" + "1) Preserve the exact array shapes and counts provided (lines per cue, cues per group).\n" + "2) Do NOT add/remove/swap items; do NOT merge or split lines.\n" + "3) Keep placeholder tokens like ⟦TAG0⟧ EXACTLY unchanged.\n" + "4) Do not hyphenate or break words across lines.\n" + "5) Keep numbers and time references as-is.\n" + "6) Unless a string is a proper name/brand/code or already in the target language, you MUST translate it.\n" + "Reply with pure JSON conforming to the requested schema.\n" + ) + +def _strip_code_fences(s: str) -> str: + s = s.strip() + if s.startswith("```"): + first = s.find("```") + if first != -1: + s = s[first+3:] + if "\n" in s: + s = s.split("\n", 1)[1] + if s.rstrip().endswith("```"): + s = s.rstrip()[:-3] + return s.strip() + +def _extract_balanced_json(s: str): + import json as _json + s = s.strip() + try: + return _json.loads(s) + except Exception: + pass + s = _strip_code_fences(s) + try: + return _json.loads(s) + except Exception: + pass + + start_obj = s.find("{") + start_arr = s.find("[") + if start_obj == -1 and start_arr == -1: + raise ValueError("No JSON start found") + start = min([i for i in [start_obj, start_arr] if i != -1]) + kind = "obj" if start == start_obj else "arr" + + depth_brace = depth_bracket = 0 + in_string = False + esc = False + for i in range(start, len(s)): + ch = s[i] + if in_string: + if esc: + esc = False + elif ch == "\\": + esc = True + elif ch == '"': + in_string = False + else: + if ch == '"': + in_string = True + elif ch == "{": + depth_brace += 1 + elif ch == "}": + depth_brace -= 1 + if kind == "obj" and depth_brace == 0 and depth_bracket == 0: + frag = s[start:i+1] + try: + return _json.loads(frag) + except Exception: + break + elif ch == "[": + depth_bracket += 1 + elif ch == "]": + depth_bracket -= 1 + if kind == "arr" and depth_bracket == 0 and depth_brace == 0: + frag = s[start:i+1] + try: + return _json.loads(frag) + except Exception: + break + frag = s[start:] + frag = re.sub(r",(\s*[}\]])", r"\1", frag) + return _json.loads(frag) + +def ollama_chat_json( + model: str, + system_prompt: str, + user_prompt: str, + host: str = "localhost", + port: int = 11434, + temperature: float = 0.2, + max_retries: int = 4, + timeout: int = 600, + cancel_event: Optional[threading.Event] = None, +): + last_err = None + for attempt in range(max_retries): + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + prompt = user_prompt if attempt == 0 else ( + user_prompt + "\n\nSTRICT RETRY: Previous reply was NOT valid JSON. " + "Respond with JSON ONLY (no code fences), matching the schema and counts exactly." + ) + body = json.dumps({ + "model": model, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt} + ], + "format": "json", + "stream": False, + "options": {"temperature": temperature} + }) + try: + conn = http.client.HTTPConnection(host, port, timeout=timeout) + conn.request("POST", "/api/chat", body=body, headers={"Content-Type": "application/json"}) + resp = conn.getresponse() + data = resp.read() + conn.close() + if resp.status != 200: + last_err = RuntimeError(f"Ollama HTTP {resp.status}: {data[:200]}") + time.sleep(1.1 * (attempt + 1)) + continue + j = json.loads(data.decode("utf-8")) + content = j.get("message", {}).get("content", "") + try: + return json.loads(content) + except Exception: + return _extract_balanced_json(content) + except Exception as e: + last_err = e + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + time.sleep(1.1 * (attempt + 1)) + raise last_err or RuntimeError("Ollama call failed") + +# ===================== Ollama prompt builders (UPDATED) ===================== +def _ollama_user_prompt_lines(target_lang: str, src_hint: str, lines: List[str], force: bool=False) -> str: + tgt_name = _lang_name_for_prompt(target_lang) + src_name = _lang_name_for_prompt(src_hint) + rules = [ + f"Translate each item to {tgt_name}.", + f"Source language hint: {src_name}.", + "Preserve array length exactly as provided.", + "No Markdown, no code fences, return JSON only." + ] + if force: + rules.append( + f"If an item is not already in {tgt_name} and contains translatable words, you MUST translate it; " + f"do NOT return the source sentence unchanged (proper names/brands/codes may remain)." + ) + payload = { + "task": "translate_lines", + "target_language": tgt_name, + "source_language_hint": src_name, + "schema": {"lines": ["", "..."]}, + "expected_counts": {"lines": len(lines)}, + "rules": rules, + "lines": lines, + } + return json.dumps(payload, ensure_ascii=False) + +def _ollama_user_prompt_cues(target_lang: str, src_hint: str, cues: List[List[str]]) -> str: + tgt_name = _lang_name_for_prompt(target_lang) + src_name = _lang_name_for_prompt(src_hint) + payload = { + "task": "translate_cues", + "target_language": tgt_name, + "source_language_hint": src_name, + "schema": {"cues": [["", "..."]]}, + "expected_counts": {"cues": [len(c) for c in cues]}, + "rules": [ + f"Translate to {tgt_name}. Source language hint: {src_name}.", + "Preserve lines per cue exactly.", + "No Markdown, no code fences, return JSON only." + ], + "cues": cues, + } + return json.dumps(payload, ensure_ascii=False) + +def _ollama_user_prompt_groups(target_lang: str, src_hint: str, groups: List[List[List[str]]]) -> str: + tgt_name = _lang_name_for_prompt(target_lang) + src_name = _lang_name_for_prompt(src_hint) + payload = { + "task": "translate_groups", + "target_language": tgt_name, + "source_language_hint": src_name, + "schema": {"groups": [[["", "..."]]]}, + "expected_counts": {"groups": [[len(cue) for cue in group] for group in groups]}, + "rules": [ + f"Translate to {tgt_name}. Source language hint: {src_name}.", + "Preserve cues per group and lines per cue exactly.", + "No Markdown, no code fences, return JSON only." + ], + "groups": groups, + } + return json.dumps(payload, ensure_ascii=False) + +# ---------- Ollama strict retry for lines (NEW) ---------- +def _ollama_retry_translate_lines(model: str, host: str, port: int, + target_language: str, src_hint: str, + masked_lines: List[str], + cancel_event: Optional[threading.Event] = None) -> List[str]: + if not masked_lines: + return [] + + sys_prompt = _ollama_system_prompt() + models = [m.strip() for m in str(model).split(",") if m.strip()] or [model] + temps = (0.1, 0.3, 0.0) + + for T in temps: + for m in models: + if cancel_event is not None and cancel_event.is_set(): + return masked_lines + try: + prompt = _ollama_user_prompt_lines(target_language, src_hint, masked_lines, force=True) + obj = ollama_chat_json(m, sys_prompt, prompt, host=host, port=port, + temperature=T, max_retries=3, cancel_event=cancel_event) + lines = obj.get("lines", []) + if isinstance(lines, list) and len(lines) == len(masked_lines): + return lines + except Exception: + continue + return masked_lines + +# ===================== Builders ===================== +def build_line_items(cue_lines: List[List[str]]) -> Tuple[List[str], List[Dict[str, str]]]: + flat_masked: List[str] = [] + masks: List[Dict[str, str]] = [] + for lines in cue_lines: + for ln in lines: + masked, ph = protect_tags(ln) + flat_masked.append(masked) + masks.append(ph) + return flat_masked, masks + +def build_cue_items_for_nllb(cue_lines: List[List[str]]) -> Tuple[List[str], List[Dict[str, str]], List[int]]: + joined: List[str] = [] + masks: List[Dict[str, str]] = [] + counts: List[int] = [] + for lines in cue_lines: + combined = f"\n{SENTINEL_BR}\n".join(lines) if lines else "" + masked, ph = protect_tags(combined) + joined.append(masked) + masks.append(ph) + counts.append(len(lines)) + return joined, masks, counts + +def build_cue_items_for_llm(cue_lines: List[List[str]]) -> Tuple[List[List[str]], List[List[Dict[str, str]]]]: + cues_text: List[List[str]] = [] + cues_masks: List[List[Dict[str, str]]] = [] + for lines in cue_lines: + masked_lines, masks = [], [] + for ln in lines: + m, ph = protect_tags(ln) + masked_lines.append(m) + masks.append(ph) + cues_text.append(masked_lines) + cues_masks.append(masks) + return cues_text, cues_masks + +def build_smart_groups_for_nllb(cue_lines: List[List[str]], max_span_cues: int = 4 + ) -> Tuple[List[str], List[Dict[str, str]], List[List[int]]]: + n = len(cue_lines) + groups_text: List[str] = [] + groups_masks: List[Dict[str, str]] = [] + groups_counts: List[List[int]] = [] + + i = 0 + while i < n: + group_cues = [] + counts = [] + span = 0 + while i < n and span < max_span_cues: + lines = cue_lines[i] + counts.append(len(lines)) + cue_text = f"\n{SENTINEL_BR}\n".join(lines) + group_cues.append(cue_text) + if looks_like_sentence_end("\n".join(lines)): + i += 1 + break + span += 1 + i += 1 + combined_group = f"\n{SENTINEL_CUE}\n".join(group_cues) + masked, ph = protect_tags(combined_group) + groups_text.append(masked) + groups_masks.append(ph) + groups_counts.append(counts) + return groups_text, groups_masks, groups_counts + +def build_smart_groups_for_llm(cue_lines: List[List[str]], max_span_cues: int = 4 + ) -> Tuple[List[List[List[str]]], List[List[List[Dict[str,str]]]], List[List[int]]]: + n = len(cue_lines) + groups_text: List[List[List[str]]] = [] + groups_masks: List[List[List[Dict[str,str]]]] = [] + groups_counts: List[List[int]] = [] + + i = 0 + while i < n: + group_cues: List[List[str]] = [] + group_masks: List[List[Dict[str,str]]] = [] + counts: List[int] = [] + span = 0 + while i < n and span < max_span_cues: + lines = cue_lines[i] + counts.append(len(lines)) + masked_lines, masks = [], [] + for ln in lines: + m, ph = protect_tags(ln) + masked_lines.append(m) + masks.append(ph) + group_cues.append(masked_lines) + group_masks.append(masks) + if looks_like_sentence_end("\n".join(lines)): + i += 1 + break + span += 1 + i += 1 + groups_text.append(group_cues) + groups_masks.append(group_masks) + groups_counts.append(counts) + return groups_text, groups_masks, groups_counts + +# ===================== Queue helper ===================== +def _drain_queue(q: "queue.Queue"): + while True: + try: + _ = q.get_nowait() + q.task_done() + except queue.Empty: + break + +# ===================== Translation workers (NLLB) ===================== +def nllb_translate_lines(flat_lines, masks, model_name, src_code, tgt_code, + workers, batch_size, on_progress, pause_event_unused, cancel_event=None, device_kwargs=None): + n = len(flat_lines) + results: List[Optional[str]] = [None] * n + q: "queue.Queue[Optional[Tuple[List[int], List[str], List[Dict[str,str]]]]]" = queue.Queue() + for i in range(0, n, batch_size): + idx_chunk = list(range(i, min(n, i+batch_size))) + masked_batch = [flat_lines[j] for j in idx_chunk] + batch_masks = [masks[j] for j in idx_chunk] + q.put((idx_chunk, masked_batch, batch_masks)) + for _ in range(workers): + q.put(None) + + def worker_main(): + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); return + translator = get_nllb_translator(model_name, src_code, tgt_code, device_kwargs or {"device": -1}) + micro = 8 + while True: + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); break + try: + item = q.get(timeout=0.2) + except queue.Empty: + continue + if item is None: + q.task_done(); break + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); break + idx_chunk, masked_batch, batch_masks = item + to_xlate_full = [" " if s.strip() == "" else s for s in masked_batch] + for s in range(0, len(to_xlate_full), micro): + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); return + sub_idx = idx_chunk[s:s+micro] + sub_in = to_xlate_full[s:s+micro] + sub_masks = batch_masks[s:s+micro] + outs = translator(sub_in) + for pos, out in enumerate(outs): + i = sub_idx[pos] + restored = restore_tags(out["translation_text"], sub_masks[pos]) + restored = _normalize_markers(restored) + restored = _unescape_and_strip_artifacts(restored) + results[i] = _fix_midword_hyphen(restored) + if on_progress: on_progress(1) + q.task_done() + + threads = [threading.Thread(target=worker_main, daemon=True) for _ in range(workers)] + [t.start() for t in threads]; q.join() + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + return [r if r is not None else "" for r in results] + +def nllb_translate_cues(flat_cues, cue_masks, counts, model_name, src_code, tgt_code, + workers, batch_size, on_progress, pause_event_unused, cancel_event=None, device_kwargs=None): + n = len(flat_cues) + results: List[Optional[List[str]]] = [None] * n + q: "queue.Queue[Optional[Tuple[List[int], List[str], List[Dict[str,str]]]]]" = queue.Queue() + for i in range(0, n, batch_size): + idx_chunk = list(range(i, min(n, i+batch_size))) + masked_batch = [flat_cues[j] for j in idx_chunk] + batch_masks = [cue_masks[j] for j in idx_chunk] + q.put((idx_chunk, masked_batch, batch_masks)) + for _ in range(workers): + q.put(None) + + def worker_main(): + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); return + translator = get_nllb_translator(model_name, src_code, tgt_code, device_kwargs or {"device": -1}) + micro = 6 + while True: + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); break + try: + item = q.get(timeout=0.2) + except queue.Empty: + continue + if item is None: + q.task_done(); break + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); break + idx_chunk, masked_batch, batch_masks = item + to_xlate_full = [" " if s.strip()=="" else s for s in masked_batch] + for s in range(0, len(to_xlate_full), micro): + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); return + sub_idx = idx_chunk[s:s+micro] + sub_in = to_xlate_full[s:s+micro] + sub_masks = batch_masks[s:s+micro] + outs = translator(sub_in) + for pos, out in enumerate(outs): + i = sub_idx[pos] + restored = restore_tags(out["translation_text"], sub_masks[pos]) + restored = _normalize_markers(restored) + restored = _unescape_and_strip_artifacts(restored) + n_lines = counts[i] + lines = resplit_translated_cue_text(restored, n_lines) + results[i] = _fix_leading_punct([_fix_midword_hyphen(p) for p in lines]) + if on_progress: on_progress(n_lines) + q.task_done() + + threads = [threading.Thread(target=worker_main, daemon=True) for _ in range(workers)] + [t.start() for t in threads]; q.join() + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + return [r if r is not None else [""]*c for r,c in zip(results, counts)] + +def nllb_translate_groups(groups_text, groups_masks, groups_counts, model_name, src_code, tgt_code, + workers, batch_size, on_progress, pause_event_unused, cancel_event=None, device_kwargs=None): + n = len(groups_text) + results: List[Optional[List[List[str]]]] = [None] * n + q: "queue.Queue[Optional[Tuple[List[int], List[str], List[Dict[str,str]]]]]" = queue.Queue() + for i in range(0, n, batch_size): + idx_chunk = list(range(i, min(n, i+batch_size))) + masked_batch = [groups_text[j] for j in idx_chunk] + batch_masks = [groups_masks[j] for j in idx_chunk] + q.put((idx_chunk, masked_batch, batch_masks)) + for _ in range(workers): + q.put(None) + + def worker_main(): + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); return + translator = get_nllb_translator(model_name, src_code, tgt_code, device_kwargs or {"device": -1}) + micro = 3 + while True: + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); break + try: + item = q.get(timeout=0.2) + except queue.Empty: + continue + if item is None: + q.task_done(); break + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); break + idx_chunk, masked_batch, batch_masks = item + to_xlate_full = [" " if s.strip()=="" else s for s in masked_batch] + for s in range(0, len(to_xlate_full), micro): + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); return + sub_idx = idx_chunk[s:s+micro] + sub_in = to_xlate_full[s:s+micro] + sub_masks = batch_masks[s:s+micro] + outs = translator(sub_in) + for pos, out in enumerate(outs): + gi = sub_idx[pos] + restored = restore_tags(out["translation_text"], sub_masks[pos]) + restored = _normalize_markers(restored) + restored = _unescape_and_strip_artifacts(restored) + cue_chunks = _split_on_cue_normalized(restored) + counts = groups_counts[gi] + if len(cue_chunks) != len(counts): + merged = " ".join(cue_chunks).strip() + cue_chunks = _split_safely(merged, len(counts)) + rebuilt_cues = [] + for chunk, n_lines in zip(cue_chunks, counts): + lines = resplit_translated_cue_text(chunk, n_lines) + rebuilt_cues.append(_fix_leading_punct([_fix_midword_hyphen(p) for p in lines])) + results[gi] = rebuilt_cues + if on_progress: on_progress(sum(counts)) + q.task_done() + + threads = [threading.Thread(target=worker_main, daemon=True) for _ in range(workers)] + [t.start() for t in threads]; q.join() + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + return [r if r is not None else [[""]*n for n in counts] for r,counts in zip(results, groups_counts)] + +# ===================== Translation workers (SeamlessM4T) ===================== +def seamless_translate_lines(flat_lines, masks, model_name, src_code, tgt_code, + workers, batch_size, on_progress, pause_event_unused, cancel_event=None, device_kwargs=None): + n = len(flat_lines) + results: List[Optional[str]] = [None] * n + q: "queue.Queue[Optional[Tuple[List[int], List[str], List[Dict[str,str]]]]]" = queue.Queue() + for i in range(0, n, batch_size): + idx_chunk = list(range(i, min(n, i+batch_size))) + masked_batch = [flat_lines[j] for j in idx_chunk] + batch_masks = [masks[j] for j in idx_chunk] + q.put((idx_chunk, masked_batch, batch_masks)) + for _ in range(workers): + q.put(None) + + def worker_main(): + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); return + translator = get_seamless_translator(model_name, src_code, tgt_code, device_kwargs or {"device": -1}) + micro = 8 + while True: + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); break + try: + item = q.get(timeout=0.2) + except queue.Empty: + continue + if item is None: + q.task_done(); break + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); break + idx_chunk, masked_batch, batch_masks = item + to_xlate_full = [" " if s.strip()=="" else s for s in masked_batch] + for s in range(0, len(to_xlate_full), micro): + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); return + sub_idx = idx_chunk[s:s+micro] + sub_in = to_xlate_full[s:s+micro] + sub_masks = batch_masks[s:s+micro] + outs = translator(sub_in) + for pos, out in enumerate(outs): + i = sub_idx[pos] + restored = restore_tags(out, sub_masks[pos]) + restored = _normalize_markers(restored) + restored = _unescape_and_strip_artifacts(restored) + results[i] = _fix_midword_hyphen(restored) + if on_progress: on_progress(1) + q.task_done() + + threads = [threading.Thread(target=worker_main, daemon=True) for _ in range(workers)] + [t.start() for t in threads]; q.join() + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + return [r if r is not None else "" for r in results] + +def seamless_translate_cues(flat_cues, cue_masks, counts, model_name, src_code, tgt_code, + workers, batch_size, on_progress, pause_event_unused, cancel_event=None, device_kwargs=None): + n = len(flat_cues) + results: List[Optional[List[str]]] = [None] * n + q: "queue.Queue[Optional[Tuple[List[int], List[str], List[Dict[str,str]]]]]" = queue.Queue() + for i in range(0, n, batch_size): + idx_chunk = list(range(i, min(n, i+batch_size))) + masked_batch = [flat_cues[j] for j in idx_chunk] + batch_masks = [cue_masks[j] for j in idx_chunk] + q.put((idx_chunk, masked_batch, batch_masks)) + for _ in range(workers): + q.put(None) + + def worker_main(): + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); return + translator = get_seamless_translator(model_name, src_code, tgt_code, device_kwargs or {"device": -1}) + micro = 6 + while True: + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); break + try: + item = q.get(timeout=0.2) + except queue.Empty: + continue + if item is None: + q.task_done(); break + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); break + idx_chunk, masked_batch, batch_masks = item + to_xlate_full = [" " if s.strip()=="" else s for s in masked_batch] + for s in range(0, len(to_xlate_full), micro): + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); return + sub_idx = idx_chunk[s:s+micro] + sub_in = to_xlate_full[s:s+micro] + sub_masks = batch_masks[s:s+micro] + outs = translator(sub_in) + for pos, out in enumerate(outs): + i = sub_idx[pos] + restored = restore_tags(out, sub_masks[pos]) + restored = _normalize_markers(restored) + restored = _unescape_and_strip_artifacts(restored) + n_lines = counts[i] + lines = resplit_translated_cue_text(restored, n_lines) + results[i] = _fix_leading_punct([_fix_midword_hyphen(p) for p in lines]) + if on_progress: on_progress(n_lines) + q.task_done() + + threads = [threading.Thread(target=worker_main, daemon=True) for _ in range(workers)] + [t.start() for t in threads]; q.join() + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + return [r if r is not None else [""]*c for r,c in zip(results, counts)] + +def seamless_translate_groups(groups_text, groups_masks, groups_counts, model_name, src_code, tgt_code, + workers, batch_size, on_progress, pause_event_unused, cancel_event=None, device_kwargs=None): + n = len(groups_text) + results: List[Optional[List[List[str]]]] = [None] * n + q: "queue.Queue[Optional[Tuple[List[int], List[str], List[Dict[str,str]]]]]" = queue.Queue() + for i in range(0, n, batch_size): + idx_chunk = list(range(i, min(n, i+batch_size))) + masked_batch = [groups_text[j] for j in idx_chunk] + batch_masks = [groups_masks[j] for j in idx_chunk] + q.put((idx_chunk, masked_batch, batch_masks)) + for _ in range(workers): + q.put(None) + + def worker_main(): + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); return + translator = get_seamless_translator(model_name, src_code, tgt_code, device_kwargs or {"device": -1}) + micro = 3 + while True: + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); break + try: + item = q.get(timeout=0.2) + except queue.Empty: + continue + if item is None: + q.task_done(); break + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); break + idx_chunk, masked_batch, batch_masks = item + to_xlate_full = [" " if s.strip()=="" else s for s in masked_batch] + for s in range(0, len(to_xlate_full), micro): + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); return + sub_idx = idx_chunk[s:s+micro] + sub_in = to_xlate_full[s:s+micro] + sub_masks = batch_masks[s:s+micro] + outs = translator(sub_in) + for pos, out in enumerate(outs): + gi = sub_idx[pos] + restored = restore_tags(out, sub_masks[pos]) + restored = _normalize_markers(restored) + restored = _unescape_and_strip_artifacts(restored) + cue_chunks = _split_on_cue_normalized(restored) + counts = groups_counts[gi] + if len(cue_chunks) != len(counts): + merged = " ".join(cue_chunks).strip() + cue_chunks = _split_safely(merged, len(counts)) + rebuilt_cues = [] + for chunk, n_lines in zip(cue_chunks, counts): + lines = resplit_translated_cue_text(chunk, n_lines) + rebuilt_cues.append(_fix_leading_punct([_fix_midword_hyphen(p) for p in lines])) + results[gi] = rebuilt_cues + if on_progress: on_progress(sum(counts)) + q.task_done() + + threads = [threading.Thread(target=worker_main, daemon=True) for _ in range(workers)] + [t.start() for t in threads]; q.join() + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + return [r if r is not None else [[""]*n for n in counts] for r,counts in zip(results, groups_counts)] + +# ===================== Ollama workers (UPDATED with Repair-Pass & cancel-aware) ===================== +def ollama_translate_lines(flat_lines, masks, target_language, src_code, model, host, port, + workers, batch_size, on_progress, pause_event_unused, cancel_event=None): + n = len(flat_lines) + results: List[Optional[str]] = [None] * n + tgt_code = norm_lang_to_code(target_language, default=None) + + q: "queue.Queue[Optional[Tuple[List[int], List[str], List[Dict[str,str]]]]]" = queue.Queue() + for i in range(0, n, batch_size): + idx_chunk = list(range(i, min(n, i+batch_size))) + batch = [flat_lines[j] for j in idx_chunk] + batch_masks = [masks[j] for j in idx_chunk] + q.put((idx_chunk, batch, batch_masks)) + for _ in range(workers): + q.put(None) + + def worker_main(): + micro = 8 + sys_prompt = _ollama_system_prompt() + while True: + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); break + try: + item = q.get(timeout=0.2) + except queue.Empty: + continue + if item is None: + q.task_done(); break + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); break + idx_chunk, batch, batch_masks = item + + # micro-batched calls to limit blocking time + for s in range(0, len(batch), micro): + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); return + sub_idx = idx_chunk[s:s+micro] + sub_batch = batch[s:s+micro] + sub_masks = batch_masks[s:s+micro] + + prompt = _ollama_user_prompt_lines(target_language, src_code, sub_batch, force=False) + out_obj = ollama_chat_json(model, sys_prompt, prompt, host=host, port=port, + temperature=0.1, cancel_event=cancel_event) + lines_out = out_obj.get("lines", []) + if not isinstance(lines_out, list) or len(lines_out) != len(sub_batch): + raise RuntimeError("Ollama returned invalid lines array.") + + # Repair pass + to_fix_idx: List[int] = [] + to_fix_src: List[str] = [] + for pos, t in enumerate(lines_out): + src_masked = sub_batch[pos] + if _looks_untranslated(src_masked, t, src_code, tgt_code): + to_fix_idx.append(pos); to_fix_src.append(src_masked) + if to_fix_src: + fixed = _ollama_retry_translate_lines(model, host, port, target_language, src_code, to_fix_src, cancel_event=cancel_event) + for k, new_t in enumerate(fixed): + lines_out[to_fix_idx[k]] = new_t + + # restore + finalize + for pos, t in enumerate(lines_out): + i = sub_idx[pos] + restored = restore_tags(t, sub_masks[pos]) + restored = _normalize_markers(restored) + restored = _unescape_and_strip_artifacts(restored) + results[i] = _fix_midword_hyphen(restored) + if on_progress: on_progress(1) + q.task_done() + + threads = [threading.Thread(target=worker_main, daemon=True) for _ in range(workers)] + [t.start() for t in threads]; q.join() + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + return [r if r is not None else "" for r in results] + +def ollama_translate_cues(cues_text, cues_masks, target_language, src_code, model, host, port, + workers, batch_size, on_progress, pause_event_unused, cancel_event=None): + n = len(cues_text) + results: List[Optional[List[str]]] = [None] * n + tgt_code = norm_lang_to_code(target_language, default=None) + + q: "queue.Queue[Optional[Tuple[List[int], List[List[str]], List[List[Dict[str,str]]]]]]" = queue.Queue() + for i in range(0, n, batch_size): + idx_chunk = list(range(i, min(n, i+batch_size))) + batch = [cues_text[j] for j in idx_chunk] + batch_masks = [cues_masks[j] for j in idx_chunk] + q.put((idx_chunk, batch, batch_masks)) + for _ in range(workers): + q.put(None) + + def worker_main(): + micro = 6 + sys_prompt = _ollama_system_prompt() + while True: + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); break + try: + item = q.get(timeout=0.2) + except queue.Empty: + continue + if item is None: + q.task_done(); break + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); break + idx_chunk, batch, batch_masks = item + + for s in range(0, len(batch), micro): + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); return + sub_idx = idx_chunk[s:s+micro] + sub_batch = batch[s:s+micro] + sub_masks = batch_masks[s:s+micro] + + prompt = _ollama_user_prompt_cues(target_language, src_code, sub_batch) + out_obj = ollama_chat_json(model, sys_prompt, prompt, host=host, port=port, + temperature=0.1, cancel_event=cancel_event) + cues_out = out_obj.get("cues", []) + if (not isinstance(cues_out, list)) or (len(cues_out) != len(sub_batch)): + raise RuntimeError("Ollama returned invalid cues array.") + + retry_pairs: List[Tuple[int,int]] = [] + retry_texts: List[str] = [] + for pos, cue_lines_trans in enumerate(cues_out): + masks_per_line = sub_masks[pos] + if len(cue_lines_trans) != len(masks_per_line): + merged = " ".join(cue_lines_trans) + cue_lines_trans = _split_safely(merged, len(masks_per_line)) + cues_out[pos] = cue_lines_trans + for li, t in enumerate(cue_lines_trans): + src_masked = sub_batch[pos][li] + if _looks_untranslated(src_masked, t, src_code, tgt_code): + retry_pairs.append((pos, li)) + retry_texts.append(src_masked) + + if retry_texts: + fixed = _ollama_retry_translate_lines(model, host, port, target_language, src_code, retry_texts, cancel_event=cancel_event) + for (pos, li), new_t in zip(retry_pairs, fixed): + cues_out[pos][li] = new_t + + for pos, cue_lines_trans in enumerate(cues_out): + masks_per_line = sub_masks[pos] + rebuilt = [] + for ln, ph in zip(cue_lines_trans, masks_per_line): + t = restore_tags(ln, ph) + t = _normalize_markers(t) + t = _unescape_and_strip_artifacts(t) + rebuilt.append(_fix_midword_hyphen(t)) + i = sub_idx[pos] + results[i] = _fix_leading_punct(rebuilt) + if on_progress: on_progress(len(rebuilt)) + q.task_done() + + threads = [threading.Thread(target=worker_main, daemon=True) for _ in range(workers)] + [t.start() for t in threads]; q.join() + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + return [r if r is not None else [""]*len(cues_text[idx]) for idx,r in enumerate(results)] + +def ollama_translate_groups(groups_text, groups_masks, groups_counts, target_language, src_code, + model, host, port, workers, batch_size, on_progress, pause_event_unused, cancel_event=None): + n = len(groups_text) + results: List[Optional[List[List[str]]]] = [None] * n + tgt_code = norm_lang_to_code(target_language, default=None) + + q: "queue.Queue[Optional[Tuple[List[int], List[List[List[str]]], List[List[List[Dict[str,str]]]]]]]" = queue.Queue() + for i in range(0, n, batch_size): + idx_chunk = list(range(i, min(n, i+batch_size))) + batch = [groups_text[j] for j in idx_chunk] + batch_masks = [groups_masks[j] for j in idx_chunk] + q.put((idx_chunk, batch, batch_masks)) + for _ in range(workers): + q.put(None) + + def worker_main(): + micro = 4 + sys_prompt = _ollama_system_prompt() + while True: + if cancel_event is not None and cancel_event.is_set(): + _drain_queue(q); break + try: + item = q.get(timeout=0.2) + except queue.Empty: + continue + if item is None: + q.task_done(); break + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); break + + idx_chunk, batch, batch_masks = item + for s in range(0, len(batch), micro): + if cancel_event is not None and cancel_event.is_set(): + q.task_done(); _drain_queue(q); return + sub_idx = idx_chunk[s:s+micro] + sub_batch = batch[s:s+micro] + sub_masks = batch_masks[s:s+micro] + + prompt = _ollama_user_prompt_groups(target_language, src_code, sub_batch) + out_obj = ollama_chat_json(model, sys_prompt, prompt, host=host, port=port, + temperature=0.1, cancel_event=cancel_event) + groups_out = out_obj.get("groups", []) + + if (not isinstance(groups_out, list)) or (len(groups_out) != len(sub_batch)): + groups_out = [] + for group in sub_batch: + prompt2 = _ollama_user_prompt_cues(target_language, src_code, group) + try: + obj2 = ollama_chat_json(model, sys_prompt, prompt2, host=host, port=port, + temperature=0.1, cancel_event=cancel_event) + cues_out2 = obj2.get("cues", []) + if not isinstance(cues_out2, list) or len(cues_out2) != len(group): + cues_out2 = group + except Exception: + cues_out2 = group + groups_out.append(cues_out2) + + retry_triples: List[Tuple[int,int,int]] = [] + retry_texts: List[str] = [] + + for pos, group_out in enumerate(groups_out): + masks_for_group = sub_masks[pos] + if len(group_out) != len(masks_for_group): + merged_cues_text = [" ".join(c) if isinstance(c, list) else str(c) for c in group_out] + merged_all = " ".join(merged_cues_text) + split_cues = _split_safely(merged_all, len(masks_for_group)) + coerced_group = [] + for chunk, line_masks in zip(split_cues, masks_for_group): + coerced_group.append(_split_safely(chunk, len(line_masks))) + group_out = coerced_group + groups_out[pos] = group_out + + for ci, (cue_out, line_masks) in enumerate(zip(group_out, masks_for_group)): + if len(cue_out) != len(line_masks): + merged = " ".join(cue_out if isinstance(cue_out, list) else [str(cue_out)]) + cue_out = _split_safely(merged, len(line_masks)) + group_out[ci] = cue_out + for li, ln in enumerate(cue_out): + src_masked = sub_batch[pos][ci][li] + if _looks_untranslated(src_masked, ln, src_code, tgt_code): + retry_triples.append((pos, ci, li)) + retry_texts.append(src_masked) + + if retry_texts: + fixed = _ollama_retry_translate_lines(model, host, port, target_language, src_code, retry_texts, cancel_event=cancel_event) + for (pos, ci, li), new_t in zip(retry_triples, fixed): + groups_out[pos][ci][li] = new_t + + for pos, group_out in enumerate(groups_out): + masks_for_group = sub_masks[pos] + rebuilt_group: List[List[str]] = [] + for cue_out, line_masks in zip(group_out, masks_for_group): + rebuilt_lines = [] + for ln, ph in zip(cue_out, line_masks): + t = restore_tags(ln, ph) + t = _normalize_markers(t) + t = _unescape_and_strip_artifacts(t) + rebuilt_lines.append(_fix_midword_hyphen(t)) + rebuilt_group.append(_fix_leading_punct(rebuilt_lines)) + gi = sub_idx[pos] + results[gi] = rebuilt_group + if on_progress: on_progress(sum(len(x) for x in rebuilt_group)) + q.task_done() + + threads = [threading.Thread(target=worker_main, daemon=True) for _ in range(workers)] + [t.start() for t in threads]; q.join() + if cancel_event is not None and cancel_event.is_set(): + raise RuntimeError("Cancelled") + return [r if r is not None else [[""]*n for n in counts] for r,counts in zip(results, groups_counts)] + +# ===================== Rebuild helpers ===================== +def resplit_translated_cue_text(translated: str, n_lines: int) -> List[str]: + normalized = _normalize_markers(translated) + parts = _split_on_br_normalized(normalized) + if len(parts) == n_lines: + return [_fix_midword_hyphen(p) for p in parts] + + nl_parts = [ln.strip() for ln in re.split(r"\r?\n", normalized) if ln.strip()] + if len(nl_parts) == n_lines: + return [_fix_midword_hyphen(p) for p in nl_parts] + + merged = " ".join(parts if parts else [normalized]).strip() + return [_fix_midword_hyphen(p) for p in _split_safely(merged, n_lines)] + +def _postprocess_line(orig_line: str, trans_line: str, squelch_for_seamless: bool=False) -> str: + t = _strip_spurious_pairs(orig_line, trans_line) + t = _unescape_and_strip_artifacts(t) + t = _remove_marker_and_tagn_tokens(t) + if squelch_for_seamless: + t = _squelch_single_word_repeat_if_needed(orig_line, t) + return t + +def rebuild_from_flat_lines(cues, cue_lines, translated_flat, squelch_single_word=False): + rebuilt_cues = [] + idx = 0 + for c, lines in zip(cues, cue_lines): + n = len(lines) + tlines = translated_flat[idx:idx + n] + idx += n + tlines = [_postprocess_line(o, _fix_midword_hyphen(t), squelch_single_word) for o, t in zip(lines, tlines)] + tlines = _fix_leading_punct(tlines) + c_new = dict(c); c_new["text"] = "\n".join(tlines) + rebuilt_cues.append(c_new) + return rebuilt_cues + +def rebuild_from_cue_parts(cues, translated_per_cue, cue_lines, squelch_single_word=False): + rebuilt_cues = [] + for c, tparts, orig_parts in zip(cues, translated_per_cue, cue_lines): + tparts = [_postprocess_line(o, _fix_midword_hyphen(t), squelch_single_word) for o, t in zip(orig_parts, tparts)] + tparts = _fix_leading_punct(tparts) + c_new = dict(c); c_new["text"] = "\n".join(tparts) + rebuilt_cues.append(c_new) + return rebuilt_cues + +# ===================== Config ===================== +def load_config(): + try: + if os.path.exists(CONFIG_PATH): + with open(CONFIG_PATH, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + pass + return {} + +def save_config(cfg: dict): + try: + with open(CONFIG_PATH, "w", encoding="utf-8") as f: + json.dump(cfg, f, ensure_ascii=False, indent=2) + except Exception: + pass + +# ===================== Ollama model listing ===================== +def list_ollama_models_http(host="localhost", port=11434) -> List[str]: + try: + conn = http.client.HTTPConnection(host, port, timeout=10) + conn.request("GET", "/api/tags") + resp = conn.getresponse() + data = resp.read() + conn.close() + if resp.status != 200: + return [] + j = json.loads(data.decode("utf-8")) + models = j.get("models", []) + names = [] + for m in models: + name = m.get("name") + if isinstance(name, str): + names.append(name) + return sorted(set(names)) + except Exception: + return [] + +def list_ollama_models_cli() -> List[str]: + try: + p = subprocess.run(["ollama", "list", "--json"], capture_output=True, text=True, timeout=10) + if p.returncode == 0 and p.stdout.strip(): + txt = p.stdout.strip() + models = [] + try: + j = json.loads(txt) + if isinstance(j, dict) and "models" in j: + for m in j["models"]: + name = m.get("name") + if isinstance(name, str): + models.append(name) + elif isinstance(j, list): + for m in j: + name = m.get("name") + if isinstance(name, str): + models.append(name) + except Exception: + for line in txt.splitlines(): + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + name = obj.get("name") + if isinstance(name, str): + models.append(name) + except Exception: + pass + if models: + return sorted(set(models)) + except Exception: + pass + try: + p = subprocess.run(["ollama", "list"], capture_output=True, text=True, timeout=10) + if p.returncode == 0 and p.stdout: + models = [] + for line in p.stdout.splitlines(): + line = line.strip() + if line.lower().startswith(("name","models")) or not line: + continue + parts = line.split() + if parts: + models.append(parts[0]) + if models: + return sorted(set(models)) + except Exception: + pass + return [] + +def get_all_ollama_models(host="localhost", port=11434) -> List[str]: + names = list_ollama_models_http(host, port) + if names: + return names + return list_ollama_models_cli() + +# ===================== CLI flow ===================== +def run_cli(args): + if not os.path.exists(args.srt_file): + print(f"File not found: {args.srt_file}", file=sys.stderr); sys.exit(1) + + with open(args.srt_file, "r", encoding="utf-8", errors="replace") as f: + srt_text = f.read() + + cues = parse_srt(srt_text) + if not cues: + print("No SRT cues detected. Is the file valid?", file=sys.stderr); sys.exit(1) + + cue_lines: List[List[str]] = [c["text"].splitlines() if c["text"] else [""] for c in cues] + tgt_code = _resolve_target(args.target_language) + src_code = _resolve_source(args.src, [ln for lines in cue_lines for ln in lines]) + base, _ = os.path.splitext(args.srt_file) + out_path = args.out or f"{base}.{tgt_code.lower()}.srt" + + total_lines = sum(len(ls) for ls in cue_lines) + pbar = tqdm(total=total_lines, unit="line", dynamic_ncols=True, desc="Translating") if (tqdm and not args.no_progress) else None + on_progress = (lambda n=1: pbar.update(n)) if pbar else None + + cancel_event = None # CLI: no cancel + + if args.engine == "nllb": + print(f"Ensuring model is cached: {args.model} …") + ensure_model_downloaded(args.model) + dev_kwargs, workers = pick_device_for_workers(max(1, args.workers), args.device) + print(f"Engine=NLLB | model={args.model} | src={src_code}→{tgt_code} | device={dev_kwargs.get('device')} | threads={args.workers} effective={workers} | context={args.context}") + + if args.context == "line": + flat_lines, masks = build_line_items(cue_lines) + translated_flat = nllb_translate_lines(flat_lines, masks, args.model, src_code, tgt_code, + workers, args.batch, on_progress, None, cancel_event, dev_kwargs) + rebuilt_cues = rebuild_from_flat_lines(cues, cue_lines, translated_flat, squelch_single_word=False) + elif args.context == "cue": + flat_cues, cue_masks, counts = build_cue_items_for_nllb(cue_lines) + translated_per_cue = nllb_translate_cues(flat_cues, cue_masks, counts, args.model, src_code, tgt_code, + workers, min(args.batch, 32), on_progress, None, cancel_event, dev_kwargs) + rebuilt_cues = rebuild_from_cue_parts(cues, translated_per_cue, cue_lines, squelch_single_word=False) + else: + groups_text, groups_masks, groups_counts = build_smart_groups_for_nllb(cue_lines, max_span_cues=args.max_span_cues) + translated_groups = nllb_translate_groups(groups_text, groups_masks, groups_counts, args.model, src_code, tgt_code, + workers, min(args.batch, 16), on_progress, None, cancel_event, dev_kwargs) + percue_lines: List[List[str]] = [] + for grp in translated_groups: percue_lines.extend(grp) + rebuilt_cues = rebuild_from_cue_parts(cues, percue_lines, cue_lines, squelch_single_word=False) + + elif args.engine == "seamless": + print(f"Ensuring SeamlessM4T is cached: {args.seamless_model} …") + ensure_seamless_downloaded(args.seamless_model) + dev_kwargs, workers = pick_device_for_workers(max(1, args.workers), args.device) + print(f"Engine=SeamlessM4T | model={args.seamless_model} | src={src_code}→{tgt_code} | device={dev_kwargs.get('device')} | threads={args.workers} effective={workers} | context={args.context}") + + if args.context == "line": + flat_lines, masks = build_line_items(cue_lines) + translated_flat = seamless_translate_lines(flat_lines, masks, args.seamless_model, src_code, tgt_code, + workers, args.batch, on_progress, None, cancel_event, dev_kwargs) + rebuilt_cues = rebuild_from_flat_lines(cues, cue_lines, translated_flat, squelch_single_word=True) + elif args.context == "cue": + flat_cues, cue_masks, counts = build_cue_items_for_nllb(cue_lines) + translated_per_cue = seamless_translate_cues(flat_cues, cue_masks, counts, args.seamless_model, src_code, tgt_code, + workers, min(args.batch, 32), on_progress, None, cancel_event, dev_kwargs) + rebuilt_cues = rebuild_from_cue_parts(cues, translated_per_cue, cue_lines, squelch_single_word=True) + else: + groups_text, groups_masks, groups_counts = build_smart_groups_for_nllb(cue_lines, max_span_cues=args.max_span_cues) + translated_groups = seamless_translate_groups(groups_text, groups_masks, groups_counts, args.seamless_model, src_code, tgt_code, + workers, min(args.batch, 16), on_progress, None, cancel_event, dev_kwargs) + percue_lines: List[List[str]] = [] + for grp in translated_groups: percue_lines.extend(grp) + rebuilt_cues = rebuild_from_cue_parts(cues, percue_lines, cue_lines, squelch_single_word=True) + + else: # Ollama + host, port = args.ollama_host, args.ollama_port + model = args.ollama_model + print(f"Engine=Ollama | model={model} | src={src_code}→{tgt_code} | host={host}:{port} | threads={args.workers} | context={args.context}") + + if args.context == "line": + flat_lines, masks = build_line_items(cue_lines) + translated_flat = ollama_translate_lines(flat_lines, masks, args.target_language, src_code, model, host, port, + max(1, args.workers), args.batch, on_progress, None, cancel_event) + rebuilt_cues = rebuild_from_flat_lines(cues, cue_lines, translated_flat, squelch_single_word=False) + elif args.context == "cue": + cues_text, cues_masks = build_cue_items_for_llm(cue_lines) + translated_per_cue = ollama_translate_cues(cues_text, cues_masks, args.target_language, src_code, model, host, port, + max(1, args.workers), min(args.batch, 24), on_progress, None, cancel_event) + rebuilt_cues = rebuild_from_cue_parts(cues, translated_per_cue, cue_lines, squelch_single_word=False) + else: + groups_text, groups_masks, groups_counts = build_smart_groups_for_llm(cue_lines, max_span_cues=args.max_span_cues) + translated_groups = ollama_translate_groups(groups_text, groups_masks, groups_counts, args.target_language, src_code, + model, host, port, max(1, args.workers), min(args.batch, 12), on_progress, None, cancel_event) + percue_lines: List[List[str]] = [] + for grp in translated_groups: percue_lines.extend(grp) + rebuilt_cues = rebuild_from_cue_parts(cues, percue_lines, cue_lines, squelch_single_word=False) + + if pbar: pbar.close() + with open(out_path, "w", encoding="utf-8") as f: + f.write(cues_to_srt(rebuilt_cues)) + print(f"✅ Wrote: {out_path}") + +# ===================== GUI flow ===================== +def run_gui(): + import tkinter as tk + from tkinter import ttk, filedialog, messagebox + + root = tk.Tk() + root.title("SRT Translator — NLLB / SeamlessM4T / Ollama") + + # Detect accelerators for NLLB/Seamless + try: + import torch + has_mps = getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() + has_cuda = torch.cuda.is_available() + except Exception: + has_mps = has_cuda = False + + # Load persisted config + cfg = load_config() + last_engine = cfg.get("engine", "NLLB") + last_ollama_model = cfg.get("ollama_model", "qwen3:32b-instruct") + last_ollama_host = cfg.get("ollama_host", "localhost") + last_ollama_port = int(cfg.get("ollama_port", 11434)) + + # Vars + srt_path_var = tk.StringVar() + lang_var = tk.StringVar(value="German") + context_var = tk.StringVar(value="Smart") + engine_var = tk.StringVar(value=last_engine if last_engine in ["NLLB","SeamlessM4T","Ollama"] else "NLLB") + device_var = tk.StringVar(value="Auto") + threads_var = tk.IntVar(value=max(1, (os.cpu_count() or 4) - 2)) + max_span_var = tk.IntVar(value=4) + status_var = tk.StringVar(value="Select file, language, context, engine, device, threads.") + progress_var = tk.IntVar(value=0) + progress_max = tk.IntVar(value=100) + ollama_model_var = tk.StringVar(value=last_ollama_model) + ollama_host_var = tk.StringVar(value=last_ollama_host) + ollama_port_var = tk.IntVar(value=last_ollama_port) + + ollama_models_list: List[str] = [] + + running_flag = tk.BooleanVar(value=False) + cancel_event = threading.Event(); cancel_event.clear() + + pad = {"padx": 8, "pady": 6} + frm = ttk.Frame(root); frm.pack(fill="both", expand=True, **pad) + + # File row + file_row = ttk.Frame(frm); file_row.pack(fill="x", **pad) + ttk.Label(file_row, text="SRT file:").pack(side="left") + file_entry = ttk.Entry(file_row, textvariable=srt_path_var); file_entry.pack(side="left", fill="x", expand=True, padx=6) + ttk.Button(file_row, text="Browse…", command=lambda: _browse_file(srt_path_var)).pack(side="left") + + # Language + Context + row2 = ttk.Frame(frm); row2.pack(fill="x", **pad) + ttk.Label(row2, text="Target language:").pack(side="left") + languages = sorted({k for k in LANG2CODE.keys() if len(k) > 2}) + lang_combo = ttk.Combobox(row2, textvariable=lang_var, values=languages) + lang_combo.pack(side="left", fill="x", expand=True, padx=6) + + # ttk.Label(row2, text="Context:").pack(side="left", padx=(12, 2)) + # context_combo = ttk.Combobox(row2, textvariable=context_var, values=["Smart", "Cue", "Line"], state="readonly", width=8) + # context_combo.pack(side="left") + # ttk.Label(row2, text="Max span:").pack(side="left", padx=(12,2)) + # span_spin = ttk.Spinbox(row2, from_=2, to=12, textvariable=max_span_var, width=5) + # span_spin.pack(side="left") + # ttk.Label(row2, text="(You can type an NLLB code like deu_Latn)").pack(side="left", padx=(12,0)) + + # Engine + Device + Threads + row_engine = ttk.Frame(frm); row_engine.pack(fill="x", **pad) + ttk.Label(row_engine, text="Engine:").pack(side="left") + engine_combo = ttk.Combobox(row_engine, textvariable=engine_var, values=["NLLB", "SeamlessM4T", "Ollama"], state="readonly", width=12) + engine_combo.pack(side="left", padx=6) + ttk.Label(row_engine, text="Device:").pack(side="left") + device_combo = ttk.Combobox(row_engine, textvariable=device_var, values=["Auto", "CPU"] + (["GPU"] if (has_mps or has_cuda) else []), state="readonly", width=8) + device_combo.pack(side="left", padx=6) + # ttk.Label(row_engine, text="Threads:").pack(side="left") + # threads_spin = ttk.Spinbox(row_engine, from_=1, to=max(1, (os.cpu_count() or 8)), textvariable=threads_var, width=6) + # threads_spin.pack(side="left") + + # Ollama model row (shown only when engine=Ollama) + row_ollama = ttk.Frame(frm) + ttk.Label(row_ollama, text="Ollama model:").pack(side="left") + ollama_model_combo = ttk.Combobox(row_ollama, textvariable=ollama_model_var, values=[], width=30) + ollama_model_combo.pack(side="left", padx=6) + # refresh_btn = ttk.Button(row_ollama, text="Refresh", command=lambda: refresh_models()) + # refresh_btn.pack(side="left") + + row_ollama2 = ttk.Frame(frm) + ttk.Label(row_ollama2, text="Ollama host:").pack(side="left") + ollama_host_entry = ttk.Entry(row_ollama2, textvariable=ollama_host_var, width=12); ollama_host_entry.pack(side="left") + ttk.Label(row_ollama2, text="Port:").pack(side="left", padx=(12,2)) + ollama_port_entry = ttk.Spinbox(row_ollama2, from_=1, to=65535, textvariable=ollama_port_var, width=6); ollama_port_entry.pack(side="left") + + def on_device_change(event=None): + return + # if device_var.get().lower() == "gpu" and engine_var.get() in ["NLLB","SeamlessM4T"]: + # threads_spin.configure(state="disabled") + # else: + # threads_spin.configure(state="normal") + + def refresh_models(): + nonlocal ollama_models_list + host = ollama_host_var.get().strip() or "localhost" + port = int(ollama_port_var.get() or 11434) + status_var.set(f"Fetching models from {host}:{port} …") + root.update_idletasks() + models = get_all_ollama_models(host, port) + if not models: + from tkinter import messagebox + messagebox.showwarning("Ollama", "Keine Modelle gefunden (prüfe Ollama-Dienst oder Host/Port).") + return + ollama_models_list = models + ollama_model_combo.configure(values=ollama_models_list) + current = ollama_model_var.get().strip() + if current not in ollama_models_list: + ollama_model_var.set(ollama_models_list[0]) + status_var.set(f"{len(ollama_models_list)} Modelle geladen.") + + cfg = load_config() + cfg["ollama_host"] = host + cfg["ollama_port"] = port + cfg["ollama_model"] = ollama_model_var.get().strip() + save_config(cfg) + + def on_engine_change(event=None): + is_ollama = (engine_var.get() == "Ollama") + + if is_ollama: + # show Ollama rows + row_ollama.pack(fill="x", **pad) + row_ollama2.pack(fill="x", **pad) + + # ALWAYS enable Ollama controls when switching to Ollama + ollama_model_combo.configure(state="normal") # or "readonly" if you prefer no typing + ollama_host_entry.configure(state="normal") + ollama_port_entry.configure(state="normal") + + # Ollama ignores device selection + device_combo.configure(state="disabled") + + # Populate models if needed + if not ollama_model_combo.cget("values"): + refresh_models() + else: + # hide Ollama rows and disable its controls when not on Ollama + row_ollama.pack_forget() + row_ollama2.pack_forget() + ollama_model_combo.configure(state="disabled") + ollama_host_entry.configure(state="disabled") + ollama_port_entry.configure(state="disabled") + + # Re-enable device selection for non-Ollama engines + device_combo.configure(state="readonly") + + device_combo.bind("<>", on_device_change) + engine_combo.bind("<>", on_engine_change) + on_engine_change(); on_device_change() + + # Progress bar + pbar = ttk.Progressbar(frm, orient="horizontal", mode="determinate", maximum=progress_max.get(), variable=progress_var) + + # Controls + ctrl = ttk.Frame(frm); ctrl.pack(fill="x", **pad) + start_btn = ttk.Button(ctrl, text="Translate"); start_btn.pack(side="left") + # status label (optional UI element) + # status_label = ttk.Label(frm, textvariable=status_var, foreground="#555"); status_label.pack(fill="x", **pad) + + def set_controls(enabled: bool): + state = "normal" if enabled else "disabled" + file_entry.configure(state=state); lang_combo.configure(state=state) + # context_combo.configure(state=state); span_spin.configure(state=state) + engine_combo.configure(state=state) + model_state = "normal" if engine_var.get()=="Ollama" else "disabled" + refresh_state = model_state + ollama_model_combo.configure(state=model_state); ollama_host_entry.configure(state=model_state); ollama_port_entry.configure(state=model_state) + # try: refresh_btn.configure(state=refresh_state) + # except Exception: pass + device_combo.configure(state=("readonly" if engine_var.get()!="Ollama" else "disabled")) + # threads_spin.configure(state=("disabled" if (engine_var.get() in ["NLLB","SeamlessM4T"] and device_var.get().lower()=="gpu") else "normal")) + start_btn.configure(state=state) + + def ui_show_progress(total: Optional[int]): + pbar.configure(mode="determinate", maximum=total if total else 100) + progress_max.set(total or 100); progress_var.set(0); pbar.pack(fill="x", **pad) + + def ui_hide_progress(): + pbar.stop(); pbar.pack_forget() + + # Tk-aware tqdm for model download progress + def make_tk_tqdm(): + try: + from tqdm.auto import tqdm as base_tqdm + except Exception: + class Dummy: + def __init__(self, *a, total=None, **kw): self.total=total + def update(self, n=1): pass + def close(self): pass + return Dummy + + class TkTqdm(base_tqdm): + _global_total = 0 + _global_n = 0 + _lock = threading.Lock() + def __init__(self, *a, **kw): + super().__init__(*a, **kw) + with TkTqdm._lock: + if self.total: + TkTqdm._global_total += int(self.total) + root.after(0, lambda: ui_show_progress(max(TkTqdm._global_total, 1))) + def update(self, n=1): + if cancel_event.is_set(): + raise RuntimeError("Cancelled") + res = super().update(n) + with TkTqdm._lock: + TkTqdm._global_n += int(n) + val = max(0, TkTqdm._global_n) + root.after(0, lambda: progress_var.set(val)) + return res + return TkTqdm + + def on_cancel(): + if not running_flag.get(): return + if not __import__("tkinter").messagebox.askyesno("Cancel", "Sicher abbrechen? Unfertige Übersetzung wird verworfen."): + return + cancel_event.set() + status_var.set("Abbruch wird ausgeführt…") + + def start_translation(): + if running_flag.get(): return + path = srt_path_var.get().strip() + if not path or not os.path.exists(path): + from tkinter import messagebox + messagebox.showerror("No file", "Please choose a valid .srt file."); return + + tgt_input = lang_var.get().strip() + tgt_code = _resolve_target_gui(tgt_input, __import__("tkinter").messagebox) + if not tgt_code: return + + context_mode = context_var.get().lower() + max_span = int(max_span_var.get() or 4) + engine = engine_var.get() + + cfg = load_config() + cfg["engine"] = engine + if engine == "Ollama": + cfg["ollama_model"] = ollama_model_var.get().strip() + cfg["ollama_host"] = ollama_host_var.get().strip() or "localhost" + cfg["ollama_port"] = int(ollama_port_var.get() or 11434) + save_config(cfg) + + try: + with open(path, "r", encoding="utf-8", errors="replace") as f: + srt_text = f.read() + except Exception as e: + from tkinter import messagebox + messagebox.showerror("Read error", str(e)); return + + cues = parse_srt(srt_text) + if not cues: + from tkinter import messagebox + messagebox.showerror("Format", "No SRT cues detected. Is the file valid?"); return + + cue_lines: List[List[str]] = [c["text"].splitlines() if c["text"] else [""] for c in cues] + flat_all_lines = [ln for lines in cue_lines for ln in lines] + sample = "\n".join([l for l in flat_all_lines if l.strip()][:20]) + src_code = autodetect_code(sample) + + base, _ = os.path.splitext(path) + out_path = f"{base}.{tgt_code.lower()}.srt" + + requested_threads = max(1, int(threads_var.get() or 1)) + device_mode = device_var.get().lower() + dev_kwargs, effective_workers = pick_device_for_workers(requested_threads, device_mode) if engine in ["NLLB","SeamlessM4T"] else ({"device": -1}, requested_threads) + total_lines = sum(len(ls) for ls in cue_lines) + + running_flag.set(True); cancel_event.clear() + set_controls(False) + start_btn.configure(text="Cancel", command=on_cancel) + start_btn.configure(state="normal") + status_var.set("Vorbereitung…") + + ui_show_progress(100) + + def worker(): + try: + if engine == "NLLB": + status_var.set("Downloading NLLB model (if needed)…") + TkTqdm = make_tk_tqdm() + ensure_model_downloaded("facebook/nllb-200-distilled-600M", tqdm_class=TkTqdm) + elif engine == "SeamlessM4T": + status_var.set("Downloading SeamlessM4T model (if needed)…") + TkTqdm = make_tk_tqdm() + ensure_seamless_downloaded("facebook/seamless-m4t-v2-large", tqdm_class=TkTqdm) + + def start_stage2(): + status_var.set( + f"Translating… (engine={engine}" + + (f", device={dev_kwargs.get('device')}, mode={device_mode}" if engine in ["NLLB","SeamlessM4T"] else "") + + f", threads={requested_threads} effective={effective_workers}, context={context_mode}, maxspan={max_span})" + ) + pbar.configure(mode="determinate", maximum=total_lines) + progress_max.set(total_lines); progress_var.set(0); pbar.pack(fill="x", **pad) + root.after(0, start_stage2) + + processed = {"n": 0} + def on_progress(n=1): + processed["n"] += n + root.after(0, lambda: progress_var.set(processed["n"])) + + # Stage 2: Translation + if engine == "NLLB": + if context_mode == "line": + flat_lines, masks = build_line_items(cue_lines) + translated_flat = nllb_translate_lines(flat_lines, masks, + "facebook/nllb-200-distilled-600M", src_code, tgt_code, + effective_workers, 32, on_progress, None, cancel_event, dev_kwargs) + rebuilt_cues = rebuild_from_flat_lines(cues, cue_lines, translated_flat, squelch_single_word=False) + elif context_mode == "cue": + flat_cues, cue_masks, counts = build_cue_items_for_nllb(cue_lines) + translated_per_cue = nllb_translate_cues(flat_cues, cue_masks, counts, + "facebook/nllb-200-distilled-600M", src_code, tgt_code, + effective_workers, 16, on_progress, None, cancel_event, dev_kwargs) + rebuilt_cues = rebuild_from_cue_parts(cues, translated_per_cue, cue_lines, squelch_single_word=False) + else: + groups_text, groups_masks, groups_counts = build_smart_groups_for_nllb(cue_lines, max_span_cues=max_span) + translated_groups = nllb_translate_groups(groups_text, groups_masks, groups_counts, + "facebook/nllb-200-distilled-600M", src_code, tgt_code, + effective_workers, 12, on_progress, None, cancel_event, dev_kwargs) + percue_lines: List[List[str]] = [] + for grp in translated_groups: percue_lines.extend(grp) + rebuilt_cues = rebuild_from_cue_parts(cues, percue_lines, cue_lines, squelch_single_word=False) + + elif engine == "SeamlessM4T": + model_name = "facebook/seamless-m4t-v2-large" + if context_mode == "line": + flat_lines, masks = build_line_items(cue_lines) + translated_flat = seamless_translate_lines(flat_lines, masks, model_name, src_code, tgt_code, + effective_workers, 32, on_progress, None, cancel_event, dev_kwargs) + rebuilt_cues = rebuild_from_flat_lines(cues, cue_lines, translated_flat, squelch_single_word=True) + elif context_mode == "cue": + flat_cues, cue_masks, counts = build_cue_items_for_nllb(cue_lines) + translated_per_cue = seamless_translate_cues(flat_cues, cue_masks, counts, model_name, src_code, tgt_code, + effective_workers, 16, on_progress, None, cancel_event, dev_kwargs) + rebuilt_cues = rebuild_from_cue_parts(cues, translated_per_cue, cue_lines, squelch_single_word=True) + else: + groups_text, groups_masks, groups_counts = build_smart_groups_for_nllb(cue_lines, max_span_cues=max_span) + translated_groups = seamless_translate_groups(groups_text, groups_masks, groups_counts, model_name, src_code, tgt_code, + effective_workers, 12, on_progress, None, cancel_event, dev_kwargs) + percue_lines: List[List[str]] = [] + for grp in translated_groups: percue_lines.extend(grp) + rebuilt_cues = rebuild_from_cue_parts(cues, percue_lines, cue_lines, squelch_single_word=True) + + else: # Ollama + model = ollama_model_var.get().strip() or "qwen3:32b-instruct" + host = ollama_host_var.get().strip() or "localhost" + port = int(ollama_port_var.get() or 11434) + if context_mode == "line": + flat_lines, masks = build_line_items(cue_lines) + translated_flat = ollama_translate_lines(flat_lines, masks, lang_var.get(), src_code, model, host, port, + requested_threads, 32, on_progress, None, cancel_event) + rebuilt_cues = rebuild_from_flat_lines(cues, cue_lines, translated_flat, squelch_single_word=False) + elif context_mode == "cue": + cues_text, cues_masks = build_cue_items_for_llm(cue_lines) + translated_per_cue = ollama_translate_cues(cues_text, cues_masks, lang_var.get(), src_code, model, host, port, + requested_threads, 16, on_progress, None, cancel_event) + rebuilt_cues = rebuild_from_cue_parts(cues, translated_per_cue, cue_lines, squelch_single_word=False) + else: + groups_text, groups_masks, groups_counts = build_smart_groups_for_llm(cue_lines, max_span_cues=max_span) + translated_groups = ollama_translate_groups(groups_text, groups_masks, groups_counts, lang_var.get(), src_code, + model, host, port, requested_threads, 12, on_progress, None, cancel_event) + percue_lines: List[List[str]] = [] + for grp in translated_groups: percue_lines.extend(grp) + rebuilt_cues = rebuild_from_cue_parts(cues, percue_lines, cue_lines, squelch_single_word=False) + + with open(out_path, "w", encoding="utf-8") as f: + f.write(cues_to_srt(rebuilt_cues)) + + except RuntimeError as e: + if "Cancelled" in str(e): + return _fail("❌ Abgebrochen.") + return _fail(f"Translation failed: {e}") + except Exception as e: + return _fail(f"Translation failed: {e}") + + _done(out_path) + + def _fail(msg: str): + running_flag.set(False) + set_controls(True); status_var.set(msg); ui_hide_progress() + start_btn.configure(text="Translate", command=start_translation) + + def _done(path_out: str): + running_flag.set(False) + set_controls(True); status_var.set(f"✅ Done. Wrote: {path_out}"); ui_hide_progress() + start_btn.configure(text="Translate", command=start_translation) + + t = threading.Thread(target=worker, daemon=True); t.start() + + start_btn.configure(command=start_translation) + root.mainloop() + +# ===================== Utils ===================== +def _browse_file(var): + from tkinter import filedialog + path = filedialog.askopenfilename(title="Select .srt file", + filetypes=[("SubRip Subtitle", "*.srt"), ("All files", "*.*")]) + if path: var.set(path) + +def _resolve_target(target_language: str) -> str: + tgt_code = norm_lang_to_code(target_language) + if tgt_code is None: + if re.match(r"^[a-z]{3}_[A-Za-z]{4}$", target_language.strip()): + tgt_code = target_language.strip() + else: + raise ValueError(f"Unrecognized target language: {target_language}") + return tgt_code + +def _resolve_target_gui(target_language: str, messagebox): + try: + return _resolve_target(target_language) + except Exception: + messagebox.showerror("Language", f"Unrecognized target language: {target_language}") + return None + +def _resolve_source(src: Optional[str], flat_lines: List[str]) -> str: + if src: + code = norm_lang_to_code(src) + if code is None: + if re.match(r"^[a-z]{3}_[A-Za-z]{4}$", src.strip()): + code = src.strip() + else: + raise ValueError(f"Unrecognized source language: {src}") + return code + sample = "\n".join([l for l in flat_lines if l.strip()][:20]) + return autodetect_code(sample) + +# ===================== Main ===================== +def main(): + if len(sys.argv) == 1: + run_gui(); return + + ap = argparse.ArgumentParser(description="Translate an SRT with NLLB, SeamlessM4T, or Ollama, preserving structure & inline tags.") + ap.add_argument("srt_file", help="Path to input .srt") + ap.add_argument("target_language", help='Target language (name like "German" or NLLB code like "deu_Latn")') + ap.add_argument("--src", default=None, help="Source language (name or NLLB code). If omitted, auto-detect.") + # Engine + ap.add_argument("--engine", choices=["nllb","seamless","ollama"], default="nllb", help="Translation engine") + # NLLB opts + ap.add_argument("--model", default="facebook/nllb-200-distilled-600M", help="Transformers model repo (NLLB)") + # Seamless opts + ap.add_argument("--seamless-model", default="facebook/seamless-m4t-v2-large", help="Transformers model repo (SeamlessM4T)") + # Device (used for NLLB & SeamlessM4T) + ap.add_argument("--device", choices=["auto","cpu","gpu"], default="auto", help="Compute device for NLLB/SeamlessM4T") + # Ollama opts + ap.add_argument("--ollama-model", default="qwen3:32b-instruct", help="Ollama model tag") + ap.add_argument("--ollama-host", default="localhost") + ap.add_argument("--ollama-port", type=int, default=11434) + # Shared + ap.add_argument("--batch", type=int, default=32, help="Batch size per worker (lines/cues/groups)") + ap.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 2), help="Number of parallel workers") + ap.add_argument("--context", choices=["line","cue","smart"], default="smart", help="Context granularity") + ap.add_argument("--max-span-cues", type=int, default=4, help="Smart mode: max cues per group") + ap.add_argument("--out", default=None, help="Output .srt path (default: ..srt)") + ap.add_argument("--no-progress", action="store_true", help="Disable CLI progress bar") + + args = ap.parse_args() + run_cli(args) + +if __name__ == "__main__": + main() \ No newline at end of file