"""Local LangExtract HTTP API.""" from __future__ import annotations import os from typing import Any import langextract as lx from fastapi import Depends, FastAPI, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from pydantic import BaseModel, Field app = FastAPI(title="langextract-api", version="0.1.0") _bearer = HTTPBearer(auto_error=False) def _expected_service_token() -> str: return os.environ.get("LANGEXTRACT_SERVICE_TOKEN", "").strip() def verify_service_token( creds: HTTPAuthorizationCredentials | None = Depends(_bearer), ) -> None: expected = _expected_service_token() if not expected: return if creds is None: raise HTTPException(status_code=401, detail="Unauthorized") token = creds.credentials.strip() if token != expected: raise HTTPException(status_code=401, detail="Unauthorized") class ExtractionIn(BaseModel): extraction_class: str extraction_text: str attributes: dict[str, Any] = Field(default_factory=dict) class ExampleIn(BaseModel): text: str extractions: list[ExtractionIn] class ExtractRequest(BaseModel): text: str prompt_description: str examples: list[ExampleIn] model_id: str model_url: str | None = None extraction_passes: int | None = None max_workers: int | None = None max_char_buffer: int | None = None api_key: str | None = Field( default=None, description="Optional Gemini / cloud key; else LANGEXTRACT_API_KEY from env.", ) fence_output: bool | None = None use_schema_constraints: bool | None = None def _normalize_attributes( attrs: dict[str, Any], ) -> dict[str, str | list[str]] | None: if not attrs: return None out: dict[str, str | list[str]] = {} for k, v in attrs.items(): if isinstance(v, list): out[k] = [str(x) for x in v] else: out[k] = str(v) return out def _examples_to_lx(examples: list[ExampleIn]) -> list[lx.data.ExampleData]: out: list[lx.data.ExampleData] = [] for ex in examples: extractions = [ lx.data.Extraction( extraction_class=e.extraction_class, extraction_text=e.extraction_text, attributes=_normalize_attributes(e.attributes), ) for e in ex.extractions ] out.append(lx.data.ExampleData(text=ex.text, extractions=extractions)) return out def _extraction_to_dict(e: Any) -> dict[str, Any]: d: dict[str, Any] = { "extraction_class": getattr(e, "extraction_class", None), "extraction_text": getattr(e, "extraction_text", None), "attributes": dict(getattr(e, "attributes", {}) or {}), } interval = getattr(e, "char_interval", None) if interval is not None: start = getattr(interval, "start", None) end = getattr(interval, "end", None) if start is not None and end is not None: d["char_interval"] = {"start": start, "end": end} return d def _document_to_dict(doc: Any) -> dict[str, Any]: extractions = getattr(doc, "extractions", None) or [] return { "extractions": [_extraction_to_dict(x) for x in extractions], } @app.get("/health") def health() -> dict[str, str]: return {"status": "ok"} @app.post("/extract", dependencies=[Depends(verify_service_token)]) def extract(req: ExtractRequest) -> dict[str, Any]: examples = _examples_to_lx(req.examples) kwargs: dict[str, Any] = { "text_or_documents": req.text, "prompt_description": req.prompt_description, "examples": examples, "model_id": req.model_id, } if req.model_url is not None: kwargs["model_url"] = req.model_url if req.extraction_passes is not None: kwargs["extraction_passes"] = req.extraction_passes if req.max_workers is not None: kwargs["max_workers"] = req.max_workers if req.max_char_buffer is not None: kwargs["max_char_buffer"] = req.max_char_buffer if req.api_key is not None: kwargs["api_key"] = req.api_key if req.fence_output is not None: kwargs["fence_output"] = req.fence_output if req.use_schema_constraints is not None: kwargs["use_schema_constraints"] = req.use_schema_constraints try: result = lx.extract(**kwargs) except Exception as e: raise HTTPException(status_code=400, detail=str(e)) from e if isinstance(result, list): return {"documents": [_document_to_dict(d) for d in result]} return {"documents": [_document_to_dict(result)]}