Files
Heimgeist/backend/main.py

1015 lines
38 KiB
Python

import asyncio
from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from typing import Any, List, Optional
import re
import html
import json
import base64
import mimetypes
import shutil
import tempfile
from pathlib import Path
from . import models, schemas
from .database import Base, engine, SessionLocal, ensure_sources_column
from .local_rag import router as local_rag_router
from .ollama_admin import inspect_ollama_startup, prepare_startup_models, pull_local_model, start_local_ollama
from .ollama_client import (
list_model_catalog as ollama_list_model_catalog,
chat as ollama_chat,
chat_stream as ollama_chat_stream,
show_model as ollama_show_model,
supports_vision as ollama_supports_vision,
)
from .whisper_admin import DEFAULT_WHISPER_MODEL, list_whisper_models, transcribe_audio_bytes
from .websearch import enrich_prompt
# Create tables + ensure migration
Base.metadata.create_all(bind=engine)
ensure_sources_column(engine)
app = FastAPI(title="LLM Desktop Backend", version="0.1.0" )
def sanitize_chat_title(title: str) -> str:
cleaned_title = html.unescape(title or "")
cleaned_title = re.sub(r'<think(?:ing)?>.*?</think(?:ing)?>', '', cleaned_title, flags=re.DOTALL | re.IGNORECASE)
cleaned_title = cleaned_title.strip()
previous_title = None
while cleaned_title and cleaned_title != previous_title:
previous_title = cleaned_title
cleaned_title = re.sub(r'^\s*#+\s*', '', cleaned_title)
cleaned_title = re.sub(r'^\s*\*{1,2}\s*', '', cleaned_title)
cleaned_title = re.sub(r'\s*\*{1,2}\s*$', '', cleaned_title)
cleaned_title = cleaned_title.strip()
cleaned_title = re.sub(r'\s+', ' ', cleaned_title)
return cleaned_title.strip()
# CORS (dev-friendly; tighten later)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(local_rag_router)
_IMAGE_DATA_URL_RE = re.compile(r"^data:(image\/[a-z0-9.+-]+);base64,([a-z0-9+/=\s]+)$", re.IGNORECASE)
_CHAT_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".heic", ".avif"}
_CHAT_FILE_EXTENSIONS = {
".pdf", ".html", ".htm", ".txt", ".md", ".rst", ".epub",
".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus", ".aac",
".mp4", ".mkv", ".mov", ".webm", ".avi", ".ts",
}
def _attachment_field(item: Any, field: str) -> Any:
if isinstance(item, dict):
return item.get(field)
return getattr(item, field, None)
def _normalize_attachment_name(value: Any, fallback: str) -> str:
name = str(value or "").strip() or fallback
return name[:255]
def _normalize_image_attachment(item: Any) -> Optional[dict]:
data_url = str(_attachment_field(item, "data_url") or "").strip()
match = _IMAGE_DATA_URL_RE.match(data_url)
if not match:
return None
detected_mime = match.group(1).lower()
payload = re.sub(r"\s+", "", match.group(2))
mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower()
if mime_type and not mime_type.startswith("image/"):
return None
try:
base64.b64decode(payload, validate=True)
except Exception:
return None
return {
"kind": "image",
"name": _normalize_attachment_name(_attachment_field(item, "name"), "image"),
"mime_type": mime_type or detected_mime,
"data_url": f"data:{detected_mime};base64,{payload}",
}
def _normalize_file_attachment(item: Any, *, include_text: bool = False) -> Optional[dict]:
source_path_raw = str(_attachment_field(item, "source_path") or "").strip()
source_path = ""
if source_path_raw:
try:
source_path = str(Path(source_path_raw).expanduser().resolve())
except Exception:
source_path = source_path_raw
text_value = str(_attachment_field(item, "text") or "")
name_seed = _attachment_field(item, "name") or (Path(source_path).name if source_path else "file")
name = _normalize_attachment_name(name_seed, "file")
mime_type = str(_attachment_field(item, "mime_type") or "").split(";", 1)[0].strip().lower()
record_type = str(_attachment_field(item, "record_type") or "").strip().lower() or None
size = _attachment_field(item, "size")
try:
size = max(0, int(size)) if size is not None else None
except Exception:
size = None
extension = Path(source_path or name).suffix.lower()
if source_path and extension not in _CHAT_FILE_EXTENSIONS:
return None
if not source_path and not text_value.strip():
return None
payload = {
"kind": "file",
"name": name,
"mime_type": mime_type or (mimetypes.guess_type(name)[0] or "").lower() or None,
"source_path": source_path or None,
"size": size,
"record_type": record_type,
}
if include_text and text_value.strip():
payload["text"] = text_value.strip()
return payload
def _normalize_chat_attachments(items: Any, *, include_text: bool = False) -> List[dict]:
cleaned: List[dict] = []
for item in items or []:
kind = str(_attachment_field(item, "kind") or "").strip().lower()
data_url = str(_attachment_field(item, "data_url") or "").strip()
if kind == "image" or (not kind and data_url):
normalized = _normalize_image_attachment(item)
else:
normalized = _normalize_file_attachment(item, include_text=include_text)
if normalized:
cleaned.append(normalized)
return cleaned
def _normalize_chat_attachment_or_raise(item: Any, *, include_text: bool = False) -> dict:
kind = str(_attachment_field(item, "kind") or "").strip().lower()
data_url = str(_attachment_field(item, "data_url") or "").strip()
source_path = str(_attachment_field(item, "source_path") or "").strip()
name_hint = _attachment_field(item, "name") or (Path(source_path).name if source_path else None) or "attachment"
label = _normalize_attachment_name(name_hint, "attachment")
if kind == "image" or (not kind and data_url):
normalized = _normalize_image_attachment(item)
if normalized:
return normalized
raise HTTPException(status_code=400, detail=f"{label}: invalid image attachment.")
normalized = _normalize_file_attachment(item, include_text=include_text)
if normalized:
return normalized
if source_path:
extension = Path(source_path).suffix.lower()
if extension in _CHAT_IMAGE_EXTENSIONS:
raise HTTPException(status_code=400, detail=f"{label}: use Add Image(s) for image attachments.")
raise HTTPException(status_code=400, detail=f"{label}: unsupported file type for chat attachment.")
raise HTTPException(status_code=400, detail=f"{label}: local file path is required for non-image attachments.")
def _load_message_attachments(raw_value: Any, *, include_text: bool = False) -> List[dict]:
if isinstance(raw_value, str):
try:
parsed = json.loads(raw_value or "[]")
except Exception:
parsed = []
else:
parsed = raw_value
return _normalize_chat_attachments(parsed, include_text=include_text)
def _attachment_is_image(attachment: dict) -> bool:
kind = str(attachment.get("kind") or "").strip().lower()
if kind == "image":
return True
return bool(_IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip()))
def _attachment_is_file(attachment: dict) -> bool:
return not _attachment_is_image(attachment)
def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]:
images: List[str] = []
for attachment in attachments:
if not _attachment_is_image(attachment):
continue
match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())
if not match:
continue
images.append(re.sub(r"\s+", "", match.group(2)))
return images
def _attachment_history_payload(attachment: dict) -> dict:
payload = {
"kind": attachment.get("kind") or ("image" if _attachment_is_image(attachment) else "file"),
"name": attachment.get("name"),
"mime_type": attachment.get("mime_type"),
}
if _attachment_is_image(attachment):
payload["data_url"] = attachment.get("data_url")
else:
payload["source_path"] = attachment.get("source_path")
payload["size"] = attachment.get("size")
payload["record_type"] = attachment.get("record_type")
return {key: value for key, value in payload.items() if value not in (None, "", [])}
async def _model_supports_vision(model_name: Optional[str]) -> bool:
normalized = str(model_name or "").strip()
if not normalized:
return False
try:
model_data = await ollama_show_model(normalized)
except Exception as exc:
raise HTTPException(status_code=502, detail=f"Ollama not available: {exc}") from exc
return ollama_supports_vision(model_data)
def _safe_temp_attachment_name(index: int, name: str, fallback_suffix: str = "") -> str:
safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._") or f"attachment-{index + 1}"
suffix = Path(safe_name).suffix
if fallback_suffix and not suffix:
safe_name = f"{safe_name}{fallback_suffix}"
return f"{index:03d}-{safe_name}"
def _materialize_attachment_source(attachment: dict, stage_root: Path, index: int) -> tuple[Path, dict]:
if _attachment_is_image(attachment):
match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())
if not match:
raise RuntimeError(f"{attachment.get('name') or 'image'}: invalid image payload.")
mime_type = match.group(1).lower()
payload = re.sub(r"\s+", "", match.group(2))
try:
image_bytes = base64.b64decode(payload, validate=True)
except Exception as exc:
raise RuntimeError(f"{attachment.get('name') or 'image'}: invalid base64 image payload.") from exc
suffix = Path(str(attachment.get("name") or "")).suffix or mimetypes.guess_extension(mime_type) or ".img"
target_path = stage_root / _safe_temp_attachment_name(index, str(attachment.get("name") or "image"), suffix)
target_path.write_bytes(image_bytes)
return target_path, {
"source_path": str(target_path.resolve()),
"size": len(image_bytes),
}
source_path = str(attachment.get("source_path") or "").strip()
if not source_path:
raise RuntimeError(f"{attachment.get('name') or 'file'}: local file path is required.")
source = Path(source_path).expanduser().resolve()
if not source.exists() or not source.is_file():
raise RuntimeError(f"{attachment.get('name') or source.name}: file is no longer available.")
if source.suffix.lower() not in _CHAT_FILE_EXTENSIONS:
raise RuntimeError(f"{source.name}: unsupported file type for chat attachment.")
target_path = stage_root / _safe_temp_attachment_name(index, source.name)
try:
target_path.symlink_to(source)
except Exception:
shutil.copy2(source, target_path)
return target_path, {
"source_path": str(source),
"size": int(source.stat().st_size),
}
def _read_corpus_records(path: Path) -> List[dict]:
records: List[dict] = []
if not path.exists():
return records
with path.open("r", encoding="utf-8", errors="ignore") as handle:
for line in handle:
line = line.strip()
if not line:
continue
records.append(json.loads(line))
return records
def _extract_attachment_texts(
attachments: List[dict],
*,
vision_model: Optional[str],
transcription_model: Optional[str],
) -> List[dict]:
from .rag.corpus_builder import run_build
if not attachments:
return []
with tempfile.TemporaryDirectory(prefix="heimgeist-chat-attachments-") as tmpdir:
temp_root = Path(tmpdir)
stage_root = temp_root / "stage"
stage_root.mkdir(parents=True, exist_ok=True)
corpus_path = temp_root / "corpus.jsonl"
index_by_key: dict[str, int] = {}
attachment_meta: List[dict] = []
for index, attachment in enumerate(attachments):
temp_path, extra_meta = _materialize_attachment_source(attachment, stage_root, index)
index_by_key[str(temp_path)] = index
index_by_key[str(temp_path.resolve())] = index
attachment_meta.append(extra_meta)
result = run_build(
root=stage_root,
out=corpus_path,
emit="per-file",
emit_av="joined",
lang_detect=False,
whisper_model=transcription_model or DEFAULT_WHISPER_MODEL,
vlm_model=vision_model or "qwen2.5vl:7b",
)
errors = [str(error).strip() for error in (result.get("errors") or []) if str(error).strip()]
if errors:
raise RuntimeError("; ".join(errors))
extracted: dict[int, dict[str, Any]] = {}
for record in _read_corpus_records(corpus_path):
record_id = str(record.get("id") or "").split("#", 1)[0].strip()
source_path = str(record.get("source_path") or "").strip()
matched_index = None
for candidate in (record_id, source_path):
if candidate in index_by_key:
matched_index = index_by_key[candidate]
break
if matched_index is None:
continue
bucket = extracted.setdefault(matched_index, {
"mime_type": "",
"record_type": "",
"text_parts": [],
})
if not bucket["mime_type"]:
bucket["mime_type"] = str(record.get("mime") or "").strip().lower()
if not bucket["record_type"]:
bucket["record_type"] = str(record.get("record_type") or "").strip().lower()
text = str(record.get("text") or "").strip()
if text:
bucket["text_parts"].append(text)
prepared: List[dict] = []
for index, attachment in enumerate(attachments):
bucket = extracted.get(index) or {}
text = "\n\n".join(
part for part in bucket.get("text_parts") or []
if str(part or "").strip()
).strip()
if not text:
raise RuntimeError(f"{attachment.get('name') or 'attachment'}: no usable text could be extracted.")
extra_meta = attachment_meta[index]
prepared.append({
**attachment,
"mime_type": bucket.get("mime_type") or attachment.get("mime_type"),
"record_type": bucket.get("record_type") or attachment.get("record_type"),
"source_path": extra_meta.get("source_path") or attachment.get("source_path"),
"size": extra_meta.get("size") if extra_meta.get("size") is not None else attachment.get("size"),
"text": text,
})
return prepared
def _build_attachment_block(attachments: List[dict]) -> str:
sections: List[str] = []
for index, attachment in enumerate(attachments, start=1):
text = str(attachment.get("text") or "").strip()
if not text:
continue
lines = [
f"[Attachment {index}]",
f"Name: {attachment.get('name') or f'attachment-{index}'}",
f"Kind: {'image' if _attachment_is_image(attachment) else 'file'}",
]
mime_type = str(attachment.get("mime_type") or "").strip()
record_type = str(attachment.get("record_type") or "").strip()
if mime_type:
lines.append(f"MIME: {mime_type}")
if record_type:
lines.append(f"Record type: {record_type}")
lines.append("Content:")
lines.append(text)
sections.append("\n".join(lines))
if not sections:
return ""
return "<chat_attachments>\n" + "\n\n".join(sections) + "\n</chat_attachments>"
def _compose_user_message_content(message: str, enriched_message: Optional[str], attachment_block: str) -> str:
base_content = str(enriched_message or "").strip() or str(message or "").strip()
parts: List[str] = []
if base_content:
parts.append(base_content)
elif attachment_block:
parts.append("Please analyze the attached material.")
if attachment_block:
parts.append(attachment_block)
if not parts:
return "Please analyze the attached material."
return "\n\n".join(part for part in parts if str(part or "").strip()).strip()
def _attachments_have_images(attachments: List[dict]) -> bool:
return any(_attachment_is_image(attachment) for attachment in attachments or [])
async def _prepare_chat_message_attachments(
attachments: List[dict],
*,
request_model_supports_vision: bool,
vision_model: Optional[str],
transcription_model: Optional[str],
persist_file_text: bool,
) -> tuple[List[dict], List[str], str]:
prepared = [
_normalize_chat_attachment_or_raise(item, include_text=True)
for item in (attachments or [])
]
if not prepared:
return [], [], ""
if _attachments_have_images(prepared) and not request_model_supports_vision:
if not str(vision_model or "").strip():
raise HTTPException(
status_code=400,
detail="Image attachments require a configured vision model when the selected chat model does not support vision.",
)
if not await _model_supports_vision(vision_model):
raise HTTPException(
status_code=400,
detail="The configured vision model does not support image inputs.",
)
prompt_entries: List[Optional[dict]] = [None] * len(prepared)
attachments_to_extract: List[dict] = []
extract_indexes: List[int] = []
for index, attachment in enumerate(prepared):
if _attachment_is_image(attachment):
if request_model_supports_vision:
continue
attachments_to_extract.append(attachment)
extract_indexes.append(index)
continue
text = str(attachment.get("text") or "").strip()
if text:
prompt_entries[index] = attachment
else:
attachments_to_extract.append(attachment)
extract_indexes.append(index)
if attachments_to_extract:
extracted_attachments = await asyncio.to_thread(
_extract_attachment_texts,
attachments_to_extract,
vision_model=vision_model,
transcription_model=transcription_model,
)
for original_index, extracted in zip(extract_indexes, extracted_attachments):
if _attachment_is_file(prepared[original_index]) and persist_file_text:
prepared[original_index] = {
**prepared[original_index],
"mime_type": extracted.get("mime_type") or prepared[original_index].get("mime_type"),
"record_type": extracted.get("record_type") or prepared[original_index].get("record_type"),
"source_path": extracted.get("source_path") or prepared[original_index].get("source_path"),
"size": extracted.get("size") if extracted.get("size") is not None else prepared[original_index].get("size"),
"text": extracted.get("text") or prepared[original_index].get("text"),
}
prompt_entries[original_index] = prepared[original_index]
else:
prompt_entries[original_index] = extracted
prompt_attachments = [entry for entry in prompt_entries if entry and str(entry.get("text") or "").strip()]
ollama_images = _attachments_to_ollama_images(prepared) if request_model_supports_vision else []
return prepared, ollama_images, _build_attachment_block(prompt_attachments)
def _row_to_history_message(row: models.ChatMessage) -> dict:
sources = []
try:
if getattr(row, "sources_json", None):
sources = json.loads(row.sources_json or "[]")
except Exception:
sources = []
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
payload = {"role": row.role, "content": row.content, "sources": sources}
if attachments:
payload["attachments"] = [_attachment_history_payload(attachment) for attachment in attachments]
return payload
async def _build_ollama_user_message(
message: str,
attachments: List[dict],
*,
enriched_message: Optional[str],
request_model_supports_vision: bool,
vision_model: Optional[str],
transcription_model: Optional[str],
persist_file_text: bool,
) -> tuple[dict, List[dict]]:
prepared_attachments, ollama_images, attachment_block = await _prepare_chat_message_attachments(
attachments,
request_model_supports_vision=request_model_supports_vision,
vision_model=vision_model,
transcription_model=transcription_model,
persist_file_text=persist_file_text,
)
payload = {
"role": "user",
"content": _compose_user_message_content(message, enriched_message, attachment_block),
}
if ollama_images:
payload["images"] = ollama_images
return payload, prepared_attachments
async def _build_ollama_messages_from_rows(
rows: List[models.ChatMessage],
*,
request_model_supports_vision: bool,
vision_model: Optional[str],
transcription_model: Optional[str],
override_user_row_index: Optional[int] = None,
override_user_content: Optional[str] = None,
) -> List[dict]:
conversation: List[dict] = []
for row_index, row in enumerate(rows):
if row.role != "user":
conversation.append({"role": row.role, "content": row.content})
continue
attachments = _load_message_attachments(getattr(row, "attachments_json", None), include_text=True)
enriched_message = override_user_content if override_user_row_index == row_index else None
message_payload, _prepared = await _build_ollama_user_message(
row.content,
attachments,
enriched_message=enriched_message,
request_model_supports_vision=request_model_supports_vision,
vision_model=vision_model,
transcription_model=transcription_model,
persist_file_text=False,
)
conversation.append(message_payload)
return conversation
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@app.get("/health")
def health():
return {"ok": True}
@app.post("/audio/transcribe", response_model=schemas.AudioTranscriptionResponse)
async def transcribe_audio_route(req: schemas.AudioTranscriptionRequest):
mime_type = str(req.mime_type or "").split(";", 1)[0].strip().lower()
if not mime_type.startswith("audio/"):
raise HTTPException(status_code=400, detail="An audio mime type is required.")
payload = re.sub(r"\s+", "", str(req.audio_base64 or ""))
if not payload:
raise HTTPException(status_code=400, detail="Audio payload is required.")
try:
audio_bytes = base64.b64decode(payload, validate=True)
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid base64 audio payload.") from exc
try:
result = await asyncio.to_thread(
transcribe_audio_bytes,
audio_bytes,
mime_type,
req.model or DEFAULT_WHISPER_MODEL,
req.language,
)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Audio transcription failed: {exc}") from exc
return {
"text": str(result.get("text") or "").strip(),
"language": str(result.get("language") or "").strip() or None,
"model": str(result.get("model") or req.model or DEFAULT_WHISPER_MODEL),
}
@app.get("/models")
async def get_models():
try:
ollama_data, whisper_data = await asyncio.gather(
ollama_list_model_catalog(),
asyncio.to_thread(list_whisper_models),
)
return {
**ollama_data,
"whisper_models": whisper_data.get("models", []),
"whisper_error": whisper_data.get("error", ""),
}
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama not available: {e}")
@app.get("/models/capabilities")
async def get_model_capabilities(name: str):
model_name = str(name or "").strip()
if not model_name:
raise HTTPException(status_code=400, detail="Model name is required.")
try:
model_data = await ollama_show_model(model_name)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama not available: {e}")
capabilities = [
str(item).strip()
for item in (model_data.get("capabilities") or [])
if str(item).strip()
]
return {
"name": model_name,
"capabilities": capabilities,
"supports_vision": ollama_supports_vision(model_data),
}
@app.get("/ollama/startup-status")
async def ollama_startup_status():
return await inspect_ollama_startup()
@app.post("/ollama/start")
async def ollama_start_route():
try:
return await start_local_ollama()
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/ollama/pull")
async def ollama_pull_route(req: schemas.OllamaPullRequest):
try:
return await pull_local_model(req.model)
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/startup/prepare-models")
async def startup_prepare_models_route():
try:
return await prepare_startup_models()
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/sessions", response_model=schemas.SessionsResponse)
def get_sessions(db: Session = Depends(get_db)):
sessions = db.query(models.ChatSession).order_by(models.ChatSession.created_at.desc()).all()
return {
"sessions": [
{
"id": session.id,
"session_id": session.session_id,
"name": sanitize_chat_title(session.name),
"created_at": session.created_at,
}
for session in sessions
]
}
@app.post("/sessions", response_model=schemas.ChatSession)
def create_session(req: schemas.CreateSessionRequest, db: Session = Depends(get_db)):
new_session = models.ChatSession(session_id=req.session_id)
db.add(new_session)
db.commit()
db.refresh(new_session)
return new_session
@app.get("/history", response_model=schemas.HistoryResponse)
def history(session_id: str, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
return {"messages": []}
rows = (
db.query(models.ChatMessage)
.filter(models.ChatMessage.session_pk == session.id)
.order_by(models.ChatMessage.created_at.asc())
.all()
)
msgs = [_row_to_history_message(r) for r in rows]
return {"messages": msgs}
@app.post("/chat")
async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)):
# Find or create session
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == req.session_id).first()
if not session:
session = models.ChatSession(session_id=req.session_id)
db.add(session)
db.commit()
db.refresh(session)
request_model_supports_vision = await _model_supports_vision(req.model)
prior_rows = (
db.query(models.ChatMessage)
.filter(models.ChatMessage.session_pk == session.id)
.order_by(models.ChatMessage.created_at.asc())
.all()
)
history_rows = prior_rows[-19:]
history_messages = await _build_ollama_messages_from_rows(
history_rows,
request_model_supports_vision=request_model_supports_vision,
vision_model=req.vision_model,
transcription_model=req.transcription_model,
)
current_user_message, user_attachments = await _build_ollama_user_message(
req.message,
req.attachments or [],
enriched_message=req.enriched_message,
request_model_supports_vision=request_model_supports_vision,
vision_model=req.vision_model,
transcription_model=req.transcription_model,
persist_file_text=True,
)
session_pk = session.id
user_row = models.ChatMessage(
session_pk=session_pk,
role='user',
content=req.message,
attachments_json=json.dumps(user_attachments or []),
)
db.add(user_row)
db.commit()
messages = [*history_messages, current_user_message]
# Sources to persist with the assistant reply
sources = req.sources or []
if req.stream:
async def stream_generator():
full_reply = ""
try:
async for chunk in ollama_chat_stream(req.model, messages):
full_reply += chunk
yield chunk
except Exception as e:
yield f"Ollama error: {e}"
# Persist assistant reply (include sources_json)
db_sess = None
try:
db_sess = SessionLocal()
db_sess.add(models.ChatMessage(
session_pk=session_pk,
role='assistant',
content=full_reply,
sources_json=json.dumps(sources or []),
))
db_sess.commit()
finally:
if db_sess is not None:
db_sess.close()
return StreamingResponse(stream_generator(), media_type="text/plain")
else:
try:
reply = await ollama_chat(req.model, messages)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama error: {e}")
as_row = models.ChatMessage(
session_pk=session_pk, role='assistant', content=reply,
sources_json=json.dumps(sources or [])
)
db.add(as_row)
db.commit()
return {"reply": reply}
@app.post("/generate-title", response_model=schemas.GenerateTitleResponse)
async def generate_title(req: schemas.GenerateTitleRequest, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == req.session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
prompt = f"Generate a very short, concise title (5 words or less) for a chat conversation that begins with this user message: \"{req.message}\". Do not use quotation marks in the title."
try:
title = await ollama_chat(req.model, [{"role": "user", "content": prompt}])
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama error: {e}")
print(f"Original title from LLM: {title}") # Debugging line to see the raw title
cleaned_title = sanitize_chat_title(title)
print(f"Cleaned title before saving: {cleaned_title}") # Debugging line to see the cleaned title
session.name = cleaned_title
db.commit()
return {"title": cleaned_title}
@app.delete("/sessions/{session_id}")
def delete_session(session_id: str, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Delete associated messages
db.query(models.ChatMessage).filter(models.ChatMessage.session_pk == session.id).delete()
db.delete(session)
db.commit()
return {"ok": True}
@app.put("/sessions/{session_id}/rename")
def rename_session(session_id: str, req: schemas.GenerateTitleResponse, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
session.name = sanitize_chat_title(req.title)
db.commit()
return {"ok": True}
@app.put("/sessions/{session_id}/messages/{index}")
def update_user_message(session_id: str, index: int, req: schemas.EditMessageRequest, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
msgs = (
db.query(models.ChatMessage)
.filter(models.ChatMessage.session_pk == session.id)
.order_by(models.ChatMessage.created_at.asc())
.all()
)
if index < 0 or index >= len(msgs):
raise HTTPException(status_code=404, detail="Message index out of range")
# Only user messages can be edited per spec
if msgs[index].role != "user":
raise HTTPException(status_code=400, detail="Only user messages can be edited")
# Update the content
msgs[index].content = req.message
# Drop everything after the edited message
for m in msgs[index + 1:]:
db.delete(m)
db.commit()
return {"ok": True}
# ADD or REPLACE this whole function
@app.post("/sessions/{session_id}/regenerate")
async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Session = Depends(get_db)):
idx = req.index
model = req.model
stream = bool(req.stream)
sources = req.sources or []
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
msgs = (
db.query(models.ChatMessage)
.filter(models.ChatMessage.session_pk == session.id)
.order_by(models.ChatMessage.created_at.asc())
.all()
)
if idx < 0 or idx >= len(msgs):
raise HTTPException(status_code=400, detail="Invalid message index")
# last user idx at/before idx
last_user_idx = idx
for i in range(idx, -1, -1):
if msgs[i].role == "user":
last_user_idx = i
break
request_model_supports_vision = await _model_supports_vision(model)
conversation = await _build_ollama_messages_from_rows(
msgs[: last_user_idx + 1],
request_model_supports_vision=request_model_supports_vision,
vision_model=req.vision_model,
transcription_model=req.transcription_model,
override_user_row_index=last_user_idx,
override_user_content=req.enriched_message,
)
# prune after that user only after conversation building succeeds
if last_user_idx < len(msgs) - 1:
for m in msgs[last_user_idx + 1:]:
db.delete(m)
db.commit()
session_pk = session.id
if stream:
async def stream_generator():
full_reply = ""
try:
async for chunk in ollama_chat_stream(model, conversation):
full_reply += chunk
yield chunk
except Exception as e:
yield f"Ollama error: {e}"
# persist (with sources)
try:
db_sess = SessionLocal()
db_sess.add(models.ChatMessage(
session_pk=session_pk, role="assistant", content=full_reply,
sources_json=json.dumps(sources or [])
))
db_sess.commit()
finally:
try:
db_sess.close()
except Exception:
pass
return StreamingResponse(stream_generator(), media_type="text/plain")
try:
reply = await ollama_chat(model, conversation)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama error: {e}")
db.add(models.ChatMessage(
session_pk=session_pk, role="assistant", content=reply,
sources_json=json.dumps(sources or [])
))
db.commit()
return {"reply": reply}
# -----------------------------------------------------------------------------
# Web search enrichment endpoint
@app.post("/websearch", response_model=schemas.WebSearchResponse)
async def websearch_route(req: schemas.WebSearchRequest):
"""
Return an enriched prompt (with citations) for a given user prompt.
Optionally uses the last `history_limit` turns from `req.messages`.
"""
try:
messages = (req.messages or [])[-int(req.history_limit or 8):]
enriched, sources = await enrich_prompt(
user_prompt=req.prompt,
model=req.model,
messages=[{"role": m.role, "content": m.content} for m in messages],
searx_url=req.searx_url,
engines=req.engines,
)
context_block = ""
if "<websearch_context>" in enriched:
context_block = enriched[enriched.index("<websearch_context>"):].strip()
return {"enriched_prompt": enriched, "sources": sources, "context_block": context_block}
except Exception:
return {"enriched_prompt": req.prompt, "sources": [], "context_block": ""}
# To run standalone: python -m uvicorn backend.main:app --host 127.0.0.1 --port 8000