Refactor attachment handling and add support for file attachments

This commit is contained in:
2026-04-17 12:54:03 +02:00
parent 67195af4c1
commit 4b3d510bcf

View File

@@ -3,11 +3,15 @@ from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from typing import Any, List
from typing import Any, List, Optional
import re
import html
import json
import base64
import mimetypes
import shutil
import tempfile
from pathlib import Path
from . import models, schemas
from .database import Base, engine, SessionLocal, ensure_sources_column
from .local_rag import router as local_rag_router
@@ -56,6 +60,12 @@ app.add_middleware(
app.include_router(local_rag_router)
_IMAGE_DATA_URL_RE = re.compile(r"^data:(image\/[a-z0-9.+-]+);base64,([a-z0-9+/=\s]+)$", re.IGNORECASE)
_CHAT_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".heic", ".avif"}
_CHAT_FILE_EXTENSIONS = {
".pdf", ".html", ".htm", ".txt", ".md", ".rst", ".epub",
".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus", ".aac",
".mp4", ".mkv", ".mov", ".webm", ".avi", ".ts",
}
def _attachment_field(item: Any, field: str) -> Any:
@@ -64,35 +74,90 @@ def _attachment_field(item: Any, field: str) -> Any:
return getattr(item, field, None)
def _normalize_image_attachments(items: Any) -> List[dict]:
def _normalize_attachment_name(value: Any, fallback: str) -> str:
name = str(value or "").strip() or fallback
return name[:255]
def _normalize_image_attachment(item: Any) -> Optional[dict]:
data_url = str(_attachment_field(item, "data_url") or "").strip()
match = _IMAGE_DATA_URL_RE.match(data_url)
if not match:
return None
detected_mime = match.group(1).lower()
payload = re.sub(r"\s+", "", match.group(2))
mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower()
if mime_type and not mime_type.startswith("image/"):
return None
try:
base64.b64decode(payload, validate=True)
except Exception:
return None
return {
"kind": "image",
"name": _normalize_attachment_name(_attachment_field(item, "name"), "image"),
"mime_type": mime_type or detected_mime,
"data_url": f"data:{detected_mime};base64,{payload}",
}
def _normalize_file_attachment(item: Any, *, include_text: bool = False) -> Optional[dict]:
source_path_raw = str(_attachment_field(item, "source_path") or "").strip()
source_path = ""
if source_path_raw:
try:
source_path = str(Path(source_path_raw).expanduser().resolve())
except Exception:
source_path = source_path_raw
text_value = str(_attachment_field(item, "text") or "")
name_seed = _attachment_field(item, "name") or (Path(source_path).name if source_path else "file")
name = _normalize_attachment_name(name_seed, "file")
mime_type = str(_attachment_field(item, "mime_type") or "").split(";", 1)[0].strip().lower()
record_type = str(_attachment_field(item, "record_type") or "").strip().lower() or None
size = _attachment_field(item, "size")
try:
size = max(0, int(size)) if size is not None else None
except Exception:
size = None
extension = Path(source_path or name).suffix.lower()
if source_path and extension not in _CHAT_FILE_EXTENSIONS:
return None
if not source_path and not text_value.strip():
return None
payload = {
"kind": "file",
"name": name,
"mime_type": mime_type or (mimetypes.guess_type(name)[0] or "").lower() or None,
"source_path": source_path or None,
"size": size,
"record_type": record_type,
}
if include_text and text_value.strip():
payload["text"] = text_value.strip()
return payload
def _normalize_chat_attachments(items: Any, *, include_text: bool = False) -> List[dict]:
cleaned: List[dict] = []
for item in items or []:
kind = str(_attachment_field(item, "kind") or "").strip().lower()
data_url = str(_attachment_field(item, "data_url") or "").strip()
name = str(_attachment_field(item, "name") or "image").strip() or "image"
mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower()
match = _IMAGE_DATA_URL_RE.match(data_url)
if not match:
continue
detected_mime = match.group(1).lower()
payload = re.sub(r"\s+", "", match.group(2))
if mime_type and not mime_type.startswith("image/"):
continue
try:
base64.b64decode(payload, validate=True)
except Exception:
continue
cleaned.append({
"name": name[:255],
"mime_type": mime_type or detected_mime,
"data_url": f"data:{detected_mime};base64,{payload}",
})
if kind == "image" or (not kind and data_url):
normalized = _normalize_image_attachment(item)
else:
normalized = _normalize_file_attachment(item, include_text=include_text)
if normalized:
cleaned.append(normalized)
return cleaned
def _load_message_attachments(raw_value: Any) -> List[dict]:
def _load_message_attachments(raw_value: Any, *, include_text: bool = False) -> List[dict]:
if isinstance(raw_value, str):
try:
parsed = json.loads(raw_value or "[]")
@@ -100,12 +165,25 @@ def _load_message_attachments(raw_value: Any) -> List[dict]:
parsed = []
else:
parsed = raw_value
return _normalize_image_attachments(parsed)
return _normalize_chat_attachments(parsed, include_text=include_text)
def _attachment_is_image(attachment: dict) -> bool:
kind = str(attachment.get("kind") or "").strip().lower()
if kind == "image":
return True
return bool(_IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip()))
def _attachment_is_file(attachment: dict) -> bool:
return not _attachment_is_image(attachment)
def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]:
images: List[str] = []
for attachment in attachments:
if not _attachment_is_image(attachment):
continue
match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())
if not match:
continue
@@ -113,6 +191,294 @@ def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]:
return images
def _attachment_history_payload(attachment: dict) -> dict:
payload = {
"kind": attachment.get("kind") or ("image" if _attachment_is_image(attachment) else "file"),
"name": attachment.get("name"),
"mime_type": attachment.get("mime_type"),
}
if _attachment_is_image(attachment):
payload["data_url"] = attachment.get("data_url")
else:
payload["source_path"] = attachment.get("source_path")
payload["size"] = attachment.get("size")
payload["record_type"] = attachment.get("record_type")
return {key: value for key, value in payload.items() if value not in (None, "", [])}
async def _model_supports_vision(model_name: Optional[str]) -> bool:
normalized = str(model_name or "").strip()
if not normalized:
return False
try:
model_data = await ollama_show_model(normalized)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Ollama not available: {exc}") from exc
return ollama_supports_vision(model_data)
def _safe_temp_attachment_name(index: int, name: str, fallback_suffix: str = "") -> str:
safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._") or f"attachment-{index + 1}"
suffix = Path(safe_name).suffix
if fallback_suffix and not suffix:
safe_name = f"{safe_name}{fallback_suffix}"
return f"{index:03d}-{safe_name}"
def _materialize_attachment_source(attachment: dict, stage_root: Path, index: int) -> tuple[Path, dict]:
if _attachment_is_image(attachment):
match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())
if not match:
raise RuntimeError(f"{attachment.get('name') or 'image'}: invalid image payload.")
mime_type = match.group(1).lower()
payload = re.sub(r"\s+", "", match.group(2))
try:
image_bytes = base64.b64decode(payload, validate=True)
except Exception as exc:
raise RuntimeError(f"{attachment.get('name') or 'image'}: invalid base64 image payload.") from exc
suffix = Path(str(attachment.get("name") or "")).suffix or mimetypes.guess_extension(mime_type) or ".img"
target_path = stage_root / _safe_temp_attachment_name(index, str(attachment.get("name") or "image"), suffix)
target_path.write_bytes(image_bytes)
return target_path, {
"source_path": str(target_path.resolve()),
"size": len(image_bytes),
}
source_path = str(attachment.get("source_path") or "").strip()
if not source_path:
raise RuntimeError(f"{attachment.get('name') or 'file'}: local file path is required.")
source = Path(source_path).expanduser().resolve()
if not source.exists() or not source.is_file():
raise RuntimeError(f"{attachment.get('name') or source.name}: file is no longer available.")
if source.suffix.lower() not in _CHAT_FILE_EXTENSIONS:
raise RuntimeError(f"{source.name}: unsupported file type for chat attachment.")
target_path = stage_root / _safe_temp_attachment_name(index, source.name)
try:
target_path.symlink_to(source)
except Exception:
shutil.copy2(source, target_path)
return target_path, {
"source_path": str(source),
"size": int(source.stat().st_size),
}
def _read_corpus_records(path: Path) -> List[dict]:
records: List[dict] = []
if not path.exists():
return records
with path.open("r", encoding="utf-8", errors="ignore") as handle:
for line in handle:
line = line.strip()
if not line:
continue
records.append(json.loads(line))
return records
def _extract_attachment_texts(
attachments: List[dict],
*,
vision_model: Optional[str],
transcription_model: Optional[str],
) -> List[dict]:
from .rag.corpus_builder import run_build
if not attachments:
return []
with tempfile.TemporaryDirectory(prefix="heimgeist-chat-attachments-") as tmpdir:
temp_root = Path(tmpdir)
stage_root = temp_root / "stage"
stage_root.mkdir(parents=True, exist_ok=True)
corpus_path = temp_root / "corpus.jsonl"
index_by_key: dict[str, int] = {}
attachment_meta: List[dict] = []
for index, attachment in enumerate(attachments):
temp_path, extra_meta = _materialize_attachment_source(attachment, stage_root, index)
index_by_key[str(temp_path)] = index
index_by_key[str(temp_path.resolve())] = index
attachment_meta.append(extra_meta)
result = run_build(
root=stage_root,
out=corpus_path,
emit="per-file",
emit_av="joined",
lang_detect=False,
whisper_model=transcription_model or DEFAULT_WHISPER_MODEL,
vlm_model=vision_model or "qwen2.5vl:7b",
)
errors = [str(error).strip() for error in (result.get("errors") or []) if str(error).strip()]
if errors:
raise RuntimeError("; ".join(errors))
extracted: dict[int, dict[str, Any]] = {}
for record in _read_corpus_records(corpus_path):
record_id = str(record.get("id") or "").split("#", 1)[0].strip()
source_path = str(record.get("source_path") or "").strip()
matched_index = None
for candidate in (record_id, source_path):
if candidate in index_by_key:
matched_index = index_by_key[candidate]
break
if matched_index is None:
continue
bucket = extracted.setdefault(matched_index, {
"mime_type": "",
"record_type": "",
"text_parts": [],
})
if not bucket["mime_type"]:
bucket["mime_type"] = str(record.get("mime") or "").strip().lower()
if not bucket["record_type"]:
bucket["record_type"] = str(record.get("record_type") or "").strip().lower()
text = str(record.get("text") or "").strip()
if text:
bucket["text_parts"].append(text)
prepared: List[dict] = []
for index, attachment in enumerate(attachments):
bucket = extracted.get(index) or {}
text = "\n\n".join(
part for part in bucket.get("text_parts") or []
if str(part or "").strip()
).strip()
if not text:
raise RuntimeError(f"{attachment.get('name') or 'attachment'}: no usable text could be extracted.")
extra_meta = attachment_meta[index]
prepared.append({
**attachment,
"mime_type": bucket.get("mime_type") or attachment.get("mime_type"),
"record_type": bucket.get("record_type") or attachment.get("record_type"),
"source_path": extra_meta.get("source_path") or attachment.get("source_path"),
"size": extra_meta.get("size") if extra_meta.get("size") is not None else attachment.get("size"),
"text": text,
})
return prepared
def _build_attachment_block(attachments: List[dict]) -> str:
sections: List[str] = []
for index, attachment in enumerate(attachments, start=1):
text = str(attachment.get("text") or "").strip()
if not text:
continue
lines = [
f"[Attachment {index}]",
f"Name: {attachment.get('name') or f'attachment-{index}'}",
f"Kind: {'image' if _attachment_is_image(attachment) else 'file'}",
]
mime_type = str(attachment.get("mime_type") or "").strip()
record_type = str(attachment.get("record_type") or "").strip()
if mime_type:
lines.append(f"MIME: {mime_type}")
if record_type:
lines.append(f"Record type: {record_type}")
lines.append("Content:")
lines.append(text)
sections.append("\n".join(lines))
if not sections:
return ""
return "<chat_attachments>\n" + "\n\n".join(sections) + "\n</chat_attachments>"
def _compose_user_message_content(message: str, enriched_message: Optional[str], attachment_block: str) -> str:
base_content = str(enriched_message or "").strip() or str(message or "").strip()
parts: List[str] = []
if base_content:
parts.append(base_content)
elif attachment_block:
parts.append("Please analyze the attached material.")
if attachment_block:
parts.append(attachment_block)
if not parts:
return "Please analyze the attached material."
return "\n\n".join(part for part in parts if str(part or "").strip()).strip()
def _attachments_have_images(attachments: List[dict]) -> bool:
return any(_attachment_is_image(attachment) for attachment in attachments or [])
async def _prepare_chat_message_attachments(
attachments: List[dict],
*,
request_model_supports_vision: bool,
vision_model: Optional[str],
transcription_model: Optional[str],
persist_file_text: bool,
) -> tuple[List[dict], List[str], str]:
prepared = _normalize_chat_attachments(attachments, include_text=True)
if not prepared:
return [], [], ""
if _attachments_have_images(prepared) and not request_model_supports_vision:
if not str(vision_model or "").strip():
raise HTTPException(
status_code=400,
detail="Image attachments require a configured vision model when the selected chat model does not support vision.",
)
if not await _model_supports_vision(vision_model):
raise HTTPException(
status_code=400,
detail="The configured vision model does not support image inputs.",
)
prompt_entries: List[Optional[dict]] = [None] * len(prepared)
attachments_to_extract: List[dict] = []
extract_indexes: List[int] = []
for index, attachment in enumerate(prepared):
if _attachment_is_image(attachment):
if request_model_supports_vision:
continue
attachments_to_extract.append(attachment)
extract_indexes.append(index)
continue
text = str(attachment.get("text") or "").strip()
if text:
prompt_entries[index] = attachment
else:
attachments_to_extract.append(attachment)
extract_indexes.append(index)
if attachments_to_extract:
extracted_attachments = await asyncio.to_thread(
_extract_attachment_texts,
attachments_to_extract,
vision_model=vision_model,
transcription_model=transcription_model,
)
for original_index, extracted in zip(extract_indexes, extracted_attachments):
if _attachment_is_file(prepared[original_index]) and persist_file_text:
prepared[original_index] = {
**prepared[original_index],
"mime_type": extracted.get("mime_type") or prepared[original_index].get("mime_type"),
"record_type": extracted.get("record_type") or prepared[original_index].get("record_type"),
"source_path": extracted.get("source_path") or prepared[original_index].get("source_path"),
"size": extracted.get("size") if extracted.get("size") is not None else prepared[original_index].get("size"),
"text": extracted.get("text") or prepared[original_index].get("text"),
}
prompt_entries[original_index] = prepared[original_index]
else:
prompt_entries[original_index] = extracted
prompt_attachments = [entry for entry in prompt_entries if entry and str(entry.get("text") or "").strip()]
ollama_images = _attachments_to_ollama_images(prepared) if request_model_supports_vision else []
return prepared, ollama_images, _build_attachment_block(prompt_attachments)
def _row_to_history_message(row: models.ChatMessage) -> dict:
sources = []
try:
@@ -124,17 +490,9 @@ def _row_to_history_message(row: models.ChatMessage) -> dict:
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
payload = {"role": row.role, "content": row.content, "sources": sources}
if attachments:
payload["attachments"] = attachments
payload["attachments"] = [_attachment_history_payload(attachment) for attachment in attachments]
return payload
def _row_to_ollama_message(row: models.ChatMessage) -> dict:
message = {"role": row.role, "content": row.content}
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
if attachments:
message["images"] = _attachments_to_ollama_images(attachments)
return message
def get_db():
db = SessionLocal()
try: