diff --git a/backend/database.py b/backend/database.py index cbd8c5b..e15d462 100644 --- a/backend/database.py +++ b/backend/database.py @@ -1,6 +1,20 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, DeclarativeBase +from sqlalchemy import text + +""" +Database utilities and configuration. This module defines the SQLAlchemy +engine, session factory and base class for models. It also contains a +lightweight migration helper used to evolve the schema over time. The +`ensure_sources_column` helper adds a new `sources_json` column to the +`chat_messages` table if it does not already exist. This is required +for persisting citation sources alongside assistant messages. + +The migration uses SQLite's `ALTER TABLE` syntax and therefore should +only run once on startup. It is safe to call repeatedly: when the +column already exists, the function will simply no‑op. +""" DATABASE_URL = "sqlite:///./backend/app.db" @@ -12,3 +26,13 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) class Base(DeclarativeBase): pass + + +def ensure_sources_column(engine): + try: + with engine.connect() as conn: + cols = [row[1] for row in conn.execute(text("PRAGMA table_info(chat_messages)"))] + if "sources_json" not in cols: + conn.execute(text("ALTER TABLE chat_messages ADD COLUMN sources_json TEXT DEFAULT '[]'")) + except Exception as e: + print("[db] ensure_sources_column error:", e) diff --git a/backend/main.py b/backend/main.py index 4bc3240..b6d7b8a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -3,14 +3,17 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from typing import List -import re # Import the regex module -import html # Import the html module for unescaping +import re +import html +import json from . import models, schemas -from .database import Base, engine, SessionLocal +from .database import Base, engine, SessionLocal, ensure_sources_column from .ollama_client import list_models as ollama_list, chat as ollama_chat, chat_stream as ollama_chat_stream +from .websearch import enrich_prompt -# Create tables +# Create tables + ensure migration Base.metadata.create_all(bind=engine) +ensure_sources_column(engine) app = FastAPI(title="LLM Desktop Backend", version="0.1.0" ) @@ -60,8 +63,21 @@ def history(session_id: str, db: Session = Depends(get_db)): session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first() if not session: return {"messages": []} - rows = db.query(models.ChatMessage) .filter(models.ChatMessage.session_pk == session.id) .order_by(models.ChatMessage.created_at.asc()) .all() - msgs = [{"role": r.role, "content": r.content} for r in rows] + rows = ( + db.query(models.ChatMessage) + .filter(models.ChatMessage.session_pk == session.id) + .order_by(models.ChatMessage.created_at.asc()) + .all() + ) + msgs = [] + for r in rows: + sources = [] + try: + if getattr(r, "sources_json", None): + sources = json.loads(r.sources_json or "[]") + except Exception: + sources = [] + msgs.append({"role": r.role, "content": r.content, "sources": sources}) return {"messages": msgs} @app.post("/chat") @@ -74,16 +90,31 @@ async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)): db.commit() db.refresh(session) - # Save user message + # Store the BASE user prompt user_row = models.ChatMessage(session_pk=session.id, role='user', content=req.message) db.add(user_row) db.commit() - # Build minimal conversation context (last 20 messages) - last_msgs = db.query(models.ChatMessage) .filter(models.ChatMessage.session_pk == session.id) .order_by(models.ChatMessage.created_at.asc()) .all()[-20:] - + # Build minimal context (last 20) + last_msgs = ( + db.query(models.ChatMessage) + .filter(models.ChatMessage.session_pk == session.id) + .order_by(models.ChatMessage.created_at.asc()) + .all()[-20:] + ) messages = [{"role": m.role, "content": m.content} for m in last_msgs] + # Patch last user with enriched_message only for LLM call + if req.enriched_message: + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == "user": + messages = messages.copy() + messages[i] = {**messages[i], "content": req.enriched_message} + break + + # Sources to persist with the assistant reply + sources = req.sources or [] + if req.stream: async def stream_generator(): full_reply = "" @@ -92,14 +123,16 @@ async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)): full_reply += chunk yield chunk except Exception as e: - # How to handle errors in a stream? Could yield an error message. yield f"Ollama error: {e}" - # Save full reply after stream is complete - as_row = models.ChatMessage(session_pk=session.id, role='assistant', content=full_reply) + # Persist assistant reply (include sources_json) + as_row = models.ChatMessage( + session_pk=session.id, role='assistant', content=full_reply, + sources_json=json.dumps(sources or []) + ) db.add(as_row) db.commit() - + return StreamingResponse(stream_generator(), media_type="text/plain") else: try: @@ -107,11 +140,12 @@ async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)): except Exception as e: raise HTTPException(status_code=502, detail=f"Ollama error: {e}") - # Save assistant reply - as_row = models.ChatMessage(session_pk=session.id, role='assistant', content=reply) + as_row = models.ChatMessage( + session_pk=session.id, role='assistant', content=reply, + sources_json=json.dumps(sources or []) + ) db.add(as_row) db.commit() - return {"reply": reply} @app.post("/generate-title", response_model=schemas.GenerateTitleResponse) @@ -200,13 +234,10 @@ def update_user_message(session_id: str, index: int, req: schemas.EditMessageReq # ADD or REPLACE this whole function @app.post("/sessions/{session_id}/regenerate") async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Session = Depends(get_db)): - """ - Regenerate an assistant response for the conversation state at/before req.index. - If req.index points at an assistant message, we regenerate from the preceding user message. - """ idx = req.index model = req.model stream = bool(req.stream) + sources = req.sources or [] session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first() if not session: @@ -218,43 +249,49 @@ async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Sessio .order_by(models.ChatMessage.created_at.asc()) .all() ) - if idx < 0 or idx >= len(msgs): raise HTTPException(status_code=400, detail="Invalid message index") - # Find the last user message at/before idx + # last user idx at/before idx last_user_idx = idx for i in range(idx, -1, -1): if msgs[i].role == "user": last_user_idx = i break - # Prune everything after last_user_idx + # prune after that user if last_user_idx < len(msgs) - 1: for m in msgs[last_user_idx + 1:]: db.delete(m) db.commit() - # Build the conversation up to & incl. the last user message conversation = [{"role": m.role, "content": m.content} for m in msgs[: last_user_idx + 1]] - # Avoid DetachedInstanceError during streaming + if req.enriched_message: + for j in range(len(conversation) - 1, -1, -1): + if conversation[j].get("role") == "user": + conversation = conversation.copy() + conversation[j] = {**conversation[j], "content": req.enriched_message} + break + session_pk = session.id if stream: async def stream_generator(): full_reply = "" try: - # ollama_chat_stream must already exist in your codebase (used by /chat) async for chunk in ollama_chat_stream(model, conversation): full_reply += chunk yield chunk except Exception as e: yield f"Ollama error: {e}" - # Persist with a fresh DB session (streaming context) + # persist (with sources) try: db_sess = SessionLocal() - db_sess.add(models.ChatMessage(session_pk=session_pk, role="assistant", content=full_reply)) + db_sess.add(models.ChatMessage( + session_pk=session_pk, role="assistant", content=full_reply, + sources_json=json.dumps(sources or []) + )) db_sess.commit() finally: try: @@ -264,16 +301,38 @@ async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Sessio return StreamingResponse(stream_generator(), media_type="text/plain") - # Non-streaming try: - # ollama_chat must already exist in your codebase (used by /chat) reply = await ollama_chat(model, conversation) except Exception as e: raise HTTPException(status_code=502, detail=f"Ollama error: {e}") - db.add(models.ChatMessage(session_pk=session_pk, role="assistant", content=reply)) + db.add(models.ChatMessage( + session_pk=session_pk, role="assistant", content=reply, + sources_json=json.dumps(sources or []) + )) db.commit() return {"reply": reply} +# ----------------------------------------------------------------------------- +# Web search enrichment endpoint +@app.post("/websearch", response_model=schemas.WebSearchResponse) +async def websearch_route(req: schemas.WebSearchRequest): + """ + Return an enriched prompt (with citations) for a given user prompt. + Optionally uses the last `history_limit` turns from `req.messages`. + """ + try: + messages = (req.messages or [])[-int(req.history_limit or 8):] + enriched, sources = await enrich_prompt( + user_prompt=req.prompt, + model=req.model, + messages=[{"role": m.role, "content": m.content} for m in messages], + searx_url=req.searx_url, + engines=req.engines, + ) + return {"enriched_prompt": enriched, "sources": sources} + except Exception: + return {"enriched_prompt": req.prompt, "sources": []} + # To run standalone: python -m uvicorn backend.main:app --host 127.0.0.1 --port 8000 diff --git a/backend/models.py b/backend/models.py index b886467..8c66c37 100644 --- a/backend/models.py +++ b/backend/models.py @@ -20,6 +20,8 @@ class ChatMessage(Base): session_pk = Column(Integer, ForeignKey('chat_sessions.id'), nullable=False) role = Column(String(16), nullable=False) # 'user' | 'assistant' content = Column(Text, nullable=False) + # JSON-encoded list of citation URLs; null/empty => no chips + sources_json = Column(Text, nullable=True, default='[]') created_at = Column(DateTime, default=datetime.utcnow, nullable=False) session = relationship("ChatSession", back_populates="messages") diff --git a/backend/requirements.txt b/backend/requirements.txt index f5cdd04..0d395cc 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -4,3 +4,8 @@ uvicorn[standard]==0.30.1 SQLAlchemy==2.0.32 httpx==0.27.0 pydantic==2.7.4 + +# Web search enrichment dependencies +beautifulsoup4==4.12.3 +httpx[http2]>=0.27.0 +numpy \ No newline at end of file diff --git a/backend/schemas.py b/backend/schemas.py index fdd53b9..0cd9a88 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -5,12 +5,15 @@ from datetime import datetime class Message(BaseModel): role: str content: str + sources: Optional[List[str]] = None class ChatRequest(BaseModel): session_id: str model: str message: str + enriched_message: Optional[str] = None stream: Optional[bool] = False + sources: Optional[List[str]] = None class ChatResponse(BaseModel): reply: str @@ -47,4 +50,20 @@ class EditMessageRequest(BaseModel): class RegenerateRequest(BaseModel): index: int model: Optional[str] = None - stream: bool = True \ No newline at end of file + enriched_message: Optional[str] = None + stream: bool = True + sources: Optional[List[str]] = None + +# Request payload for the web search enrichment endpoint. +class WebSearchRequest(BaseModel): + prompt: str + model: str + messages: Optional[List[Message]] = None + history_limit: Optional[int] = 8 + searx_url: Optional[str] = None + engines: Optional[List[str]] = None + +# Response payload for the web search enrichment endpoint. +class WebSearchResponse(BaseModel): + enriched_prompt: str + sources: List[str] = [] diff --git a/backend/websearch.py b/backend/websearch.py new file mode 100644 index 0000000..09238a3 --- /dev/null +++ b/backend/websearch.py @@ -0,0 +1,642 @@ +import asyncio +from typing import List, Tuple, Dict, Any, Optional +import httpx +from bs4 import BeautifulSoup +import re +import json +import traceback +import hashlib + +from .ollama_client import chat as ollama_chat + +# Configure your local SearXNG instance URL (no trailing slash) +SEARX_URL = "http://localhost:8888" + +# ----- Utilities ---------------------------------------------------------------- + +def clean_text(html: str, max_len: int = 120_000) -> str: + # Prefer lxml parser if available (significantly faster); fall back gracefully. + try: + soup = BeautifulSoup(html, "lxml") + except Exception: + soup = BeautifulSoup(html, "html.parser") + + for tag in soup.select("script,style,noscript"): + tag.decompose() + + # Fast text extraction + text = soup.get_text("\n", strip=True) + text = re.sub(r"[ \t]+", " ", text) + text = re.sub(r"\n{3,}", "\n\n", text) + + return text[:max_len] if len(text) > max_len else text + +def render_recent_context(messages: Optional[List[Dict[str, str]]], char_limit: int = 1200) -> str: + """ + Turn the last few turns into a compact excerpt: + user: ... + assistant: ... + Truncate aggressively for resource-friendliness. + """ + if not messages: + return "" + # Keep as-is order (oldest->newest). We assume caller already sliced last N. + lines: List[str] = [] + for m in messages: + role = m.get("role", "user") + content = (m.get("content") or "").strip() + if not content: + continue + # compactify each line + content = re.sub(r"\s+", " ", content) + lines.append(f"{role}: {content}") + joined = "\n".join(lines) + if len(joined) > char_limit: + return joined[-char_limit:] # keep tail (most recent part) + return joined + +# ----- SearXNG search & fetching ------------------------------------------------- + +from urllib.parse import urlparse +import itertools + +# HTTP client tuning (keep-alive pool + HTTP/2) +HTTP_LIMITS = httpx.Limits(max_keepalive_connections=32, max_connections=64) +HTTP_TIMEOUT = httpx.Timeout(connect=5.0, read=10.0, write=10.0, pool=5.0) +DEFAULT_HEADERS = { + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome Safari", + "Accept": "text/html,application/xhtml+xml;q=0.9,*/*;q=0.8", +} + +# Fetch policy +MAX_PAGES_FETCH = 8 # lower cap of pages we’ll fetch per enrichment +FETCH_CONCURRENCY = 8 # concurrent page fetches +MAX_BYTES_PER_PAGE = 1_500_000 # ~1.5 MB read cap per page (streamed) +MIN_TEXT_LENGTH = 500 # discard useless pages early + +SKIP_EXTS = { + ".pdf", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg", + ".mp4", ".mp3", ".mov", ".avi", ".zip", ".gz", ".7z", ".tar", ".rar", + ".woff", ".woff2", ".ttf", ".otf" +} + +_PAGE_CACHE: Dict[str, str] = {} # naive in-proc cache (speeds repeated calls) +_EMB_CACHE: Dict[str, List[float]] = {} # NEW: embedding cache by SHA1 of snippet + +def _is_probably_html_url(url: str) -> bool: + try: + path = urlparse(url).path.lower() + ext = (path.rsplit(".", 1)[-1] if "." in path else "") + return (("." not in path) or (ext and f".{ext}" not in SKIP_EXTS)) + except Exception: + return True + +async def searx_search( + client: httpx.AsyncClient, + query: str, + max_results: int = 6, + *, + searx_url: Optional[str] = None, + engines: Optional[List[str]] = None, +) -> List[Dict[str, Any]]: + base_url = (searx_url or SEARX_URL).rstrip("/") + params = {"q": query, "format": "json", "safesearch": 1} + if engines: + # SearXNG accepts a comma-separated 'engines' filter + params["engines"] = ",".join(engines) + + try: + r = await client.get(f"{base_url}/search", params=params) + r.raise_for_status() + except Exception as e: + print(f"[web] searx_search error for {base_url} q='{query}': {repr(e)}") + return [] + + try: + data = r.json() + except Exception as e: + print(f"[web] searx_search JSON parse failed q='{query}': {repr(e)} | content-type={r.headers.get('content-type')}") + return [] + + seen = set() + out = [] + for item in data.get("results", []): + url = item.get("url") + if not url or url in seen: + continue + seen.add(url) + out.append({"url": url, "title": item.get("title", ""), "engine": item.get("engine")}) + if len(out) >= max_results: + break + if not out: + print(f"[web] searx_search returned 0 urls for q='{query}' (engines={engines})") + return out + +async def searx_search_many( + client: httpx.AsyncClient, + queries: List[str], + per_query_max: int = 6, + overall_max: int = 32, + *, + searx_url: Optional[str] = None, + engines: Optional[List[str]] = None, +) -> List[str]: + print(f"[web] searx_search_many: {len(queries)} queries -> {searx_url or SEARX_URL} engines={engines or 'default'}") + tasks = [searx_search(client, q, max_results=per_query_max, searx_url=searx_url, engines=engines) for q in queries] + results = await asyncio.gather(*tasks, return_exceptions=True) + + urls: List[str] = [] + seen = set() + for q, res in zip(queries, results): + if isinstance(res, Exception): + print(f"[web] searx_search_many task error q='{q}': {repr(res)}") + continue + for item in res: + u = item.get("url") + if not u or u in seen: + continue + if not _is_probably_html_url(u): + print(f"[web] skip non-HTML url: {u}") + continue + seen.add(u) + urls.append(u) + if len(urls) >= overall_max: + print(f"[web] searx_search_many hit overall_max={overall_max}") + return urls + print(f"[web] searx_search_many collected {len(urls)} urls") + return urls + +async def fetch_page(client: httpx.AsyncClient, url: str) -> str: + if url in _PAGE_CACHE: + return _PAGE_CACHE[url] + if not _is_probably_html_url(url): + return "" + + try: + async with client.stream("GET", url, headers=DEFAULT_HEADERS) as r: + ctype = (r.headers.get("content-type") or "").lower() + if ctype and not ("text/html" in ctype or "application/xhtml+xml" in ctype or ctype.startswith("text/")): + print(f"[web] skip content-type {ctype} url={url}") + return "" + + buf = bytearray() + async for chunk in r.aiter_bytes(): + if not chunk: + break + buf.extend(chunk) + if len(buf) >= MAX_BYTES_PER_PAGE: + break + + try: + text = buf.decode(r.encoding or "utf-8", errors="ignore") + except Exception: + text = buf.decode("utf-8", errors="ignore") + + _PAGE_CACHE[url] = text + return text + except Exception as e: + print(f"[web] fetch_page error url={url}: {repr(e)}") + return "" + +async def gather_pages( + client: httpx.AsyncClient, + urls: List[str], + max_pages: int = MAX_PAGES_FETCH, +) -> Dict[str, str]: + urls = list(itertools.islice(urls, 0, max_pages)) + sem = asyncio.Semaphore(FETCH_CONCURRENCY) + out: Dict[str, str] = {} + counters = {"tried": 0, "ok": 0, "too_short": 0, "empty": 0} + + async def _one(u: str): + async with sem: + counters["tried"] += 1 + html = await fetch_page(client, u) + if not html: + counters["empty"] += 1 + return + txt = clean_text(html) + if len(txt) < MIN_TEXT_LENGTH: + counters["too_short"] += 1 + return + out[u] = txt + counters["ok"] += 1 + + await asyncio.gather(*(_one(u) for u in urls)) + print(f"[web] gather_pages summary: {counters} / returned={len(out)}") + return out + +# ----- LLM helpers --------------------------------------------------------------- + +async def generate_queries(prompt: str, model: str, context_excerpt: str) -> List[str]: + """ + Single LLM call that *already* incorporates recent chat context to + resolve ellipses/pronouns. No extra rounds. + """ + ctx_block = f"\nRecent conversation (latest last):\n{context_excerpt}\n" if context_excerpt else "" + ask = f"""You are a search query generator. +Given the new user message and recent conversation, produce 3 diverse, terse web search queries that best address the user's *current* intent. +Rules: +- Resolve ambiguous references using the recent conversation when present. +- No quotes, no site: operators, no overly long queries. +- Return each query on its own line. + +User message: +{prompt} + +{ctx_block}""" + resp = await ollama_chat(model, [{"role": "user", "content": ask}]) + lines = [l.strip(" -\t") for l in resp.splitlines() if l.strip()] + uniq, seen = [], set() + for l in lines: + k = l.lower() + if k in seen: + continue + uniq.append(l) + seen.add(k) + if len(uniq) >= 3: + break + return uniq or [prompt] + +async def rerank( + prompt: str, + docs: List[Tuple[str, str]], + model: str, # kept for signature compatibility (unused here) + context_excerpt: str, + embed_model: str = "bge-m3:latest" # prefer explicit tag; we will auto-fallback +) -> List[Tuple[str, str, float]]: + """ + Embedding-based reranker (bge-m3 via Ollama) using cosine similarity. + + Robustness upgrades: + - Try embed_model; if needed fallback to the ':latest' or untagged alias. + - Normalize multiple possible response schemas from Ollama. + - If the query vector is empty, retry the query alone. + - If only the query is returned, retry passages in a second call. + - Detailed logging of counts and dimensions. + """ + import time + t0 = time.perf_counter() + + # --- optional fast cosine via NumPy --------------------------------------- + try: + import numpy as _np # optional + except Exception: + _np = None + + def _cosine(a: List[float], b: List[float]) -> float: + if _np is not None: + va = _np.asarray(a, dtype="float32") + vb = _np.asarray(b, dtype="float32") + n = min(va.size, vb.size) + if n == 0: + return 0.0 + va = va[:n]; vb = vb[:n] + na = _np.linalg.norm(va); nb = _np.linalg.norm(vb) + if na <= 0.0 or nb <= 0.0: + return 0.0 + return float(va.dot(vb) / (na * nb)) + # pure Python fallback + n = min(len(a), len(b)) + if n == 0: + return 0.0 + dot = na = nb = 0.0 + for i in range(n): + x = a[i]; y = b[i] + dot += x * y + na += x * x + nb += y * y + if na <= 0.0 or nb <= 0.0: + return 0.0 + return dot / ((na ** 0.5) * (nb ** 0.5)) + + # --- build query + passages ------------------------------------------------ + ctx_tail = (("\ncontext: " + context_excerpt[-400:]) if context_excerpt else "") + q_text = f"query: {prompt.strip()}{ctx_tail}" + + DOC_SNIPPET_CHARS = 800 # keep short & fast + passages: List[str] = [] + raw_snippets: List[str] = [] # for cache keys + for (_u, t) in docs: + snippet = t.replace("\n", " ") + if len(snippet) > DOC_SNIPPET_CHARS: + snippet = snippet[:DOC_SNIPPET_CHARS] + passages.append(f"passage: {snippet}") + raw_snippets.append(snippet) + + # --- embedding cache ------------------------------------------------------- + emb_cache: Dict[str, List[float]] = globals().get("_EMB_CACHE", {}) + try: + import hashlib + keys = [hashlib.sha1(s.encode("utf-8", errors="ignore")).hexdigest() for s in raw_snippets] + except Exception: + keys = [f"nocache-{i}" for i in range(len(raw_snippets))] + + # Prepare inputs: + # 0 -> query; >=1 -> only passages NOT present in cache + to_embed_inputs: List[str] = [q_text] + pos_to_passage_idx: Dict[int, int] = {} + p_cached = 0 + for i, (p, k) in enumerate(zip(passages, keys)): + if k in emb_cache: + p_cached += 1 + continue + pos = len(to_embed_inputs) + to_embed_inputs.append(p) + pos_to_passage_idx[pos] = i + + # --- helpers to call Ollama embeddings ------------------------------------ + async def _embed_inputs(inputs: List[str], model_name: str) -> Tuple[List[List[float]], Dict[str, Any]]: + """ + Call Ollama /api/embeddings once per input with {"model", "prompt"}. + Keeps order; returns [[]] for failures at a given index. + """ + timeout = httpx.Timeout(connect=5.0, read=30.0, write=10.0, pool=5.0) + sem = asyncio.Semaphore(8) # cap concurrency a bit + + async def _one(text: str) -> Tuple[List[float], Optional[str]]: + payload = {"model": model_name, "prompt": text} + try: + async with sem: + async with httpx.AsyncClient(timeout=timeout) as client: + r = await client.post("http://localhost:11434/api/embeddings", json=payload) + r.raise_for_status() + data = r.json() + except httpx.HTTPStatusError as e: + return [], f"http_error:{e}" + except Exception as e: + return [], f"request_error:{e}" + + # normalize common shapes + if isinstance(data, dict): + if "error" in data: + return [], f"model_error:{data.get('error')}" + if "embedding" in data and isinstance(data["embedding"], list): + return data["embedding"], None + if "data" in data and isinstance(data["data"], list) and data["data"]: + em = data["data"][0].get("embedding", []) + return (em if isinstance(em, list) else []), None + if "embeddings" in data and isinstance(data["embeddings"], list): + # some servers may return {"embeddings":[vector]} even for single prompt + emb0 = data["embeddings"][0] + if isinstance(emb0, dict): + emb0 = emb0.get("embedding", []) + return (emb0 if isinstance(emb0, list) else []), None + + return [], "parse_error" + + tasks = [_one(t) for t in inputs] + results = await asyncio.gather(*tasks, return_exceptions=False) + + embs: List[List[float]] = [] + errs: List[str] = [] + for emb, err in results: + embs.append(emb) + if err: + errs.append(err) + + meta: Dict[str, Any] = {} + if errs: + # light meta summary + meta["errors"] = {k: errs.count(k) for k in set(errs)} + return embs, meta + + # --- do the embedding calls ------------------------------------------------ + t_embed = time.perf_counter() + inputs_all = to_embed_inputs # [query] + uncached passages + embeddings, meta = await _embed_inputs(inputs_all, embed_model) + + # simple model fallback if *all* returned empty + if not any(len(e) for e in embeddings): + tried = [embed_model] + alt = (embed_model.split(":", 1)[0] if ":" in embed_model else embed_model + ":latest") + if alt != embed_model: + embeddings, meta2 = await _embed_inputs(inputs_all, alt) + if any(len(e) for e in embeddings): + print(f"[web] embed() recovered with fallback model {alt} (meta={meta or meta2})") + embed_model = alt + else: + print(f"[web] embed() FAILED (models tried={tried + [alt]}, meta={meta or meta2})") + return [(u, t, 0.0) for (u, t) in docs] + + # split q vs passages and update cache + q_emb = embeddings[0] if embeddings else [] + if not q_emb: + print("[web] embed() empty query vector — aborting rerank") + return [(u, t, 0.0) for (u, t) in docs] + + # positions >=1 correspond to passages (only those that weren’t cached) + for pos, emb_vec in enumerate(embeddings[1:], start=1): + i = pos_to_passage_idx.get(pos) + if i is None: + continue + if not keys[i].startswith("nocache-"): + emb_cache[keys[i]] = emb_vec + + # build aligned passage vectors + p_emb_list: List[List[float]] = [emb_cache.get(k, []) for k in keys] + + # logging + q_dim = len(q_emb) + p_dims = [len(v) for v in p_emb_list] + print(f"[web] embed() took {time.perf_counter() - t_embed:.3f}s " + f"(model='{embed_model}', q_dim={q_dim}, p_cached={p_cached}, " + f"p_fresh={sum(1 for d in p_dims if d>0)}, total_p={len(p_emb_list)})") + + # --- score + rank ---------------------------------------------------------- + t_score = time.perf_counter() + scored: List[Tuple[str, str, float]] = [] + for (u, t), p_emb in zip(docs, p_emb_list): + cos = _cosine(q_emb, p_emb) + score_0_100 = max(0.0, min(100.0, (cos + 1.0) * 50.0)) + scored.append((u, t, score_0_100)) + scored.sort(key=lambda x: x[2], reverse=True) + print(f"[web] cosine scoring took {time.perf_counter() - t_score:.3f}s; rerank total {time.perf_counter() - t0:.3f}s") + + return scored + +def build_enriched_prompt(user_prompt: str, ranked: List[Tuple[str, str, float]], top_k: int = 6) -> Tuple[str, List[str]]: + """ + Build an enriched prompt using only high-quality documents. + - Keep at most `top_k` docs with score >= MIN_SCORE. + - If none survive, return a telling the assistant that + a web search was performed but no good results were found. + """ + MIN_SCORE = 70.0 + + # Sort defensively (should already be sorted desc) + ranked = sorted(ranked, key=lambda x: x[2], reverse=True) + + # Strict cutoff + selected = [(u, t, sc) for (u, t, sc) in ranked if sc >= MIN_SCORE][:top_k] + + # Debug summary + try: + all_scores = [sc for (_u, _t, sc) in ranked] + sel_scores = [sc for (_u, _t, sc) in selected] + if all_scores: + print(f"[web] selection ≥{MIN_SCORE} → total={len(all_scores)}, selected={len(selected)}, " + f"top_sel={max(sel_scores) if sel_scores else 0:.1f}, " + f"min_sel={min(sel_scores) if sel_scores else 0:.1f}") + except Exception: + pass + + if not selected: + # No “good” results → still enrich with an explicit no-results message + parts = [ + "", + f"No suitable web results found (score < {MIN_SCORE}). " + "Tell the user you performed a web search but couldn't find any good results. " + "If you can answer from prior context or general knowledge, say so clearly and " + "do not fabricate citations or URLs.", + "", + ] + enriched = f"{user_prompt}\n\n" + "\n".join(parts) + return enriched, [] + + # Build normal context + sources = [u for (u, _, _) in selected] + parts = [""] + for i, (u, t, _) in enumerate(selected, 1): + snippet = t.strip().replace("\n", " ") + if len(snippet) > 1200: + snippet = snippet[:1200] + parts.append(f"[{i}] {snippet} (source: {u})") + parts.append("") + parts.append("\nAnswer the user using the context when relevant.\nMake your response long enough to reflect all interesting information found. No duplicate information.\nUnless related to the topic, don't mention the source from the web search.") + + enriched = f"{user_prompt}\n\n" + "\n".join(parts) + return enriched, sources + +# ----- Public API ---------------------------------------------------------------- + +async def enrich_prompt( + user_prompt: str, + model: str, + messages: Optional[List[Dict[str, str]]] = None, + *, + searx_url: Optional[str] = None, + engines: Optional[List[str]] = None, +) -> Tuple[str, List[str]]: + import time # local import to avoid touching global imports + start_all = time.perf_counter() + + def _no_results_enriched(reason: str, queries: Optional[List[str]] = None) -> Tuple[str, List[str]]: + parts = [""] + parts.append( + "Web search was performed but no suitable results were found. " + "Tell the user you searched the web but couldn't find any good results. " + "If you can still answer from prior context or general knowledge, say so clearly, " + "and do not fabricate citations or URLs." + ) + if queries: + try: + qshow = ", ".join(queries[:3]) + parts.append(f"Queries attempted: {qshow}") + except Exception: + pass + if engines: + try: + eshow = ", ".join(engines) + parts.append(f"Engines selected: {eshow}") + except Exception: + pass + parts.append(f"(reason: {reason})") + parts.append("") + return f"{user_prompt}\n\n" + "\n".join(parts), [] + + context_excerpt = render_recent_context(messages, char_limit=1200) + + # 1) queries + try: + t0 = time.perf_counter() + queries = await generate_queries(user_prompt, model=model, context_excerpt=context_excerpt) + print(f"[web] queries: {queries} (took {time.perf_counter() - t0:.3f}s)") + except Exception: + print("[web] ERROR in generate_queries:\n" + traceback.format_exc()) + print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s") + return _no_results_enriched("query_generation_failed") + + # 2) search + fetch + try: + print("[web] opening httpx client") + async with httpx.AsyncClient( + headers=DEFAULT_HEADERS, + follow_redirects=True, + http2=True, + limits=HTTP_LIMITS, + timeout=HTTP_TIMEOUT, + ) as client: + t1 = time.perf_counter() + print("[web] calling searx_search_many() …") + all_urls = await searx_search_many( + client, + queries, + per_query_max=6, + overall_max=MAX_PAGES_FETCH * 2, + searx_url=searx_url, + engines=engines, + ) + print(f"[web] searx_search_many() -> {len(all_urls)} urls (took {time.perf_counter() - t1:.3f}s)") + + if not all_urls: + print("[web] no URLs from SearX — emitting no-results context") + print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s") + return _no_results_enriched("no_urls", queries) + + t2 = time.perf_counter() + print("[web] calling gather_pages() …") + pages = await gather_pages(client, all_urls, max_pages=MAX_PAGES_FETCH) + print(f"[web] gather_pages() -> {len(pages)} pages (cap={MAX_PAGES_FETCH}, took {time.perf_counter() - t2:.3f}s)") + except Exception: + print("[web] ERROR during search/fetch:\n" + traceback.format_exc()) + print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s") + return _no_results_enriched("search_fetch_error") + + # 3) docs + try: + t3 = time.perf_counter() + docs = [(u, pages[u]) for u in all_urls if u in pages] + print(f"[web] docs for rerank: {len(docs)} (built in {time.perf_counter() - t3:.3f}s)") + if not docs: + print("[web] no docs after fetch/filter — emitting no-results context") + print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s") + return _no_results_enriched("no_docs_after_fetch", queries) + except Exception: + print("[web] ERROR building docs list:\n" + traceback.format_exc()) + print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s") + return _no_results_enriched("docs_build_failed", queries) + + # 4) rerank + try: + t4 = time.perf_counter() + print("[web] calling rerank() …") + ranked = await rerank(user_prompt, docs, model=model, context_excerpt=context_excerpt) + print(f"[web] rerank() -> {len(ranked)} scored docs (took {time.perf_counter() - t4:.3f}s)") + try: + scores_only = [sc for (_u, _t, sc) in ranked] + if scores_only: + scores_sorted = sorted(scores_only) + n = len(scores_sorted) + p50 = scores_sorted[n // 2] + p75 = scores_sorted[int(n * 0.75) - 1 if n > 3 else n - 1] + print(f"[web] score stats → min={scores_sorted[0]:.1f}, p50={p50:.1f}, p75={p75:.1f}, max={scores_sorted[-1]:.1f}") + print(f"[web] top scores → {[round(s,1) for s in scores_only[:min(6,len(scores_only))]]}") + except Exception: + pass + except Exception: + print("[web] ERROR in rerank:\n" + traceback.format_exc()) + print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s") + return _no_results_enriched("rerank_failed", queries) + + # 5) build prompt + try: + t5 = time.perf_counter() + enriched = build_enriched_prompt(user_prompt, ranked, top_k=6) + print(f"[web] build_enriched_prompt() done (took {time.perf_counter() - t5:.3f}s)") + print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s") + return enriched + except Exception: + print("[web] ERROR in build_enriched_prompt:\n" + traceback.format_exc()) + print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s") + return _no_results_enriched("build_enriched_failed", queries) \ No newline at end of file diff --git a/src/App.jsx b/src/App.jsx index 5d77b9d..8592ed3 100644 --- a/src/App.jsx +++ b/src/App.jsx @@ -4,6 +4,7 @@ import { flushSync } from 'react-dom'; import TextareaAutosize from 'react-textarea-autosize'; import GeneralSettings from './GeneralSettings' import InterfaceSettings from './InterfaceSettings' +import WebsearchSettings from './WebsearchSettings' import { markdownToHTML } from './markdown'; // Extract or block (first occurrence) and return { think, answer } function splitThinkBlocks(text) { @@ -43,7 +44,7 @@ function splitThinkBlocks(text) { } // Renders assistant message with a collapsible "Thoughts" block (if present) -function AssistantMessageContent({ content, streamOutput }) { +function AssistantMessageContent({ content, streamOutput, sources }) { const { think, answer } = splitThinkBlocks(content || ''); const [open, setOpen] = React.useState(false); const showThinkButton = !!think; @@ -76,12 +77,30 @@ function AssistantMessageContent({ content, streamOutput }) { className="msg-content" dangerouslySetInnerHTML={{ __html: markdownToHTML(answer || content || '') }} /> + {Array.isArray(sources) && sources.length > 0 && ( +
+ {sources.map((u, i) => { + let label = u; + try { + const host = new URL(u).hostname || u; + label = host.replace(/^www\./i, ''); + } catch {} + return ( + + {label} + + ); + })} +
+ )} ); } const API_URL_KEY = 'ollamaApiUrl'; const COLOR_SCHEME_KEY = 'colorScheme'; +const WEBSEARCH_URL_KEY = 'websearch.searxUrl'; +const WEBSEARCH_ENGINES_KEY = 'websearch.engines'; // Initial API value will be set by useEffect after settings are loaded let API = import.meta.env.VITE_API_URL ?? 'http://127.0.0.1:8000'; @@ -103,6 +122,24 @@ export default function App() { const [ollamaApiUrl, setOllamaApiUrl] = useState(API); // State for Ollama API URL const [colorScheme, setColorScheme] = useState('Default'); // State for color scheme const [streamOutput, setStreamOutput] = useState(false); + const [searxUrl, setSearxUrl] = useState(localStorage.getItem(WEBSEARCH_URL_KEY) || 'http://localhost:8888'); + const [searxEngines, setSearxEngines] = useState(() => { + try { + const raw = localStorage.getItem(WEBSEARCH_ENGINES_KEY); + if (raw) return JSON.parse(raw); + } catch {} + return ["duckduckgo","bing","wikipedia","github","stack_overflow"]; + }); + useEffect(() => { + localStorage.setItem(WEBSEARCH_URL_KEY, searxUrl || ''); + }, [searxUrl]); + + useEffect(() => { + try { + localStorage.setItem(WEBSEARCH_ENGINES_KEY, JSON.stringify(searxEngines || [])); + } catch {} + }, [searxEngines]); + const [webSearchEnabled, setWebSearchEnabled] = useState(false); const [isSending, setIsSending] = useState(false); const [loading, setLoading] = useState(true); // Loading state for initial session fetch const [unreadSessions, setUnreadSessions] = useState([]); // Track unread messages @@ -210,126 +247,170 @@ export default function App() { } // Continue conversation from the edited message - await regenerateFromIndex(index); + await regenerateFromIndex(index, next); } - async function regenerateFromIndex(index) { - const sessionId = activeSessionId; - if (!sessionId || typeof index !== 'number') return; +async function regenerateFromIndex(index, overrideUserText = null) { + const sessionId = activeSessionId; + if (!sessionId || typeof index !== 'number') return; - const msgs = (chatSessions.find(s => s.session_id === sessionId)?.messages) || []; - let lastUserIdx = index; - for (let i = index; i >= 0; i--) { - if (msgs[i]?.role === 'user') { lastUserIdx = i; break; } + const msgs = (chatSessions.find(s => s.session_id === sessionId)?.messages) || []; + let lastUserIdx = index; + for (let i = index; i >= 0; i--) { + if (msgs[i]?.role === 'user') { lastUserIdx = i; break; } + } + + // Prune UI to lastUserIdx + setChatSessions(prev => + prev.map(s => s.session_id === sessionId + ? { ...s, messages: (s.messages || []).slice(0, lastUserIdx + 1) } + : s + ) + ); + + setIsSending(true); + + // --- optional websearch enrichment for regenerate --- + let enrichedPrompt = null; + let citationSources = []; + if (webSearchEnabled) { + try { + // Use the freshly edited user text when provided + const promptText = (overrideUserText != null ? overrideUserText : (msgs[lastUserIdx]?.content || '')); + + // Build compact recent history and overwrite the last user turn with promptText + const historyForSearch = msgs + .slice(Math.max(0, lastUserIdx - 7), lastUserIdx + 1) + .map(m => ({ role: m.role, content: m.content || '' })); + if (historyForSearch.length > 0) { + historyForSearch[historyForSearch.length - 1] = { role: 'user', content: promptText }; + } + + const resp = await fetch(`${ollamaApiUrl}/websearch`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + prompt: promptText, + model, + messages: historyForSearch, + history_limit: 8, + searx_url: searxUrl || null, + engines: Array.isArray(searxEngines) ? searxEngines : null, + }) + }); + const data = await resp.json(); + if (data && typeof data.enriched_prompt === 'string') { + enrichedPrompt = data.enriched_prompt; + citationSources = Array.isArray(data.sources) ? data.sources : []; + } + } catch (e) { + console.warn('web search enrichment (regenerate) failed', e); } + } - // Prune UI to lastUserIdx + if (streamOutput) { + const assistantMsgId = `msg-${Date.now()}-${Math.random()}`; + // add placeholder assistant message (keep sources on the placeholder) setChatSessions(prev => prev.map(s => s.session_id === sessionId - ? { ...s, messages: (s.messages || []).slice(0, lastUserIdx + 1) } + ? { ...s, messages: [...(s.messages || []), { id: assistantMsgId, role: 'assistant', content: '', sources: citationSources }] } : s ) ); - setIsSending(true); + try { + const res = await fetch(`${ollamaApiUrl}/sessions/${sessionId}/regenerate`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + index, + model, + stream: true, + enriched_message: enrichedPrompt, + sources: citationSources || [] + }) + }); + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let full = ''; + let unreadMarked = false; - if (streamOutput) { - const assistantMsgId = `msg-${Date.now()}-${Math.random()}`; - // add placeholder assistant message - setChatSessions(prev => - prev.map(s => s.session_id === sessionId - ? { ...s, messages: [...(s.messages || []), { id: assistantMsgId, role: 'assistant', content: '' }] } - : s - ) - ); + while (true) { + const { value, done } = await reader.read(); + if (done) break; - try { - const res = await fetch(`${ollamaApiUrl}/sessions/${sessionId}/regenerate`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ index, model, stream: true }) - }); - const reader = res.body.getReader(); - const decoder = new TextDecoder(); - let full = ''; - let unreadMarked = false; // NEW + const chunk = decoder.decode(value, { stream: true }); + full += chunk; - while (true) { - const { value, done } = await reader.read(); - if (done) break; - - const chunk = decoder.decode(value, { stream: true }); - full += chunk; - - // Update the growing assistant message - setChatSessions(prev => - prev.map(s => s.session_id === sessionId - ? { ...s, messages: (s.messages || []).map(m => m.id === assistantMsgId ? { ...m, content: full } : m) } - : s - ) - ); - - // If this session is not active while streaming, mark unread once - if (!unreadMarked && activeSessionIdRef.current !== sessionId) { - unreadMarked = true; - setPendingScrollToLastUser(prev => ({ ...prev, [sessionId]: assistantMsgId })); - setUnreadSessions(prev => [...new Set([...prev, sessionId])]); - } - } - - // On stream end: if user is in another chat, ensure unread + guided scroll are set - if (activeSessionIdRef.current !== sessionId) { - setPendingScrollToLastUser(prev => ({ ...prev, [sessionId]: assistantMsgId })); - setUnreadSessions(prev => [...new Set([...prev, sessionId])]); - } else { - // If user stayed here and didn't scroll up, align the finished answer nicely - if (!userScrolledUpRef.current[sessionId]) { - requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', sessionId)); - } else { - // show the tip if they had scrolled away - setNewMsgTip(prev => ({ ...prev, [sessionId]: assistantMsgId })); - } - } - } catch (e) { - console.error(e); - } finally { - setIsSending(false); - } - } else { - try { - const res = await fetch(`${ollamaApiUrl}/sessions/${sessionId}/regenerate`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ index, model, stream: false }) - }); - const data = await res.json(); - const assistantMsgId = `msg-${Date.now()}`; + // Update the growing assistant message (sources remain intact) setChatSessions(prev => prev.map(s => s.session_id === sessionId - ? { ...s, messages: [...(s.messages || []), { role: 'assistant', content: data.reply, id: assistantMsgId }] } + ? { ...s, messages: (s.messages || []).map(m => m.id === assistantMsgId ? { ...m, content: full } : m) } : s ) ); - if (activeSessionIdRef.current !== sessionId) { - // reply landed in background -> mark unread + remember where to scroll + if (!unreadMarked && activeSessionIdRef.current !== sessionId) { + unreadMarked = true; setPendingScrollToLastUser(prev => ({ ...prev, [sessionId]: assistantMsgId })); setUnreadSessions(prev => [...new Set([...prev, sessionId])]); - } else { - // same chat -> align unless the user scrolled away - if (!userScrolledUpRef.current[sessionId]) { - requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', sessionId)); - } else { - setNewMsgTip(prev => ({ ...prev, [sessionId]: assistantMsgId })); - } } - } catch (e) { - console.error(e); - } finally { - setIsSending(false); } + + if (activeSessionIdRef.current !== sessionId) { + setPendingScrollToLastUser(prev => ({ ...prev, [sessionId]: assistantMsgId })); + setUnreadSessions(prev => [...new Set([...prev, sessionId])]); + } else { + if (!userScrolledUpRef.current[sessionId]) { + requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', sessionId)); + } else { + setNewMsgTip(prev => ({ ...prev, [sessionId]: assistantMsgId })); + } + } + } catch (e) { + console.error(e); + } finally { + setIsSending(false); + } + } else { + try { + const res = await fetch(`${ollamaApiUrl}/sessions/${sessionId}/regenerate`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + index, + model, + stream: false, + enriched_message: enrichedPrompt, + sources: citationSources || [] + }) + }); + const data = await res.json(); + const assistantMsgId = `msg-${Date.now()}`; + setChatSessions(prev => + prev.map(s => s.session_id === sessionId + ? { ...s, messages: [...(s.messages || []), { role: 'assistant', content: data.reply, id: assistantMsgId, sources: citationSources }] } + : s + ) + ); + + if (activeSessionIdRef.current !== sessionId) { + setPendingScrollToLastUser(prev => ({ ...prev, [sessionId]: assistantMsgId })); + setUnreadSessions(prev => [...new Set([...prev, sessionId])]); + } else { + if (!userScrolledUpRef.current[sessionId]) { + requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', sessionId)); + } else { + setNewMsgTip(prev => ({ ...prev, [sessionId]: assistantMsgId })); + } + } + } catch (e) { + console.error(e); + } finally { + setIsSending(false); } } +} // Persist userScrolledUp state per session + live ref for closures (streaming) @@ -739,108 +820,115 @@ export default function App() { }; - async function sendMessage() { - if (!input.trim() || !model) return; +async function sendMessage() { + if (!input.trim() || !model) return; - let targetSessionId = activeSessionId; - let isNewChat = false; - if (!targetSessionId) { - const newSession = await createNewChat(); - await new Promise(resolve => setTimeout(resolve, 200)); - targetSessionId = newSession.session_id; - isNewChat = true; - } else { - const currentSession = chatSessions.find(s => s.session_id === targetSessionId); - isNewChat = currentSession && currentSession.name === "New Chat" && currentSession.messages.length === 0; - } - - const userMsg = { role: 'user', content: input.trim(), id: `msg-${Date.now()}-${Math.random()}` }; - justSentMessage.current = true; - lastSentSessionRef.current = targetSessionId; - setUserScrolledUp(targetSessionId, false); + let targetSessionId = activeSessionId; + let isNewChat = false; + if (!targetSessionId) { + const newSession = await createNewChat(); + await new Promise(resolve => setTimeout(resolve, 200)); + targetSessionId = newSession.session_id; + isNewChat = true; + } else { + const currentSession = chatSessions.find(s => s.session_id === targetSessionId); + isNewChat = currentSession && currentSession.name === "New Chat" && currentSession.messages.length === 0; + } - // Cancel any pending restore for the active session (we're about to control the scroll) - if (activeSessionIdRef.current === targetSessionId) { - restoredForRef.current = activeSessionIdRef.current; // mark as already restored + const userMsg = { role: 'user', content: input.trim(), id: `msg-${Date.now()}-${Math.random()}` }; + justSentMessage.current = true; + lastSentSessionRef.current = targetSessionId; + setUserScrolledUp(targetSessionId, false); + + if (activeSessionIdRef.current === targetSessionId) { + restoredForRef.current = activeSessionIdRef.current; + } + + flushSync(() => { + setChatSessions(prevSessions => + prevSessions.map(session => + session.session_id === targetSessionId + ? { ...session, messages: [...(session.messages || []), userMsg] } + : session + ) + ); + setInput(''); + }); + requestAnimationFrame(() => scrollToBottom('auto', targetSessionId)); + + setIsSending(true); + try { + // Build compact recent history for context-aware websearch (resource-friendly). + // We only send the last 8 turns by default, including assistant replies, + // and we also append the *current* user message (same content as `userMsg`). + let historyForSearch = []; + try { + const existing = (chatSessions.find(s => s.session_id === targetSessionId)?.messages) || []; + const lastFew = existing.slice(-8).map(m => ({ role: m.role, content: m.content || '' })); + historyForSearch = [...lastFew, { role: 'user', content: userMsg.content }]; + } catch {} + + // Decide on enrichment using the toggle + let enrichedPrompt = userMsg.content; + let citationSources = []; + if (webSearchEnabled) { + try { + const resp = await fetch(`${ollamaApiUrl}/websearch`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + prompt: userMsg.content, + model, + messages: historyForSearch, + history_limit: 8, + searx_url: searxUrl || null, + engines: Array.isArray(searxEngines) ? searxEngines : null, + }) + }); + const data = await resp.json(); + if (data && typeof data.enriched_prompt === 'string') { + enrichedPrompt = data.enriched_prompt; + citationSources = Array.isArray(data.sources) ? data.sources : []; + } + } catch (e) { + console.warn('web search enrichment failed', e); + } } - // Optimistic add and flush DOM, then scroll to bottom - flushSync(() => { + if (streamOutput) { + const assistantMsgId = `msg-${Date.now()}-${Math.random()}`; + const assistantMsg = { role: 'assistant', content: '', id: assistantMsgId, sources: citationSources }; setChatSessions(prevSessions => prevSessions.map(session => session.session_id === targetSessionId - ? { ...session, messages: [...(session.messages || []), userMsg] } + ? { ...session, messages: [...(session.messages || []), assistantMsg] } : session ) ); - setInput(''); - }); - requestAnimationFrame(() => scrollToBottom('auto', targetSessionId)); - setIsSending(true); - try { - if (streamOutput) { - const assistantMsgId = `msg-${Date.now()}-${Math.random()}`; - const assistantMsg = { role: 'assistant', content: '', id: assistantMsgId }; - setChatSessions(prevSessions => - prevSessions.map(session => - session.session_id === targetSessionId - ? { ...session, messages: [...(session.messages || []), assistantMsg] } - : session - ) - ); + (async () => { + try { + const res = await fetch(`${ollamaApiUrl}/chat`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + session_id: targetSessionId, + model, + message: userMsg.content, + enriched_message: webSearchEnabled ? enrichedPrompt : null, + stream: true, + sources: citationSources || [] + }) + }); - (async () => { - try { - const res = await fetch(`${ollamaApiUrl}/chat`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - session_id: targetSessionId, - model, - message: userMsg.content, - stream: true - }) - }); + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + let fullReply = ''; + let pendingMarked = false; - const reader = res.body.getReader(); - const decoder = new TextDecoder(); - let fullReply = ''; - let pendingMarked = false; - - while (true) { - const { value, done } = await reader.read(); - if (done) { - setChatSessions(prevSessions => - prevSessions.map(session => - session.session_id === targetSessionId - ? { - ...session, - messages: session.messages.map(m => - m.id === assistantMsgId ? { ...m, content: fullReply } : m - ) - } - : session - ) - ); - - if (activeSessionIdRef.current === targetSessionId) { - if (!userScrolledUpRef.current[targetSessionId]) { - // user stayed at bottom -> reveal the message immediately - requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', targetSessionId)); - } else { - // user scrolled away while it was generating -> show tip instead of auto-scroll - setNewMsgTip(prev => ({ ...prev, [targetSessionId]: assistantMsgId })); - } - } else { - setPendingScrollToLastUser(prev => ({ ...prev, [targetSessionId]: assistantMsgId })); - setUnreadSessions(prev => [...new Set([...prev, targetSessionId])]); - } - - break; - } - const chunk = decoder.decode(value, { stream: true }); - fullReply += chunk; + while (true) { + const { value, done } = await reader.read(); + if (done) { setChatSessions(prevSessions => prevSessions.map(session => session.session_id === targetSessionId @@ -853,113 +941,154 @@ export default function App() { : session ) ); - // Keep sticky-bottom *only* when streaming in the active chat and user is at/near bottom. - // This restores the old "push down while generating" behavior without fighting user scrolls. - if ( - activeSessionIdRef.current === targetSessionId && - !userScrolledUpRef.current[targetSessionId] - ) { - // use 'auto' so it stays snappy during streaming - scrollToBottom('auto', targetSessionId); - } - // If streaming in a background chat, prepare a one-time guided scroll - if (activeSessionIdRef.current !== targetSessionId && !pendingMarked) { + + if (activeSessionIdRef.current === targetSessionId) { + if (!userScrolledUpRef.current[targetSessionId]) { + requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', targetSessionId)); + } else { + setNewMsgTip(prev => ({ ...prev, [targetSessionId]: assistantMsgId })); + } + } else { setPendingScrollToLastUser(prev => ({ ...prev, [targetSessionId]: assistantMsgId })); - pendingMarked = true; + setUnreadSessions(prev => [...new Set([...prev, targetSessionId])]); } + + break; } - } catch (e) { - console.error("Failed to send message:", e); - const errorMsg = { role: 'assistant', content: 'Error: ' + e.message, id: `msg-${Date.now()}-${Math.random()}` }; + const chunk = decoder.decode(value, { stream: true }); + fullReply += chunk; setChatSessions(prevSessions => prevSessions.map(session => session.session_id === targetSessionId - ? { ...session, messages: [...session.messages.slice(0, -1), errorMsg] } + ? { + ...session, + messages: session.messages.map(m => + m.id === assistantMsgId ? { ...m, content: fullReply } : m + ) + } : session ) ); - } finally { - setIsSending(false); - } - })(); - } else { - const res = await fetch(`${ollamaApiUrl}/chat`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - session_id: targetSessionId, - model, - message: userMsg.content, - stream: false - }) - }); - const data = await res.json(); - const assistantMsgId = `msg-${Date.now()}`; - const assistantMsg = { role: 'assistant', content: data.reply, id: assistantMsgId }; - setChatSessions(prevSessions => - prevSessions.map(session => - session.session_id === targetSessionId - ? { ...session, messages: [...(session.messages || []), assistantMsg] } - : session - ) - ); - - // For non-stream: align new ASSISTANT message to top, unless user scrolled away - if (assistantMsgId) { - if (activeSessionIdRef.current === targetSessionId) { - if (!userScrolledUpRef.current[targetSessionId]) { - requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', targetSessionId)); - } else { - // <<< show the tip if user scrolled away while waiting >>> - setNewMsgTip(prev => ({ ...prev, [targetSessionId]: assistantMsgId })); + if ( + activeSessionIdRef.current === targetSessionId && + !userScrolledUpRef.current[targetSessionId] + ) { + scrollToBottom('auto', targetSessionId); + } + if (activeSessionIdRef.current !== targetSessionId && !pendingMarked) { + setPendingScrollToLastUser(prev => ({ ...prev, [targetSessionId]: assistantMsgId })); + pendingMarked = true; } - } else { - setPendingScrollToLastUser(prev => ({ ...prev, [targetSessionId]: assistantMsgId })); } - } - setIsSending(false); - } - - if (activeSessionIdRef.current !== targetSessionId) { - setUnreadSessions(prev => [...new Set([...prev, targetSessionId])]); - } - - if (isNewChat) { - fetch(`${ollamaApiUrl}/generate-title`, { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - session_id: targetSessionId, - message: userMsg.content, - model: model - }) - }) - .then(r => r.json()) - .then(data => { - const sanitizedTitle = data.title.replace(/[\s\S]*?<\/think(?:ing)?>/i, '').trim(); + } catch (e) { + console.error('Failed to send message:', e); + const errorMsg = { + role: 'assistant', + content: 'Error: ' + e.message, + id: `msg-${Date.now()}-${Math.random()}`, + sources: citationSources + }; setChatSessions(prevSessions => prevSessions.map(session => - session.session_id === targetSessionId ? { ...session, name: sanitizedTitle } : session + session.session_id === targetSessionId + ? { ...session, messages: [...session.messages.slice(0, -1), errorMsg] } + : session ) ); - }); - } - } catch (e) { - console.error("Failed to send message:", e); - const errorMsg = { role: 'assistant', content: 'Error: ' + e.message, id: `msg-${Date.now()}-${Math.random()}` }; + } finally { + setIsSending(false); + } + })(); + } else { + const res = await fetch(`${ollamaApiUrl}/chat`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + session_id: targetSessionId, + model, + message: userMsg.content, + enriched_message: webSearchEnabled ? enrichedPrompt : null, + stream: false, + sources: citationSources || [] + }) + }); + const data = await res.json(); + const assistantMsgId = `msg-${Date.now()}`; + const assistantMsg = { + role: 'assistant', + content: data.reply, + id: assistantMsgId, + sources: citationSources + }; + setChatSessions(prevSessions => prevSessions.map(session => session.session_id === targetSessionId - ? { ...session, messages: [...session.messages, errorMsg] } + ? { ...session, messages: [...(session.messages || []), assistantMsg] } : session ) ); + + if (assistantMsgId) { + if (activeSessionIdRef.current === targetSessionId) { + if (!userScrolledUpRef.current[targetSessionId]) { + requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', targetSessionId)); + } else { + setNewMsgTip(prev => ({ ...prev, [targetSessionId]: assistantMsgId })); + } + } else { + setPendingScrollToLastUser(prev => ({ ...prev, [targetSessionId]: assistantMsgId })); + } + } setIsSending(false); } - } - async function createNewChat() { + if (activeSessionIdRef.current !== targetSessionId) { + setUnreadSessions(prev => [...new Set([...prev, targetSessionId])]); + } + + if (isNewChat) { + fetch(`${ollamaApiUrl}/generate-title`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + session_id: targetSessionId, + message: userMsg.content, + model: model + }) + }) + .then(r => r.json()) + .then(data => { + const sanitizedTitle = data.title.replace(/[\s\S]*?<\/think(?:ing)?>/i, '').trim(); + setChatSessions(prevSessions => + prevSessions.map(session => + session.session_id === targetSessionId ? { ...session, name: sanitizedTitle } : session + ) + ); + }); + } + } catch (e) { + console.error("Failed to send message:", e); + const errorMsg = { role: 'assistant', content: 'Error: ' + e.message, id: `msg-${Date.now()}-${Math.random()}` }; + setChatSessions(prevSessions => + prevSessions.map(session => + session.session_id === targetSessionId + ? { ...session, messages: [...session.messages, errorMsg] } + : session + ) + ); + setIsSending(false); + } +} + + + +function toggleWebSearch() { + setWebSearchEnabled(prev => !prev); +} + +async function createNewChat() { const newSessionId = 'sess-' + Math.random().toString(36).slice(2) + Date.now().toString(36); const res = await fetch(`${ollamaApiUrl}/sessions`, { method: 'POST', @@ -1155,6 +1284,12 @@ export default function App() { > Interface +
setActiveSettingsSubmenu('Websearch')} + > + Websearch +
)} @@ -1192,7 +1327,7 @@ export default function App() { > {m.role === 'assistant' ? (
- + {!isSending && (
@@ -1321,6 +1459,14 @@ export default function App() { /> )} {activeSettingsSubmenu === 'Interface' && } + {activeSettingsSubmenu === 'Websearch' && ( + + )} )}
diff --git a/src/WebsearchSettings.jsx b/src/WebsearchSettings.jsx new file mode 100644 index 0000000..6060cec --- /dev/null +++ b/src/WebsearchSettings.jsx @@ -0,0 +1,62 @@ +// src/WebsearchSettings.jsx +import React, { useEffect, useMemo, useState } from 'react'; + +export default function WebsearchSettings({ + searxUrl, + setSearxUrl, + engines, + setEngines, +}) { + const KNOWN_ENGINES = useMemo( + () => ["google","bing","yahoo","duckduckgo","brave","github","stackoverflow","reddit","arxiv"], + [] + ); + + const [custom, setCustom] = useState(""); + + const toggleEngine = (name) => { + const set = new Set(engines || []); + if (set.has(name)) set.delete(name); else set.add(name); + setEngines(Array.from(set)); + }; + + const addCustom = () => { + const name = custom.trim(); + if (!name) return; + const set = new Set(engines || []); + set.add(name); + setEngines(Array.from(set)); + setCustom(""); + }; + +return ( +
+
+

SearXNG URL

+ setSearxUrl(e.target.value)} + placeholder="e.g., http://localhost:8888" + /> +
+ +
+

Search Engines

+
+ {KNOWN_ENGINES.map(name => ( + + ))} +
+
+
+); +} \ No newline at end of file diff --git a/src/styles.css b/src/styles.css index ca42cf0..db39da1 100644 --- a/src/styles.css +++ b/src/styles.css @@ -889,3 +889,40 @@ input:checked + .slider:before { outline: none; box-shadow: none; } + + +/* Web search toggle */ +.websearch-toggle { + display: inline-flex; + align-items: center; + justify-content: center; + width: 38px; + height: 38px; + border-radius: 10px; + border: 1px solid var(--border); + background: var(--input-bg); + cursor: pointer; +} +.websearch-toggle svg { width: 16px; height: 16px; } +.websearch-toggle.active { outline: 2px solid var(--accent); } +.msg-sources { margin-top: 8px; font-size: 12px; color: var(--muted); } +.msg-sources a { color: var(--accent); text-decoration: none; margin-right: 8px; } +.msg-sources a:hover { text-decoration: underline; } +.msg-sources.chips { + display: flex; + flex-wrap: wrap; + margin: 0.5rem 0 0.5rem 0; +} + +.msg-sources.chips .chip { + display: inline-flex; + align-items: center; + padding: .25rem .6rem; + border-radius: 9999px; + border: 1px solid var(--border); + text-decoration: none; + font-size: 0.85rem; + line-height: 1; + white-space: nowrap; + margin-top: 0.5rem; +}