Add support for image attachments in chat messages and update model capabilities endpoint

This commit is contained in:
2026-04-16 21:27:43 +02:00
parent d8784463b5
commit e88ac88840
5 changed files with 179 additions and 25 deletions

View File

@@ -4,16 +4,15 @@ from sqlalchemy.orm import sessionmaker, DeclarativeBase
from sqlalchemy import text
"""
Database utilities and configuration. This module defines the SQLAlchemy
engine, session factory and base class for models. It also contains a
lightweight migration helper used to evolve the schema over time. The
`ensure_sources_column` helper adds a new `sources_json` column to the
`chat_messages` table if it does not already exist. This is required
for persisting citation sources alongside assistant messages.
Database utilities and configuration. This module defines the SQLAlchemy
engine, session factory and base class for models. It also contains a
lightweight migration helper used to evolve the schema over time. The
`ensure_sources_column` helper adds the JSON-backed columns used by chat
messages when they do not already exist.
The migration uses SQLite's `ALTER TABLE` syntax and therefore should
only run once on startup. It is safe to call repeatedly: when the
column already exists, the function will simply noop.
only run once on startup. It is safe to call repeatedly: when a column
already exists, the function will simply no-op.
"""
DATABASE_URL = "sqlite:///./backend/app.db"
@@ -34,5 +33,7 @@ def ensure_sources_column(engine):
cols = [row[1] for row in conn.execute(text("PRAGMA table_info(chat_messages)"))]
if "sources_json" not in cols:
conn.execute(text("ALTER TABLE chat_messages ADD COLUMN sources_json TEXT DEFAULT '[]'"))
if "attachments_json" not in cols:
conn.execute(text("ALTER TABLE chat_messages ADD COLUMN attachments_json TEXT DEFAULT '[]'"))
except Exception as e:
print("[db] ensure_sources_column error:", e)

View File

@@ -2,15 +2,22 @@ from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from typing import List
from typing import Any, List
import re
import html
import json
import base64
from . import models, schemas
from .database import Base, engine, SessionLocal, ensure_sources_column
from .local_rag import router as local_rag_router
from .ollama_admin import inspect_ollama_startup, prepare_startup_models, pull_local_model, start_local_ollama
from .ollama_client import list_models as ollama_list, chat as ollama_chat, chat_stream as ollama_chat_stream
from .ollama_client import (
list_models as ollama_list,
chat as ollama_chat,
chat_stream as ollama_chat_stream,
show_model as ollama_show_model,
supports_vision as ollama_supports_vision,
)
from .websearch import enrich_prompt
# Create tables + ensure migration
@@ -37,6 +44,86 @@ app.add_middleware(
)
app.include_router(local_rag_router)
_IMAGE_DATA_URL_RE = re.compile(r"^data:(image\/[a-z0-9.+-]+);base64,([a-z0-9+/=\s]+)$", re.IGNORECASE)
def _attachment_field(item: Any, field: str) -> Any:
if isinstance(item, dict):
return item.get(field)
return getattr(item, field, None)
def _normalize_image_attachments(items: Any) -> List[dict]:
cleaned: List[dict] = []
for item in items or []:
data_url = str(_attachment_field(item, "data_url") or "").strip()
name = str(_attachment_field(item, "name") or "image").strip() or "image"
mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower()
match = _IMAGE_DATA_URL_RE.match(data_url)
if not match:
continue
detected_mime = match.group(1).lower()
payload = re.sub(r"\s+", "", match.group(2))
if mime_type and not mime_type.startswith("image/"):
continue
try:
base64.b64decode(payload, validate=True)
except Exception:
continue
cleaned.append({
"name": name[:255],
"mime_type": mime_type or detected_mime,
"data_url": f"data:{detected_mime};base64,{payload}",
})
return cleaned
def _load_message_attachments(raw_value: Any) -> List[dict]:
if isinstance(raw_value, str):
try:
parsed = json.loads(raw_value or "[]")
except Exception:
parsed = []
else:
parsed = raw_value
return _normalize_image_attachments(parsed)
def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]:
images: List[str] = []
for attachment in attachments:
match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())
if not match:
continue
images.append(re.sub(r"\s+", "", match.group(2)))
return images
def _row_to_history_message(row: models.ChatMessage) -> dict:
sources = []
try:
if getattr(row, "sources_json", None):
sources = json.loads(row.sources_json or "[]")
except Exception:
sources = []
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
payload = {"role": row.role, "content": row.content, "sources": sources}
if attachments:
payload["attachments"] = attachments
return payload
def _row_to_ollama_message(row: models.ChatMessage) -> dict:
message = {"role": row.role, "content": row.content}
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
if attachments:
message["images"] = _attachments_to_ollama_images(attachments)
return message
def get_db():
db = SessionLocal()
try:
@@ -57,6 +144,29 @@ async def get_models():
raise HTTPException(status_code=502, detail=f"Ollama not available: {e}")
@app.get("/models/capabilities")
async def get_model_capabilities(name: str):
model_name = str(name or "").strip()
if not model_name:
raise HTTPException(status_code=400, detail="Model name is required.")
try:
model_data = await ollama_show_model(model_name)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama not available: {e}")
capabilities = [
str(item).strip()
for item in (model_data.get("capabilities") or [])
if str(item).strip()
]
return {
"name": model_name,
"capabilities": capabilities,
"supports_vision": ollama_supports_vision(model_data),
}
@app.get("/ollama/startup-status")
async def ollama_startup_status():
return await inspect_ollama_startup()
@@ -115,15 +225,7 @@ def history(session_id: str, db: Session = Depends(get_db)):
.order_by(models.ChatMessage.created_at.asc())
.all()
)
msgs = []
for r in rows:
sources = []
try:
if getattr(r, "sources_json", None):
sources = json.loads(r.sources_json or "[]")
except Exception:
sources = []
msgs.append({"role": r.role, "content": r.content, "sources": sources})
msgs = [_row_to_history_message(r) for r in rows]
return {"messages": msgs}
@app.post("/chat")
@@ -137,7 +239,13 @@ async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)):
db.refresh(session)
# Store the BASE user prompt
user_row = models.ChatMessage(session_pk=session.id, role='user', content=req.message)
user_attachments = _normalize_image_attachments(req.attachments)
user_row = models.ChatMessage(
session_pk=session.id,
role='user',
content=req.message,
attachments_json=json.dumps(user_attachments or []),
)
db.add(user_row)
db.commit()
@@ -148,7 +256,7 @@ async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)):
.order_by(models.ChatMessage.created_at.asc())
.all()[-20:]
)
messages = [{"role": m.role, "content": m.content} for m in last_msgs]
messages = [_row_to_ollama_message(m) for m in last_msgs]
# Patch last user with enriched_message only for LLM call
if req.enriched_message:
@@ -305,7 +413,7 @@ async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Sessio
db.delete(m)
db.commit()
conversation = [{"role": m.role, "content": m.content} for m in msgs[: last_user_idx + 1]]
conversation = [_row_to_ollama_message(m) for m in msgs[: last_user_idx + 1]]
if req.enriched_message:
for j in range(len(conversation) - 1, -1, -1):

