Add support for image attachments in chat messages and update model capabilities endpoint
This commit is contained in:
@@ -4,16 +4,15 @@ from sqlalchemy.orm import sessionmaker, DeclarativeBase
|
||||
from sqlalchemy import text
|
||||
|
||||
"""
|
||||
Database utilities and configuration. This module defines the SQLAlchemy
|
||||
engine, session factory and base class for models. It also contains a
|
||||
lightweight migration helper used to evolve the schema over time. The
|
||||
`ensure_sources_column` helper adds a new `sources_json` column to the
|
||||
`chat_messages` table if it does not already exist. This is required
|
||||
for persisting citation sources alongside assistant messages.
|
||||
Database utilities and configuration. This module defines the SQLAlchemy
|
||||
engine, session factory and base class for models. It also contains a
|
||||
lightweight migration helper used to evolve the schema over time. The
|
||||
`ensure_sources_column` helper adds the JSON-backed columns used by chat
|
||||
messages when they do not already exist.
|
||||
|
||||
The migration uses SQLite's `ALTER TABLE` syntax and therefore should
|
||||
only run once on startup. It is safe to call repeatedly: when the
|
||||
column already exists, the function will simply no‑op.
|
||||
only run once on startup. It is safe to call repeatedly: when a column
|
||||
already exists, the function will simply no-op.
|
||||
"""
|
||||
|
||||
DATABASE_URL = "sqlite:///./backend/app.db"
|
||||
@@ -34,5 +33,7 @@ def ensure_sources_column(engine):
|
||||
cols = [row[1] for row in conn.execute(text("PRAGMA table_info(chat_messages)"))]
|
||||
if "sources_json" not in cols:
|
||||
conn.execute(text("ALTER TABLE chat_messages ADD COLUMN sources_json TEXT DEFAULT '[]'"))
|
||||
if "attachments_json" not in cols:
|
||||
conn.execute(text("ALTER TABLE chat_messages ADD COLUMN attachments_json TEXT DEFAULT '[]'"))
|
||||
except Exception as e:
|
||||
print("[db] ensure_sources_column error:", e)
|
||||
|
||||
136
backend/main.py
136
backend/main.py
@@ -2,15 +2,22 @@ from fastapi import FastAPI, Depends, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
import re
|
||||
import html
|
||||
import json
|
||||
import base64
|
||||
from . import models, schemas
|
||||
from .database import Base, engine, SessionLocal, ensure_sources_column
|
||||
from .local_rag import router as local_rag_router
|
||||
from .ollama_admin import inspect_ollama_startup, prepare_startup_models, pull_local_model, start_local_ollama
|
||||
from .ollama_client import list_models as ollama_list, chat as ollama_chat, chat_stream as ollama_chat_stream
|
||||
from .ollama_client import (
|
||||
list_models as ollama_list,
|
||||
chat as ollama_chat,
|
||||
chat_stream as ollama_chat_stream,
|
||||
show_model as ollama_show_model,
|
||||
supports_vision as ollama_supports_vision,
|
||||
)
|
||||
from .websearch import enrich_prompt
|
||||
|
||||
# Create tables + ensure migration
|
||||
@@ -37,6 +44,86 @@ app.add_middleware(
|
||||
)
|
||||
app.include_router(local_rag_router)
|
||||
|
||||
_IMAGE_DATA_URL_RE = re.compile(r"^data:(image\/[a-z0-9.+-]+);base64,([a-z0-9+/=\s]+)$", re.IGNORECASE)
|
||||
|
||||
|
||||
def _attachment_field(item: Any, field: str) -> Any:
|
||||
if isinstance(item, dict):
|
||||
return item.get(field)
|
||||
return getattr(item, field, None)
|
||||
|
||||
|
||||
def _normalize_image_attachments(items: Any) -> List[dict]:
|
||||
cleaned: List[dict] = []
|
||||
for item in items or []:
|
||||
data_url = str(_attachment_field(item, "data_url") or "").strip()
|
||||
name = str(_attachment_field(item, "name") or "image").strip() or "image"
|
||||
mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower()
|
||||
match = _IMAGE_DATA_URL_RE.match(data_url)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
detected_mime = match.group(1).lower()
|
||||
payload = re.sub(r"\s+", "", match.group(2))
|
||||
if mime_type and not mime_type.startswith("image/"):
|
||||
continue
|
||||
|
||||
try:
|
||||
base64.b64decode(payload, validate=True)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
cleaned.append({
|
||||
"name": name[:255],
|
||||
"mime_type": mime_type or detected_mime,
|
||||
"data_url": f"data:{detected_mime};base64,{payload}",
|
||||
})
|
||||
return cleaned
|
||||
|
||||
|
||||
def _load_message_attachments(raw_value: Any) -> List[dict]:
|
||||
if isinstance(raw_value, str):
|
||||
try:
|
||||
parsed = json.loads(raw_value or "[]")
|
||||
except Exception:
|
||||
parsed = []
|
||||
else:
|
||||
parsed = raw_value
|
||||
return _normalize_image_attachments(parsed)
|
||||
|
||||
|
||||
def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]:
|
||||
images: List[str] = []
|
||||
for attachment in attachments:
|
||||
match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())
|
||||
if not match:
|
||||
continue
|
||||
images.append(re.sub(r"\s+", "", match.group(2)))
|
||||
return images
|
||||
|
||||
|
||||
def _row_to_history_message(row: models.ChatMessage) -> dict:
|
||||
sources = []
|
||||
try:
|
||||
if getattr(row, "sources_json", None):
|
||||
sources = json.loads(row.sources_json or "[]")
|
||||
except Exception:
|
||||
sources = []
|
||||
|
||||
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
|
||||
payload = {"role": row.role, "content": row.content, "sources": sources}
|
||||
if attachments:
|
||||
payload["attachments"] = attachments
|
||||
return payload
|
||||
|
||||
|
||||
def _row_to_ollama_message(row: models.ChatMessage) -> dict:
|
||||
message = {"role": row.role, "content": row.content}
|
||||
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
|
||||
if attachments:
|
||||
message["images"] = _attachments_to_ollama_images(attachments)
|
||||
return message
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -57,6 +144,29 @@ async def get_models():
|
||||
raise HTTPException(status_code=502, detail=f"Ollama not available: {e}")
|
||||
|
||||
|
||||
@app.get("/models/capabilities")
|
||||
async def get_model_capabilities(name: str):
|
||||
model_name = str(name or "").strip()
|
||||
if not model_name:
|
||||
raise HTTPException(status_code=400, detail="Model name is required.")
|
||||
|
||||
try:
|
||||
model_data = await ollama_show_model(model_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=502, detail=f"Ollama not available: {e}")
|
||||
|
||||
capabilities = [
|
||||
str(item).strip()
|
||||
for item in (model_data.get("capabilities") or [])
|
||||
if str(item).strip()
|
||||
]
|
||||
return {
|
||||
"name": model_name,
|
||||
"capabilities": capabilities,
|
||||
"supports_vision": ollama_supports_vision(model_data),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/ollama/startup-status")
|
||||
async def ollama_startup_status():
|
||||
return await inspect_ollama_startup()
|
||||
@@ -115,15 +225,7 @@ def history(session_id: str, db: Session = Depends(get_db)):
|
||||
.order_by(models.ChatMessage.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
msgs = []
|
||||
for r in rows:
|
||||
sources = []
|
||||
try:
|
||||
if getattr(r, "sources_json", None):
|
||||
sources = json.loads(r.sources_json or "[]")
|
||||
except Exception:
|
||||
sources = []
|
||||
msgs.append({"role": r.role, "content": r.content, "sources": sources})
|
||||
msgs = [_row_to_history_message(r) for r in rows]
|
||||
return {"messages": msgs}
|
||||
|
||||
@app.post("/chat")
|
||||
@@ -137,7 +239,13 @@ async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)):
|
||||
db.refresh(session)
|
||||
|
||||
# Store the BASE user prompt
|
||||
user_row = models.ChatMessage(session_pk=session.id, role='user', content=req.message)
|
||||
user_attachments = _normalize_image_attachments(req.attachments)
|
||||
user_row = models.ChatMessage(
|
||||
session_pk=session.id,
|
||||
role='user',
|
||||
content=req.message,
|
||||
attachments_json=json.dumps(user_attachments or []),
|
||||
)
|
||||
db.add(user_row)
|
||||
db.commit()
|
||||
|
||||
@@ -148,7 +256,7 @@ async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)):
|
||||
.order_by(models.ChatMessage.created_at.asc())
|
||||
.all()[-20:]
|
||||
)
|
||||
messages = [{"role": m.role, "content": m.content} for m in last_msgs]
|
||||
messages = [_row_to_ollama_message(m) for m in last_msgs]
|
||||
|
||||
# Patch last user with enriched_message only for LLM call
|
||||
if req.enriched_message:
|
||||
@@ -305,7 +413,7 @@ async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Sessio
|
||||
db.delete(m)
|
||||
db.commit()
|
||||
|
||||
conversation = [{"role": m.role, "content": m.content} for m in msgs[: last_user_idx + 1]]
|
||||
conversation = [_row_to_ollama_message(m) for m in msgs[: last_user_idx + 1]]
|
||||
|
||||
if req.enriched_message:
|
||||
for j in range(len(conversation) - 1, -1, -1):
|
||||
|
||||
@@ -22,6 +22,8 @@ class ChatMessage(Base):
|
||||
content = Column(Text, nullable=False)
|
||||
# JSON-encoded list of citation URLs; null/empty => no chips
|
||||
sources_json = Column(Text, nullable=True, default='[]')
|
||||
# JSON-encoded list of inline image attachments for user messages.
|
||||
attachments_json = Column(Text, nullable=True, default='[]')
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
session = relationship("ChatSession", back_populates="messages")
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
|
||||
import httpx
|
||||
import json
|
||||
from typing import Dict, Any, List, AsyncGenerator
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, Any, List, AsyncGenerator, Tuple
|
||||
|
||||
from .app_settings import get_ollama_api_url
|
||||
|
||||
_MODEL_DETAILS_CACHE: Dict[Tuple[str, str], Tuple[float, Dict[str, Any]]] = {}
|
||||
_MODEL_DETAILS_TTL_S = 15.0
|
||||
|
||||
async def list_models() -> Dict[str, Any]:
|
||||
ollama_url = get_ollama_api_url()
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
@@ -15,7 +20,38 @@ async def list_models() -> Dict[str, Any]:
|
||||
models = [m.get('name') for m in data.get('models', [])]
|
||||
return {"models": models}
|
||||
|
||||
async def chat(model: str, messages: List[Dict[str, str]]) -> str:
|
||||
async def show_model(model: str, *, refresh: bool = False) -> Dict[str, Any]:
|
||||
ollama_url = get_ollama_api_url()
|
||||
cache_key = (ollama_url.rstrip('/'), str(model or '').strip())
|
||||
cached = _MODEL_DETAILS_CACHE.get(cache_key)
|
||||
now = time.monotonic()
|
||||
if not refresh and cached and (now - cached[0]) < _MODEL_DETAILS_TTL_S:
|
||||
return cached[1]
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
r = await client.post(f"{ollama_url}/api/show", json={"model": model})
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
_MODEL_DETAILS_CACHE[cache_key] = (now, data)
|
||||
return data
|
||||
|
||||
def supports_vision(model_data: Dict[str, Any]) -> bool:
|
||||
capabilities = model_data.get("capabilities") or []
|
||||
if any(str(item).strip().lower() == "vision" for item in capabilities):
|
||||
return True
|
||||
|
||||
model_info = model_data.get("model_info") or {}
|
||||
if isinstance(model_info, dict):
|
||||
for key in model_info.keys():
|
||||
lowered = str(key).strip().lower()
|
||||
if ".vision." in lowered or lowered.endswith(".vision"):
|
||||
return True
|
||||
if lowered.endswith("tokens_per_image") or re.search(r"\bmm\b", lowered):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def chat(model: str, messages: List[Dict[str, Any]]) -> str:
|
||||
ollama_url = get_ollama_api_url()
|
||||
payload = {
|
||||
"model": model,
|
||||
@@ -36,7 +72,7 @@ async def chat(model: str, messages: List[Dict[str, str]]) -> str:
|
||||
return msgs[-1].get("content", "")
|
||||
return data.get("content", "")
|
||||
|
||||
async def chat_stream(model: str, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
||||
async def chat_stream(model: str, messages: List[Dict[str, Any]]) -> AsyncGenerator[str, None]:
|
||||
ollama_url = get_ollama_api_url()
|
||||
payload = {
|
||||
"model": model,
|
||||
|
||||
@@ -2,10 +2,16 @@ from pydantic import BaseModel, ConfigDict
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
class ImageAttachment(BaseModel):
|
||||
name: str
|
||||
mime_type: Optional[str] = None
|
||||
data_url: str
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
sources: Optional[List[str]] = None
|
||||
attachments: Optional[List[ImageAttachment]] = None
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
session_id: str
|
||||
@@ -14,6 +20,7 @@ class ChatRequest(BaseModel):
|
||||
enriched_message: Optional[str] = None
|
||||
stream: Optional[bool] = False
|
||||
sources: Optional[List[str]] = None
|
||||
attachments: Optional[List[ImageAttachment]] = None
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
reply: str
|
||||
|
||||
Reference in New Issue
Block a user