Refactor model details caching and vision support checks in ollama_client.py

This commit is contained in:
2026-04-17 08:46:05 +02:00
parent d8a8e9be20
commit e9b96812f2

View File

@@ -11,6 +11,16 @@ _MODEL_DETAILS_CACHE: Dict[Tuple[str, str], Tuple[float, Dict[str, Any]]] = {}
_MODEL_DETAILS_TTL_S = 15.0
def _cache_key(model: str) -> Tuple[str, str]:
ollama_url = get_ollama_api_url()
return (ollama_url.rstrip('/'), str(model or '').strip())
def _get_cached_model_details(model: str) -> Dict[str, Any]:
cached = _MODEL_DETAILS_CACHE.get(_cache_key(model))
return cached[1] if cached else {}
def _string_tokens(value: Any) -> List[str]:
if isinstance(value, str):
trimmed = value.strip()
@@ -38,15 +48,19 @@ def _normalize_capabilities(model_data: Dict[str, Any]) -> List[str]:
return out
def _combined_model_tokens(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> str:
return " ".join(
token.lower()
for token in _string_tokens(name) + _string_tokens(tag_item.get("details")) + _string_tokens(model_data)
)
def _is_embedding_model(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> bool:
capabilities = set(_normalize_capabilities(model_data))
if "embedding" in capabilities or "embeddings" in capabilities:
return True
lowered_tokens = " ".join(
token.lower()
for token in _string_tokens(name) + _string_tokens(tag_item.get("details")) + _string_tokens(model_data)
)
lowered_tokens = _combined_model_tokens(name, model_data, tag_item)
return any(
marker in lowered_tokens
for marker in (
@@ -64,10 +78,7 @@ def _is_embedding_model(name: str, model_data: Dict[str, Any], tag_item: Dict[st
def _is_rerank_model(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> bool:
lowered_tokens = " ".join(
token.lower()
for token in _string_tokens(name) + _string_tokens(tag_item.get("details")) + _string_tokens(model_data)
)
lowered_tokens = _combined_model_tokens(name, model_data, tag_item)
return (
_is_embedding_model(name, model_data, tag_item)
or "rerank" in lowered_tokens
@@ -75,18 +86,52 @@ def _is_rerank_model(name: str, model_data: Dict[str, Any], tag_item: Dict[str,
)
def _supports_vision_fast(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> bool:
if supports_vision(model_data):
return True
lowered_tokens = _combined_model_tokens(name, model_data, tag_item)
return any(
marker in lowered_tokens
for marker in (
" vision ",
"-vision",
" vision-",
"vision:",
"llava",
"bakllava",
"moondream",
"minicpm-v",
"minicpmv",
"pixtral",
"qwen-vl",
"qwen2vl",
"qwen2.5vl",
"qwen2.5-omni",
"granite3.2-vision",
"llama3.2-vision",
"gemma3",
"gemma4",
"-vl",
" vl ",
)
)
def _build_model_catalog_entry(tag_item: Dict[str, Any], model_data: Dict[str, Any]) -> Dict[str, Any]:
name = str((tag_item or {}).get("name") or "").strip()
capabilities = _normalize_capabilities(model_data)
is_embedding = _is_embedding_model(name, model_data, tag_item)
is_rerank = _is_rerank_model(name, model_data, tag_item)
has_vision = supports_vision(model_data)
has_vision = _supports_vision_fast(name, model_data, tag_item)
lowered_tokens = _combined_model_tokens(name, model_data, tag_item)
is_non_chat = is_embedding or "rerank" in lowered_tokens or "cross-encoder" in lowered_tokens
return {
"name": name,
"capabilities": capabilities,
"supports_vision": has_vision,
"is_embedding": is_embedding,
"can_chat": not is_embedding,
"can_chat": not is_non_chat,
"can_rerank": is_rerank,
}
@@ -110,19 +155,11 @@ async def list_model_catalog(*, refresh: bool = False) -> Dict[str, Any]:
payload = r.json()
raw_models = payload.get("models", []) or []
names = [str((item or {}).get("name") or "").strip() for item in raw_models]
details = await asyncio.gather(
*(show_model(name, refresh=refresh) for name in names if name),
return_exceptions=True,
)
detail_by_name: Dict[str, Dict[str, Any]] = {}
for name, detail in zip([name for name in names if name], details):
if isinstance(detail, dict):
detail_by_name[name] = detail
models = [
_build_model_catalog_entry(item or {}, detail_by_name.get(str((item or {}).get("name") or "").strip(), {}))
_build_model_catalog_entry(
item or {},
show_model(str((item or {}).get("name") or "").strip(), refresh=refresh) if False else _get_cached_model_details(str((item or {}).get("name") or "").strip()),
)
for item in raw_models
if str((item or {}).get("name") or "").strip()
]