diff --git a/services/worker/pipelines/extract.py b/services/worker/pipelines/extract.py index 3bab275..f4e1ed1 100644 --- a/services/worker/pipelines/extract.py +++ b/services/worker/pipelines/extract.py @@ -6,6 +6,7 @@ import os import logging import re from typing import Dict, Any, List +from services.worker.utils.llm_client import WorkerLLMClient logger = logging.getLogger(__name__) @@ -20,6 +21,13 @@ def run(doc_id: str, ctx: Dict[str, Any]) -> None: # Extraction basique entities = _extract_basic_entities(ocr_text, document_type) + # Extraction avancée via LLM (merge non destructif) + llm = WorkerLLMClient() + prompt = _build_extraction_prompt(ocr_text[:3000] if ocr_text else "", document_type) + llm_response = llm.generate(prompt) + llm_json = WorkerLLMClient.extract_first_json(llm_response) or {} + entities = _merge_entities_basic_with_llm(entities, llm_json) + ctx.update({ "extracted_entities": entities, "entities_count": len(entities) @@ -63,4 +71,33 @@ def _extract_basic_entities(text: str, doc_type: str) -> List[Dict[str, Any]]: "confidence": 0.8 }) - return entities \ No newline at end of file + return entities + + +def _build_extraction_prompt(text: str, doc_type: str) -> str: + return f""" +Tu es un extracteur d'entités pour documents notariaux. +Type de document: {doc_type} +Extrait en JSON strict les objets: identites, adresses, biens, entreprises, montants, dates. +Réponds UNIQUEMENT par un JSON. + +TEXTE: +{text} +""" + + +def _merge_entities_basic_with_llm(basic: List[Dict[str, Any]], advanced: Dict[str, Any]) -> List[Dict[str, Any]]: + merged = list(basic) + if not isinstance(advanced, dict): + return merged + # Aplatit les entités LLM en liste simple type/value pour compatibilité minimale + for key in ["identites", "adresses", "biens", "entreprises", "montants", "dates"]: + items = advanced.get(key, []) or [] + for item in items: + try: + value = item.get("adresse_complete") or item.get("date") or item.get("montant") or item.get("nom") or item.get("description") or str(item) + if value: + merged.append({"type": key, "value": value, "confidence": item.get("confidence", 0.8)}) + except Exception: + continue + return merged \ No newline at end of file diff --git a/services/worker/utils/llm_client.py b/services/worker/utils/llm_client.py new file mode 100644 index 0000000..7e920b5 --- /dev/null +++ b/services/worker/utils/llm_client.py @@ -0,0 +1,58 @@ +""" +Client LLM minimal côté worker pour appels à Ollama +""" + +import os +import json +import logging +from typing import Optional + +import requests + +logger = logging.getLogger(__name__) + + +class WorkerLLMClient: + def __init__(self) -> None: + self.base_url = os.getenv("OLLAMA_BASE_URL", "http://ollama:11434") + self.default_model = os.getenv("OLLAMA_DEFAULT_MODEL", "llama3:8b") + self.session = requests.Session() + + def generate(self, prompt: str, model: Optional[str] = None, max_tokens: int = 2000) -> str: + model_name = model or self.default_model + url = f"{self.base_url}/api/generate" + payload = { + "model": model_name, + "prompt": prompt, + "stream": False, + "options": { + "temperature": 0.1, + "top_p": 0.9, + "max_tokens": max_tokens, + }, + } + try: + resp = self.session.post(url, json=payload, timeout=120) + if resp.status_code != 200: + logger.error("Erreur LLM %s: %s", resp.status_code, resp.text) + raise RuntimeError(f"LLM HTTP {resp.status_code}") + data = resp.json() + return data.get("response", "") + except Exception as exc: + logger.error("Erreur appel LLM: %s", exc) + raise + + @staticmethod + def extract_first_json(text: str) -> Optional[dict]: + try: + import re + + m = re.search(r"\{[\s\S]*\}", text) + if not m: + return None + return json.loads(m.group(0)) + except Exception as exc: + logger.warning("JSON non parsé depuis la réponse LLM: %s", exc) + return None + +