Add web search enrichment feature with source persistence and UI integration

Introduce optional web search enrichment flow for chat and regenerate requests. New /websearch endpoint calls enrich_prompt via SearXNG and returns enriched_prompt + citation sources.

DB: add sources_json column to chat_messages via ensure_sources_column migration helper.

Backend: persist sources_json for assistant replies (streaming and non-streaming); extend ChatRequest/RegenerateRequest to accept enriched_message and sources; history endpoint returns sources.

Frontend: add toggle for web search, settings for SearXNG URL + engines, and optional enrichment calls in sendMessage/regenerate. Render citation sources as rounded chips labeled with base domain under assistant replies.

Dependencies: add beautifulsoup4, httpx[http2], numpy for enrichment pipeline.
This commit is contained in:
2025-08-27 04:27:18 +02:00
parent e262e4b4fe
commit 728c7763e2
9 changed files with 1296 additions and 300 deletions

View File

@@ -1,6 +1,20 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, DeclarativeBase
from sqlalchemy import text
"""
Database utilities and configuration. This module defines the SQLAlchemy
engine, session factory and base class for models. It also contains a
lightweight migration helper used to evolve the schema over time. The
`ensure_sources_column` helper adds a new `sources_json` column to the
`chat_messages` table if it does not already exist. This is required
for persisting citation sources alongside assistant messages.
The migration uses SQLite's `ALTER TABLE` syntax and therefore should
only run once on startup. It is safe to call repeatedly: when the
column already exists, the function will simply noop.
"""
DATABASE_URL = "sqlite:///./backend/app.db"
@@ -12,3 +26,13 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
class Base(DeclarativeBase):
pass
def ensure_sources_column(engine):
try:
with engine.connect() as conn:
cols = [row[1] for row in conn.execute(text("PRAGMA table_info(chat_messages)"))]
if "sources_json" not in cols:
conn.execute(text("ALTER TABLE chat_messages ADD COLUMN sources_json TEXT DEFAULT '[]'"))
except Exception as e:
print("[db] ensure_sources_column error:", e)

View File

