feat(extraction): appels LLM avancés côté worker (Ollama)
This commit is contained in:
parent
425cb21e20
commit
fddfa6f7bf
@ -6,6 +6,7 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
|
from services.worker.utils.llm_client import WorkerLLMClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -20,6 +21,13 @@ def run(doc_id: str, ctx: Dict[str, Any]) -> None:
|
|||||||
# Extraction basique
|
# Extraction basique
|
||||||
entities = _extract_basic_entities(ocr_text, document_type)
|
entities = _extract_basic_entities(ocr_text, document_type)
|
||||||
|
|
||||||
|
# Extraction avancée via LLM (merge non destructif)
|
||||||
|
llm = WorkerLLMClient()
|
||||||
|
prompt = _build_extraction_prompt(ocr_text[:3000] if ocr_text else "", document_type)
|
||||||
|
llm_response = llm.generate(prompt)
|
||||||
|
llm_json = WorkerLLMClient.extract_first_json(llm_response) or {}
|
||||||
|
entities = _merge_entities_basic_with_llm(entities, llm_json)
|
||||||
|
|
||||||
ctx.update({
|
ctx.update({
|
||||||
"extracted_entities": entities,
|
"extracted_entities": entities,
|
||||||
"entities_count": len(entities)
|
"entities_count": len(entities)
|
||||||
@ -63,4 +71,33 @@ def _extract_basic_entities(text: str, doc_type: str) -> List[Dict[str, Any]]:
|
|||||||
"confidence": 0.8
|
"confidence": 0.8
|
||||||
})
|
})
|
||||||
|
|
||||||
return entities
|
return entities
|
||||||
|
|
||||||
|
|
||||||
|
def _build_extraction_prompt(text: str, doc_type: str) -> str:
|
||||||
|
return f"""
|
||||||
|
Tu es un extracteur d'entités pour documents notariaux.
|
||||||
|
Type de document: {doc_type}
|
||||||
|
Extrait en JSON strict les objets: identites, adresses, biens, entreprises, montants, dates.
|
||||||
|
Réponds UNIQUEMENT par un JSON.
|
||||||
|
|
||||||
|
TEXTE:
|
||||||
|
{text}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_entities_basic_with_llm(basic: List[Dict[str, Any]], advanced: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
merged = list(basic)
|
||||||
|
if not isinstance(advanced, dict):
|
||||||
|
return merged
|
||||||
|
# Aplatit les entités LLM en liste simple type/value pour compatibilité minimale
|
||||||
|
for key in ["identites", "adresses", "biens", "entreprises", "montants", "dates"]:
|
||||||
|
items = advanced.get(key, []) or []
|
||||||
|
for item in items:
|
||||||
|
try:
|
||||||
|
value = item.get("adresse_complete") or item.get("date") or item.get("montant") or item.get("nom") or item.get("description") or str(item)
|
||||||
|
if value:
|
||||||
|
merged.append({"type": key, "value": value, "confidence": item.get("confidence", 0.8)})
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return merged
|
58
services/worker/utils/llm_client.py
Normal file
58
services/worker/utils/llm_client.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
"""
|
||||||
|
Client LLM minimal côté worker pour appels à Ollama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerLLMClient:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.base_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434")
|
||||||
|
self.default_model = os.getenv("OLLAMA_DEFAULT_MODEL", "llama3:8b")
|
||||||
|
self.session = requests.Session()
|
||||||
|
|
||||||
|
def generate(self, prompt: str, model: Optional[str] = None, max_tokens: int = 2000) -> str:
|
||||||
|
model_name = model or self.default_model
|
||||||
|
url = f"{self.base_url}/api/generate"
|
||||||
|
payload = {
|
||||||
|
"model": model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False,
|
||||||
|
"options": {
|
||||||
|
"temperature": 0.1,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
resp = self.session.post(url, json=payload, timeout=120)
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.error("Erreur LLM %s: %s", resp.status_code, resp.text)
|
||||||
|
raise RuntimeError(f"LLM HTTP {resp.status_code}")
|
||||||
|
data = resp.json()
|
||||||
|
return data.get("response", "")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Erreur appel LLM: %s", exc)
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_first_json(text: str) -> Optional[dict]:
|
||||||
|
try:
|
||||||
|
import re
|
||||||
|
|
||||||
|
m = re.search(r"\{[\s\S]*\}", text)
|
||||||
|
if not m:
|
||||||
|
return None
|
||||||
|
return json.loads(m.group(0))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("JSON non parsé depuis la réponse LLM: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user