59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
"""
|
|
Client LLM minimal côté worker pour appels à Ollama
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import logging
|
|
from typing import Optional
|
|
|
|
import requests
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class WorkerLLMClient:
|
|
def __init__(self) -> None:
|
|
self.base_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
|
|
self.default_model = os.getenv("OLLAMA_DEFAULT_MODEL", "llama3:8b")
|
|
self.session = requests.Session()
|
|
|
|
def generate(self, prompt: str, model: Optional[str] = None, max_tokens: int = 2000) -> str:
|
|
model_name = model or self.default_model
|
|
url = f"{self.base_url}/api/generate"
|
|
payload = {
|
|
"model": model_name,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": 0.1,
|
|
"top_p": 0.9,
|
|
"max_tokens": max_tokens,
|
|
},
|
|
}
|
|
try:
|
|
resp = self.session.post(url, json=payload, timeout=120)
|
|
if resp.status_code != 200:
|
|
logger.error("Erreur LLM %s: %s", resp.status_code, resp.text)
|
|
raise RuntimeError(f"LLM HTTP {resp.status_code}")
|
|
data = resp.json()
|
|
return data.get("response", "")
|
|
except Exception as exc:
|
|
logger.error("Erreur appel LLM: %s", exc)
|
|
raise
|
|
|
|
@staticmethod
|
|
def extract_first_json(text: str) -> Optional[dict]:
|
|
try:
|
|
import re
|
|
|
|
m = re.search(r"\{[\s\S]*\}", text)
|
|
if not m:
|
|
return None
|
|
return json.loads(m.group(0))
|
|
except Exception as exc:
|
|
logger.warning("JSON non parsé depuis la réponse LLM: %s", exc)
|
|
return None
|
|
|
|
|