Refactor attachment handling and add support for file attachments
This commit is contained in:
426
backend/main.py
426
backend/main.py
@@ -3,11 +3,15 @@ from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional
|
||||
import re
|
||||
import html
|
||||
import json
|
||||
import base64
|
||||
import mimetypes
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from . import models, schemas
|
||||
from .database import Base, engine, SessionLocal, ensure_sources_column
|
||||
from .local_rag import router as local_rag_router
|
||||
@@ -56,6 +60,12 @@ app.add_middleware(
|
||||
app.include_router(local_rag_router)
|
||||
|
||||
_IMAGE_DATA_URL_RE = re.compile(r"^data:(image\/[a-z0-9.+-]+);base64,([a-z0-9+/=\s]+)$", re.IGNORECASE)
|
||||
_CHAT_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".heic", ".avif"}
|
||||
_CHAT_FILE_EXTENSIONS = {
|
||||
".pdf", ".html", ".htm", ".txt", ".md", ".rst", ".epub",
|
||||
".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus", ".aac",
|
||||
".mp4", ".mkv", ".mov", ".webm", ".avi", ".ts",
|
||||
}
|
||||
|
||||
|
||||
def _attachment_field(item: Any, field: str) -> Any:
|
||||
@@ -64,35 +74,90 @@ def _attachment_field(item: Any, field: str) -> Any:
|
||||
return getattr(item, field, None)
|
||||
|
||||
|
||||
def _normalize_image_attachments(items: Any) -> List[dict]:
|
||||
def _normalize_attachment_name(value: Any, fallback: str) -> str:
|
||||
name = str(value or "").strip() or fallback
|
||||
return name[:255]
|
||||
|
||||
|
||||
def _normalize_image_attachment(item: Any) -> Optional[dict]:
|
||||
data_url = str(_attachment_field(item, "data_url") or "").strip()
|
||||
match = _IMAGE_DATA_URL_RE.match(data_url)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
detected_mime = match.group(1).lower()
|
||||
payload = re.sub(r"\s+", "", match.group(2))
|
||||
mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower()
|
||||
if mime_type and not mime_type.startswith("image/"):
|
||||
return None
|
||||
|
||||
try:
|
||||
base64.b64decode(payload, validate=True)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return {
|
||||
"kind": "image",
|
||||
"name": _normalize_attachment_name(_attachment_field(item, "name"), "image"),
|
||||
"mime_type": mime_type or detected_mime,
|
||||
"data_url": f"data:{detected_mime};base64,{payload}",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_file_attachment(item: Any, *, include_text: bool = False) -> Optional[dict]:
|
||||
source_path_raw = str(_attachment_field(item, "source_path") or "").strip()
|
||||
source_path = ""
|
||||
if source_path_raw:
|
||||
try:
|
||||
source_path = str(Path(source_path_raw).expanduser().resolve())
|
||||
except Exception:
|
||||
source_path = source_path_raw
|
||||
|
||||
text_value = str(_attachment_field(item, "text") or "")
|
||||
name_seed = _attachment_field(item, "name") or (Path(source_path).name if source_path else "file")
|
||||
name = _normalize_attachment_name(name_seed, "file")
|
||||
mime_type = str(_attachment_field(item, "mime_type") or "").split(";", 1)[0].strip().lower()
|
||||
record_type = str(_attachment_field(item, "record_type") or "").strip().lower() or None
|
||||
size = _attachment_field(item, "size")
|
||||
try:
|
||||
size = max(0, int(size)) if size is not None else None
|
||||
except Exception:
|
||||
size = None
|
||||
|
||||
extension = Path(source_path or name).suffix.lower()
|
||||
if source_path and extension not in _CHAT_FILE_EXTENSIONS:
|
||||
return None
|
||||
if not source_path and not text_value.strip():
|
||||
return None
|
||||
|
||||
payload = {
|
||||
"kind": "file",
|
||||
"name": name,
|
||||
"mime_type": mime_type or (mimetypes.guess_type(name)[0] or "").lower() or None,
|
||||
"source_path": source_path or None,
|
||||
"size": size,
|
||||
"record_type": record_type,
|
||||
}
|
||||
if include_text and text_value.strip():
|
||||
payload["text"] = text_value.strip()
|
||||
return payload
|
||||
|
||||
|
||||
def _normalize_chat_attachments(items: Any, *, include_text: bool = False) -> List[dict]:
|
||||
cleaned: List[dict] = []
|
||||
for item in items or []:
|
||||
kind = str(_attachment_field(item, "kind") or "").strip().lower()
|
||||
data_url = str(_attachment_field(item, "data_url") or "").strip()
|
||||
name = str(_attachment_field(item, "name") or "image").strip() or "image"
|
||||
mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower()
|
||||
match = _IMAGE_DATA_URL_RE.match(data_url)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
detected_mime = match.group(1).lower()
|
||||
payload = re.sub(r"\s+", "", match.group(2))
|
||||
if mime_type and not mime_type.startswith("image/"):
|
||||
continue
|
||||
|
||||
try:
|
||||
base64.b64decode(payload, validate=True)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
cleaned.append({
|
||||
"name": name[:255],
|
||||
"mime_type": mime_type or detected_mime,
|
||||
"data_url": f"data:{detected_mime};base64,{payload}",
|
||||
})
|
||||
if kind == "image" or (not kind and data_url):
|
||||
normalized = _normalize_image_attachment(item)
|
||||
else:
|
||||
normalized = _normalize_file_attachment(item, include_text=include_text)
|
||||
if normalized:
|
||||
cleaned.append(normalized)
|
||||
return cleaned
|
||||
|
||||
|
||||
def _load_message_attachments(raw_value: Any) -> List[dict]:
|
||||
def _load_message_attachments(raw_value: Any, *, include_text: bool = False) -> List[dict]:
|
||||
if isinstance(raw_value, str):
|
||||
try:
|
||||
parsed = json.loads(raw_value or "[]")
|
||||
@@ -100,12 +165,25 @@ def _load_message_attachments(raw_value: Any) -> List[dict]:
|
||||
parsed = []
|
||||
else:
|
||||
parsed = raw_value
|
||||
return _normalize_image_attachments(parsed)
|
||||
return _normalize_chat_attachments(parsed, include_text=include_text)
|
||||
|
||||
|
||||
def _attachment_is_image(attachment: dict) -> bool:
|
||||
kind = str(attachment.get("kind") or "").strip().lower()
|
||||
if kind == "image":
|
||||
return True
|
||||
return bool(_IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip()))
|
||||
|
||||
|
||||
def _attachment_is_file(attachment: dict) -> bool:
|
||||
return not _attachment_is_image(attachment)
|
||||
|
||||
|
||||
def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]:
|
||||
images: List[str] = []
|
||||
for attachment in attachments:
|
||||
if not _attachment_is_image(attachment):
|
||||
continue
|
||||
match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())
|
||||
if not match:
|
||||
continue
|
||||
@@ -113,6 +191,294 @@ def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]:
|
||||
return images
|
||||
|
||||
|
||||
def _attachment_history_payload(attachment: dict) -> dict:
|
||||
payload = {
|
||||
"kind": attachment.get("kind") or ("image" if _attachment_is_image(attachment) else "file"),
|
||||
"name": attachment.get("name"),
|
||||
"mime_type": attachment.get("mime_type"),
|
||||
}
|
||||
if _attachment_is_image(attachment):
|
||||
payload["data_url"] = attachment.get("data_url")
|
||||
else:
|
||||
payload["source_path"] = attachment.get("source_path")
|
||||
payload["size"] = attachment.get("size")
|
||||
payload["record_type"] = attachment.get("record_type")
|
||||
return {key: value for key, value in payload.items() if value not in (None, "", [])}
|
||||
|
||||
|
||||
async def _model_supports_vision(model_name: Optional[str]) -> bool:
|
||||
normalized = str(model_name or "").strip()
|
||||
if not normalized:
|
||||
return False
|
||||
try:
|
||||
model_data = await ollama_show_model(normalized)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=502, detail=f"Ollama not available: {exc}") from exc
|
||||
return ollama_supports_vision(model_data)
|
||||
|
||||
|
||||
def _safe_temp_attachment_name(index: int, name: str, fallback_suffix: str = "") -> str:
|
||||
safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._") or f"attachment-{index + 1}"
|
||||
suffix = Path(safe_name).suffix
|
||||
if fallback_suffix and not suffix:
|
||||
safe_name = f"{safe_name}{fallback_suffix}"
|
||||
return f"{index:03d}-{safe_name}"
|
||||
|
||||
|
||||
def _materialize_attachment_source(attachment: dict, stage_root: Path, index: int) -> tuple[Path, dict]:
|
||||
if _attachment_is_image(attachment):
|
||||
match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())
|
||||
if not match:
|
||||
raise RuntimeError(f"{attachment.get('name') or 'image'}: invalid image payload.")
|
||||
mime_type = match.group(1).lower()
|
||||
payload = re.sub(r"\s+", "", match.group(2))
|
||||
try:
|
||||
image_bytes = base64.b64decode(payload, validate=True)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"{attachment.get('name') or 'image'}: invalid base64 image payload.") from exc
|
||||
suffix = Path(str(attachment.get("name") or "")).suffix or mimetypes.guess_extension(mime_type) or ".img"
|
||||
target_path = stage_root / _safe_temp_attachment_name(index, str(attachment.get("name") or "image"), suffix)
|
||||
target_path.write_bytes(image_bytes)
|
||||
return target_path, {
|
||||
"source_path": str(target_path.resolve()),
|
||||
"size": len(image_bytes),
|
||||
}
|
||||
|
||||
source_path = str(attachment.get("source_path") or "").strip()
|
||||
if not source_path:
|
||||
raise RuntimeError(f"{attachment.get('name') or 'file'}: local file path is required.")
|
||||
|
||||
source = Path(source_path).expanduser().resolve()
|
||||
if not source.exists() or not source.is_file():
|
||||
raise RuntimeError(f"{attachment.get('name') or source.name}: file is no longer available.")
|
||||
if source.suffix.lower() not in _CHAT_FILE_EXTENSIONS:
|
||||
raise RuntimeError(f"{source.name}: unsupported file type for chat attachment.")
|
||||
|
||||
target_path = stage_root / _safe_temp_attachment_name(index, source.name)
|
||||
try:
|
||||
target_path.symlink_to(source)
|
||||
except Exception:
|
||||
shutil.copy2(source, target_path)
|
||||
return target_path, {
|
||||
"source_path": str(source),
|
||||
"size": int(source.stat().st_size),
|
||||
}
|
||||
|
||||
|
||||
def _read_corpus_records(path: Path) -> List[dict]:
|
||||
records: List[dict] = []
|
||||
if not path.exists():
|
||||
return records
|
||||
with path.open("r", encoding="utf-8", errors="ignore") as handle:
|
||||
for line in handle:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
records.append(json.loads(line))
|
||||
return records
|
||||
|
||||
|
||||
def _extract_attachment_texts(
|
||||
attachments: List[dict],
|
||||
*,
|
||||
vision_model: Optional[str],
|
||||
transcription_model: Optional[str],
|
||||
) -> List[dict]:
|
||||
from .rag.corpus_builder import run_build
|
||||
|
||||
if not attachments:
|
||||
return []
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix="heimgeist-chat-attachments-") as tmpdir:
|
||||
temp_root = Path(tmpdir)
|
||||
stage_root = temp_root / "stage"
|
||||
stage_root.mkdir(parents=True, exist_ok=True)
|
||||
corpus_path = temp_root / "corpus.jsonl"
|
||||
|
||||
index_by_key: dict[str, int] = {}
|
||||
attachment_meta: List[dict] = []
|
||||
for index, attachment in enumerate(attachments):
|
||||
temp_path, extra_meta = _materialize_attachment_source(attachment, stage_root, index)
|
||||
index_by_key[str(temp_path)] = index
|
||||
index_by_key[str(temp_path.resolve())] = index
|
||||
attachment_meta.append(extra_meta)
|
||||
|
||||
result = run_build(
|
||||
root=stage_root,
|
||||
out=corpus_path,
|
||||
emit="per-file",
|
||||
emit_av="joined",
|
||||
lang_detect=False,
|
||||
whisper_model=transcription_model or DEFAULT_WHISPER_MODEL,
|
||||
vlm_model=vision_model or "qwen2.5vl:7b",
|
||||
)
|
||||
|
||||
errors = [str(error).strip() for error in (result.get("errors") or []) if str(error).strip()]
|
||||
if errors:
|
||||
raise RuntimeError("; ".join(errors))
|
||||
|
||||
extracted: dict[int, dict[str, Any]] = {}
|
||||
for record in _read_corpus_records(corpus_path):
|
||||
record_id = str(record.get("id") or "").split("#", 1)[0].strip()
|
||||
source_path = str(record.get("source_path") or "").strip()
|
||||
matched_index = None
|
||||
for candidate in (record_id, source_path):
|
||||
if candidate in index_by_key:
|
||||
matched_index = index_by_key[candidate]
|
||||
break
|
||||
if matched_index is None:
|
||||
continue
|
||||
|
||||
bucket = extracted.setdefault(matched_index, {
|
||||
"mime_type": "",
|
||||
"record_type": "",
|
||||
"text_parts": [],
|
||||
})
|
||||
if not bucket["mime_type"]:
|
||||
bucket["mime_type"] = str(record.get("mime") or "").strip().lower()
|
||||
if not bucket["record_type"]:
|
||||
bucket["record_type"] = str(record.get("record_type") or "").strip().lower()
|
||||
|
||||
text = str(record.get("text") or "").strip()
|
||||
if text:
|
||||
bucket["text_parts"].append(text)
|
||||
|
||||
prepared: List[dict] = []
|
||||
for index, attachment in enumerate(attachments):
|
||||
bucket = extracted.get(index) or {}
|
||||
text = "\n\n".join(
|
||||
part for part in bucket.get("text_parts") or []
|
||||
if str(part or "").strip()
|
||||
).strip()
|
||||
if not text:
|
||||
raise RuntimeError(f"{attachment.get('name') or 'attachment'}: no usable text could be extracted.")
|
||||
|
||||
extra_meta = attachment_meta[index]
|
||||
prepared.append({
|
||||
**attachment,
|
||||
"mime_type": bucket.get("mime_type") or attachment.get("mime_type"),
|
||||
"record_type": bucket.get("record_type") or attachment.get("record_type"),
|
||||
"source_path": extra_meta.get("source_path") or attachment.get("source_path"),
|
||||
"size": extra_meta.get("size") if extra_meta.get("size") is not None else attachment.get("size"),
|
||||
"text": text,
|
||||
})
|
||||
|
||||
return prepared
|
||||
|
||||
|
||||
def _build_attachment_block(attachments: List[dict]) -> str:
|
||||
sections: List[str] = []
|
||||
for index, attachment in enumerate(attachments, start=1):
|
||||
text = str(attachment.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
lines = [
|
||||
f"[Attachment {index}]",
|
||||
f"Name: {attachment.get('name') or f'attachment-{index}'}",
|
||||
f"Kind: {'image' if _attachment_is_image(attachment) else 'file'}",
|
||||
]
|
||||
mime_type = str(attachment.get("mime_type") or "").strip()
|
||||
record_type = str(attachment.get("record_type") or "").strip()
|
||||
if mime_type:
|
||||
lines.append(f"MIME: {mime_type}")
|
||||
if record_type:
|
||||
lines.append(f"Record type: {record_type}")
|
||||
lines.append("Content:")
|
||||
lines.append(text)
|
||||
sections.append("\n".join(lines))
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
return "<chat_attachments>\n" + "\n\n".join(sections) + "\n</chat_attachments>"
|
||||
|
||||
|
||||
def _compose_user_message_content(message: str, enriched_message: Optional[str], attachment_block: str) -> str:
|
||||
base_content = str(enriched_message or "").strip() or str(message or "").strip()
|
||||
parts: List[str] = []
|
||||
if base_content:
|
||||
parts.append(base_content)
|
||||
elif attachment_block:
|
||||
parts.append("Please analyze the attached material.")
|
||||
if attachment_block:
|
||||
parts.append(attachment_block)
|
||||
if not parts:
|
||||
return "Please analyze the attached material."
|
||||
return "\n\n".join(part for part in parts if str(part or "").strip()).strip()
|
||||
|
||||
|
||||
def _attachments_have_images(attachments: List[dict]) -> bool:
|
||||
return any(_attachment_is_image(attachment) for attachment in attachments or [])
|
||||
|
||||
|
||||
async def _prepare_chat_message_attachments(
|
||||
attachments: List[dict],
|
||||
*,
|
||||
request_model_supports_vision: bool,
|
||||
vision_model: Optional[str],
|
||||
transcription_model: Optional[str],
|
||||
persist_file_text: bool,
|
||||
) -> tuple[List[dict], List[str], str]:
|
||||
prepared = _normalize_chat_attachments(attachments, include_text=True)
|
||||
if not prepared:
|
||||
return [], [], ""
|
||||
|
||||
if _attachments_have_images(prepared) and not request_model_supports_vision:
|
||||
if not str(vision_model or "").strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Image attachments require a configured vision model when the selected chat model does not support vision.",
|
||||
)
|
||||
if not await _model_supports_vision(vision_model):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="The configured vision model does not support image inputs.",
|
||||
)
|
||||
|
||||
prompt_entries: List[Optional[dict]] = [None] * len(prepared)
|
||||
attachments_to_extract: List[dict] = []
|
||||
extract_indexes: List[int] = []
|
||||
|
||||
for index, attachment in enumerate(prepared):
|
||||
if _attachment_is_image(attachment):
|
||||
if request_model_supports_vision:
|
||||
continue
|
||||
attachments_to_extract.append(attachment)
|
||||
extract_indexes.append(index)
|
||||
continue
|
||||
|
||||
text = str(attachment.get("text") or "").strip()
|
||||
if text:
|
||||
prompt_entries[index] = attachment
|
||||
else:
|
||||
attachments_to_extract.append(attachment)
|
||||
extract_indexes.append(index)
|
||||
|
||||
if attachments_to_extract:
|
||||
extracted_attachments = await asyncio.to_thread(
|
||||
_extract_attachment_texts,
|
||||
attachments_to_extract,
|
||||
vision_model=vision_model,
|
||||
transcription_model=transcription_model,
|
||||
)
|
||||
for original_index, extracted in zip(extract_indexes, extracted_attachments):
|
||||
if _attachment_is_file(prepared[original_index]) and persist_file_text:
|
||||
prepared[original_index] = {
|
||||
**prepared[original_index],
|
||||
"mime_type": extracted.get("mime_type") or prepared[original_index].get("mime_type"),
|
||||
"record_type": extracted.get("record_type") or prepared[original_index].get("record_type"),
|
||||
"source_path": extracted.get("source_path") or prepared[original_index].get("source_path"),
|
||||
"size": extracted.get("size") if extracted.get("size") is not None else prepared[original_index].get("size"),
|
||||
"text": extracted.get("text") or prepared[original_index].get("text"),
|
||||
}
|
||||
prompt_entries[original_index] = prepared[original_index]
|
||||
else:
|
||||
prompt_entries[original_index] = extracted
|
||||
|
||||
prompt_attachments = [entry for entry in prompt_entries if entry and str(entry.get("text") or "").strip()]
|
||||
ollama_images = _attachments_to_ollama_images(prepared) if request_model_supports_vision else []
|
||||
return prepared, ollama_images, _build_attachment_block(prompt_attachments)
|
||||
|
||||
|
||||
def _row_to_history_message(row: models.ChatMessage) -> dict:
|
||||
sources = []
|
||||
try:
|
||||
@@ -124,17 +490,9 @@ def _row_to_history_message(row: models.ChatMessage) -> dict:
|
||||
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
|
||||
payload = {"role": row.role, "content": row.content, "sources": sources}
|
||||
if attachments:
|
||||
payload["attachments"] = attachments
|
||||
payload["attachments"] = [_attachment_history_payload(attachment) for attachment in attachments]
|
||||
return payload
|
||||
|
||||
|
||||
def _row_to_ollama_message(row: models.ChatMessage) -> dict:
|
||||
message = {"role": row.role, "content": row.content}
|
||||
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
|
||||
if attachments:
|
||||
message["images"] = _attachments_to_ollama_images(attachments)
|
||||
return message
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user