From 4b3d510bcfd51c36d6e10bfc87a3aa02f55933bb Mon Sep 17 00:00:00 2001 From: Victor Giers Date: Fri, 17 Apr 2026 12:54:03 +0200 Subject: [PATCH] Refactor attachment handling and add support for file attachments --- backend/main.py | 426 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 392 insertions(+), 34 deletions(-) diff --git a/backend/main.py b/backend/main.py index f72aa85..70d69dc 100644 --- a/backend/main.py +++ b/backend/main.py @@ -3,11 +3,15 @@ from fastapi import FastAPI, Depends, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session -from typing import Any, List +from typing import Any, List, Optional import re import html import json import base64 +import mimetypes +import shutil +import tempfile +from pathlib import Path from . import models, schemas from .database import Base, engine, SessionLocal, ensure_sources_column from .local_rag import router as local_rag_router @@ -56,6 +60,12 @@ app.add_middleware( app.include_router(local_rag_router) _IMAGE_DATA_URL_RE = re.compile(r"^data:(image\/[a-z0-9.+-]+);base64,([a-z0-9+/=\s]+)$", re.IGNORECASE) +_CHAT_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".heic", ".avif"} +_CHAT_FILE_EXTENSIONS = { + ".pdf", ".html", ".htm", ".txt", ".md", ".rst", ".epub", + ".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus", ".aac", + ".mp4", ".mkv", ".mov", ".webm", ".avi", ".ts", +} def _attachment_field(item: Any, field: str) -> Any: @@ -64,35 +74,90 @@ def _attachment_field(item: Any, field: str) -> Any: return getattr(item, field, None) -def _normalize_image_attachments(items: Any) -> List[dict]: +def _normalize_attachment_name(value: Any, fallback: str) -> str: + name = str(value or "").strip() or fallback + return name[:255] + + +def _normalize_image_attachment(item: Any) -> Optional[dict]: + data_url = str(_attachment_field(item, "data_url") or "").strip() + match = _IMAGE_DATA_URL_RE.match(data_url) + if not match: + return None + + detected_mime = match.group(1).lower() + payload = re.sub(r"\s+", "", match.group(2)) + mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower() + if mime_type and not mime_type.startswith("image/"): + return None + + try: + base64.b64decode(payload, validate=True) + except Exception: + return None + + return { + "kind": "image", + "name": _normalize_attachment_name(_attachment_field(item, "name"), "image"), + "mime_type": mime_type or detected_mime, + "data_url": f"data:{detected_mime};base64,{payload}", + } + + +def _normalize_file_attachment(item: Any, *, include_text: bool = False) -> Optional[dict]: + source_path_raw = str(_attachment_field(item, "source_path") or "").strip() + source_path = "" + if source_path_raw: + try: + source_path = str(Path(source_path_raw).expanduser().resolve()) + except Exception: + source_path = source_path_raw + + text_value = str(_attachment_field(item, "text") or "") + name_seed = _attachment_field(item, "name") or (Path(source_path).name if source_path else "file") + name = _normalize_attachment_name(name_seed, "file") + mime_type = str(_attachment_field(item, "mime_type") or "").split(";", 1)[0].strip().lower() + record_type = str(_attachment_field(item, "record_type") or "").strip().lower() or None + size = _attachment_field(item, "size") + try: + size = max(0, int(size)) if size is not None else None + except Exception: + size = None + + extension = Path(source_path or name).suffix.lower() + if source_path and extension not in _CHAT_FILE_EXTENSIONS: + return None + if not source_path and not text_value.strip(): + return None + + payload = { + "kind": "file", + "name": name, + "mime_type": mime_type or (mimetypes.guess_type(name)[0] or "").lower() or None, + "source_path": source_path or None, + "size": size, + "record_type": record_type, + } + if include_text and text_value.strip(): + payload["text"] = text_value.strip() + return payload + + +def _normalize_chat_attachments(items: Any, *, include_text: bool = False) -> List[dict]: cleaned: List[dict] = [] for item in items or []: + kind = str(_attachment_field(item, "kind") or "").strip().lower() data_url = str(_attachment_field(item, "data_url") or "").strip() - name = str(_attachment_field(item, "name") or "image").strip() or "image" - mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower() - match = _IMAGE_DATA_URL_RE.match(data_url) - if not match: - continue - - detected_mime = match.group(1).lower() - payload = re.sub(r"\s+", "", match.group(2)) - if mime_type and not mime_type.startswith("image/"): - continue - - try: - base64.b64decode(payload, validate=True) - except Exception: - continue - - cleaned.append({ - "name": name[:255], - "mime_type": mime_type or detected_mime, - "data_url": f"data:{detected_mime};base64,{payload}", - }) + if kind == "image" or (not kind and data_url): + normalized = _normalize_image_attachment(item) + else: + normalized = _normalize_file_attachment(item, include_text=include_text) + if normalized: + cleaned.append(normalized) return cleaned -def _load_message_attachments(raw_value: Any) -> List[dict]: +def _load_message_attachments(raw_value: Any, *, include_text: bool = False) -> List[dict]: if isinstance(raw_value, str): try: parsed = json.loads(raw_value or "[]") @@ -100,12 +165,25 @@ def _load_message_attachments(raw_value: Any) -> List[dict]: parsed = [] else: parsed = raw_value - return _normalize_image_attachments(parsed) + return _normalize_chat_attachments(parsed, include_text=include_text) + + +def _attachment_is_image(attachment: dict) -> bool: + kind = str(attachment.get("kind") or "").strip().lower() + if kind == "image": + return True + return bool(_IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())) + + +def _attachment_is_file(attachment: dict) -> bool: + return not _attachment_is_image(attachment) def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]: images: List[str] = [] for attachment in attachments: + if not _attachment_is_image(attachment): + continue match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip()) if not match: continue @@ -113,6 +191,294 @@ def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]: return images +def _attachment_history_payload(attachment: dict) -> dict: + payload = { + "kind": attachment.get("kind") or ("image" if _attachment_is_image(attachment) else "file"), + "name": attachment.get("name"), + "mime_type": attachment.get("mime_type"), + } + if _attachment_is_image(attachment): + payload["data_url"] = attachment.get("data_url") + else: + payload["source_path"] = attachment.get("source_path") + payload["size"] = attachment.get("size") + payload["record_type"] = attachment.get("record_type") + return {key: value for key, value in payload.items() if value not in (None, "", [])} + + +async def _model_supports_vision(model_name: Optional[str]) -> bool: + normalized = str(model_name or "").strip() + if not normalized: + return False + try: + model_data = await ollama_show_model(normalized) + except Exception as exc: + raise HTTPException(status_code=502, detail=f"Ollama not available: {exc}") from exc + return ollama_supports_vision(model_data) + + +def _safe_temp_attachment_name(index: int, name: str, fallback_suffix: str = "") -> str: + safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._") or f"attachment-{index + 1}" + suffix = Path(safe_name).suffix + if fallback_suffix and not suffix: + safe_name = f"{safe_name}{fallback_suffix}" + return f"{index:03d}-{safe_name}" + + +def _materialize_attachment_source(attachment: dict, stage_root: Path, index: int) -> tuple[Path, dict]: + if _attachment_is_image(attachment): + match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip()) + if not match: + raise RuntimeError(f"{attachment.get('name') or 'image'}: invalid image payload.") + mime_type = match.group(1).lower() + payload = re.sub(r"\s+", "", match.group(2)) + try: + image_bytes = base64.b64decode(payload, validate=True) + except Exception as exc: + raise RuntimeError(f"{attachment.get('name') or 'image'}: invalid base64 image payload.") from exc + suffix = Path(str(attachment.get("name") or "")).suffix or mimetypes.guess_extension(mime_type) or ".img" + target_path = stage_root / _safe_temp_attachment_name(index, str(attachment.get("name") or "image"), suffix) + target_path.write_bytes(image_bytes) + return target_path, { + "source_path": str(target_path.resolve()), + "size": len(image_bytes), + } + + source_path = str(attachment.get("source_path") or "").strip() + if not source_path: + raise RuntimeError(f"{attachment.get('name') or 'file'}: local file path is required.") + + source = Path(source_path).expanduser().resolve() + if not source.exists() or not source.is_file(): + raise RuntimeError(f"{attachment.get('name') or source.name}: file is no longer available.") + if source.suffix.lower() not in _CHAT_FILE_EXTENSIONS: + raise RuntimeError(f"{source.name}: unsupported file type for chat attachment.") + + target_path = stage_root / _safe_temp_attachment_name(index, source.name) + try: + target_path.symlink_to(source) + except Exception: + shutil.copy2(source, target_path) + return target_path, { + "source_path": str(source), + "size": int(source.stat().st_size), + } + + +def _read_corpus_records(path: Path) -> List[dict]: + records: List[dict] = [] + if not path.exists(): + return records + with path.open("r", encoding="utf-8", errors="ignore") as handle: + for line in handle: + line = line.strip() + if not line: + continue + records.append(json.loads(line)) + return records + + +def _extract_attachment_texts( + attachments: List[dict], + *, + vision_model: Optional[str], + transcription_model: Optional[str], +) -> List[dict]: + from .rag.corpus_builder import run_build + + if not attachments: + return [] + + with tempfile.TemporaryDirectory(prefix="heimgeist-chat-attachments-") as tmpdir: + temp_root = Path(tmpdir) + stage_root = temp_root / "stage" + stage_root.mkdir(parents=True, exist_ok=True) + corpus_path = temp_root / "corpus.jsonl" + + index_by_key: dict[str, int] = {} + attachment_meta: List[dict] = [] + for index, attachment in enumerate(attachments): + temp_path, extra_meta = _materialize_attachment_source(attachment, stage_root, index) + index_by_key[str(temp_path)] = index + index_by_key[str(temp_path.resolve())] = index + attachment_meta.append(extra_meta) + + result = run_build( + root=stage_root, + out=corpus_path, + emit="per-file", + emit_av="joined", + lang_detect=False, + whisper_model=transcription_model or DEFAULT_WHISPER_MODEL, + vlm_model=vision_model or "qwen2.5vl:7b", + ) + + errors = [str(error).strip() for error in (result.get("errors") or []) if str(error).strip()] + if errors: + raise RuntimeError("; ".join(errors)) + + extracted: dict[int, dict[str, Any]] = {} + for record in _read_corpus_records(corpus_path): + record_id = str(record.get("id") or "").split("#", 1)[0].strip() + source_path = str(record.get("source_path") or "").strip() + matched_index = None + for candidate in (record_id, source_path): + if candidate in index_by_key: + matched_index = index_by_key[candidate] + break + if matched_index is None: + continue + + bucket = extracted.setdefault(matched_index, { + "mime_type": "", + "record_type": "", + "text_parts": [], + }) + if not bucket["mime_type"]: + bucket["mime_type"] = str(record.get("mime") or "").strip().lower() + if not bucket["record_type"]: + bucket["record_type"] = str(record.get("record_type") or "").strip().lower() + + text = str(record.get("text") or "").strip() + if text: + bucket["text_parts"].append(text) + + prepared: List[dict] = [] + for index, attachment in enumerate(attachments): + bucket = extracted.get(index) or {} + text = "\n\n".join( + part for part in bucket.get("text_parts") or [] + if str(part or "").strip() + ).strip() + if not text: + raise RuntimeError(f"{attachment.get('name') or 'attachment'}: no usable text could be extracted.") + + extra_meta = attachment_meta[index] + prepared.append({ + **attachment, + "mime_type": bucket.get("mime_type") or attachment.get("mime_type"), + "record_type": bucket.get("record_type") or attachment.get("record_type"), + "source_path": extra_meta.get("source_path") or attachment.get("source_path"), + "size": extra_meta.get("size") if extra_meta.get("size") is not None else attachment.get("size"), + "text": text, + }) + + return prepared + + +def _build_attachment_block(attachments: List[dict]) -> str: + sections: List[str] = [] + for index, attachment in enumerate(attachments, start=1): + text = str(attachment.get("text") or "").strip() + if not text: + continue + lines = [ + f"[Attachment {index}]", + f"Name: {attachment.get('name') or f'attachment-{index}'}", + f"Kind: {'image' if _attachment_is_image(attachment) else 'file'}", + ] + mime_type = str(attachment.get("mime_type") or "").strip() + record_type = str(attachment.get("record_type") or "").strip() + if mime_type: + lines.append(f"MIME: {mime_type}") + if record_type: + lines.append(f"Record type: {record_type}") + lines.append("Content:") + lines.append(text) + sections.append("\n".join(lines)) + + if not sections: + return "" + return "\n" + "\n\n".join(sections) + "\n" + + +def _compose_user_message_content(message: str, enriched_message: Optional[str], attachment_block: str) -> str: + base_content = str(enriched_message or "").strip() or str(message or "").strip() + parts: List[str] = [] + if base_content: + parts.append(base_content) + elif attachment_block: + parts.append("Please analyze the attached material.") + if attachment_block: + parts.append(attachment_block) + if not parts: + return "Please analyze the attached material." + return "\n\n".join(part for part in parts if str(part or "").strip()).strip() + + +def _attachments_have_images(attachments: List[dict]) -> bool: + return any(_attachment_is_image(attachment) for attachment in attachments or []) + + +async def _prepare_chat_message_attachments( + attachments: List[dict], + *, + request_model_supports_vision: bool, + vision_model: Optional[str], + transcription_model: Optional[str], + persist_file_text: bool, +) -> tuple[List[dict], List[str], str]: + prepared = _normalize_chat_attachments(attachments, include_text=True) + if not prepared: + return [], [], "" + + if _attachments_have_images(prepared) and not request_model_supports_vision: + if not str(vision_model or "").strip(): + raise HTTPException( + status_code=400, + detail="Image attachments require a configured vision model when the selected chat model does not support vision.", + ) + if not await _model_supports_vision(vision_model): + raise HTTPException( + status_code=400, + detail="The configured vision model does not support image inputs.", + ) + + prompt_entries: List[Optional[dict]] = [None] * len(prepared) + attachments_to_extract: List[dict] = [] + extract_indexes: List[int] = [] + + for index, attachment in enumerate(prepared): + if _attachment_is_image(attachment): + if request_model_supports_vision: + continue + attachments_to_extract.append(attachment) + extract_indexes.append(index) + continue + + text = str(attachment.get("text") or "").strip() + if text: + prompt_entries[index] = attachment + else: + attachments_to_extract.append(attachment) + extract_indexes.append(index) + + if attachments_to_extract: + extracted_attachments = await asyncio.to_thread( + _extract_attachment_texts, + attachments_to_extract, + vision_model=vision_model, + transcription_model=transcription_model, + ) + for original_index, extracted in zip(extract_indexes, extracted_attachments): + if _attachment_is_file(prepared[original_index]) and persist_file_text: + prepared[original_index] = { + **prepared[original_index], + "mime_type": extracted.get("mime_type") or prepared[original_index].get("mime_type"), + "record_type": extracted.get("record_type") or prepared[original_index].get("record_type"), + "source_path": extracted.get("source_path") or prepared[original_index].get("source_path"), + "size": extracted.get("size") if extracted.get("size") is not None else prepared[original_index].get("size"), + "text": extracted.get("text") or prepared[original_index].get("text"), + } + prompt_entries[original_index] = prepared[original_index] + else: + prompt_entries[original_index] = extracted + + prompt_attachments = [entry for entry in prompt_entries if entry and str(entry.get("text") or "").strip()] + ollama_images = _attachments_to_ollama_images(prepared) if request_model_supports_vision else [] + return prepared, ollama_images, _build_attachment_block(prompt_attachments) + + def _row_to_history_message(row: models.ChatMessage) -> dict: sources = [] try: @@ -124,17 +490,9 @@ def _row_to_history_message(row: models.ChatMessage) -> dict: attachments = _load_message_attachments(getattr(row, "attachments_json", None)) payload = {"role": row.role, "content": row.content, "sources": sources} if attachments: - payload["attachments"] = attachments + payload["attachments"] = [_attachment_history_payload(attachment) for attachment in attachments] return payload - -def _row_to_ollama_message(row: models.ChatMessage) -> dict: - message = {"role": row.role, "content": row.content} - attachments = _load_message_attachments(getattr(row, "attachments_json", None)) - if attachments: - message["images"] = _attachments_to_ollama_images(attachments) - return message - def get_db(): db = SessionLocal() try: