246 lines
8.4 KiB
Python
246 lines
8.4 KiB
Python
import httpx
|
|
import json
|
|
import re
|
|
import time
|
|
from typing import Dict, Any, List, AsyncGenerator, Tuple
|
|
|
|
from .app_settings import get_ollama_api_url
|
|
|
|
_MODEL_DETAILS_CACHE: Dict[Tuple[str, str], Tuple[float, Dict[str, Any]]] = {}
|
|
_MODEL_DETAILS_TTL_S = 15.0
|
|
|
|
|
|
def _cache_key(model: str) -> Tuple[str, str]:
|
|
ollama_url = get_ollama_api_url()
|
|
return (ollama_url.rstrip('/'), str(model or '').strip())
|
|
|
|
|
|
def _get_cached_model_details(model: str) -> Dict[str, Any]:
|
|
cached = _MODEL_DETAILS_CACHE.get(_cache_key(model))
|
|
return cached[1] if cached else {}
|
|
|
|
|
|
def _string_tokens(value: Any) -> List[str]:
|
|
if isinstance(value, str):
|
|
trimmed = value.strip()
|
|
return [trimmed] if trimmed else []
|
|
if isinstance(value, dict):
|
|
out: List[str] = []
|
|
for key, item in value.items():
|
|
out.extend(_string_tokens(key))
|
|
out.extend(_string_tokens(item))
|
|
return out
|
|
if isinstance(value, (list, tuple, set)):
|
|
out: List[str] = []
|
|
for item in value:
|
|
out.extend(_string_tokens(item))
|
|
return out
|
|
return []
|
|
|
|
|
|
def _normalize_capabilities(model_data: Dict[str, Any]) -> List[str]:
|
|
out = []
|
|
for item in model_data.get("capabilities") or []:
|
|
text = str(item).strip().lower()
|
|
if text and text not in out:
|
|
out.append(text)
|
|
return out
|
|
|
|
|
|
def _combined_model_tokens(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> str:
|
|
return " ".join(
|
|
token.lower()
|
|
for token in _string_tokens(name) + _string_tokens(tag_item.get("details")) + _string_tokens(model_data)
|
|
)
|
|
|
|
|
|
def _is_embedding_model(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> bool:
|
|
capabilities = set(_normalize_capabilities(model_data))
|
|
if "embedding" in capabilities or "embeddings" in capabilities:
|
|
return True
|
|
|
|
lowered_tokens = _combined_model_tokens(name, model_data, tag_item)
|
|
return any(
|
|
marker in lowered_tokens
|
|
for marker in (
|
|
" embed ",
|
|
" embedding ",
|
|
"embed-",
|
|
"-embed",
|
|
"nomic-embed",
|
|
"mxbai-embed",
|
|
"snowflake-arctic-embed",
|
|
"bge-m3",
|
|
"bge ",
|
|
)
|
|
) or lowered_tokens.startswith("bge")
|
|
|
|
|
|
def _is_rerank_model(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> bool:
|
|
lowered_tokens = _combined_model_tokens(name, model_data, tag_item)
|
|
return (
|
|
_is_embedding_model(name, model_data, tag_item)
|
|
or "rerank" in lowered_tokens
|
|
or "cross-encoder" in lowered_tokens
|
|
)
|
|
|
|
|
|
def _supports_vision_fast(name: str, model_data: Dict[str, Any], tag_item: Dict[str, Any]) -> bool:
|
|
if supports_vision(model_data):
|
|
return True
|
|
|
|
lowered_tokens = _combined_model_tokens(name, model_data, tag_item)
|
|
return any(
|
|
marker in lowered_tokens
|
|
for marker in (
|
|
" vision ",
|
|
"-vision",
|
|
" vision-",
|
|
"vision:",
|
|
"llava",
|
|
"bakllava",
|
|
"moondream",
|
|
"minicpm-v",
|
|
"minicpmv",
|
|
"pixtral",
|
|
"qwen-vl",
|
|
"qwen2vl",
|
|
"qwen2.5vl",
|
|
"qwen2.5-omni",
|
|
"granite3.2-vision",
|
|
"llama3.2-vision",
|
|
"gemma3",
|
|
"gemma4",
|
|
"-vl",
|
|
" vl ",
|
|
)
|
|
)
|
|
|
|
|
|
def _build_model_catalog_entry(tag_item: Dict[str, Any], model_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
name = str((tag_item or {}).get("name") or "").strip()
|
|
capabilities = _normalize_capabilities(model_data)
|
|
is_embedding = _is_embedding_model(name, model_data, tag_item)
|
|
is_rerank = _is_rerank_model(name, model_data, tag_item)
|
|
has_vision = _supports_vision_fast(name, model_data, tag_item)
|
|
lowered_tokens = _combined_model_tokens(name, model_data, tag_item)
|
|
is_non_chat = is_embedding or "rerank" in lowered_tokens or "cross-encoder" in lowered_tokens
|
|
return {
|
|
"name": name,
|
|
"capabilities": capabilities,
|
|
"supports_vision": has_vision,
|
|
"is_embedding": is_embedding,
|
|
"can_chat": not is_non_chat,
|
|
"can_rerank": is_rerank,
|
|
}
|
|
|
|
|
|
async def list_models() -> Dict[str, Any]:
|
|
ollama_url = get_ollama_api_url()
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
r = await client.get(f"{ollama_url}/api/tags")
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
# Normalize to a simple list of names
|
|
models = [m.get('name') for m in data.get('models', [])]
|
|
return {"models": models}
|
|
|
|
|
|
async def list_model_catalog() -> Dict[str, Any]:
|
|
ollama_url = get_ollama_api_url()
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
r = await client.get(f"{ollama_url}/api/tags")
|
|
r.raise_for_status()
|
|
payload = r.json()
|
|
|
|
raw_models = payload.get("models", []) or []
|
|
models = [
|
|
_build_model_catalog_entry(
|
|
item or {},
|
|
_get_cached_model_details(str((item or {}).get("name") or "").strip()),
|
|
)
|
|
for item in raw_models
|
|
if str((item or {}).get("name") or "").strip()
|
|
]
|
|
|
|
return {
|
|
"models": models,
|
|
"chat_models": [model["name"] for model in models if model["can_chat"]],
|
|
"embedding_models": [model["name"] for model in models if model["is_embedding"]],
|
|
"vision_models": [model["name"] for model in models if model["supports_vision"]],
|
|
"reranking_models": [model["name"] for model in models if model["can_rerank"]],
|
|
}
|
|
|
|
async def show_model(model: str, *, refresh: bool = False) -> Dict[str, Any]:
|
|
ollama_url = get_ollama_api_url()
|
|
cache_key = (ollama_url.rstrip('/'), str(model or '').strip())
|
|
cached = _MODEL_DETAILS_CACHE.get(cache_key)
|
|
now = time.monotonic()
|
|
if not refresh and cached and (now - cached[0]) < _MODEL_DETAILS_TTL_S:
|
|
return cached[1]
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
r = await client.post(f"{ollama_url}/api/show", json={"model": model})
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
_MODEL_DETAILS_CACHE[cache_key] = (now, data)
|
|
return data
|
|
|
|
def supports_vision(model_data: Dict[str, Any]) -> bool:
|
|
capabilities = model_data.get("capabilities") or []
|
|
if any(str(item).strip().lower() == "vision" for item in capabilities):
|
|
return True
|
|
|
|
model_info = model_data.get("model_info") or {}
|
|
if isinstance(model_info, dict):
|
|
for key in model_info.keys():
|
|
lowered = str(key).strip().lower()
|
|
if ".vision." in lowered or lowered.endswith(".vision"):
|
|
return True
|
|
if lowered.endswith("tokens_per_image") or re.search(r"\bmm\b", lowered):
|
|
return True
|
|
|
|
return False
|
|
|
|
async def chat(model: str, messages: List[Dict[str, Any]]) -> str:
|
|
ollama_url = get_ollama_api_url()
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": False
|
|
}
|
|
async with httpx.AsyncClient(timeout=600.0) as client:
|
|
r = await client.post(f"{ollama_url}/api/chat", json=payload)
|
|
r.raise_for_status()
|
|
data = r.json()
|
|
# Ollama returns full conversation; pick last message content
|
|
try:
|
|
return data["message"]["content"]
|
|
except Exception:
|
|
# Newer Ollama formats may return messages list
|
|
msgs = data.get("messages") or []
|
|
if msgs:
|
|
return msgs[-1].get("content", "")
|
|
return data.get("content", "")
|
|
|
|
async def chat_stream(model: str, messages: List[Dict[str, Any]]) -> AsyncGenerator[str, None]:
|
|
ollama_url = get_ollama_api_url()
|
|
payload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": True
|
|
}
|
|
async with httpx.AsyncClient(timeout=600.0) as client:
|
|
async with client.stream("POST", f"{ollama_url}/api/chat", json=payload) as r:
|
|
r.raise_for_status()
|
|
async for line in r.aiter_lines():
|
|
if line:
|
|
try:
|
|
chunk = json.loads(line)
|
|
if "content" in chunk: # Newer Ollama format
|
|
yield chunk["content"]
|
|
elif "message" in chunk and "content" in chunk["message"]: # Older format
|
|
yield chunk["message"]["content"]
|
|
except json.JSONDecodeError:
|
|
pass # Ignore invalid JSON lines
|