@@ -3,14 +3,17 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from typing import List
import re # Import the regex module
import html # Import the html module for unescaping
import re
import html
import json
from . import models, schemas
from .database import Base, engine, SessionLocal
from .database import Base, engine, SessionLocal, ensure_sources_column
from .ollama_client import list_models as ollama_list, chat as ollama_chat, chat_stream as ollama_chat_stream
from .websearch import enrich_prompt
# Create tables
# Create tables + ensure migration
Base.metadata.create_all(bind=engine)
ensure_sources_column(engine)
app = FastAPI(title="LLM Desktop Backend", version="0.1.0" )
@@ -60,8 +63,21 @@ def history(session_id: str, db: Session = Depends(get_db)):
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
return {"messages": []}
rows = db.query(models.ChatMessage) .filter(models.ChatMessage.session_pk == session.id) .order_by(models.ChatMessage.created_at.asc()) .all()
msgs = [{"role": r.role, "content": r.content} for r in rows]
rows = (
db.query(models.ChatMessage)
.filter(models.ChatMessage.session_pk == session.id)
.order_by(models.ChatMessage.created_at.asc())
.all()
)
msgs = []
for r in rows:
sources = []
try:
if getattr(r, "sources_json", None):
sources = json.loads(r.sources_json or "[]")
except Exception:
sources = []
msgs.append({"role": r.role, "content": r.content, "sources": sources})
return {"messages": msgs}
@app.post("/chat")
@@ -74,16 +90,31 @@ async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)):
db.commit()
db.refresh(session)
# Save user message
# Store the BASE user prompt
user_row = models.ChatMessage(session_pk=session.id, role='user', content=req.message)
db.add(user_row)
db.commit()
# Build minimal conversation context (last 20 messages)
last_msgs = db.query(models.ChatMessage) .filter(models.ChatMessage.session_pk == session.id) .order_by(models.ChatMessage.created_at.asc()) .all()[-20:]
# Build minimal context (last 20)
last_msgs = (
db.query(models.ChatMessage)
.filter(models.ChatMessage.session_pk == session.id)
.order_by(models.ChatMessage.created_at.asc())
.all()[-20:]
)
messages = [{"role": m.role, "content": m.content} for m in last_msgs]
# Patch last user with enriched_message only for LLM call
if req.enriched_message:
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == "user":
messages = messages.copy()
messages[i] = {**messages[i], "content": req.enriched_message}
break
# Sources to persist with the assistant reply
sources = req.sources or []
if req.stream:
async def stream_generator():
full_reply = ""
@@ -92,14 +123,16 @@ async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)):
full_reply += chunk
yield chunk
except Exception as e:
# How to handle errors in a stream? Could yield an error message.
yield f"Ollama error: {e}"
# Save full reply after stream is complete
as_row = models.ChatMessage(session_pk=session.id, role='assistant', content=full_reply)
# Persist assistant reply (include sources_json)
as_row = models.ChatMessage(
session_pk=session.id, role='assistant', content=full_reply,
sources_json=json.dumps(sources or [])
)
db.add(as_row)
db.commit()
return StreamingResponse(stream_generator(), media_type="text/plain")
else:
try:
@@ -107,11 +140,12 @@ async def chat(req: schemas.ChatRequest, db: Session = Depends(get_db)):
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama error: {e}")
# Save assistant reply
as_row = models.ChatMessage(session_pk=session.id, role='assistant', content=reply)
as_row = models.ChatMessage(
session_pk=session.id, role='assistant', content=reply,
sources_json=json.dumps(sources or [])
)
db.add(as_row)
db.commit()
return {"reply": reply}
@app.post("/generate-title", response_model=schemas.GenerateTitleResponse)
@@ -200,13 +234,10 @@ def update_user_message(session_id: str, index: int, req: schemas.EditMessageReq
# ADD or REPLACE this whole function
@app.post("/sessions/{session_id}/regenerate")
async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Session = Depends(get_db)):
"""
Regenerate an assistant response for the conversation state at/before req.index.
If req.index points at an assistant message, we regenerate from the preceding user message.
"""
idx = req.index
model = req.model
stream = bool(req.stream)
sources = req.sources or []
session = db.query(models.ChatSession).filter(models.ChatSession.session_id == session_id).first()
if not session:
@@ -218,43 +249,49 @@ async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Sessio
.order_by(models.ChatMessage.created_at.asc())
.all()
)
if idx < 0 or idx >= len(msgs):
raise HTTPException(status_code=400, detail="Invalid message index")
# Find the last user message at/before idx
# last user idx at/before idx
last_user_idx = idx
for i in range(idx, -1, -1):
if msgs[i].role == "user":
last_user_idx = i
break
# Prune everything after last_user_idx
# prune after that user
if last_user_idx < len(msgs) - 1:
for m in msgs[last_user_idx + 1:]:
db.delete(m)
db.commit()
# Build the conversation up to & incl. the last user message
conversation = [{"role": m.role, "content": m.content} for m in msgs[: last_user_idx + 1]]
# Avoid DetachedInstanceError during streaming
if req.enriched_message:
for j in range(len(conversation) - 1, -1, -1):
if conversation[j].get("role") == "user":
conversation = conversation.copy()
conversation[j] = {**conversation[j], "content": req.enriched_message}
break
session_pk = session.id
if stream:
async def stream_generator():
full_reply = ""
try:
# ollama_chat_stream must already exist in your codebase (used by /chat)
async for chunk in ollama_chat_stream(model, conversation):
full_reply += chunk
yield chunk
except Exception as e:
yield f"Ollama error: {e}"
# Persist with a fresh DB session (streaming context)
# persist (with sources)
try:
db_sess = SessionLocal()
db_sess.add(models.ChatMessage(session_pk=session_pk, role="assistant", content=full_reply))
db_sess.add(models.ChatMessage(
session_pk=session_pk, role="assistant", content=full_reply,
sources_json=json.dumps(sources or [])
))
db_sess.commit()
finally:
try:
@@ -264,16 +301,38 @@ async def regenerate(session_id: str, req: schemas.RegenerateRequest, db: Sessio
return StreamingResponse(stream_generator(), media_type="text/plain")
# Non-streaming
try:
# ollama_chat must already exist in your codebase (used by /chat)
reply = await ollama_chat(model, conversation)
except Exception as e:
raise HTTPException(status_code=502, detail=f"Ollama error: {e}")
db.add(models.ChatMessage(session_pk=session_pk, role="assistant", content=reply))
db.add(models.ChatMessage(
session_pk=session_pk, role="assistant", content=reply,
sources_json=json.dumps(sources or [])
))
db.commit()
return {"reply": reply}
# -----------------------------------------------------------------------------
# Web search enrichment endpoint
@app.post("/websearch", response_model=schemas.WebSearchResponse)
async def websearch_route(req: schemas.WebSearchRequest):
"""
Return an enriched prompt (with citations) for a given user prompt.
Optionally uses the last `history_limit` turns from `req.messages`.
"""
try:
messages = (req.messages or [])[-int(req.history_limit or 8):]
enriched, sources = await enrich_prompt(
user_prompt=req.prompt,
model=req.model,
messages=[{"role": m.role, "content": m.content} for m in messages],
searx_url=req.searx_url,
engines=req.engines,
)
return {"enriched_prompt": enriched, "sources": sources}
except Exception:
return {"enriched_prompt": req.prompt, "sources": []}
# To run standalone: python -m uvicorn backend.main:app --host 127.0.0.1 --port 8000

View File

@@ -20,6 +20,8 @@ class ChatMessage(Base):
session_pk = Column(Integer, ForeignKey('chat_sessions.id'), nullable=False)
role = Column(String(16), nullable=False) # 'user' | 'assistant'
content = Column(Text, nullable=False)
# JSON-encoded list of citation URLs; null/empty => no chips
sources_json = Column(Text, nullable=True, default='[]')
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
session = relationship("ChatSession", back_populates="messages")

View File

@@ -4,3 +4,8 @@ uvicorn[standard]==0.30.1
SQLAlchemy==2.0.32
httpx==0.27.0
pydantic==2.7.4
# Web search enrichment dependencies
beautifulsoup4==4.12.3
httpx[http2]>=0.27.0
numpy

View File

@@ -5,12 +5,15 @@ from datetime import datetime
class Message(BaseModel):
role: str
content: str
sources: Optional[List[str]] = None
class ChatRequest(BaseModel):
session_id: str
model: str
message: str
enriched_message: Optional[str] = None
stream: Optional[bool] = False
sources: Optional[List[str]] = None
class ChatResponse(BaseModel):
reply: str
@@ -47,4 +50,20 @@ class EditMessageRequest(BaseModel):
class RegenerateRequest(BaseModel):
index: int
model: Optional[str] = None
stream: bool = True
enriched_message: Optional[str] = None
stream: bool = True
sources: Optional[List[str]] = None
# Request payload for the web search enrichment endpoint.
class WebSearchRequest(BaseModel):
prompt: str
model: str
messages: Optional[List[Message]] = None
history_limit: Optional[int] = 8
searx_url: Optional[str] = None
engines: Optional[List[str]] = None
# Response payload for the web search enrichment endpoint.
class WebSearchResponse(BaseModel):
enriched_prompt: str
sources: List[str] = []

642
backend/websearch.py Normal file
View File

@@ -0,0 +1,642 @@
import asyncio
from typing import List, Tuple, Dict, Any, Optional
import httpx
from bs4 import BeautifulSoup
import re
import json
import traceback
import hashlib
from .ollama_client import chat as ollama_chat
# Configure your local SearXNG instance URL (no trailing slash)
SEARX_URL = "http://localhost:8888"
# ----- Utilities ----------------------------------------------------------------
def clean_text(html: str, max_len: int = 120_000) -> str:
# Prefer lxml parser if available (significantly faster); fall back gracefully.
try:
soup = BeautifulSoup(html, "lxml")
except Exception:
soup = BeautifulSoup(html, "html.parser")
for tag in soup.select("script,style,noscript"):
tag.decompose()
# Fast text extraction
text = soup.get_text("\n", strip=True)
text = re.sub(r"[ \t]+", " ", text)
text = re.sub(r"\n{3,}", "\n\n", text)
return text[:max_len] if len(text) > max_len else text
def render_recent_context(messages: Optional[List[Dict[str, str]]], char_limit: int = 1200) -> str:
"""
Turn the last few turns into a compact excerpt:
user: ...
assistant: ...
Truncate aggressively for resource-friendliness.
"""
if not messages:
return ""
# Keep as-is order (oldest->newest). We assume caller already sliced last N.
lines: List[str] = []
for m in messages:
role = m.get("role", "user")
content = (m.get("content") or "").strip()
if not content:
continue
# compactify each line
content = re.sub(r"\s+", " ", content)
lines.append(f"{role}: {content}")
joined = "\n".join(lines)
if len(joined) > char_limit:
return joined[-char_limit:] # keep tail (most recent part)
return joined
# ----- SearXNG search & fetching -------------------------------------------------
from urllib.parse import urlparse
import itertools
# HTTP client tuning (keep-alive pool + HTTP/2)
HTTP_LIMITS = httpx.Limits(max_keepalive_connections=32, max_connections=64)
HTTP_TIMEOUT = httpx.Timeout(connect=5.0, read=10.0, write=10.0, pool=5.0)
DEFAULT_HEADERS = {
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome Safari",
"Accept": "text/html,application/xhtml+xml;q=0.9,*/*;q=0.8",
}
# Fetch policy
MAX_PAGES_FETCH = 8 # lower cap of pages well fetch per enrichment
FETCH_CONCURRENCY = 8 # concurrent page fetches
MAX_BYTES_PER_PAGE = 1_500_000 # ~1.5 MB read cap per page (streamed)
MIN_TEXT_LENGTH = 500 # discard useless pages early
SKIP_EXTS = {
".pdf", ".jpg", ".jpeg", ".png", ".gif", ".webp", ".svg",
".mp4", ".mp3", ".mov", ".avi", ".zip", ".gz", ".7z", ".tar", ".rar",
".woff", ".woff2", ".ttf", ".otf"
}
_PAGE_CACHE: Dict[str, str] = {} # naive in-proc cache (speeds repeated calls)
_EMB_CACHE: Dict[str, List[float]] = {} # NEW: embedding cache by SHA1 of snippet
def _is_probably_html_url(url: str) -> bool:
try:
path = urlparse(url).path.lower()
ext = (path.rsplit(".", 1)[-1] if "." in path else "")
return (("." not in path) or (ext and f".{ext}" not in SKIP_EXTS))
except Exception:
return True
async def searx_search(
client: httpx.AsyncClient,
query: str,
max_results: int = 6,
*,
searx_url: Optional[str] = None,
engines: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
base_url = (searx_url or SEARX_URL).rstrip("/")
params = {"q": query, "format": "json", "safesearch": 1}
if engines:
# SearXNG accepts a comma-separated 'engines' filter
params["engines"] = ",".join(engines)
try:
r = await client.get(f"{base_url}/search", params=params)
r.raise_for_status()
except Exception as e:
print(f"[web] searx_search error for {base_url} q='{query}': {repr(e)}")
return []
try:
data = r.json()
except Exception as e:
print(f"[web] searx_search JSON parse failed q='{query}': {repr(e)} | content-type={r.headers.get('content-type')}")
return []
seen = set()
out = []
for item in data.get("results", []):
url = item.get("url")
if not url or url in seen:
continue
seen.add(url)
out.append({"url": url, "title": item.get("title", ""), "engine": item.get("engine")})
if len(out) >= max_results:
break
if not out:
print(f"[web] searx_search returned 0 urls for q='{query}' (engines={engines})")
return out
async def searx_search_many(
client: httpx.AsyncClient,
queries: List[str],
per_query_max: int = 6,
overall_max: int = 32,
*,
searx_url: Optional[str] = None,
engines: Optional[List[str]] = None,
) -> List[str]:
print(f"[web] searx_search_many: {len(queries)} queries -> {searx_url or SEARX_URL} engines={engines or 'default'}")
tasks = [searx_search(client, q, max_results=per_query_max, searx_url=searx_url, engines=engines) for q in queries]
results = await asyncio.gather(*tasks, return_exceptions=True)
urls: List[str] = []
seen = set()
for q, res in zip(queries, results):
if isinstance(res, Exception):
print(f"[web] searx_search_many task error q='{q}': {repr(res)}")
continue
for item in res:
u = item.get("url")
if not u or u in seen:
continue
if not _is_probably_html_url(u):
print(f"[web] skip non-HTML url: {u}")
continue
seen.add(u)
urls.append(u)
if len(urls) >= overall_max:
print(f"[web] searx_search_many hit overall_max={overall_max}")
return urls
print(f"[web] searx_search_many collected {len(urls)} urls")
return urls
async def fetch_page(client: httpx.AsyncClient, url: str) -> str:
if url in _PAGE_CACHE:
return _PAGE_CACHE[url]
if not _is_probably_html_url(url):
return ""
try:
async with client.stream("GET", url, headers=DEFAULT_HEADERS) as r:
ctype = (r.headers.get("content-type") or "").lower()
if ctype and not ("text/html" in ctype or "application/xhtml+xml" in ctype or ctype.startswith("text/")):
print(f"[web] skip content-type {ctype} url={url}")
return ""
buf = bytearray()
async for chunk in r.aiter_bytes():
if not chunk:
break
buf.extend(chunk)
if len(buf) >= MAX_BYTES_PER_PAGE:
break
try:
text = buf.decode(r.encoding or "utf-8", errors="ignore")
except Exception:
text = buf.decode("utf-8", errors="ignore")
_PAGE_CACHE[url] = text
return text
except Exception as e:
print(f"[web] fetch_page error url={url}: {repr(e)}")
return ""
async def gather_pages(
client: httpx.AsyncClient,
urls: List[str],
max_pages: int = MAX_PAGES_FETCH,
) -> Dict[str, str]:
urls = list(itertools.islice(urls, 0, max_pages))
sem = asyncio.Semaphore(FETCH_CONCURRENCY)
out: Dict[str, str] = {}
counters = {"tried": 0, "ok": 0, "too_short": 0, "empty": 0}
async def _one(u: str):
async with sem:
counters["tried"] += 1
html = await fetch_page(client, u)
if not html:
counters["empty"] += 1
return
txt = clean_text(html)
if len(txt) < MIN_TEXT_LENGTH:
counters["too_short"] += 1
return
out[u] = txt
counters["ok"] += 1
await asyncio.gather(*(_one(u) for u in urls))
print(f"[web] gather_pages summary: {counters} / returned={len(out)}")
return out
# ----- LLM helpers ---------------------------------------------------------------
async def generate_queries(prompt: str, model: str, context_excerpt: str) -> List[str]:
"""
Single LLM call that *already* incorporates recent chat context to
resolve ellipses/pronouns. No extra rounds.
"""
ctx_block = f"\nRecent conversation (latest last):\n{context_excerpt}\n" if context_excerpt else ""
ask = f"""You are a search query generator.
Given the new user message and recent conversation, produce 3 diverse, terse web search queries that best address the user's *current* intent.
Rules:
- Resolve ambiguous references using the recent conversation when present.
- No quotes, no site: operators, no overly long queries.
- Return each query on its own line.
User message:
{prompt}
{ctx_block}"""
resp = await ollama_chat(model, [{"role": "user", "content": ask}])
lines = [l.strip(" -\t") for l in resp.splitlines() if l.strip()]
uniq, seen = [], set()
for l in lines:
k = l.lower()
if k in seen:
continue
uniq.append(l)
seen.add(k)
if len(uniq) >= 3:
break
return uniq or [prompt]
async def rerank(
prompt: str,
docs: List[Tuple[str, str]],
model: str, # kept for signature compatibility (unused here)
context_excerpt: str,
embed_model: str = "bge-m3:latest" # prefer explicit tag; we will auto-fallback
) -> List[Tuple[str, str, float]]:
"""
Embedding-based reranker (bge-m3 via Ollama) using cosine similarity.
Robustness upgrades:
- Try embed_model; if needed fallback to the ':latest' or untagged alias.
- Normalize multiple possible response schemas from Ollama.
- If the query vector is empty, retry the query alone.
- If only the query is returned, retry passages in a second call.
- Detailed logging of counts and dimensions.
"""
import time
t0 = time.perf_counter()
# --- optional fast cosine via NumPy ---------------------------------------
try:
import numpy as _np # optional
except Exception:
_np = None
def _cosine(a: List[float], b: List[float]) -> float:
if _np is not None:
va = _np.asarray(a, dtype="float32")
vb = _np.asarray(b, dtype="float32")
n = min(va.size, vb.size)
if n == 0:
return 0.0
va = va[:n]; vb = vb[:n]
na = _np.linalg.norm(va); nb = _np.linalg.norm(vb)
if na <= 0.0 or nb <= 0.0:
return 0.0
return float(va.dot(vb) / (na * nb))
# pure Python fallback
n = min(len(a), len(b))
if n == 0:
return 0.0
dot = na = nb = 0.0
for i in range(n):
x = a[i]; y = b[i]
dot += x * y
na += x * x
nb += y * y
if na <= 0.0 or nb <= 0.0:
return 0.0
return dot / ((na ** 0.5) * (nb ** 0.5))
# --- build query + passages ------------------------------------------------
ctx_tail = (("\ncontext: " + context_excerpt[-400:]) if context_excerpt else "")
q_text = f"query: {prompt.strip()}{ctx_tail}"
DOC_SNIPPET_CHARS = 800 # keep short & fast
passages: List[str] = []
raw_snippets: List[str] = [] # for cache keys
for (_u, t) in docs:
snippet = t.replace("\n", " ")
if len(snippet) > DOC_SNIPPET_CHARS:
snippet = snippet[:DOC_SNIPPET_CHARS]
passages.append(f"passage: {snippet}")
raw_snippets.append(snippet)
# --- embedding cache -------------------------------------------------------
emb_cache: Dict[str, List[float]] = globals().get("_EMB_CACHE", {})
try:
import hashlib
keys = [hashlib.sha1(s.encode("utf-8", errors="ignore")).hexdigest() for s in raw_snippets]
except Exception:
keys = [f"nocache-{i}" for i in range(len(raw_snippets))]
# Prepare inputs:
# 0 -> query; >=1 -> only passages NOT present in cache
to_embed_inputs: List[str] = [q_text]
pos_to_passage_idx: Dict[int, int] = {}
p_cached = 0
for i, (p, k) in enumerate(zip(passages, keys)):
if k in emb_cache:
p_cached += 1
continue
pos = len(to_embed_inputs)
to_embed_inputs.append(p)
pos_to_passage_idx[pos] = i
# --- helpers to call Ollama embeddings ------------------------------------
async def _embed_inputs(inputs: List[str], model_name: str) -> Tuple[List[List[float]], Dict[str, Any]]:
"""
Call Ollama /api/embeddings once per input with {"model", "prompt"}.
Keeps order; returns [[]] for failures at a given index.
"""
timeout = httpx.Timeout(connect=5.0, read=30.0, write=10.0, pool=5.0)
sem = asyncio.Semaphore(8) # cap concurrency a bit
async def _one(text: str) -> Tuple[List[float], Optional[str]]:
payload = {"model": model_name, "prompt": text}
try:
async with sem:
async with httpx.AsyncClient(timeout=timeout) as client:
r = await client.post("http://localhost:11434/api/embeddings", json=payload)
r.raise_for_status()
data = r.json()
except httpx.HTTPStatusError as e:
return [], f"http_error:{e}"
except Exception as e:
return [], f"request_error:{e}"
# normalize common shapes
if isinstance(data, dict):
if "error" in data:
return [], f"model_error:{data.get('error')}"
if "embedding" in data and isinstance(data["embedding"], list):
return data["embedding"], None
if "data" in data and isinstance(data["data"], list) and data["data"]:
em = data["data"][0].get("embedding", [])
return (em if isinstance(em, list) else []), None
if "embeddings" in data and isinstance(data["embeddings"], list):
# some servers may return {"embeddings":[vector]} even for single prompt
emb0 = data["embeddings"][0]
if isinstance(emb0, dict):
emb0 = emb0.get("embedding", [])
return (emb0 if isinstance(emb0, list) else []), None
return [], "parse_error"
tasks = [_one(t) for t in inputs]
results = await asyncio.gather(*tasks, return_exceptions=False)
embs: List[List[float]] = []
errs: List[str] = []
for emb, err in results:
embs.append(emb)
if err:
errs.append(err)
meta: Dict[str, Any] = {}
if errs:
# light meta summary
meta["errors"] = {k: errs.count(k) for k in set(errs)}
return embs, meta
# --- do the embedding calls ------------------------------------------------
t_embed = time.perf_counter()
inputs_all = to_embed_inputs # [query] + uncached passages
embeddings, meta = await _embed_inputs(inputs_all, embed_model)
# simple model fallback if *all* returned empty
if not any(len(e) for e in embeddings):
tried = [embed_model]
alt = (embed_model.split(":", 1)[0] if ":" in embed_model else embed_model + ":latest")
if alt != embed_model:
embeddings, meta2 = await _embed_inputs(inputs_all, alt)
if any(len(e) for e in embeddings):
print(f"[web] embed() recovered with fallback model {alt} (meta={meta or meta2})")
embed_model = alt
else:
print(f"[web] embed() FAILED (models tried={tried + [alt]}, meta={meta or meta2})")
return [(u, t, 0.0) for (u, t) in docs]
# split q vs passages and update cache
q_emb = embeddings[0] if embeddings else []
if not q_emb:
print("[web] embed() empty query vector — aborting rerank")
return [(u, t, 0.0) for (u, t) in docs]
# positions >=1 correspond to passages (only those that werent cached)
for pos, emb_vec in enumerate(embeddings[1:], start=1):
i = pos_to_passage_idx.get(pos)
if i is None:
continue
if not keys[i].startswith("nocache-"):
emb_cache[keys[i]] = emb_vec
# build aligned passage vectors
p_emb_list: List[List[float]] = [emb_cache.get(k, []) for k in keys]
# logging
q_dim = len(q_emb)
p_dims = [len(v) for v in p_emb_list]
print(f"[web] embed() took {time.perf_counter() - t_embed:.3f}s "
f"(model='{embed_model}', q_dim={q_dim}, p_cached={p_cached}, "
f"p_fresh={sum(1 for d in p_dims if d>0)}, total_p={len(p_emb_list)})")
# --- score + rank ----------------------------------------------------------
t_score = time.perf_counter()
scored: List[Tuple[str, str, float]] = []
for (u, t), p_emb in zip(docs, p_emb_list):
cos = _cosine(q_emb, p_emb)
score_0_100 = max(0.0, min(100.0, (cos + 1.0) * 50.0))
scored.append((u, t, score_0_100))
scored.sort(key=lambda x: x[2], reverse=True)
print(f"[web] cosine scoring took {time.perf_counter() - t_score:.3f}s; rerank total {time.perf_counter() - t0:.3f}s")
return scored
def build_enriched_prompt(user_prompt: str, ranked: List[Tuple[str, str, float]], top_k: int = 6) -> Tuple[str, List[str]]:
"""
Build an enriched prompt using only high-quality documents.
- Keep at most `top_k` docs with score >= MIN_SCORE.
- If none survive, return a <websearch_context> telling the assistant that
a web search was performed but no good results were found.
"""
MIN_SCORE = 70.0
# Sort defensively (should already be sorted desc)
ranked = sorted(ranked, key=lambda x: x[2], reverse=True)
# Strict cutoff
selected = [(u, t, sc) for (u, t, sc) in ranked if sc >= MIN_SCORE][:top_k]
# Debug summary
try:
all_scores = [sc for (_u, _t, sc) in ranked]
sel_scores = [sc for (_u, _t, sc) in selected]
if all_scores:
print(f"[web] selection ≥{MIN_SCORE} → total={len(all_scores)}, selected={len(selected)}, "
f"top_sel={max(sel_scores) if sel_scores else 0:.1f}, "
f"min_sel={min(sel_scores) if sel_scores else 0:.1f}")
except Exception:
pass
if not selected:
# No “good” results → still enrich with an explicit no-results message
parts = [
"<websearch_context>",
f"No suitable web results found (score < {MIN_SCORE}). "
"Tell the user you performed a web search but couldn't find any good results. "
"If you can answer from prior context or general knowledge, say so clearly and "
"do not fabricate citations or URLs.",
"</websearch_context>",
]
enriched = f"{user_prompt}\n\n" + "\n".join(parts)
return enriched, []
# Build normal context
sources = [u for (u, _, _) in selected]
parts = ["<websearch_context>"]
for i, (u, t, _) in enumerate(selected, 1):
snippet = t.strip().replace("\n", " ")
if len(snippet) > 1200:
snippet = snippet[:1200]
parts.append(f"[{i}] {snippet} (source: {u})")
parts.append("</websearch_context>")
parts.append("\nAnswer the user using the context when relevant.\nMake your response long enough to reflect all interesting information found. No duplicate information.\nUnless related to the topic, don't mention the source from the web search.")
enriched = f"{user_prompt}\n\n" + "\n".join(parts)
return enriched, sources
# ----- Public API ----------------------------------------------------------------
async def enrich_prompt(
user_prompt: str,
model: str,
messages: Optional[List[Dict[str, str]]] = None,
*,
searx_url: Optional[str] = None,
engines: Optional[List[str]] = None,
) -> Tuple[str, List[str]]:
import time # local import to avoid touching global imports
start_all = time.perf_counter()
def _no_results_enriched(reason: str, queries: Optional[List[str]] = None) -> Tuple[str, List[str]]:
parts = ["<websearch_context>"]
parts.append(
"Web search was performed but no suitable results were found. "
"Tell the user you searched the web but couldn't find any good results. "
"If you can still answer from prior context or general knowledge, say so clearly, "
"and do not fabricate citations or URLs."
)
if queries:
try:
qshow = ", ".join(queries[:3])
parts.append(f"Queries attempted: {qshow}")
except Exception:
pass
if engines:
try:
eshow = ", ".join(engines)
parts.append(f"Engines selected: {eshow}")
except Exception:
pass
parts.append(f"(reason: {reason})")
parts.append("</websearch_context>")
return f"{user_prompt}\n\n" + "\n".join(parts), []
context_excerpt = render_recent_context(messages, char_limit=1200)
# 1) queries
try:
t0 = time.perf_counter()
queries = await generate_queries(user_prompt, model=model, context_excerpt=context_excerpt)
print(f"[web] queries: {queries} (took {time.perf_counter() - t0:.3f}s)")
except Exception:
print("[web] ERROR in generate_queries:\n" + traceback.format_exc())
print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s")
return _no_results_enriched("query_generation_failed")
# 2) search + fetch
try:
print("[web] opening httpx client")
async with httpx.AsyncClient(
headers=DEFAULT_HEADERS,
follow_redirects=True,
http2=True,
limits=HTTP_LIMITS,
timeout=HTTP_TIMEOUT,
) as client:
t1 = time.perf_counter()
print("[web] calling searx_search_many() …")
all_urls = await searx_search_many(
client,
queries,
per_query_max=6,
overall_max=MAX_PAGES_FETCH * 2,
searx_url=searx_url,
engines=engines,
)
print(f"[web] searx_search_many() -> {len(all_urls)} urls (took {time.perf_counter() - t1:.3f}s)")
if not all_urls:
print("[web] no URLs from SearX — emitting no-results context")
print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s")
return _no_results_enriched("no_urls", queries)
t2 = time.perf_counter()
print("[web] calling gather_pages() …")
pages = await gather_pages(client, all_urls, max_pages=MAX_PAGES_FETCH)
print(f"[web] gather_pages() -> {len(pages)} pages (cap={MAX_PAGES_FETCH}, took {time.perf_counter() - t2:.3f}s)")
except Exception:
print("[web] ERROR during search/fetch:\n" + traceback.format_exc())
print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s")
return _no_results_enriched("search_fetch_error")
# 3) docs
try:
t3 = time.perf_counter()
docs = [(u, pages[u]) for u in all_urls if u in pages]
print(f"[web] docs for rerank: {len(docs)} (built in {time.perf_counter() - t3:.3f}s)")
if not docs:
print("[web] no docs after fetch/filter — emitting no-results context")
print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s")
return _no_results_enriched("no_docs_after_fetch", queries)
except Exception:
print("[web] ERROR building docs list:\n" + traceback.format_exc())
print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s")
return _no_results_enriched("docs_build_failed", queries)
# 4) rerank
try:
t4 = time.perf_counter()
print("[web] calling rerank() …")
ranked = await rerank(user_prompt, docs, model=model, context_excerpt=context_excerpt)
print(f"[web] rerank() -> {len(ranked)} scored docs (took {time.perf_counter() - t4:.3f}s)")
try:
scores_only = [sc for (_u, _t, sc) in ranked]
if scores_only:
scores_sorted = sorted(scores_only)
n = len(scores_sorted)
p50 = scores_sorted[n // 2]
p75 = scores_sorted[int(n * 0.75) - 1 if n > 3 else n - 1]
print(f"[web] score stats → min={scores_sorted[0]:.1f}, p50={p50:.1f}, p75={p75:.1f}, max={scores_sorted[-1]:.1f}")
print(f"[web] top scores → {[round(s,1) for s in scores_only[:min(6,len(scores_only))]]}")
except Exception:
pass
except Exception:
print("[web] ERROR in rerank:\n" + traceback.format_exc())
print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s")
return _no_results_enriched("rerank_failed", queries)
# 5) build prompt
try:
t5 = time.perf_counter()
enriched = build_enriched_prompt(user_prompt, ranked, top_k=6)
print(f"[web] build_enriched_prompt() done (took {time.perf_counter() - t5:.3f}s)")
print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s")
return enriched
except Exception:
print("[web] ERROR in build_enriched_prompt:\n" + traceback.format_exc())
print(f"[web] enrich_prompt total: {time.perf_counter() - start_all:.3f}s")
return _no_results_enriched("build_enriched_failed", queries)

View File

@@ -4,6 +4,7 @@ import { flushSync } from 'react-dom';
import TextareaAutosize from 'react-textarea-autosize';
import GeneralSettings from './GeneralSettings'
import InterfaceSettings from './InterfaceSettings'
import WebsearchSettings from './WebsearchSettings'
import { markdownToHTML } from './markdown';
// Extract <think> or <thinking> block (first occurrence) and return { think, answer }
function splitThinkBlocks(text) {
@@ -43,7 +44,7 @@ function splitThinkBlocks(text) {
}
// Renders assistant message with a collapsible "Thoughts" block (if present)
function AssistantMessageContent({ content, streamOutput }) {
function AssistantMessageContent({ content, streamOutput, sources }) {
const { think, answer } = splitThinkBlocks(content || '');
const [open, setOpen] = React.useState(false);
const showThinkButton = !!think;
@@ -76,12 +77,30 @@ function AssistantMessageContent({ content, streamOutput }) {
className="msg-content"
dangerouslySetInnerHTML={{ __html: markdownToHTML(answer || content || '') }}
/>
{Array.isArray(sources) && sources.length > 0 && (
<div className="msg-sources chips">
{sources.map((u, i) => {
let label = u;
try {
const host = new URL(u).hostname || u;
label = host.replace(/^www\./i, '');
} catch {}
return (
<a key={u + i} className="chip" href={u} target="_blank" rel="noreferrer" title={u}>
{label}
</a>
);
})}
</div>
)}
</div>
);
}
const API_URL_KEY = 'ollamaApiUrl';
const COLOR_SCHEME_KEY = 'colorScheme';
const WEBSEARCH_URL_KEY = 'websearch.searxUrl';
const WEBSEARCH_ENGINES_KEY = 'websearch.engines';
// Initial API value will be set by useEffect after settings are loaded
let API = import.meta.env.VITE_API_URL ?? 'http://127.0.0.1:8000';
@@ -103,6 +122,24 @@ export default function App() {
const [ollamaApiUrl, setOllamaApiUrl] = useState(API); // State for Ollama API URL
const [colorScheme, setColorScheme] = useState('Default'); // State for color scheme
const [streamOutput, setStreamOutput] = useState(false);
const [searxUrl, setSearxUrl] = useState(localStorage.getItem(WEBSEARCH_URL_KEY) || 'http://localhost:8888');
const [searxEngines, setSearxEngines] = useState(() => {
try {
const raw = localStorage.getItem(WEBSEARCH_ENGINES_KEY);
if (raw) return JSON.parse(raw);
} catch {}
return ["duckduckgo","bing","wikipedia","github","stack_overflow"];
});
useEffect(() => {
localStorage.setItem(WEBSEARCH_URL_KEY, searxUrl || '');
}, [searxUrl]);
useEffect(() => {
try {
localStorage.setItem(WEBSEARCH_ENGINES_KEY, JSON.stringify(searxEngines || []));
} catch {}
}, [searxEngines]);
const [webSearchEnabled, setWebSearchEnabled] = useState(false);
const [isSending, setIsSending] = useState(false);
const [loading, setLoading] = useState(true); // Loading state for initial session fetch
const [unreadSessions, setUnreadSessions] = useState([]); // Track unread messages
@@ -210,126 +247,170 @@ export default function App() {
}
// Continue conversation from the edited message
await regenerateFromIndex(index);
await regenerateFromIndex(index, next);
}
async function regenerateFromIndex(index) {
const sessionId = activeSessionId;
if (!sessionId || typeof index !== 'number') return;
async function regenerateFromIndex(index, overrideUserText = null) {
const sessionId = activeSessionId;
if (!sessionId || typeof index !== 'number') return;
const msgs = (chatSessions.find(s => s.session_id === sessionId)?.messages) || [];
let lastUserIdx = index;
for (let i = index; i >= 0; i--) {
if (msgs[i]?.role === 'user') { lastUserIdx = i; break; }
const msgs = (chatSessions.find(s => s.session_id === sessionId)?.messages) || [];
let lastUserIdx = index;
for (let i = index; i >= 0; i--) {
if (msgs[i]?.role === 'user') { lastUserIdx = i; break; }
}
// Prune UI to lastUserIdx
setChatSessions(prev =>
prev.map(s => s.session_id === sessionId
? { ...s, messages: (s.messages || []).slice(0, lastUserIdx + 1) }
: s
)
);
setIsSending(true);
// --- optional websearch enrichment for regenerate ---
let enrichedPrompt = null;
let citationSources = [];
if (webSearchEnabled) {
try {
// Use the freshly edited user text when provided
const promptText = (overrideUserText != null ? overrideUserText : (msgs[lastUserIdx]?.content || ''));
// Build compact recent history and overwrite the last user turn with promptText
const historyForSearch = msgs
.slice(Math.max(0, lastUserIdx - 7), lastUserIdx + 1)
.map(m => ({ role: m.role, content: m.content || '' }));
if (historyForSearch.length > 0) {
historyForSearch[historyForSearch.length - 1] = { role: 'user', content: promptText };
}
const resp = await fetch(`${ollamaApiUrl}/websearch`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
prompt: promptText,
model,
messages: historyForSearch,
history_limit: 8,
searx_url: searxUrl || null,
engines: Array.isArray(searxEngines) ? searxEngines : null,
})
});
const data = await resp.json();
if (data && typeof data.enriched_prompt === 'string') {
enrichedPrompt = data.enriched_prompt;
citationSources = Array.isArray(data.sources) ? data.sources : [];
}
} catch (e) {
console.warn('web search enrichment (regenerate) failed', e);
}
}
// Prune UI to lastUserIdx
if (streamOutput) {
const assistantMsgId = `msg-${Date.now()}-${Math.random()}`;
// add placeholder assistant message (keep sources on the placeholder)
setChatSessions(prev =>
prev.map(s => s.session_id === sessionId
? { ...s, messages: (s.messages || []).slice(0, lastUserIdx + 1) }
? { ...s, messages: [...(s.messages || []), { id: assistantMsgId, role: 'assistant', content: '', sources: citationSources }] }
: s
)
);
setIsSending(true);
try {
const res = await fetch(`${ollamaApiUrl}/sessions/${sessionId}/regenerate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
index,
model,
stream: true,
enriched_message: enrichedPrompt,
sources: citationSources || []
})
});
const reader = res.body.getReader();
const decoder = new TextDecoder();
let full = '';
let unreadMarked = false;
if (streamOutput) {
const assistantMsgId = `msg-${Date.now()}-${Math.random()}`;
// add placeholder assistant message
setChatSessions(prev =>
prev.map(s => s.session_id === sessionId
? { ...s, messages: [...(s.messages || []), { id: assistantMsgId, role: 'assistant', content: '' }] }
: s
)
);
while (true) {
const { value, done } = await reader.read();
if (done) break;
try {
const res = await fetch(`${ollamaApiUrl}/sessions/${sessionId}/regenerate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ index, model, stream: true })
});
const reader = res.body.getReader();
const decoder = new TextDecoder();
let full = '';
let unreadMarked = false; // NEW
const chunk = decoder.decode(value, { stream: true });
full += chunk;
while (true) {
const { value, done } = await reader.read();
if (done) break;
const chunk = decoder.decode(value, { stream: true });
full += chunk;
// Update the growing assistant message
setChatSessions(prev =>
prev.map(s => s.session_id === sessionId
? { ...s, messages: (s.messages || []).map(m => m.id === assistantMsgId ? { ...m, content: full } : m) }
: s
)
);
// If this session is not active while streaming, mark unread once
if (!unreadMarked && activeSessionIdRef.current !== sessionId) {
unreadMarked = true;
setPendingScrollToLastUser(prev => ({ ...prev, [sessionId]: assistantMsgId }));
setUnreadSessions(prev => [...new Set([...prev, sessionId])]);
}
}
// On stream end: if user is in another chat, ensure unread + guided scroll are set
if (activeSessionIdRef.current !== sessionId) {
setPendingScrollToLastUser(prev => ({ ...prev, [sessionId]: assistantMsgId }));
setUnreadSessions(prev => [...new Set([...prev, sessionId])]);
} else {
// If user stayed here and didn't scroll up, align the finished answer nicely
if (!userScrolledUpRef.current[sessionId]) {
requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', sessionId));
} else {
// show the tip if they had scrolled away
setNewMsgTip(prev => ({ ...prev, [sessionId]: assistantMsgId }));
}
}
} catch (e) {
console.error(e);
} finally {
setIsSending(false);
}
} else {
try {
const res = await fetch(`${ollamaApiUrl}/sessions/${sessionId}/regenerate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ index, model, stream: false })
});
const data = await res.json();
const assistantMsgId = `msg-${Date.now()}`;
// Update the growing assistant message (sources remain intact)
setChatSessions(prev =>
prev.map(s => s.session_id === sessionId
? { ...s, messages: [...(s.messages || []), { role: 'assistant', content: data.reply, id: assistantMsgId }] }
? { ...s, messages: (s.messages || []).map(m => m.id === assistantMsgId ? { ...m, content: full } : m) }
: s
)
);
if (activeSessionIdRef.current !== sessionId) {
// reply landed in background -> mark unread + remember where to scroll
if (!unreadMarked && activeSessionIdRef.current !== sessionId) {
unreadMarked = true;
setPendingScrollToLastUser(prev => ({ ...prev, [sessionId]: assistantMsgId }));
setUnreadSessions(prev => [...new Set([...prev, sessionId])]);
} else {
// same chat -> align unless the user scrolled away
if (!userScrolledUpRef.current[sessionId]) {
requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', sessionId));
} else {
setNewMsgTip(prev => ({ ...prev, [sessionId]: assistantMsgId }));
}
}
} catch (e) {
console.error(e);
} finally {
setIsSending(false);
}
if (activeSessionIdRef.current !== sessionId) {
setPendingScrollToLastUser(prev => ({ ...prev, [sessionId]: assistantMsgId }));
setUnreadSessions(prev => [...new Set([...prev, sessionId])]);
} else {
if (!userScrolledUpRef.current[sessionId]) {
requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', sessionId));
} else {
setNewMsgTip(prev => ({ ...prev, [sessionId]: assistantMsgId }));
}
}
} catch (e) {
console.error(e);
} finally {
setIsSending(false);
}
} else {
try {
const res = await fetch(`${ollamaApiUrl}/sessions/${sessionId}/regenerate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
index,
model,
stream: false,
enriched_message: enrichedPrompt,
sources: citationSources || []
})
});
const data = await res.json();
const assistantMsgId = `msg-${Date.now()}`;
setChatSessions(prev =>
prev.map(s => s.session_id === sessionId
? { ...s, messages: [...(s.messages || []), { role: 'assistant', content: data.reply, id: assistantMsgId, sources: citationSources }] }
: s
)
);
if (activeSessionIdRef.current !== sessionId) {
setPendingScrollToLastUser(prev => ({ ...prev, [sessionId]: assistantMsgId }));
setUnreadSessions(prev => [...new Set([...prev, sessionId])]);
} else {
if (!userScrolledUpRef.current[sessionId]) {
requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', sessionId));
} else {
setNewMsgTip(prev => ({ ...prev, [sessionId]: assistantMsgId }));
}
}
} catch (e) {
console.error(e);
} finally {
setIsSending(false);
}
}
}
// Persist userScrolledUp state per session + live ref for closures (streaming)
@@ -739,108 +820,115 @@ export default function App() {
};
async function sendMessage() {
if (!input.trim() || !model) return;
async function sendMessage() {
if (!input.trim() || !model) return;
let targetSessionId = activeSessionId;
let isNewChat = false;
if (!targetSessionId) {
const newSession = await createNewChat();
await new Promise(resolve => setTimeout(resolve, 200));
targetSessionId = newSession.session_id;
isNewChat = true;
} else {
const currentSession = chatSessions.find(s => s.session_id === targetSessionId);
isNewChat = currentSession && currentSession.name === "New Chat" && currentSession.messages.length === 0;
}
const userMsg = { role: 'user', content: input.trim(), id: `msg-${Date.now()}-${Math.random()}` };
justSentMessage.current = true;
lastSentSessionRef.current = targetSessionId;
setUserScrolledUp(targetSessionId, false);
let targetSessionId = activeSessionId;
let isNewChat = false;
if (!targetSessionId) {
const newSession = await createNewChat();
await new Promise(resolve => setTimeout(resolve, 200));
targetSessionId = newSession.session_id;
isNewChat = true;
} else {
const currentSession = chatSessions.find(s => s.session_id === targetSessionId);
isNewChat = currentSession && currentSession.name === "New Chat" && currentSession.messages.length === 0;
}
// Cancel any pending restore for the active session (we're about to control the scroll)
if (activeSessionIdRef.current === targetSessionId) {
restoredForRef.current = activeSessionIdRef.current; // mark as already restored
const userMsg = { role: 'user', content: input.trim(), id: `msg-${Date.now()}-${Math.random()}` };
justSentMessage.current = true;
lastSentSessionRef.current = targetSessionId;
setUserScrolledUp(targetSessionId, false);
if (activeSessionIdRef.current === targetSessionId) {
restoredForRef.current = activeSessionIdRef.current;
}
flushSync(() => {
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId
? { ...session, messages: [...(session.messages || []), userMsg] }
: session
)
);
setInput('');
});
requestAnimationFrame(() => scrollToBottom('auto', targetSessionId));
setIsSending(true);
try {
// Build compact recent history for context-aware websearch (resource-friendly).
// We only send the last 8 turns by default, including assistant replies,
// and we also append the *current* user message (same content as `userMsg`).
let historyForSearch = [];
try {
const existing = (chatSessions.find(s => s.session_id === targetSessionId)?.messages) || [];
const lastFew = existing.slice(-8).map(m => ({ role: m.role, content: m.content || '' }));
historyForSearch = [...lastFew, { role: 'user', content: userMsg.content }];
} catch {}
// Decide on enrichment using the toggle
let enrichedPrompt = userMsg.content;
let citationSources = [];
if (webSearchEnabled) {
try {
const resp = await fetch(`${ollamaApiUrl}/websearch`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
prompt: userMsg.content,
model,
messages: historyForSearch,
history_limit: 8,
searx_url: searxUrl || null,
engines: Array.isArray(searxEngines) ? searxEngines : null,
})
});
const data = await resp.json();
if (data && typeof data.enriched_prompt === 'string') {
enrichedPrompt = data.enriched_prompt;
citationSources = Array.isArray(data.sources) ? data.sources : [];
}
} catch (e) {
console.warn('web search enrichment failed', e);
}
}
// Optimistic add and flush DOM, then scroll to bottom
flushSync(() => {
if (streamOutput) {
const assistantMsgId = `msg-${Date.now()}-${Math.random()}`;
const assistantMsg = { role: 'assistant', content: '', id: assistantMsgId, sources: citationSources };
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId
? { ...session, messages: [...(session.messages || []), userMsg] }
? { ...session, messages: [...(session.messages || []), assistantMsg] }
: session
)
);
setInput('');
});
requestAnimationFrame(() => scrollToBottom('auto', targetSessionId));
setIsSending(true);
try {
if (streamOutput) {
const assistantMsgId = `msg-${Date.now()}-${Math.random()}`;
const assistantMsg = { role: 'assistant', content: '', id: assistantMsgId };
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId
? { ...session, messages: [...(session.messages || []), assistantMsg] }
: session
)
);
(async () => {
try {
const res = await fetch(`${ollamaApiUrl}/chat`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
session_id: targetSessionId,
model,
message: userMsg.content,
enriched_message: webSearchEnabled ? enrichedPrompt : null,
stream: true,
sources: citationSources || []
})
});
(async () => {
try {
const res = await fetch(`${ollamaApiUrl}/chat`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
session_id: targetSessionId,
model,
message: userMsg.content,
stream: true
})
});
const reader = res.body.getReader();
const decoder = new TextDecoder();
let fullReply = '';
let pendingMarked = false;
const reader = res.body.getReader();
const decoder = new TextDecoder();
let fullReply = '';
let pendingMarked = false;
while (true) {
const { value, done } = await reader.read();
if (done) {
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId
? {
...session,
messages: session.messages.map(m =>
m.id === assistantMsgId ? { ...m, content: fullReply } : m
)
}
: session
)
);
if (activeSessionIdRef.current === targetSessionId) {
if (!userScrolledUpRef.current[targetSessionId]) {
// user stayed at bottom -> reveal the message immediately
requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', targetSessionId));
} else {
// user scrolled away while it was generating -> show tip instead of auto-scroll
setNewMsgTip(prev => ({ ...prev, [targetSessionId]: assistantMsgId }));
}
} else {
setPendingScrollToLastUser(prev => ({ ...prev, [targetSessionId]: assistantMsgId }));
setUnreadSessions(prev => [...new Set([...prev, targetSessionId])]);
}
break;
}
const chunk = decoder.decode(value, { stream: true });
fullReply += chunk;
while (true) {
const { value, done } = await reader.read();
if (done) {
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId
@@ -853,113 +941,154 @@ export default function App() {
: session
)
);
// Keep sticky-bottom *only* when streaming in the active chat and user is at/near bottom.
// This restores the old "push down while generating" behavior without fighting user scrolls.
if (
activeSessionIdRef.current === targetSessionId &&
!userScrolledUpRef.current[targetSessionId]
) {
// use 'auto' so it stays snappy during streaming
scrollToBottom('auto', targetSessionId);
}
// If streaming in a background chat, prepare a one-time guided scroll
if (activeSessionIdRef.current !== targetSessionId && !pendingMarked) {
if (activeSessionIdRef.current === targetSessionId) {
if (!userScrolledUpRef.current[targetSessionId]) {
requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', targetSessionId));
} else {
setNewMsgTip(prev => ({ ...prev, [targetSessionId]: assistantMsgId }));
}
} else {
setPendingScrollToLastUser(prev => ({ ...prev, [targetSessionId]: assistantMsgId }));
pendingMarked = true;
setUnreadSessions(prev => [...new Set([...prev, targetSessionId])]);
}
break;
}
} catch (e) {
console.error("Failed to send message:", e);
const errorMsg = { role: 'assistant', content: 'Error: ' + e.message, id: `msg-${Date.now()}-${Math.random()}` };
const chunk = decoder.decode(value, { stream: true });
fullReply += chunk;
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId
? { ...session, messages: [...session.messages.slice(0, -1), errorMsg] }
? {
...session,
messages: session.messages.map(m =>
m.id === assistantMsgId ? { ...m, content: fullReply } : m
)
}
: session
)
);
} finally {
setIsSending(false);
}
})();
} else {
const res = await fetch(`${ollamaApiUrl}/chat`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
session_id: targetSessionId,
model,
message: userMsg.content,
stream: false
})
});
const data = await res.json();
const assistantMsgId = `msg-${Date.now()}`;
const assistantMsg = { role: 'assistant', content: data.reply, id: assistantMsgId };
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId
? { ...session, messages: [...(session.messages || []), assistantMsg] }
: session
)
);
// For non-stream: align new ASSISTANT message to top, unless user scrolled away
if (assistantMsgId) {
if (activeSessionIdRef.current === targetSessionId) {
if (!userScrolledUpRef.current[targetSessionId]) {
requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', targetSessionId));
} else {
// <<< show the tip if user scrolled away while waiting >>>
setNewMsgTip(prev => ({ ...prev, [targetSessionId]: assistantMsgId }));
if (
activeSessionIdRef.current === targetSessionId &&
!userScrolledUpRef.current[targetSessionId]
) {
scrollToBottom('auto', targetSessionId);
}
if (activeSessionIdRef.current !== targetSessionId && !pendingMarked) {
setPendingScrollToLastUser(prev => ({ ...prev, [targetSessionId]: assistantMsgId }));
pendingMarked = true;
}
} else {
setPendingScrollToLastUser(prev => ({ ...prev, [targetSessionId]: assistantMsgId }));
}
}
setIsSending(false);
}
if (activeSessionIdRef.current !== targetSessionId) {
setUnreadSessions(prev => [...new Set([...prev, targetSessionId])]);
}
if (isNewChat) {
fetch(`${ollamaApiUrl}/generate-title`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
session_id: targetSessionId,
message: userMsg.content,
model: model
})
})
.then(r => r.json())
.then(data => {
const sanitizedTitle = data.title.replace(/<think(?:ing)?>[\s\S]*?<\/think(?:ing)?>/i, '').trim();
} catch (e) {
console.error('Failed to send message:', e);
const errorMsg = {
role: 'assistant',
content: 'Error: ' + e.message,
id: `msg-${Date.now()}-${Math.random()}`,
sources: citationSources
};
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId ? { ...session, name: sanitizedTitle } : session
session.session_id === targetSessionId
? { ...session, messages: [...session.messages.slice(0, -1), errorMsg] }
: session
)
);
});
}
} catch (e) {
console.error("Failed to send message:", e);
const errorMsg = { role: 'assistant', content: 'Error: ' + e.message, id: `msg-${Date.now()}-${Math.random()}` };
} finally {
setIsSending(false);
}
})();
} else {
const res = await fetch(`${ollamaApiUrl}/chat`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
session_id: targetSessionId,
model,
message: userMsg.content,
enriched_message: webSearchEnabled ? enrichedPrompt : null,
stream: false,
sources: citationSources || []
})
});
const data = await res.json();
const assistantMsgId = `msg-${Date.now()}`;
const assistantMsg = {
role: 'assistant',
content: data.reply,
id: assistantMsgId,
sources: citationSources
};
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId
? { ...session, messages: [...session.messages, errorMsg] }
? { ...session, messages: [...(session.messages || []), assistantMsg] }
: session
)
);
if (assistantMsgId) {
if (activeSessionIdRef.current === targetSessionId) {
if (!userScrolledUpRef.current[targetSessionId]) {
requestAnimationFrame(() => scrollMessageToTop(assistantMsgId, 'smooth', targetSessionId));
} else {
setNewMsgTip(prev => ({ ...prev, [targetSessionId]: assistantMsgId }));
}
} else {
setPendingScrollToLastUser(prev => ({ ...prev, [targetSessionId]: assistantMsgId }));
}
}
setIsSending(false);
}
}
async function createNewChat() {
if (activeSessionIdRef.current !== targetSessionId) {
setUnreadSessions(prev => [...new Set([...prev, targetSessionId])]);
}
if (isNewChat) {
fetch(`${ollamaApiUrl}/generate-title`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
session_id: targetSessionId,
message: userMsg.content,
model: model
})
})
.then(r => r.json())
.then(data => {
const sanitizedTitle = data.title.replace(/<think(?:ing)?>[\s\S]*?<\/think(?:ing)?>/i, '').trim();
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId ? { ...session, name: sanitizedTitle } : session
)
);
});
}
} catch (e) {
console.error("Failed to send message:", e);
const errorMsg = { role: 'assistant', content: 'Error: ' + e.message, id: `msg-${Date.now()}-${Math.random()}` };
setChatSessions(prevSessions =>
prevSessions.map(session =>
session.session_id === targetSessionId
? { ...session, messages: [...session.messages, errorMsg] }
: session
)
);
setIsSending(false);
}
}
function toggleWebSearch() {
setWebSearchEnabled(prev => !prev);
}
async function createNewChat() {
const newSessionId = 'sess-' + Math.random().toString(36).slice(2) + Date.now().toString(36);
const res = await fetch(`${ollamaApiUrl}/sessions`, {
method: 'POST',
@@ -1155,6 +1284,12 @@ export default function App() {
>
Interface
</div>
<div
className={`settings-item ${activeSettingsSubmenu === 'Websearch' ? 'active' : ''}`}
onClick={() => setActiveSettingsSubmenu('Websearch')}
>
Websearch
</div>
</div>
)}
</div>
@@ -1192,7 +1327,7 @@ export default function App() {
>
{m.role === 'assistant' ? (
<div className="assistant-message-wrapper">
<AssistantMessageContent content={m.content} streamOutput={streamOutput} />
<AssistantMessageContent content={m.content} streamOutput={streamOutput} sources={m.sources} />
{!isSending && (
<div className="message-options-bar assistant-options">
<button className="icon-button" title="Copy message" onClick={() => handleCopyMessage(m)}>
@@ -1295,6 +1430,9 @@ export default function App() {
placeholder="Ask any question..."
maxRows={13}
/>
<button className={"websearch-toggle" + (webSearchEnabled ? " active" : "")} onClick={toggleWebSearch} title="Toggle web search" aria-pressed={webSearchEnabled}>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>
</button>
<button className="button" onClick={sendMessage} disabled={isSending}>
{isSending ? <div className="spinner"></div> : 'Send'}
</button>
@@ -1321,6 +1459,14 @@ export default function App() {
/>
)}
{activeSettingsSubmenu === 'Interface' && <InterfaceSettings />}
{activeSettingsSubmenu === 'Websearch' && (
<WebsearchSettings
searxUrl={searxUrl}
setSearxUrl={setSearxUrl}
engines={searxEngines}
setEngines={setSearxEngines}
/>
)}
</>
)}
</div>

62
src/WebsearchSettings.jsx Normal file
View File

@@ -0,0 +1,62 @@
// src/WebsearchSettings.jsx
import React, { useEffect, useMemo, useState } from 'react';
export default function WebsearchSettings({
searxUrl,
setSearxUrl,
engines,
setEngines,
}) {
const KNOWN_ENGINES = useMemo(
() => ["google","bing","yahoo","duckduckgo","brave","github","stackoverflow","reddit","arxiv"],
[]
);
const [custom, setCustom] = useState("");
const toggleEngine = (name) => {
const set = new Set(engines || []);
if (set.has(name)) set.delete(name); else set.add(name);
setEngines(Array.from(set));
};
const addCustom = () => {
const name = custom.trim();
if (!name) return;
const set = new Set(engines || []);
set.add(name);
setEngines(Array.from(set));
setCustom("");
};
return (
<div className="settings-content-panel">
<div className="setting-section">
<h3>SearXNG URL</h3>
<input
type="text"
className="input"
value={searxUrl}
onChange={e => setSearxUrl(e.target.value)}
placeholder="e.g., http://localhost:8888"
/>
</div>
<div className="setting-section">
<h3>Search Engines</h3>
<div className="engine-grid">
{KNOWN_ENGINES.map(name => (
<label key={name} className="engine-row">
<input
type="checkbox"
checked={Array.isArray(engines) ? engines.includes(name) : false}
onChange={() => toggleEngine(name)}
/>
<span>{name}</span>
</label>
))}
</div>
</div>
</div>
);
}

View File

@@ -889,3 +889,40 @@ input:checked + .slider:before {
outline: none;
box-shadow: none;
}
/* Web search toggle */
.websearch-toggle {
display: inline-flex;
align-items: center;
justify-content: center;
width: 38px;
height: 38px;
border-radius: 10px;
border: 1px solid var(--border);
background: var(--input-bg);
cursor: pointer;
}
.websearch-toggle svg { width: 16px; height: 16px; }
.websearch-toggle.active { outline: 2px solid var(--accent); }
.msg-sources { margin-top: 8px; font-size: 12px; color: var(--muted); }
.msg-sources a { color: var(--accent); text-decoration: none; margin-right: 8px; }
.msg-sources a:hover { text-decoration: underline; }
.msg-sources.chips {
display: flex;
flex-wrap: wrap;
margin: 0.5rem 0 0.5rem 0;
}
.msg-sources.chips .chip {
display: inline-flex;
align-items: center;
padding: .25rem .6rem;
border-radius: 9999px;
border: 1px solid var(--border);
text-decoration: none;
font-size: 0.85rem;
line-height: 1;
white-space: nowrap;
margin-top: 0.5rem;
}