"""SQLite-backed TTL cache for AXL responses. Keyed on (method_name, sorted_kwargs_json). Cache survives server restarts, which makes exploratory audit sessions dramatically faster — the LLM can re-run the same `listPhone` queries across conversations without paying the SOAP round-trip every time. """ from __future__ import annotations import json import sqlite3 import time from pathlib import Path from typing import Any SCHEMA = """ CREATE TABLE IF NOT EXISTS axl_cache ( cache_key TEXT PRIMARY KEY, method TEXT NOT NULL, args_json TEXT NOT NULL, result_json TEXT NOT NULL, created_at REAL NOT NULL, expires_at REAL NOT NULL ); CREATE INDEX IF NOT EXISTS axl_cache_method_idx ON axl_cache(method); CREATE INDEX IF NOT EXISTS axl_cache_expires_idx ON axl_cache(expires_at); """ class AxlCache: """SQLite TTL cache. Thread-safe via per-call connections.""" def __init__(self, db_path: Path, default_ttl: int): self.db_path = db_path self.default_ttl = default_ttl self.db_path.parent.mkdir(parents=True, exist_ok=True) with self._conn() as c: c.executescript(SCHEMA) def _conn(self) -> sqlite3.Connection: conn = sqlite3.connect(self.db_path, isolation_level=None) conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA synchronous=NORMAL") return conn @staticmethod def _make_key(method: str, kwargs: dict) -> str: # sort_keys gives us a deterministic key regardless of dict order return f"{method}::{json.dumps(kwargs, sort_keys=True, default=str)}" def get(self, method: str, kwargs: dict) -> Any | None: if self.default_ttl <= 0: return None key = self._make_key(method, kwargs) now = time.time() with self._conn() as c: row = c.execute( "SELECT result_json FROM axl_cache WHERE cache_key = ? AND expires_at > ?", (key, now), ).fetchone() return json.loads(row[0]) if row else None def set(self, method: str, kwargs: dict, result: Any, ttl: int | None = None) -> None: if self.default_ttl <= 0 and ttl is None: return ttl = ttl if ttl is not None else self.default_ttl if ttl <= 0: return key = self._make_key(method, kwargs) now = time.time() with self._conn() as c: c.execute( """ INSERT OR REPLACE INTO axl_cache (cache_key, method, args_json, result_json, created_at, expires_at) VALUES (?, ?, ?, ?, ?, ?) """, ( key, method, json.dumps(kwargs, sort_keys=True, default=str), json.dumps(result, default=str), now, now + ttl, ), ) def stats(self) -> dict: now = time.time() with self._conn() as c: total = c.execute("SELECT COUNT(*) FROM axl_cache").fetchone()[0] live = c.execute( "SELECT COUNT(*) FROM axl_cache WHERE expires_at > ?", (now,) ).fetchone()[0] by_method = { row[0]: row[1] for row in c.execute( "SELECT method, COUNT(*) FROM axl_cache " "WHERE expires_at > ? GROUP BY method ORDER BY 2 DESC", (now,), ).fetchall() } return { "db_path": str(self.db_path), "default_ttl_seconds": self.default_ttl, "total_entries": total, "live_entries": live, "expired_entries": total - live, "by_method": by_method, } def clear(self, method_pattern: str | None = None) -> int: with self._conn() as c: if method_pattern: cursor = c.execute( "DELETE FROM axl_cache WHERE method LIKE ?", (method_pattern.replace("*", "%"),), ) else: cursor = c.execute("DELETE FROM axl_cache") return cursor.rowcount def purge_expired(self) -> int: with self._conn() as c: cursor = c.execute("DELETE FROM axl_cache WHERE expires_at <= ?", (time.time(),)) return cursor.rowcount