diff --git a/src/tiararodney/sekft/resident.py b/src/tiararodney/sekft/resident.py new file mode 100644 index 0000000..c56979b --- /dev/null +++ b/src/tiararodney/sekft/resident.py @@ -0,0 +1,187 @@ +"""Resident harness: load the base ONCE, cycle adapters. + +On a slow link (OcuLink / PCIe 3.0 x4) the 14 GB base transfer dominates every +process start. This loads the base once and keeps it hot, so the +iterate-train-eval loop pays the transfer only at startup. Each ``fit`` trains a +fresh LoRA adapter on the resident base and ``unload``s it back to clean; each +``evaluate`` attaches a saved adapter for inference and unloads. + +Interactive (IPython on the GPU box) is the intended use: + + from resident import Resident + r = Resident("~/llm-models/mistral-7b-instruct-v0.2", load_4bit=True) + r.fit("~/sekft/trajectories", "~/sekft/ckpt-a", lora_r=16, lr=2e-4, epochs=3) + r.evaluate("~/sekft/ckpt-a", "~/sekft/holdout", n=10) + r.fit("~/sekft/trajectories", "~/sekft/ckpt-b", lora_r=32) # NO base reload + +Or `python resident.py --base --selftest-data ` to prove the +base loads once and two adapters train against it. +""" +from __future__ import annotations + +import argparse +import gc +import json +from pathlib import Path + +import torch +from datasets import Dataset +from peft import (LoraConfig, PeftModel, get_peft_model, + prepare_model_for_kbit_training) +from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, + DataCollatorForSeq2Seq, Trainer, TrainingArguments) + +from .sft import build_masked_example, iter_keepers, normalize_for_template + +LORA_TARGETS = ["q_proj", "k_proj", "v_proj", "o_proj"] + + +def _free() -> None: + gc.collect() + torch.cuda.empty_cache() + + +class Resident: + """A base model held resident on the GPU; adapters cycle through it.""" + + def __init__(self, base: str, load_4bit: bool = False) -> None: + self.base_path = str(Path(base).expanduser()) + self.load_4bit = load_4bit + self.tok = AutoTokenizer.from_pretrained(self.base_path) + if self.tok.pad_token is None: + self.tok.pad_token = self.tok.eos_token + quant = None + if load_4bit: + quant = BitsAndBytesConfig( + load_in_4bit=True, bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True) + print(f"[resident] loading base ONCE: {self.base_path} (4bit={load_4bit}) ...") + self.base = AutoModelForCausalLM.from_pretrained( + self.base_path, dtype=torch.float16, quantization_config=quant) + self.base = (prepare_model_for_kbit_training(self.base) if load_4bit + else self.base) + if not load_4bit: + self.base.enable_input_require_grads() + dev = next(self.base.parameters()).device + mem = torch.cuda.memory_allocated() / 1e9 + print(f"[resident] base resident on {dev}; {mem:.1f} GB VRAM") + + # -- build masked rows from kept trajectories -------------------------- + + def _rows(self, data_dir: Path, max_len: int) -> list[dict]: + rows = [] + for turns in iter_keepers(data_dir): + ex = build_masked_example(turns, self.tok) + if len(ex["input_ids"]) <= max_len and any(l != -100 for l in ex["labels"]): + rows.append(ex) + if not rows: + raise SystemExit(f"no usable keeper trajectories in {data_dir}") + return rows + + # -- train a fresh adapter on the resident base ------------------------ + + def fit(self, data_dir: str, out: str, lora_r: int = 16, lr: float = 2e-4, + epochs: float = 3.0, batch: int = 1, accum: int = 8, + max_len: int = 4096) -> Path: + data_dir, out = Path(data_dir).expanduser(), Path(out).expanduser() + ds = Dataset.from_list(self._rows(data_dir, max_len)) + if not self.load_4bit: + self.base.gradient_checkpointing_enable() + model = get_peft_model(self.base, LoraConfig( + r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.05, + task_type="CAUSAL_LM", target_modules=LORA_TARGETS)) + model.print_trainable_parameters() + args = TrainingArguments( + output_dir=str(out), per_device_train_batch_size=batch, + gradient_accumulation_steps=accum, num_train_epochs=epochs, + learning_rate=lr, fp16=True, logging_steps=1, save_strategy="no", + report_to=["tensorboard"], logging_dir=str(out / "runs"), + remove_unused_columns=False, warmup_ratio=0.03) + tr = Trainer(model=model, args=args, train_dataset=ds, + data_collator=DataCollatorForSeq2Seq( + self.tok, padding=True, label_pad_token_id=-100)) + tr.train() + out.mkdir(parents=True, exist_ok=True) + model.save_pretrained(str(out)) + self.tok.save_pretrained(str(out)) + (out / "log_history.jsonl").write_text( + "\n".join(json.dumps(r) for r in tr.state.log_history)) + losses = [h["loss"] for h in tr.state.log_history if "loss" in h] + print(f"[resident] fit -> {out} final loss {losses[-1] if losses else '?'}") + self.base = model.unload() # strip LoRA, restore resident base + del model, tr, ds + _free() + return out + + # -- behavioural eval of a saved adapter ------------------------------- + + def evaluate(self, adapter: str, scenarios_dir: str, n: int = 10, + max_steps: int = 30, temperature: float = 0.7) -> dict: + from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available + from tiararodney.posix_sdc.factory.rollout import rollout + from tiararodney.posix_sdc.schema import Scenario + if not available(): + raise SystemExit("sekft-dash image unavailable on this box") + # adapter=None -> evaluate the BASE model (the within-holdout baseline). + if adapter: + adapter = str(Path(adapter).expanduser()) + pm = PeftModel.from_pretrained(self.base, adapter) + else: + pm = self.base + pm.eval() + + def operator(messages): + msgs = normalize_for_template(messages) + ids = self.tok.apply_chat_template( + msgs, add_generation_prompt=True, return_tensors="pt").to(pm.device) + with torch.no_grad(): + o = pm.generate(ids, max_new_tokens=64, do_sample=temperature > 0, + temperature=max(temperature, 1e-2), + eos_token_id=self.tok.eos_token_id, + pad_token_id=self.tok.eos_token_id) + return self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip() + + backend = DashDocker() + rows = [] + for f in sorted(Path(scenarios_dir).expanduser().glob("*.json"))[:n]: + sc = Scenario.from_dict(json.loads(f.read_text())) + tj = rollout(sc, backend, max_steps=max_steps, temperature=temperature, + operator=operator, use_scaffold=False) + rows.append(tj) + print(f" {sc.id}: {tj.outcome} terminal={tj.terminal} verified={tj.verified}") + d = len(rows) or 1 + m = { + "n": len(rows), + "operate_rate": round(sum(t.steps > 0 and t.meta.get("clean") for t in rows) / d, 3), + "terminate_rate": round(sum(t.terminal in ("exit", "panic") for t in rows) / d, 3), + "verified_rate": round(sum(t.verified for t in rows) / d, 3), + "clean_rate": round(sum(t.keep for t in rows) / d, 3), + } + if adapter: # base is unwrapped only if we wrapped it + self.base = pm.unload() + del pm + _free() + print("[resident] eval:", json.dumps(m)) + return m + + +def main() -> None: + ap = argparse.ArgumentParser(description="Resident base; cycle adapters.") + ap.add_argument("--base", required=True) + ap.add_argument("--load-4bit", action="store_true") + ap.add_argument("--selftest-data", + help="fit two adapters on this data to prove resident multi-fit") + ns = ap.parse_args() + r = Resident(ns.base, ns.load_4bit) + if ns.selftest_data: + print("=== selftest: two fits on the SAME resident base (no reload) ===") + r.fit(ns.selftest_data, "/tmp/res-a", epochs=1, lora_r=8) + r.fit(ns.selftest_data, "/tmp/res-b", epochs=1, lora_r=8) + print("=== selftest OK: base loaded once, two adapters trained ===") + else: + print("Resident ready. Import and use r.fit() / r.evaluate(), " + "or pass --selftest-data .") + + +if __name__ == "__main__": + main()