import asyncio from fastapi import FastAPI, Depends, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from typing import Any, List, Optional import re import html import json import base64 import mimetypes import shutil import tempfile from pathlib import Path from . import models, schemas from .database import Base, engine, SessionLocal, ensure_sources_column from .local_rag import router as local_rag_router from .ollama_admin import inspect_ollama_startup, prepare_startup_models, pull_local_model, start_local_ollama from .ollama_client import ( list_model_catalog as ollama_list_model_catalog, chat as ollama_chat, chat_stream as ollama_chat_stream, show_model as ollama_show_model, supports_vision as ollama_supports_vision, ) from .whisper_admin import DEFAULT_WHISPER_MODEL, list_whisper_models, transcribe_audio_bytes from .websearch import enrich_prompt # Create tables + ensure migration Base.metadata.create_all(bind=engine) ensure_sources_column(engine) app = FastAPI(title="LLM Desktop Backend", version="0.1.0" ) def sanitize_chat_title(title: str) -> str: cleaned_title = html.unescape(title or "") cleaned_title = re.sub(r'.*?', '', cleaned_title, flags=re.DOTALL | re.IGNORECASE) cleaned_title = cleaned_title.strip() previous_title = None while cleaned_title and cleaned_title != previous_title: previous_title = cleaned_title cleaned_title = re.sub(r'^\s*#+\s*', '', cleaned_title) cleaned_title = re.sub(r'^\s*\*{1,2}\s*', '', cleaned_title) cleaned_title = re.sub(r'\s*\*{1,2}\s*$', '', cleaned_title) cleaned_title = cleaned_title.strip() cleaned_title = re.sub(r'\s+', ' ', cleaned_title) return cleaned_title.strip() # CORS (dev-friendly; tighten later) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.include_router(local_rag_router) _IMAGE_DATA_URL_RE = re.compile(r"^data:(image\/[a-z0-9.+-]+);base64,([a-z0-9+/=\s]+)$", re.IGNORECASE) _CHAT_IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".heic", ".avif"} _CHAT_FILE_EXTENSIONS = { ".pdf", ".html", ".htm", ".txt", ".md", ".rst", ".epub", ".mp3", ".wav", ".m4a", ".flac", ".ogg", ".opus", ".aac", ".mp4", ".mkv", ".mov", ".webm", ".avi", ".ts", } def _attachment_field(item: Any, field: str) -> Any: if isinstance(item, dict): return item.get(field) return getattr(item, field, None) def _normalize_attachment_name(value: Any, fallback: str) -> str: name = str(value or "").strip() or fallback return name[:255] def _normalize_image_attachment(item: Any) -> Optional[dict]: data_url = str(_attachment_field(item, "data_url") or "").strip() match = _IMAGE_DATA_URL_RE.match(data_url) if not match: return None detected_mime = match.group(1).lower() payload = re.sub(r"\s+", "", match.group(2)) mime_type = str(_attachment_field(item, "mime_type") or "").strip().lower() if mime_type and not mime_type.startswith("image/"): return None try: base64.b64decode(payload, validate=True) except Exception: return None return { "kind": "image", "name": _normalize_attachment_name(_attachment_field(item, "name"), "image"), "mime_type": mime_type or detected_mime, "data_url": f"data:{detected_mime};base64,{payload}", } def _normalize_file_attachment(item: Any, *, include_text: bool = False) -> Optional[dict]: source_path_raw = str(_attachment_field(item, "source_path") or "").strip() source_path = "" if source_path_raw: try: source_path = str(Path(source_path_raw).expanduser().resolve()) except Exception: source_path = source_path_raw text_value = str(_attachment_field(item, "text") or "") name_seed = _attachment_field(item, "name") or (Path(source_path).name if source_path else "file") name = _normalize_attachment_name(name_seed, "file") mime_type = str(_attachment_field(item, "mime_type") or "").split(";", 1)[0].strip().lower() record_type = str(_attachment_field(item, "record_type") or "").strip().lower() or None size = _attachment_field(item, "size") try: size = max(0, int(size)) if size is not None else None except Exception: size = None extension = Path(source_path or name).suffix.lower() if source_path and extension not in _CHAT_FILE_EXTENSIONS: return None if not source_path and not text_value.strip(): return None payload = { "kind": "file", "name": name, "mime_type": mime_type or (mimetypes.guess_type(name)[0] or "").lower() or None, "source_path": source_path or None, "size": size, "record_type": record_type, } if include_text and text_value.strip(): payload["text"] = text_value.strip() return payload def _normalize_chat_attachments(items: Any, *, include_text: bool = False) -> List[dict]: cleaned: List[dict] = [] for item in items or []: kind = str(_attachment_field(item, "kind") or "").strip().lower() data_url = str(_attachment_field(item, "data_url") or "").strip() if kind == "image" or (not kind and data_url): normalized = _normalize_image_attachment(item) else: normalized = _normalize_file_attachment(item, include_text=include_text) if normalized: cleaned.append(normalized) return cleaned def _normalize_chat_attachment_or_raise(item: Any, *, include_text: bool = False) -> dict: kind = str(_attachment_field(item, "kind") or "").strip().lower() data_url = str(_attachment_field(item, "data_url") or "").strip() source_path = str(_attachment_field(item, "source_path") or "").strip() name_hint = _attachment_field(item, "name") or (Path(source_path).name if source_path else None) or "attachment" label = _normalize_attachment_name(name_hint, "attachment") if kind == "image" or (not kind and data_url): normalized = _normalize_image_attachment(item) if normalized: return normalized raise HTTPException(status_code=400, detail=f"{label}: invalid image attachment.") normalized = _normalize_file_attachment(item, include_text=include_text) if normalized: return normalized if source_path: extension = Path(source_path).suffix.lower() if extension in _CHAT_IMAGE_EXTENSIONS: raise HTTPException(status_code=400, detail=f"{label}: use Add Image(s) for image attachments.") raise HTTPException(status_code=400, detail=f"{label}: unsupported file type for chat attachment.") raise HTTPException(status_code=400, detail=f"{label}: local file path is required for non-image attachments.") def _load_message_attachments(raw_value: Any, *, include_text: bool = False) -> List[dict]: if isinstance(raw_value, str): try: parsed = json.loads(raw_value or "[]") except Exception: parsed = [] else: parsed = raw_value return _normalize_chat_attachments(parsed, include_text=include_text) def _attachment_is_image(attachment: dict) -> bool: kind = str(attachment.get("kind") or "").strip().lower() if kind == "image": return True return bool(_IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip())) def _attachment_is_file(attachment: dict) -> bool: return not _attachment_is_image(attachment) def _attachments_to_ollama_images(attachments: List[dict]) -> List[str]: images: List[str] = [] for attachment in attachments: if not _attachment_is_image(attachment): continue match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip()) if not match: continue images.append(re.sub(r"\s+", "", match.group(2))) return images def _attachment_history_payload(attachment: dict) -> dict: payload = { "kind": attachment.get("kind") or ("image" if _attachment_is_image(attachment) else "file"), "name": attachment.get("name"), "mime_type": attachment.get("mime_type"), } if _attachment_is_image(attachment): payload["data_url"] = attachment.get("data_url") else: payload["source_path"] = attachment.get("source_path") payload["size"] = attachment.get("size") payload["record_type"] = attachment.get("record_type") return {key: value for key, value in payload.items() if value not in (None, "", [])} async def _model_supports_vision(model_name: Optional[str]) -> bool: normalized = str(model_name or "").strip() if not normalized: return False try: model_data = await ollama_show_model(normalized) except Exception as exc: raise HTTPException(status_code=502, detail=f"Ollama not available: {exc}") from exc return ollama_supports_vision(model_data) def _safe_temp_attachment_name(index: int, name: str, fallback_suffix: str = "") -> str: safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._") or f"attachment-{index + 1}" suffix = Path(safe_name).suffix if fallback_suffix and not suffix: safe_name = f"{safe_name}{fallback_suffix}" return f"{index:03d}-{safe_name}" def _materialize_attachment_source(attachment: dict, stage_root: Path, index: int) -> tuple[Path, dict]: if _attachment_is_image(attachment): match = _IMAGE_DATA_URL_RE.match(str(attachment.get("data_url") or "").strip()) if not match: raise RuntimeError(f"{attachment.get('name') or 'image'}: invalid image payload.") mime_type = match.group(1).lower() payload = re.sub(r"\s+", "", match.group(2)) try: image_bytes = base64.b64decode(payload, validate=True) except Exception as exc: raise RuntimeError(f"{attachment.get('name') or 'image'}: invalid base64 image payload.") from exc suffix = Path(str(attachment.get("name") or "")).suffix or mimetypes.guess_extension(mime_type) or ".img" target_path = stage_root / _safe_temp_attachment_name(index, str(attachment.get("name") or "image"), suffix) target_path.write_bytes(image_bytes) return target_path, { "source_path": str(target_path.resolve()), "size": len(image_bytes), } source_path = str(attachment.get("source_path") or "").strip() if not source_path: raise RuntimeError(f"{attachment.get('name') or 'file'}: local file path is required.") source = Path(source_path).expanduser().resolve() if not source.exists() or not source.is_file(): raise RuntimeError(f"{attachment.get('name') or source.name}: file is no longer available.") if source.suffix.lower() not in _CHAT_FILE_EXTENSIONS: raise RuntimeError(f"{source.name}: unsupported file type for chat attachment.") target_path = stage_root / _safe_temp_attachment_name(index, source.name) try: target_path.symlink_to(source) except Exception: shutil.copy2(source, target_path) return target_path, { "source_path": str(source), "size": int(source.stat().st_size), } def _read_corpus_records(path: Path) -> List[dict]: records: List[dict] = [] if not path.exists(): return records with path.open("r", encoding="utf-8", errors="ignore") as handle: for line in handle: line = line.strip() if not line: continue records.append(json.loads(line)) return records def _extract_attachment_texts( attachments: List[dict], *, vision_model: Optional[str], transcription_model: Optional[str], ) -> List[dict]: from .rag.corpus_builder import run_build if not attachments: return [] with tempfile.TemporaryDirectory(prefix="heimgeist-chat-attachments-") as tmpdir: temp_root = Path(tmpdir) stage_root = temp_root / "stage" stage_root.mkdir(parents=True, exist_ok=True) corpus_path = temp_root / "corpus.jsonl" index_by_key: dict[str, int] = {} attachment_meta: List[dict] = [] for index, attachment in enumerate(attachments): temp_path, extra_meta = _materialize_attachment_source(attachment, stage_root, index) index_by_key[str(temp_path)] = index index_by_key[str(temp_path.resolve())] = index attachment_meta.append(extra_meta) result = run_build( root=stage_root, out=corpus_path, emit="per-file", emit_av="joined", lang_detect=False, whisper_model=transcription_model or DEFAULT_WHISPER_MODEL, vlm_model=vision_model or "qwen2.5vl:7b", ) errors = [str(error).strip() for error in (result.get("errors") or []) if str(error).strip()] if errors: raise RuntimeError("; ".join(errors)) extracted: dict[int, dict[str, Any]] = {} for record in _read_corpus_records(corpus_path): record_id = str(record.get("id") or "").split("#", 1)[0].strip() source_path = str(record.get("source_path") or "").strip() matched_index = None for candidate in (record_id, source_path): if candidate in index_by_key: matched_index = index_by_key[candidate] break if matched_index is None: continue bucket = extracted.setdefault(matched_index, { "mime_type": "", "record_type": "", "text_parts": [], }) if not bucket["mime_type"]: bucket["mime_type"] = str(record.get("mime") or "").strip().lower() if not bucket["record_type"]: bucket["record_type"] = str(record.get("record_type") or "").strip().lower() text = str(record.get("text") or "").strip() if text: bucket["text_parts"].append(text) prepared: List[dict] = [] for index, attachment in enumerate(attachments): bucket = extracted.get(index) or {} text = "\n\n".join( part for part in bucket.get("text_parts") or [] if str(part or "").strip() ).strip() if not text: raise RuntimeError(f"{attachment.get('name') or 'attachment'}: no usable text could be extracted.") extra_meta = attachment_meta[index] prepared.append({ **attachment, "mime_type": bucket.get("mime_type") or attachment.get("mime_type"), "record_type": bucket.get("record_type") or attachment.get("record_type"), "source_path": extra_meta.get("source_path") or attachment.get("source_path"), "size": extra_meta.get("size") if extra_meta.get("size") is not None else attachment.get("size"), "text": text, }) return prepared def _build_attachment_block(attachments: List[dict]) -> str: sections: List[str] = [] for index, attachment in enumerate(attachments, start=1): text = str(attachment.get("text") or "").strip() if not text: continue lines = [ f"[Attachment {index}]", f"Name: {attachment.get('name') or f'attachment-{index}'}", f"Kind: {'image' if _attachment_is_image(attachment) else 'file'}", ] mime_type = str(attachment.get("mime_type") or "").strip() record_type = str(attachment.get("record_type") or "").strip() if mime_type: lines.append(f"MIME: {mime_type}") if record_type: lines.append(f"Record type: {record_type}") lines.append("Content:") lines.append(text) sections.append("\n".join(lines)) if not sections: return "" return "\n" + "\n\n".join(sections) + "\n" def _compose_user_message_content(message: str, enriched_message: Optional[str], attachment_block: str) -> str: base_content = str(enriched_message or "").strip() or str(message or "").strip() parts: List[str] = [] if base_content: parts.append(base_content) elif attachment_block: parts.append("Please analyze the attached material.") if attachment_block: parts.append(attachment_block) if not parts: return "Please analyze the attached material." return "\n\n".join(part for part in parts if str(part or "").strip()).strip() def _attachments_have_images(attachments: List[dict]) -> bool: return any(_attachment_is_image(attachment) for attachment in attachments or []) async def _prepare_chat_message_attachments( attachments: List[dict], *, request_model_supports_vision: bool, vision_model: Optional[str], transcription_model: Optional[str], persist_file_text: bool, ) -> tuple[List[dict], List[str], str]: prepared = [ _normalize_chat_attachment_or_raise(item, include_text=True) for item in (attachments or []) ] if not prepared: return [], [], "" if _attachments_have_images(prepared) and not request_model_supports_vision: if not str(vision_model or "").strip(): raise HTTPException( status_code=400, detail="Image attachments require a configured vision model when the selected chat model does not support vision.", ) if not await _model_supports_vision(vision_model): raise HTTPException( status_code=400, detail="The configured vision model does not support image inputs.", ) prompt_entries: List[Optional[dict]] = [None] * len(prepared) attachments_to_extract: List[dict] = [] extract_indexes: List[int] = [] for index, attachment in enumerate(prepared): if _attachment_is_image(attachment): if request_model_supports_vision: continue attachments_to_extract.append(attachment) extract_indexes.append(index) continue text = str(attachment.get("text") or "").strip() if text: prompt_entries[index] = attachment else: attachments_to_extract.append(attachment) extract_indexes.append(index) if attachments_to_extract: extracted_attachments = await asyncio.to_thread( _extract_attachment_texts, attachments_to_extract, vision_model=vision_model, transcription_model=transcription_model, ) for original_index, extracted in zip(extract_indexes, extracted_attachments): if _attachment_is_file(prepared[original_index]) and persist_file_text: prepared[original_index] = { **prepared[original_index], "mime_type": extracted.get("mime_type") or prepared[original_index].get("mime_type"), "record_type": extracted.get("record_type") or prepared[original_index].get("record_type"), "source_path": extracted.get("source_path") or prepared[original_index].get("source_path"), "size": extracted.get("size") if extracted.get("size") is not None else prepared[original_index].get("size"), "text": extracted.get("text") or prepared[original_index].get("text"), } prompt_entries[original_index] = prepared[original_index] else: prompt_entries[original_index] = extracted prompt_attachments = [entry for entry in prompt_entries if entry and str(entry.get("text") or "").strip()] ollama_images = _attachments_to_ollama_images(prepared) if request_model_supports_vision else [] return prepared, ollama_images, _build_attachment_block(prompt_attachments) def _row_to_history_message(row: models.ChatMessage) -> dict: sources = [] try: if getattr(row, "sources_json", None): sources = json.loads(row.sources_json or "[]") except Exception: sources = [] attachments = _load_message_attachments(getattr(row, "attachments_json", None)) payload = {"role": row.role, "content": row.content, "sources": sources} if attachments: payload["attachments"] = [_attachment_history_payload(attachment) for attachment in attachments] return payload async def _build_ollama_user_message( message: str, attachments: List[dict], *, enriched_message: Optional[str], request_model_supports_vision: bool, vision_model: Optional[str], transcription_model: Optional[str], persist_file_text: bool, ) -> tuple[dict, List[dict]]: prepared_attachments, ollama_images, attachment_block = await _prepare_chat_message_attachments( attachments, request_model_supports_vision=request_model_supports_vision, vision_model=vision_model, transcription_model=transcription_model, persist_file_text=persist_file_text, ) payload = { "role": "user", "content": _compose_user_message_content(message, enriched_message, attachment_block), } if ollama_images: payload["images"] = ollama_images return payload, prepared_attachments async def _build_ollama_messages_from_rows( rows: List[models.ChatMessage], *, request_model_supports_vision: bool, vision_model: Optional[str], transcription_model: Optional[str], override_user_row_index: Optional[int] = None, override_user_content: Optional[str] = None, ) -> List[dict]: conversation: List[dict] = [] for row_index, row in enumerate(rows): if row.role != "user": conversation.append({"role": row.role, "content": row.content}) continue attachments = _load_message_attachments(getattr(row, "attachments_json", None), include_text=True) enriched_message = override_user_content if override_user_row_index == row_index else None message_payload, _prepared = await _build_ollama_user_message( row.content, attachments, enriched_message=enriched_message, request_model_supports_vision=request_model_supports_vision, vision_model=vision_model, transcription_model=transcription_model, persist_file_text=False, ) conversation.append(message_payload) return conversation def get_db(): db = SessionLocal() try: yield db finally: db.close() @app.get("/health") def health(): return {"ok": True} @app.post("/audio/transcribe", response_model=schemas.AudioTranscriptionResponse) async def transcribe_audio_route(req: schemas.AudioTranscriptionRequest): mime_type = str(req.mime_type or "").split(";", 1)[0].strip().lower() if not mime_type.startswith("audio/"): raise HTTPException(status_code=400, detail="An audio mime type is required.") payload = re.sub(r"\s+", "", str(req.audio_base64 or "")) if not payload: raise HTTPException(status_code=400, detail="Audio payload is required.") try: audio_bytes = base64.b64decode(payload, validate=True) except Exception as exc: raise HTTPException(status_code=400, detail="Invalid base64 audio payload.") from exc try: result = await asyncio.to_thread( transcribe_audio_bytes, audio_bytes, mime_type, req.model or DEFAULT_WHISPER_MODEL, req.language, ) except RuntimeError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=f"Audio transcription failed: {exc}") from exc return { "text": str(result.get("text") or "").strip(), "language": str(result.get("language") or "").strip() or None, "model": str(result.get("model") or req.model or DEFAULT_WHISPER_MODEL), } @app.get("/models") async def get_models(): try: ollama_data, whisper_data = await asyncio.gather( ollama_list_model_catalog(), asyncio.to_thread(list_whisper_models), ) return { **ollama_data, "whisper_models": whisper_data.get("models", []), "whisper_error": whisper_data.get("error", ""), } except Exception as e: raise HTTPException(status_code=502, detail=f"Ollama not available: {e}") @app.get("/models/capabilities") async def get_model_capabilities(name: str): model_name = str(name or "").strip() if not model_name: raise HTTPException(status_code=400, detail="Model name is required.") try: model_data = await ollama_show_model(model_name) except Exception as e: raise HTTPException(status_code=502, detail=f"Ollama not available: {e}") capabilities = [ str(item).strip() for item in (model_data.get("capabilities") or []) if str(item).strip() ] return { "name": model_name, "capabilities": capabilities, "supports_vision": ollama_supports_vision(model_data), } @app.get("/ollama/startup-status") async def ollama_startup_status(): return await inspect_ollama_startup() @app.post("/ollama/start") async def ollama_start_route(): try: return await start_local_ollama() except FileNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except RuntimeError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.post("/ollama/pull") async def ollama_pull_route(req: schemas.OllamaPullRequest): try: return await pull_local_model(req.model) except FileNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except RuntimeError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.post("/startup/prepare-models") async def startup_prepare_models_route(): try: return await prepare_startup_models() except FileNotFoundError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except RuntimeError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.get("/sessions", response_model=schemas.SessionsResponse) def get_sessions(db: Session = Depends(get_db)): sessions = db.query(models.ChatSession).order_by(models.ChatSession.created_at.desc()).all() return { "sessions": [ { "id": session.id, "session_id": session.session_id, "name": sanitize_chat_title(session.name), "created_at": session.created_at, } for session in sessions ] } @app.post("/sessions", response_model=schemas.ChatSession) def create_session(req: schemas.CreateSessionRequest, db: Session = Depends(get_db)): new_session = models.ChatSession(session_id=req.session_id) db.add(new_session) db.commit() db.refresh(new_session) return new_session @app.get("/history", response_model=schemas.HistoryResponse) def history(session_id: str, db: Session = Depends(get_db)): session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first() if not session: return {"messages": []} rows = ( db.query(models.ChatMessage) .filter(models.ChatMessage.session_pk == session.id) .order_by(models.ChatMessage.created_at.asc()) .all() ) msgs = [_row_to_history_message(r) for r in rows] return {"messages": msgs} @app.post("/chat") async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)): # Find or create session session = db.query(models.ChatSession).filter(models.ChatSession.session_id == req.session_id).first() if not session: session = models.ChatSession(session_id=req.session_id) db.add(session) db.commit() db.refresh(session) request_model_supports_vision = await _model_supports_vision(req.model) prior_rows = ( db.query(models.ChatMessage) .filter(models.ChatMessage.session_pk == session.id) .order_by(models.ChatMessage.created_at.asc()) .all() ) history_rows = prior_rows[-19:] history_messages = await _build_ollama_messages_from_rows( history_rows, request_model_supports_vision=request_model_supports_vision, vision_model=req.vision_model, transcription_model=req.transcription_model, ) current_user_message, user_attachments = await _build_ollama_user_message( req.message, req.attachments or [], enriched_message=req.enriched_message, request_model_supports_vision=request_model_supports_vision, vision_model=req.vision_model, transcription_model=req.transcription_model, persist_file_text=True, ) session_pk = session.id user_row = models.ChatMessage( session_pk=session_pk, role='user', content=req.message, attachments_json=json.dumps(user_attachments or []), ) db.add(user_row) db.commit() messages = [*history_messages, current_user_message] # Sources to persist with the assistant reply sources = req.sources or [] if req.stream: async def stream_generator(): full_reply = "" try: async for chunk in ollama_chat_stream(req.model, messages): full_reply += chunk yield chunk except Exception as e: yield f"Ollama error: {e}" # Persist assistant reply (include sources_json) db_sess = None try: db_sess = SessionLocal() db_sess.add(models.ChatMessage( session_pk=session_pk, role='assistant', content=full_reply, sources_json=json.dumps(sources or []), )) db_sess.commit() finally: if db_sess is not None: db_sess.close() return StreamingResponse(stream_generator(), media_type="text/plain") else: try: reply = await ollama_chat(req.model, messages) except Exception as e: raise HTTPException(status_code=502, detail=f"Ollama error: {e}") as_row = models.ChatMessage( session_pk=session_pk, role='assistant', content=reply, sources_json=json.dumps(sources or []) ) db.add(as_row) db.commit() return {"reply": reply} @app.post("/generate-title", response_model=schemas.GenerateTitleResponse) async def generate_title(req: schemas.GenerateTitleRequest, db: Session = Depends(get_db)): session = db.query(models.ChatSession).filter(models.ChatSession.session_id == req.session_id).first() if not session: raise HTTPException(status_code=404, detail="Session not found") prompt = f"Generate a very short, concise title (5 words or less) for a chat conversation that begins with this user message: \"{req.message}\". Do not use quotation marks in the title." try: title = await ollama_chat(req.model, [{"role": "user", "content": prompt}]) except Exception as e: raise HTTPException(status_code=502, detail=f"Ollama error: {e}") print(f"Original title from LLM: {title}") # Debugging line to see the raw title cleaned_title = sanitize_chat_title(title) print(f"Cleaned title before saving: {cleaned_title}") # Debugging line to see the cleaned title session.name = cleaned_title db.commit() return {"title": cleaned_title} @app.delete("/sessions/{session_id}") def delete_session(session_id: str, db: Session = Depends(get_db)): session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first() if not session: raise HTTPException(status_code=404, detail="Session not found") # Delete associated messages db.query(models.ChatMessage).filter(models.ChatMessage.session_pk == session.id).delete() db.delete(session) db.commit() return {"ok": True} @app.put("/sessions/{session_id}/rename") def rename_session(session_id: str, req: schemas.GenerateTitleResponse, db: Session = Depends(get_db)): session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first() if not session: raise HTTPException(status_code=404, detail="Session not found") session.name = sanitize_chat_title(req.title) db.commit() return {"ok": True} @app.put("/sessions/{session_id}/messages/{index}") def update_user_message(session_id: str, index: int, req: schemas.EditMessageRequest, db: Session = Depends(get_db)): session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first() if not session: raise HTTPException(status_code=404, detail="Session not found") msgs = ( db.query(models.ChatMessage) .filter(models.ChatMessage.session_pk == session.id) .order_by(models.ChatMessage.created_at.asc()) .all() ) if index < 0 or index >= len(msgs): raise HTTPException(status_code=404, detail="Message index out of range") # Only user messages can be edited per spec if msgs[index].role != "user": raise HTTPException(status_code=400, detail="Only user messages can be edited") # Update the content msgs[index].content = req.message # Drop everything after the edited message for m in msgs[index + 1:]: db.delete(m) db.commit() return {"ok": True} # ADD or REPLACE this whole function @app.post("/sessions/{session_id}/regenerate") async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Session = Depends(get_db)): idx = req.index model = req.model stream = bool(req.stream) sources = req.sources or [] session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first() if not session: raise HTTPException(status_code=404, detail="Session not found") msgs = ( db.query(models.ChatMessage) .filter(models.ChatMessage.session_pk == session.id) .order_by(models.ChatMessage.created_at.asc()) .all() ) if idx < 0 or idx >= len(msgs): raise HTTPException(status_code=400, detail="Invalid message index") # last user idx at/before idx last_user_idx = idx for i in range(idx, -1, -1): if msgs[i].role == "user": last_user_idx = i break request_model_supports_vision = await _model_supports_vision(model) conversation = await _build_ollama_messages_from_rows( msgs[: last_user_idx + 1], request_model_supports_vision=request_model_supports_vision, vision_model=req.vision_model, transcription_model=req.transcription_model, override_user_row_index=last_user_idx, override_user_content=req.enriched_message, ) # prune after that user only after conversation building succeeds if last_user_idx < len(msgs) - 1: for m in msgs[last_user_idx + 1:]: db.delete(m) db.commit() session_pk = session.id if stream: async def stream_generator(): full_reply = "" try: async for chunk in ollama_chat_stream(model, conversation): full_reply += chunk yield chunk except Exception as e: yield f"Ollama error: {e}" # persist (with sources) try: db_sess = SessionLocal() db_sess.add(models.ChatMessage( session_pk=session_pk, role="assistant", content=full_reply, sources_json=json.dumps(sources or []) )) db_sess.commit() finally: try: db_sess.close() except Exception: pass return StreamingResponse(stream_generator(), media_type="text/plain") try: reply = await ollama_chat(model, conversation) except Exception as e: raise HTTPException(status_code=502, detail=f"Ollama error: {e}") db.add(models.ChatMessage( session_pk=session_pk, role="assistant", content=reply, sources_json=json.dumps(sources or []) )) db.commit() return {"reply": reply} # ----------------------------------------------------------------------------- # Web search enrichment endpoint @app.post("/websearch", response_model=schemas.WebSearchResponse) async def websearch_route(req: schemas.WebSearchRequest): """ Return an enriched prompt (with citations) for a given user prompt. Optionally uses the last `history_limit` turns from `req.messages`. """ try: messages = (req.messages or [])[-int(req.history_limit or 8):] enriched, sources = await enrich_prompt( user_prompt=req.prompt, model=req.model, messages=[{"role": m.role, "content": m.content} for m in messages], searx_url=req.searx_url, engines=req.engines, ) context_block = "" if "" in enriched: context_block = enriched[enriched.index(""):].strip() return {"enriched_prompt": enriched, "sources": sources, "context_block": context_block} except Exception: return {"enriched_prompt": req.prompt, "sources": [], "context_block": ""} # To run standalone: python -m uvicorn backend.main:app --host 127.0.0.1 --port 8000