"""
Shared helper utilities for the legal-assistant agent.
• invoke_with_retry – resilient LLM / chain invocation
• _clean_llm_text – strips blocks, markdown fences, whitespace
• _first_json_block – extracts first {...} that contains a key
• safe_parse_json_block – tolerant JSON→dict loader
• salvage_json – multi-strategy JSON extraction fallback
• detect_placeholders – finds tokens like [DATE], [NAME]
• strip_llm_fluff – removes “Here is …” boilerplate
• is_finalization_command – simplistic “sign-off” detector
"""
from __future__ import annotations
import ast
import json
import logging
import re
import time
from typing import Any, Dict, List, Optional
logger = logging.getLogger("agent.utils")
logger.setLevel(logging.DEBUG)
# ────────────────────────────────────────────────────────────
# Retry wrapper
# ────────────────────────────────────────────────────────────
def invoke_with_retry(
chain_or_runnable,
inputs: Dict[str, Any],
max_retries: int = 100,
):
"""
Invoke *any* LangChain object (Runnable, Chain, or plain callable)
with simple exponential-backoff retry.
"""
for attempt in range(1, max_retries + 1):
try:
if hasattr(chain_or_runnable, "invoke"):
return chain_or_runnable.invoke(inputs)
if hasattr(chain_or_runnable, "run"):
if isinstance(inputs, dict):
return chain_or_runnable.run(**inputs)
return chain_or_runnable.run(inputs)
return chain_or_runnable(inputs)
except Exception as exc:
# Retry logic only for transient HTTP 503 / rate-limit style errors
transient = False
msg = str(exc).lower()
if "status code: 503" in msg or "503" in msg:
transient = True
if "rate limit" in msg or "temporarily unavailable" in msg:
transient = True
if not transient:
raise # propagate other errors immediately
logger.warning(
"⚠️ invoke_with_retry (%d/%d transient) failed: %s",
attempt,
max_retries,
exc,
)
if attempt == max_retries:
raise
# Exponential backoff: 10s * attempt
time.sleep(10 * attempt)
# ────────────────────────────────────────────────────────────
# Text-cleanup helpers
# ────────────────────────────────────────────────────────────
_THINK_END_RE = re.compile(r"", re.IGNORECASE)
_CODE_FENCE_START_RE = re.compile(r"^\s*```[a-zA-Z0-9_-]*\s*")
_CODE_FENCE_END_RE = re.compile(r"\s*```\s*$")
def _clean_llm_text(text: str) -> str:
"""
Strip … blocks *and* leading / trailing code fences.
"""
m = list(_THINK_END_RE.finditer(text))
if m:
text = text[m[-1].end():]
text = _CODE_FENCE_START_RE.sub("", text)
text = _CODE_FENCE_END_RE.sub("", text)
return text.strip()
# ────────────────────────────────────────────────────────────
# Balanced-brace JSON extraction helpers
# ────────────────────────────────────────────────────────────
def _iterate_json_candidates(text: str):
depth = 0
start = None
in_str = False
escape = False
for i, ch in enumerate(text):
if ch == '"' and not escape:
in_str = not in_str
escape = (ch == "\\" and not escape)
if in_str:
continue
if ch == "{":
if depth == 0:
start = i
depth += 1
elif ch == "}":
depth -= 1
if depth == 0 and start is not None:
yield text[start : i + 1]
start = None
_SINGLE_TO_DOUBLE_RE = re.compile(r"'([^']+?)'")
def safe_parse_json_block(block: str) -> Optional[Dict[str, Any]]:
try:
return json.loads(block)
except json.JSONDecodeError:
pass
try:
converted = _SINGLE_TO_DOUBLE_RE.sub(r'"\1"', block)
return json.loads(converted)
except json.JSONDecodeError:
pass
try:
return ast.literal_eval(block)
except Exception:
return None
# ────────────────────────────────────────────────────────────
# Fallback JSON-extraction strategies
# ────────────────────────────────────────────────────────────
_JSON_FENCE_RE = re.compile(
r"```json\s*({[\s\S]+?})\s*```",
re.IGNORECASE,
)
def salvage_json(text: str, required_key: str = "destination") -> Optional[Dict[str, Any]]:
"""
Robust, multi-strategy JSON extraction.
• `required_key` – only return a dict that contains this key
(defaults to "destination" for router output).
"""
# 1) ```json … ``` fenced blocks
for m in _JSON_FENCE_RE.finditer(text):
block = m.group(1)
parsed = safe_parse_json_block(block)
if parsed and required_key in parsed:
logger.debug("🛟 Salvaged JSON from fenced block.")
return parsed
# 2) Any balanced { … } group
for block in _iterate_json_candidates(text):
if f'"{required_key}"' not in block and f"'{required_key}'" not in block:
continue
parsed = safe_parse_json_block(block)
if parsed and required_key in parsed:
logger.debug("🛟 Salvaged JSON from balanced scan.")
return parsed
# 3) Reverse scan – last brace group containing the key
idx = text.rfind("}")
while idx != -1:
start = text.rfind("{", 0, idx)
if start == -1:
break
block = text[start : idx + 1]
if f'"{required_key}"' in block or f"'{required_key}'" in block:
parsed = safe_parse_json_block(block)
if parsed and required_key in parsed:
logger.debug("🛟 Salvaged JSON from reverse scan.")
return parsed
idx = text.rfind("}", 0, start)
return None
# ────────────────────────────────────────────────────────────
# Misc helpers
# ────────────────────────────────────────────────────────────
_PLACEHOLDER_RE = re.compile(r"\[[A-Z0-9_]+\]")
def detect_placeholders(doc: str) -> List[str]:
return list(dict.fromkeys(_PLACEHOLDER_RE.findall(doc)))
_FLUFF_RE = re.compile(
r"^\s*(Here is|Below is|Sure[,:\-]?|Certainly[,:\-]?|Here's the)\b[^\n]*\n+",
re.IGNORECASE,
)
def strip_llm_fluff(text: str) -> str:
return _FLUFF_RE.sub("", text).strip()
_FINAL_CMD_RE = re.compile(
r"\b(finalise|finalize|looks\s+good|approved|no\s+further\s+changes)\b",
re.IGNORECASE,
)
def is_finalization_command(user_input: str) -> bool:
return bool(_FINAL_CMD_RE.search(user_input))