Add model catalog functionality to backend/ollama_client.py
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import re
|
||||
@@ -10,6 +10,87 @@ from .app_settings import get_ollama_api_url
|
||||
_MODEL_DETAILS_CACHE: Dict[Tuple[str, str], Tuple[float, Dict[str, Any]]] = {}
|
||||
_MODEL_DETAILS_TTL_S = 15.0
|
||||
|
||||
|
||||
def _string_tokens(value: Any) -> List[str]:
|
||||
if isinstance(value, str):
|
||||
trimmed = value.strip()
|
||||
return [trimmed] if trimmed else []
|
||||
if isinstance(value, dict):
|
||||
out: List[str] = []
|
||||
for key, item in value.items():
|
||||
out.extend(_string_tokens(key))
|
||||
out.extend(_string_tokens(item))
|
||||
return out
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
out: List[str] = []
|
||||
for item in value:
|
||||
out.extend(_string_tokens(item))
|
||||
return out
|
||||
return []
|
||||
|
||||
|
||||
def _normalize_capabilities(model_data: Dict[str, Any]) -> List[str]:
|
||||
out = []
|
||||
for item in model_data.get("capabilities") or []:
|
||||
text = str(item).strip().lower()
|
||||
if text and text not in out:
|
||||
out.append(text)
|
||||
return out
|
||||
|
||||
|
||||
def _is_embedding_model(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> bool:
|
||||
capabilities = set(_normalize_capabilities(model_data))
|
||||
if "embedding" in capabilities or "embeddings" in capabilities:
|
||||
return True
|
||||
|
||||
lowered_tokens = " ".join(
|
||||
token.lower()
|
||||
for token in _string_tokens(name) + _string_tokens(tag_item.get("details")) + _string_tokens(model_data)
|
||||
)
|
||||
return any(
|
||||
marker in lowered_tokens
|
||||
for marker in (
|
||||
" embed ",
|
||||
" embedding ",
|
||||
"embed-",
|
||||
"-embed",
|
||||
"nomic-embed",
|
||||
"mxbai-embed",
|
||||
"snowflake-arctic-embed",
|
||||
"bge-m3",
|
||||
"bge ",
|
||||
)
|
||||
) or lowered_tokens.startswith("bge")
|
||||
|
||||
|
||||
def _is_rerank_model(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> bool:
|
||||
lowered_tokens = " ".join(
|
||||
token.lower()
|
||||
for token in _string_tokens(name) + _string_tokens(tag_item.get("details")) + _string_tokens(model_data)
|
||||
)
|
||||
return (
|
||||
_is_embedding_model(name, model_data, tag_item)
|
||||
or "rerank" in lowered_tokens
|
||||
or "cross-encoder" in lowered_tokens
|
||||
)
|
||||
|
||||
|
||||
def _build_model_catalog_entry(tag_item: Dict[str, Any], model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
name = str((tag_item or {}).get("name") or "").strip()
|
||||
capabilities = _normalize_capabilities(model_data)
|
||||
is_embedding = _is_embedding_model(name, model_data, tag_item)
|
||||
is_rerank = _is_rerank_model(name, model_data, tag_item)
|
||||
has_vision = supports_vision(model_data)
|
||||
return {
|
||||
"name": name,
|
||||
"capabilities": capabilities,
|
||||
"supports_vision": has_vision,
|
||||
"is_embedding": is_embedding,
|
||||
"can_chat": not is_embedding,
|
||||
"can_rerank": is_rerank,
|
||||
}
|
||||
|
||||
|
||||
async def list_models() -> Dict[str, Any]:
|
||||
ollama_url = get_ollama_api_url()
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
@@ -20,6 +101,40 @@ async def list_models() -> Dict[str, Any]:
|
||||
models = [m.get('name') for m in data.get('models', [])]
|
||||
return {"models": models}
|
||||
|
||||
|
||||
async def list_model_catalog(*, refresh: bool = False) -> Dict[str, Any]:
|
||||
ollama_url = get_ollama_api_url()
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
r = await client.get(f"{ollama_url}/api/tags")
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
|
||||
raw_models = payload.get("models", []) or []
|
||||
names = [str((item or {}).get("name") or "").strip() for item in raw_models]
|
||||
details = await asyncio.gather(
|
||||
*(show_model(name, refresh=refresh) for name in names if name),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
detail_by_name: Dict[str, Dict[str, Any]] = {}
|
||||
for name, detail in zip([name for name in names if name], details):
|
||||
if isinstance(detail, dict):
|
||||
detail_by_name[name] = detail
|
||||
|
||||
models = [
|
||||
_build_model_catalog_entry(item or {}, detail_by_name.get(str((item or {}).get("name") or "").strip(), {}))
|
||||
for item in raw_models
|
||||
if str((item or {}).get("name") or "").strip()
|
||||
]
|
||||
|
||||
return {
|
||||
"models": models,
|
||||
"chat_models": [model["name"] for model in models if model["can_chat"]],
|
||||
"embedding_models": [model["name"] for model in models if model["is_embedding"]],
|
||||
"vision_models": [model["name"] for model in models if model["supports_vision"]],
|
||||
"reranking_models": [model["name"] for model in models if model["can_rerank"]],
|
||||
}
|
||||
|
||||
async def show_model(model: str, *, refresh: bool = False) -> Dict[str, Any]:
|
||||
ollama_url = get_ollama_api_url()
|
||||
cache_key = (ollama_url.rstrip('/'), str(model or '').strip())
|
||||
|
||||
Reference in New Issue
Block a user