Files
Heimgeist/backend/main.py

553 lines
19 KiB
Python

import asyncio
from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from typing import Any, List
import re
import html
import json
import base64
from . import models, schemas
from .database import Base, engine, SessionLocal, ensure_sources_column
from .local_rag import router as local_rag_router
from .ollama_admin import inspect_ollama_startup, prepare_startup_models, pull_local_model, start_local_ollama
from .ollama_client import (
list_model_catalog as ollama_list_model_catalog,
chat as ollama_chat,
chat_stream as ollama_chat_stream,
show_model as ollama_show_model,
supports_vision as ollama_supports_vision,
)
from .whisper_admin import DEFAULT_WHISPER_MODEL, list_whisper_models, transcribe_audio_bytes
from .websearch import enrich_prompt
# Create tables + ensure migration
Base.metadata.create_all(bind=engine)
ensure_sources_column(engine)
app = FastAPI(title="LLM Desktop Backend", version="0.1.0" )
def sanitize_chat_title(title: str) -> str:
cleaned_title = html.unescape(title or "")
cleaned_title = re.sub(r'<think(?:ing)?>.*?</think(?:ing)?>', '', cleaned_title, flags=re.DOTALL | re.IGNORECASE)
cleaned_title = cleaned_title.strip()
previous_title = None
while cleaned_title and cleaned_title != previous_title:
previous_title = cleaned_title
cleaned_title = re.sub(r'^\s*#+\s*', '', cleaned_title)
cleaned_title = re.sub(r'^\s*\*{1,2}\s*', '', cleaned_title)
cleaned_title = re.sub(r'\s*\*{1,2}\s*$', '', cleaned_title)
cleaned_title = cleaned_title.strip()
cleaned_title = re.sub(r'\s+', ' ', cleaned_title)
return cleaned_title.strip()
# CORS (dev-friendly; tighten later)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(local_rag_router)
_IMAGE_DATA_URL_RE = re.compile(r"^data:(image\/[a-z0-9.+-]+);base64,([a-z0-9+/=\s]+)$", re.IGNORECASE)
def _attachment_field(item: Any, field: str) -> Any:
if isinstance(item, dict):
return item.get(field)
return getattr(item, field, None)
def _normalize_image_attachments(items: Any) -> List[dict]:
cleaned: List[dict] = []
for item in items or []:
data_url = str(_attachment_field(item, "data_url") or "").strip()
name = str(_attachment_field(item, "name") or "image").strip() or "image"
mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower()
match = _IMAGE_DATA_URL_RE.match(data_url)
if not match:
continue
detected_mime = match.group(1).lower()
payload = re.sub(r"\s+", "", match.group(2))
if mime_type and not mime_type.startswith("image/"):
continue
try:
base64.b64decode(payload, validate=True)
except Exception:
continue
cleaned.append({
"name": name[:255],
"mime_type": mime_type or detected_mime,
"data_url": f"data:{detected_mime};base64,{payload}",
})
return cleaned
def _load_message_attachments(raw_value: Any) -> List[dict]:
if isinstance(raw_value, str):
try:
parsed = json.loads(raw_value or "[]")
except Exception:
parsed = []
else:
parsed = raw_value
return _normalize_image_attachments(parsed)
def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]:
images: List[str] = []
for attachment in attachments:
match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())
if not match:
continue
images.append(re.sub(r"\s+", "", match.group(2)))
return images
def _row_to_history_message(row: models.ChatMessage) -> dict:
sources = []
try:
if getattr(row, "sources_json", None):
sources = json.loads(row.sources_json or "[]")
except Exception:
sources = []
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
payload = {"role": row.role, "content": row.content, "sources": sources}
if attachments:
payload["attachments"] = attachments
return payload
def _row_to_ollama_message(row: models.ChatMessage) -> dict:
message = {"role": row.role, "content": row.content}
attachments = _load_message_attachments(getattr(row, "attachments_json", None))
if attachments:
message["images"] = _attachments_to_ollama_images(attachments)
return message
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
@app.get("/health")
def health():
return {"ok": True}
@app.post("/audio/transcribe", response_model=schemas.AudioTranscriptionResponse)
async def transcribe_audio_route(req: schemas.AudioTranscriptionRequest):
mime_type = str(req.mime_type or "").split(";", 1)[0].strip().lower()
if not mime_type.startswith("audio/"):
raise HTTPException(status_code=400, detail="An audio mime type is required.")
payload = re.sub(r"\s+", "", str(req.audio_base64 or ""))
if not payload:
raise HTTPException(status_code=400, detail="Audio payload is required.")
try:
audio_bytes = base64.b64decode(payload, validate=True)
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid base64 audio payload.") from exc
try:
result = await asyncio.to_thread(
transcribe_audio_bytes,
audio_bytes,
mime_type,
req.model or DEFAULT_WHISPER_MODEL,
req.language,
)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Audio transcription failed: {exc}") from exc
return {
"text": str(result.get("text") or "").strip(),
"language": str(result.get("language") or "").strip() or None,
"model": str(result.get("model") or req.model or DEFAULT_WHISPER_MODEL),
}
@app.get("/models")
async def get_models():
try:
ollama_data, whisper_data = await asyncio.gather(
ollama_list_model_catalog(),
asyncio.to_thread(list_whisper_models),
)
return {
**ollama_data,
"whisper_models": whisper_data.get("models", []),
"whisper_error": whisper_data.get("error", ""),
}
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama not available: {e}")
@app.get("/models/capabilities")
async def get_model_capabilities(name: str):
model_name = str(name or "").strip()
if not model_name:
raise HTTPException(status_code=400, detail="Model name is required.")
try:
model_data = await ollama_show_model(model_name)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama not available: {e}")
capabilities = [
str(item).strip()
for item in (model_data.get("capabilities") or [])
if str(item).strip()
]
return {
"name": model_name,
"capabilities": capabilities,
"supports_vision": ollama_supports_vision(model_data),
}
@app.get("/ollama/startup-status")
async def ollama_startup_status():
return await inspect_ollama_startup()
@app.post("/ollama/start")
async def ollama_start_route():
try:
return await start_local_ollama()
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/ollama/pull")
async def ollama_pull_route(req: schemas.OllamaPullRequest):
try:
return await pull_local_model(req.model)
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/startup/prepare-models")
async def startup_prepare_models_route():
try:
return await prepare_startup_models()
except FileNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/sessions", response_model=schemas.SessionsResponse)
def get_sessions(db: Session = Depends(get_db)):
sessions = db.query(models.ChatSession).order_by(models.ChatSession.created_at.desc()).all()
return {
"sessions": [
{
"id": session.id,
"session_id": session.session_id,
"name": sanitize_chat_title(session.name),
"created_at": session.created_at,
}
for session in sessions
]
}
@app.post("/sessions", response_model=schemas.ChatSession)
def create_session(req: schemas.CreateSessionRequest, db: Session = Depends(get_db)):
new_session = models.ChatSession(session_id=req.session_id)
db.add(new_session)
db.commit()
db.refresh(new_session)
return new_session
@app.get("/history", response_model=schemas.HistoryResponse)
def history(session_id: str, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
return {"messages": []}
rows = (
db.query(models.ChatMessage)
.filter(models.ChatMessage.session_pk == session.id)
.order_by(models.ChatMessage.created_at.asc())
.all()
)
msgs = [_row_to_history_message(r) for r in rows]
return {"messages": msgs}
@app.post("/chat")
async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)):
# Find or create session
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == req.session_id).first()
if not session:
session = models.ChatSession(session_id=req.session_id)
db.add(session)
db.commit()
db.refresh(session)
# Store the BASE user prompt
user_attachments = _normalize_image_attachments(req.attachments)
user_row = models.ChatMessage(
session_pk=session.id,
role='user',
content=req.message,
attachments_json=json.dumps(user_attachments or []),
)
db.add(user_row)
db.commit()
# Build minimal context (last 20)
last_msgs = (
db.query(models.ChatMessage)
.filter(models.ChatMessage.session_pk == session.id)
.order_by(models.ChatMessage.created_at.asc())
.all()[-20:]
)
messages = [_row_to_ollama_message(m) for m in last_msgs]
# Patch last user with enriched_message only for LLM call
if req.enriched_message:
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == "user":
messages = messages.copy()
messages[i] = {**messages[i], "content": req.enriched_message}
break
# Sources to persist with the assistant reply
sources = req.sources or []
if req.stream:
async def stream_generator():
full_reply = ""
try:
async for chunk in ollama_chat_stream(req.model, messages):
full_reply += chunk
yield chunk
except Exception as e:
yield f"Ollama error: {e}"
# Persist assistant reply (include sources_json)
as_row = models.ChatMessage(
session_pk=session.id, role='assistant', content=full_reply,
sources_json=json.dumps(sources or [])
)
db.add(as_row)
db.commit()
return StreamingResponse(stream_generator(), media_type="text/plain")
else:
try:
reply = await ollama_chat(req.model, messages)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama error: {e}")
as_row = models.ChatMessage(
session_pk=session.id, role='assistant', content=reply,
sources_json=json.dumps(sources or [])
)
db.add(as_row)
db.commit()
return {"reply": reply}
@app.post("/generate-title", response_model=schemas.GenerateTitleResponse)
async def generate_title(req: schemas.GenerateTitleRequest, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == req.session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
prompt = f"Generate a very short, concise title (5 words or less) for a chat conversation that begins with this user message: \"{req.message}\". Do not use quotation marks in the title."
try:
title = await ollama_chat(req.model, [{"role": "user", "content": prompt}])
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama error: {e}")
print(f"Original title from LLM: {title}") # Debugging line to see the raw title
cleaned_title = sanitize_chat_title(title)
print(f"Cleaned title before saving: {cleaned_title}") # Debugging line to see the cleaned title
session.name = cleaned_title
db.commit()
return {"title": cleaned_title}
@app.delete("/sessions/{session_id}")
def delete_session(session_id: str, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Delete associated messages
db.query(models.ChatMessage).filter(models.ChatMessage.session_pk == session.id).delete()
db.delete(session)
db.commit()
return {"ok": True}
@app.put("/sessions/{session_id}/rename")
def rename_session(session_id: str, req: schemas.GenerateTitleResponse, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
session.name = sanitize_chat_title(req.title)
db.commit()
return {"ok": True}
@app.put("/sessions/{session_id}/messages/{index}")
def update_user_message(session_id: str, index: int, req: schemas.EditMessageRequest, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
msgs = (
db.query(models.ChatMessage)
.filter(models.ChatMessage.session_pk == session.id)
.order_by(models.ChatMessage.created_at.asc())
.all()
)
if index < 0 or index >= len(msgs):
raise HTTPException(status_code=404, detail="Message index out of range")
# Only user messages can be edited per spec
if msgs[index].role != "user":
raise HTTPException(status_code=400, detail="Only user messages can be edited")
# Update the content
msgs[index].content = req.message
# Drop everything after the edited message
for m in msgs[index + 1:]:
db.delete(m)
db.commit()
return {"ok": True}
# ADD or REPLACE this whole function
@app.post("/sessions/{session_id}/regenerate")
async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Session = Depends(get_db)):
idx = req.index
model = req.model
stream = bool(req.stream)
sources = req.sources or []
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
msgs = (
db.query(models.ChatMessage)
.filter(models.ChatMessage.session_pk == session.id)
.order_by(models.ChatMessage.created_at.asc())
.all()
)
if idx < 0 or idx >= len(msgs):
raise HTTPException(status_code=400, detail="Invalid message index")
# last user idx at/before idx
last_user_idx = idx
for i in range(idx, -1, -1):
if msgs[i].role == "user":
last_user_idx = i
break
# prune after that user
if last_user_idx < len(msgs) - 1:
for m in msgs[last_user_idx + 1:]:
db.delete(m)
db.commit()
conversation = [_row_to_ollama_message(m) for m in msgs[: last_user_idx + 1]]
if req.enriched_message:
for j in range(len(conversation) - 1, -1, -1):
if conversation[j].get("role") == "user":
conversation = conversation.copy()
conversation[j] = {**conversation[j], "content": req.enriched_message}
break
session_pk = session.id
if stream:
async def stream_generator():
full_reply = ""
try:
async for chunk in ollama_chat_stream(model, conversation):
full_reply += chunk
yield chunk
except Exception as e:
yield f"Ollama error: {e}"
# persist (with sources)
try:
db_sess = SessionLocal()
db_sess.add(models.ChatMessage(
session_pk=session_pk, role="assistant", content=full_reply,
sources_json=json.dumps(sources or [])
))
db_sess.commit()
finally:
try:
db_sess.close()
except Exception:
pass
return StreamingResponse(stream_generator(), media_type="text/plain")
try:
reply = await ollama_chat(model, conversation)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama error: {e}")
db.add(models.ChatMessage(
session_pk=session_pk, role="assistant", content=reply,
sources_json=json.dumps(sources or [])
))
db.commit()
return {"reply": reply}
# -----------------------------------------------------------------------------
# Web search enrichment endpoint
@app.post("/websearch", response_model=schemas.WebSearchResponse)
async def websearch_route(req: schemas.WebSearchRequest):
"""
Return an enriched prompt (with citations) for a given user prompt.
Optionally uses the last `history_limit` turns from `req.messages`.
"""
try:
messages = (req.messages or [])[-int(req.history_limit or 8):]
enriched, sources = await enrich_prompt(
user_prompt=req.prompt,
model=req.model,
messages=[{"role": m.role, "content": m.content} for m in messages],
searx_url=req.searx_url,
engines=req.engines,
)
context_block = ""
if "<websearch_context>" in enriched:
context_block = enriched[enriched.index("<websearch_context>"):].strip()
return {"enriched_prompt": enriched, "sources": sources, "context_block": context_block}
except Exception:
return {"enriched_prompt": req.prompt, "sources": [], "context_block": ""}
# To run standalone: python -m uvicorn backend.main:app --host 127.0.0.1 --port 8000