View File

@@ -22,6 +22,8 @@ class ChatMessage(Base):
content = Column(Text, nullable=False)
# JSON-encoded list of citation URLs; null/empty => no chips
sources_json = Column(Text, nullable=True, default='[]')
# JSON-encoded list of inline image attachments for user messages.
attachments_json = Column(Text, nullable=True, default='[]')
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
session = relationship("ChatSession", back_populates="messages")

View File

@@ -1,10 +1,15 @@
import httpx
import json
from typing import Dict, Any, List, AsyncGenerator
import re
import time
from typing import Dict, Any, List, AsyncGenerator, Tuple
from .app_settings import get_ollama_api_url
_MODEL_DETAILS_CACHE: Dict[Tuple[str, str], Tuple[float, Dict[str, Any]]] = {}
_MODEL_DETAILS_TTL_S = 15.0
async def list_models() -> Dict[str, Any]:
ollama_url = get_ollama_api_url()
async with httpx.AsyncClient(timeout=30.0) as client:
@@ -15,7 +20,38 @@ async def list_models() -> Dict[str, Any]:
models = [m.get('name') for m in data.get('models', [])]
return {"models": models}
async def chat(model: str, messages: List[Dict[str, str]]) -> str:
async def show_model(model: str, *, refresh: bool = False) -> Dict[str, Any]:
ollama_url = get_ollama_api_url()
cache_key = (ollama_url.rstrip('/'), str(model or '').strip())
cached = _MODEL_DETAILS_CACHE.get(cache_key)
now = time.monotonic()
if not refresh and cached and (now - cached[0]) < _MODEL_DETAILS_TTL_S:
return cached[1]
async with httpx.AsyncClient(timeout=30.0) as client:
r = await client.post(f"{ollama_url}/api/show", json={"model": model})
r.raise_for_status()
data = r.json()
_MODEL_DETAILS_CACHE[cache_key] = (now, data)
return data
def supports_vision(model_data: Dict[str, Any]) -> bool:
capabilities = model_data.get("capabilities") or []
if any(str(item).strip().lower() == "vision" for item in capabilities):
return True
model_info = model_data.get("model_info") or {}
if isinstance(model_info, dict):
for key in model_info.keys():
lowered = str(key).strip().lower()
if ".vision." in lowered or lowered.endswith(".vision"):
return True
if lowered.endswith("tokens_per_image") or re.search(r"\bmm\b", lowered):
return True
return False
async def chat(model: str, messages: List[Dict[str, Any]]) -> str:
ollama_url = get_ollama_api_url()
payload = {
"model": model,
@@ -36,7 +72,7 @@ async def chat(model: str, messages: List[Dict[str, str]]) -> str:
return msgs[-1].get("content", "")
return data.get("content", "")
async def chat_stream(model: str, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
async def chat_stream(model: str, messages: List[Dict[str, Any]]) -> AsyncGenerator[str, None]:
ollama_url = get_ollama_api_url()
payload = {
"model": model,

View File

@@ -2,10 +2,16 @@ from pydantic import BaseModel, ConfigDict
from typing import List, Optional
from datetime import datetime
class ImageAttachment(BaseModel):
name: str
mime_type: Optional[str] = None
data_url: str
class Message(BaseModel):
role: str
content: str
sources: Optional[List[str]] = None
attachments: Optional[List[ImageAttachment]] = None
class ChatRequest(BaseModel):
session_id: str
@@ -14,6 +20,7 @@ class ChatRequest(BaseModel):
enriched_message: Optional[str] = None
stream: Optional[bool] = False
sources: Optional[List[str]] = None
attachments: Optional[List[ImageAttachment]] = None
class ChatResponse(BaseModel):
reply: str