diff --git a/src/tiararodney/sekft/eval.py b/src/tiararodney/sekft/eval.py new file mode 100644 index 0000000..b381385 --- /dev/null +++ b/src/tiararodney/sekft/eval.py @@ -0,0 +1,101 @@ +"""Behavioural eval: the metric that matters. + +Train loss says nothing about whether the model operates the shell and leaves. +This loads a fine-tuned model (base + LoRA adapter), drops it into held-out +scenarios with NO scaffold (the trained behaviour must stand on its own), and +reports the rates that count: does it reach command-mode, does it terminate, +does the checker pass. + + python eval.py --base --adapter ./ckpt-mistral-r16 \ + --scenarios ./holdout-scenarios --n 10 + +Reuses the rollout loop with a *local* operator: the model formats and +generates in the same role-delimited render it was trained on (train == eval == +deploy, or the prompts go out of distribution). Prerequisites on the box: torch ++ transformers + peft, the ``sekft-dash`` image, and held-out SCENARIO bundles +(from ``generate.py`` -- not trajectories; the eval stands up and verifies each). +""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available +from tiararodney.posix_sdc.factory.rollout import rollout +from tiararodney.posix_sdc.schema import Scenario + +from .sft import normalize_for_template + + +def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64, + temperature: float = 0.7): + """A ``messages -> command`` callable backed by base + LoRA adapter. + + Renders the conversation exactly as the model was trained, appends the + assistant header, generates one turn, and cuts at the first stop marker. + """ + import torch + from peft import PeftModel + from transformers import AutoModelForCausalLM, AutoTokenizer + + tok = AutoTokenizer.from_pretrained(adapter) + model = AutoModelForCausalLM.from_pretrained( + base, torch_dtype=torch.float16, device_map="auto") + model = PeftModel.from_pretrained(model, adapter) + model.eval() + + def operator(messages): + msgs = normalize_for_template(messages) + ids = tok.apply_chat_template( + msgs, add_generation_prompt=True, return_tensors="pt").to(model.device) + with torch.no_grad(): + out = model.generate( + ids, max_new_tokens=max_new_tokens, + do_sample=temperature > 0, temperature=max(temperature, 1e-2), + eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id) + return tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip() + + return operator + + +def evaluate(base: str, adapter: str, scenarios_dir: Path, n: int, + max_steps: int, temperature: float) -> dict: + if not available(): + raise SystemExit("sekft-dash image unavailable; `docker build -t sekft-dash .`") + operator = make_local_operator(base, adapter, temperature=temperature) + backend = DashDocker() + rows = [] + for f in sorted(scenarios_dir.glob("*.json"))[:n]: + sc = Scenario.from_dict(json.loads(f.read_text())) + tj = rollout(sc, backend, max_steps=max_steps, temperature=temperature, + operator=operator, use_scaffold=False) + rows.append(tj) + print(f" {sc.id}: {tj.outcome} (terminal={tj.terminal} " + f"verified={tj.verified} steps={tj.steps})") + d = len(rows) or 1 + return { + "n": len(rows), + "operate_rate": round(sum(t.steps > 0 and t.meta.get("clean") for t in rows) / d, 3), + "terminate_rate": round(sum(t.terminal in ("exit", "panic") for t in rows) / d, 3), + "verified_rate": round(sum(t.verified for t in rows) / d, 3), + "clean_rate": round(sum(t.keep for t in rows) / d, 3), + } + + +def main() -> None: + ap = argparse.ArgumentParser(description="Behavioural eval of a tuned model.") + ap.add_argument("--base", required=True) + ap.add_argument("--adapter", required=True) + ap.add_argument("--scenarios", type=Path, required=True) + ap.add_argument("--n", type=int, default=10) + ap.add_argument("--max-steps", type=int, default=30) + ap.add_argument("--temperature", type=float, default=0.7) + ns = ap.parse_args() + m = evaluate(ns.base, ns.adapter, ns.scenarios, ns.n, ns.max_steps, ns.temperature) + print("\n=== behavioural metrics ===") + print(json.dumps(m, indent=2)) + + +if __name__ == "__main__": + main()