clawbench/scripts/classify_regimes.py
2026-04-28 10:50:07 -07:00

220 lines
7.4 KiB
Python

#!/usr/bin/env python3
"""Classify posterior run trajectories into dynamical regimes.
We embed each assistant turn using bag-of-words text plus tool-call summaries,
then compute simple geometric proxies:
drift_mean = mean ||x_t - x_{t-1}||
from_start = max ||x_t - x_0||
recurrence = max cosine(x_i, x_j) for non-adjacent turns
vol_log = log det(Sigma + eps I)
Runs are then bucketed into coarse regimes such as trapped, limit_cycle, and
diffusive using quartile-based thresholds estimated from the observed archive.
"""
from __future__ import annotations
import argparse
import json
import re
import sys
from collections import Counter
from pathlib import Path
import numpy as np
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from clawbench.dynamics_archive import load_task_runs_by_model
WORD_RE = re.compile(r"[a-z]{3,}")
STOPWORDS = set(
"the and that with this have from what your will can but not "
"was are been one would there they their has had its were only some "
"than about these which into also each when where them how who very "
"much more most other then here such does like just make many want need take".split()
)
def tokenize(text: str) -> list[str]:
return [w for w in WORD_RE.findall((text or "").lower()) if w not in STOPWORDS]
def build_vocab(texts: list[str], top_k: int = 500) -> dict[str, int]:
counter = Counter()
for text in texts:
counter.update(set(tokenize(text)))
return {w: i for i, (w, _) in enumerate(counter.most_common(top_k))}
def vectorize(text: str, vocab: dict[str, int]) -> np.ndarray:
vec = np.zeros(len(vocab), dtype=np.float32)
for word, cnt in Counter(tokenize(text)).items():
if word in vocab:
vec[vocab[word]] = cnt
norm = np.linalg.norm(vec)
return vec / norm if norm > 0 else vec
def turn_texts(run, fallback_any_message: bool = False) -> list[str]:
source = run.transcript.messages if fallback_any_message else run.transcript.assistant_messages
out = []
for msg in source:
parts = []
if msg.text:
parts.append(msg.text)
for tc in msg.tool_calls:
parts.append(tc.name)
if tc.input:
parts.append(json.dumps(tc.input, sort_keys=True)[:200])
if parts:
out.append(" ".join(parts))
return out
def trajectory_metrics(vecs: np.ndarray) -> dict[str, float]:
"""Compute drift, recurrence, and support-volume proxies for one run."""
n = vecs.shape[0]
if n < 2:
return {
"n_turns": float(n),
"drift_mean": 0.0,
"from_start": 0.0,
"recurrence": 0.0,
"vol_log": -12.0,
}
diffs = np.linalg.norm(np.diff(vecs, axis=0), axis=1)
drift_mean = float(diffs.mean())
from_start = float(np.linalg.norm(vecs - vecs[0:1], axis=1).max())
recurrence = 0.0
for i in range(n):
for j in range(i + 2, n):
ni = np.linalg.norm(vecs[i])
nj = np.linalg.norm(vecs[j])
if ni > 0 and nj > 0:
sim = float(vecs[i] @ vecs[j] / (ni * nj))
recurrence = max(recurrence, sim)
if n >= 3:
sigma = np.cov(vecs.T)
eigs = np.linalg.eigvalsh(sigma + 1e-6 * np.eye(vecs.shape[1], dtype=np.float32))
vol_log = float(np.log(np.clip(eigs, 1e-12, None)).sum())
else:
vol_log = -12.0
return {
"n_turns": float(n),
"drift_mean": drift_mean,
"from_start": from_start,
"recurrence": recurrence,
"vol_log": vol_log,
}
def classify(metrics: dict[str, float], thresholds: dict[str, float]) -> str:
"""Map trajectory metrics to a coarse regime label."""
n_turns = int(metrics["n_turns"])
if n_turns < 3:
return "too_short"
drift = metrics["drift_mean"]
recurrence = metrics["recurrence"]
vol = metrics["vol_log"]
if drift < thresholds["drift_low"] and vol < thresholds["vol_low"]:
return "trapped"
if recurrence > thresholds["rec_hi"] and drift < thresholds["drift_med"]:
return "limit_cycle"
if drift > thresholds["drift_hi"] and vol > thresholds["vol_hi"]:
return "diffusive"
return "mixed"
def main() -> None:
parser = argparse.ArgumentParser(description="Classify cached run regimes")
parser.add_argument("--archive-dir", type=Path, default=Path(".clawbench/run_cache"))
parser.add_argument("--reports-dir", type=Path, default=Path("reports"))
parser.add_argument("--tier", choices=["tier1", "tier2", "tier3", "tier4", "tier5"], default=None)
args = parser.parse_args()
grouped = load_task_runs_by_model(args.archive_dir, tier=args.tier)
if not grouped:
raise SystemExit(f"No cached runs found under {args.archive_dir}")
all_turn_texts: list[str] = []
run_turns: dict[str, list[str]] = {}
for model_name, task_runs in grouped.items():
for task_id, runs in task_runs.items():
for run in runs:
ts = turn_texts(run, fallback_any_message=False)
key = f"{model_name}/{task_id}/run{run.run_index}"
run_turns[key] = ts
all_turn_texts.extend(ts)
used_fallback_messages = False
if not all_turn_texts:
used_fallback_messages = True
all_turn_texts = []
run_turns = {}
for model_name, task_runs in grouped.items():
for task_id, runs in task_runs.items():
for run in runs:
ts = turn_texts(run, fallback_any_message=True)
key = f"{model_name}/{task_id}/run{run.run_index}"
run_turns[key] = ts
all_turn_texts.extend(ts)
if not all_turn_texts:
raise SystemExit("No usable turn text found in archive.")
vocab = build_vocab(all_turn_texts, top_k=500)
per_run: dict[str, dict[str, float | str]] = {}
for key, ts in run_turns.items():
if not ts:
continue
vecs = np.stack([vectorize(text, vocab) for text in ts])
per_run[key] = trajectory_metrics(vecs)
eligible = [r for r in per_run.values() if int(r["n_turns"]) >= 3]
if eligible:
drifts = np.array([float(v["drift_mean"]) for v in eligible])
recs = np.array([float(v["recurrence"]) for v in eligible])
vols = np.array([float(v["vol_log"]) for v in eligible])
thresholds = {
"drift_low": float(np.percentile(drifts, 25)),
"drift_med": float(np.percentile(drifts, 50)),
"drift_hi": float(np.percentile(drifts, 75)),
"vol_low": float(np.percentile(vols, 25)),
"vol_hi": float(np.percentile(vols, 75)),
"rec_hi": float(np.percentile(recs, 75)),
}
else:
thresholds = {
"drift_low": 0.15,
"drift_med": 0.25,
"drift_hi": 0.35,
"vol_low": -6.0,
"vol_hi": -3.0,
"rec_hi": 0.8,
}
for key, metrics in per_run.items():
metrics["regime"] = classify(metrics, thresholds)
metrics["turn_source"] = "any_message" if used_fallback_messages else "assistant"
args.reports_dir.mkdir(parents=True, exist_ok=True)
out = args.reports_dir / "regimes.json"
out.write_text(json.dumps(per_run, indent=2), encoding="utf-8")
counts = Counter(str(v["regime"]) for v in per_run.values())
print(f"Wrote: {out}")
print(f"Regime counts: {dict(counts)}")
if __name__ == "__main__":
main()