diff --git a/backend/ollama_client.py b/backend/ollama_client.py index fd00e8e..37e3cbf 100644 --- a/backend/ollama_client.py +++ b/backend/ollama_client.py @@ -1,4 +1,4 @@ - +import asyncio import httpx import json import re @@ -10,6 +10,87 @@ from .app_settings import get_ollama_api_url _MODEL_DETAILS_CACHE: Dict[Tuple[str, str], Tuple[float, Dict[str, Any]]] = {} _MODEL_DETAILS_TTL_S = 15.0 + +def _string_tokens(value: Any) -> List[str]: + if isinstance(value, str): + trimmed = value.strip() + return [trimmed] if trimmed else [] + if isinstance(value, dict): + out: List[str] = [] + for key, item in value.items(): + out.extend(_string_tokens(key)) + out.extend(_string_tokens(item)) + return out + if isinstance(value, (list, tuple, set)): + out: List[str] = [] + for item in value: + out.extend(_string_tokens(item)) + return out + return [] + + +def _normalize_capabilities(model_data: Dict[str, Any]) -> List[str]: + out = [] + for item in model_data.get("capabilities") or []: + text = str(item).strip().lower() + if text and text not in out: + out.append(text) + return out + + +def _is_embedding_model(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> bool: + capabilities = set(_normalize_capabilities(model_data)) + if "embedding" in capabilities or "embeddings" in capabilities: + return True + + lowered_tokens = " ".join( + token.lower() + for token in _string_tokens(name) + _string_tokens(tag_item.get("details")) + _string_tokens(model_data) + ) + return any( + marker in lowered_tokens + for marker in ( + " embed ", + " embedding ", + "embed-", + "-embed", + "nomic-embed", + "mxbai-embed", + "snowflake-arctic-embed", + "bge-m3", + "bge ", + ) + ) or lowered_tokens.startswith("bge") + + +def _is_rerank_model(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> bool: + lowered_tokens = " ".join( + token.lower() + for token in _string_tokens(name) + _string_tokens(tag_item.get("details")) + _string_tokens(model_data) + ) + return ( + _is_embedding_model(name, model_data, tag_item) + or "rerank" in lowered_tokens + or "cross-encoder" in lowered_tokens + ) + + +def _build_model_catalog_entry(tag_item: Dict[str, Any], model_data: Dict[str, Any]) -> Dict[str, Any]: + name = str((tag_item or {}).get("name") or "").strip() + capabilities = _normalize_capabilities(model_data) + is_embedding = _is_embedding_model(name, model_data, tag_item) + is_rerank = _is_rerank_model(name, model_data, tag_item) + has_vision = supports_vision(model_data) + return { + "name": name, + "capabilities": capabilities, + "supports_vision": has_vision, + "is_embedding": is_embedding, + "can_chat": not is_embedding, + "can_rerank": is_rerank, + } + + async def list_models() -> Dict[str, Any]: ollama_url = get_ollama_api_url() async with httpx.AsyncClient(timeout=30.0) as client: @@ -20,6 +101,40 @@ async def list_models() -> Dict[str, Any]: models = [m.get('name') for m in data.get('models', [])] return {"models": models} + +async def list_model_catalog(*, refresh: bool = False) -> Dict[str, Any]: + ollama_url = get_ollama_api_url() + async with httpx.AsyncClient(timeout=30.0) as client: + r = await client.get(f"{ollama_url}/api/tags") + r.raise_for_status() + payload = r.json() + + raw_models = payload.get("models", []) or [] + names = [str((item or {}).get("name") or "").strip() for item in raw_models] + details = await asyncio.gather( + *(show_model(name, refresh=refresh) for name in names if name), + return_exceptions=True, + ) + + detail_by_name: Dict[str, Dict[str, Any]] = {} + for name, detail in zip([name for name in names if name], details): + if isinstance(detail, dict): + detail_by_name[name] = detail + + models = [ + _build_model_catalog_entry(item or {}, detail_by_name.get(str((item or {}).get("name") or "").strip(), {})) + for item in raw_models + if str((item or {}).get("name") or "").strip() + ] + + return { + "models": models, + "chat_models": [model["name"] for model in models if model["can_chat"]], + "embedding_models": [model["name"] for model in models if model["is_embedding"]], + "vision_models": [model["name"] for model in models if model["supports_vision"]], + "reranking_models": [model["name"] for model in models if model["can_rerank"]], + } + async def show_model(model: str, *, refresh: bool = False) -> Dict[str, Any]: ollama_url = get_ollama_api_url() cache_key = (ollama_url.rstrip('/'), str(model or '').strip())