feat: add resident-base train and eval harness
This commit is contained in:
parent
e810d0e442
commit
23e749c878
1 changed files with 187 additions and 0 deletions
187
src/tiararodney/sekft/resident.py
Normal file
187
src/tiararodney/sekft/resident.py
Normal file
|
|
@ -0,0 +1,187 @@
|
||||||
|
"""Resident harness: load the base ONCE, cycle adapters.
|
||||||
|
|
||||||
|
On a slow link (OcuLink / PCIe 3.0 x4) the 14 GB base transfer dominates every
|
||||||
|
process start. This loads the base once and keeps it hot, so the
|
||||||
|
iterate-train-eval loop pays the transfer only at startup. Each ``fit`` trains a
|
||||||
|
fresh LoRA adapter on the resident base and ``unload``s it back to clean; each
|
||||||
|
``evaluate`` attaches a saved adapter for inference and unloads.
|
||||||
|
|
||||||
|
Interactive (IPython on the GPU box) is the intended use:
|
||||||
|
|
||||||
|
from resident import Resident
|
||||||
|
r = Resident("~/llm-models/mistral-7b-instruct-v0.2", load_4bit=True)
|
||||||
|
r.fit("~/sekft/trajectories", "~/sekft/ckpt-a", lora_r=16, lr=2e-4, epochs=3)
|
||||||
|
r.evaluate("~/sekft/ckpt-a", "~/sekft/holdout", n=10)
|
||||||
|
r.fit("~/sekft/trajectories", "~/sekft/ckpt-b", lora_r=32) # NO base reload
|
||||||
|
|
||||||
|
Or `python resident.py --base <dir> --selftest-data <stub_dir>` to prove the
|
||||||
|
base loads once and two adapters train against it.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import Dataset
|
||||||
|
from peft import (LoraConfig, PeftModel, get_peft_model,
|
||||||
|
prepare_model_for_kbit_training)
|
||||||
|
from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
|
||||||
|
DataCollatorForSeq2Seq, Trainer, TrainingArguments)
|
||||||
|
|
||||||
|
from .sft import build_masked_example, iter_keepers, normalize_for_template
|
||||||
|
|
||||||
|
LORA_TARGETS = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||||
|
|
||||||
|
|
||||||
|
def _free() -> None:
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
class Resident:
|
||||||
|
"""A base model held resident on the GPU; adapters cycle through it."""
|
||||||
|
|
||||||
|
def __init__(self, base: str, load_4bit: bool = False) -> None:
|
||||||
|
self.base_path = str(Path(base).expanduser())
|
||||||
|
self.load_4bit = load_4bit
|
||||||
|
self.tok = AutoTokenizer.from_pretrained(self.base_path)
|
||||||
|
if self.tok.pad_token is None:
|
||||||
|
self.tok.pad_token = self.tok.eos_token
|
||||||
|
quant = None
|
||||||
|
if load_4bit:
|
||||||
|
quant = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True)
|
||||||
|
print(f"[resident] loading base ONCE: {self.base_path} (4bit={load_4bit}) ...")
|
||||||
|
self.base = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.base_path, dtype=torch.float16, quantization_config=quant)
|
||||||
|
self.base = (prepare_model_for_kbit_training(self.base) if load_4bit
|
||||||
|
else self.base)
|
||||||
|
if not load_4bit:
|
||||||
|
self.base.enable_input_require_grads()
|
||||||
|
dev = next(self.base.parameters()).device
|
||||||
|
mem = torch.cuda.memory_allocated() / 1e9
|
||||||
|
print(f"[resident] base resident on {dev}; {mem:.1f} GB VRAM")
|
||||||
|
|
||||||
|
# -- build masked rows from kept trajectories --------------------------
|
||||||
|
|
||||||
|
def _rows(self, data_dir: Path, max_len: int) -> list[dict]:
|
||||||
|
rows = []
|
||||||
|
for turns in iter_keepers(data_dir):
|
||||||
|
ex = build_masked_example(turns, self.tok)
|
||||||
|
if len(ex["input_ids"]) <= max_len and any(l != -100 for l in ex["labels"]):
|
||||||
|
rows.append(ex)
|
||||||
|
if not rows:
|
||||||
|
raise SystemExit(f"no usable keeper trajectories in {data_dir}")
|
||||||
|
return rows
|
||||||
|
|
||||||
|
# -- train a fresh adapter on the resident base ------------------------
|
||||||
|
|
||||||
|
def fit(self, data_dir: str, out: str, lora_r: int = 16, lr: float = 2e-4,
|
||||||
|
epochs: float = 3.0, batch: int = 1, accum: int = 8,
|
||||||
|
max_len: int = 4096) -> Path:
|
||||||
|
data_dir, out = Path(data_dir).expanduser(), Path(out).expanduser()
|
||||||
|
ds = Dataset.from_list(self._rows(data_dir, max_len))
|
||||||
|
if not self.load_4bit:
|
||||||
|
self.base.gradient_checkpointing_enable()
|
||||||
|
model = get_peft_model(self.base, LoraConfig(
|
||||||
|
r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.05,
|
||||||
|
task_type="CAUSAL_LM", target_modules=LORA_TARGETS))
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir=str(out), per_device_train_batch_size=batch,
|
||||||
|
gradient_accumulation_steps=accum, num_train_epochs=epochs,
|
||||||
|
learning_rate=lr, fp16=True, logging_steps=1, save_strategy="no",
|
||||||
|
report_to=["tensorboard"], logging_dir=str(out / "runs"),
|
||||||
|
remove_unused_columns=False, warmup_ratio=0.03)
|
||||||
|
tr = Trainer(model=model, args=args, train_dataset=ds,
|
||||||
|
data_collator=DataCollatorForSeq2Seq(
|
||||||
|
self.tok, padding=True, label_pad_token_id=-100))
|
||||||
|
tr.train()
|
||||||
|
out.mkdir(parents=True, exist_ok=True)
|
||||||
|
model.save_pretrained(str(out))
|
||||||
|
self.tok.save_pretrained(str(out))
|
||||||
|
(out / "log_history.jsonl").write_text(
|
||||||
|
"\n".join(json.dumps(r) for r in tr.state.log_history))
|
||||||
|
losses = [h["loss"] for h in tr.state.log_history if "loss" in h]
|
||||||
|
print(f"[resident] fit -> {out} final loss {losses[-1] if losses else '?'}")
|
||||||
|
self.base = model.unload() # strip LoRA, restore resident base
|
||||||
|
del model, tr, ds
|
||||||
|
_free()
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -- behavioural eval of a saved adapter -------------------------------
|
||||||
|
|
||||||
|
def evaluate(self, adapter: str, scenarios_dir: str, n: int = 10,
|
||||||
|
max_steps: int = 30, temperature: float = 0.7) -> dict:
|
||||||
|
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
|
||||||
|
from tiararodney.posix_sdc.factory.rollout import rollout
|
||||||
|
from tiararodney.posix_sdc.schema import Scenario
|
||||||
|
if not available():
|
||||||
|
raise SystemExit("sekft-dash image unavailable on this box")
|
||||||
|
# adapter=None -> evaluate the BASE model (the within-holdout baseline).
|
||||||
|
if adapter:
|
||||||
|
adapter = str(Path(adapter).expanduser())
|
||||||
|
pm = PeftModel.from_pretrained(self.base, adapter)
|
||||||
|
else:
|
||||||
|
pm = self.base
|
||||||
|
pm.eval()
|
||||||
|
|
||||||
|
def operator(messages):
|
||||||
|
msgs = normalize_for_template(messages)
|
||||||
|
ids = self.tok.apply_chat_template(
|
||||||
|
msgs, add_generation_prompt=True, return_tensors="pt").to(pm.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
o = pm.generate(ids, max_new_tokens=64, do_sample=temperature > 0,
|
||||||
|
temperature=max(temperature, 1e-2),
|
||||||
|
eos_token_id=self.tok.eos_token_id,
|
||||||
|
pad_token_id=self.tok.eos_token_id)
|
||||||
|
return self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip()
|
||||||
|
|
||||||
|
backend = DashDocker()
|
||||||
|
rows = []
|
||||||
|
for f in sorted(Path(scenarios_dir).expanduser().glob("*.json"))[:n]:
|
||||||
|
sc = Scenario.from_dict(json.loads(f.read_text()))
|
||||||
|
tj = rollout(sc, backend, max_steps=max_steps, temperature=temperature,
|
||||||
|
operator=operator, use_scaffold=False)
|
||||||
|
rows.append(tj)
|
||||||
|
print(f" {sc.id}: {tj.outcome} terminal={tj.terminal} verified={tj.verified}")
|
||||||
|
d = len(rows) or 1
|
||||||
|
m = {
|
||||||
|
"n": len(rows),
|
||||||
|
"operate_rate": round(sum(t.steps > 0 and t.meta.get("clean") for t in rows) / d, 3),
|
||||||
|
"terminate_rate": round(sum(t.terminal in ("exit", "panic") for t in rows) / d, 3),
|
||||||
|
"verified_rate": round(sum(t.verified for t in rows) / d, 3),
|
||||||
|
"clean_rate": round(sum(t.keep for t in rows) / d, 3),
|
||||||
|
}
|
||||||
|
if adapter: # base is unwrapped only if we wrapped it
|
||||||
|
self.base = pm.unload()
|
||||||
|
del pm
|
||||||
|
_free()
|
||||||
|
print("[resident] eval:", json.dumps(m))
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(description="Resident base; cycle adapters.")
|
||||||
|
ap.add_argument("--base", required=True)
|
||||||
|
ap.add_argument("--load-4bit", action="store_true")
|
||||||
|
ap.add_argument("--selftest-data",
|
||||||
|
help="fit two adapters on this data to prove resident multi-fit")
|
||||||
|
ns = ap.parse_args()
|
||||||
|
r = Resident(ns.base, ns.load_4bit)
|
||||||
|
if ns.selftest_data:
|
||||||
|
print("=== selftest: two fits on the SAME resident base (no reload) ===")
|
||||||
|
r.fit(ns.selftest_data, "/tmp/res-a", epochs=1, lora_r=8)
|
||||||
|
r.fit(ns.selftest_data, "/tmp/res-b", epochs=1, lora_r=8)
|
||||||
|
print("=== selftest OK: base loaded once, two adapters trained ===")
|
||||||
|
else:
|
||||||
|
print("Resident ready. Import and use r.fit() / r.evaluate(), "
|
||||||
|
"or pass --selftest-data <dir>.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue