diff --git a/src/tiararodney/sekft/sft.py b/src/tiararodney/sekft/sft.py new file mode 100644 index 0000000..db5df3b --- /dev/null +++ b/src/tiararodney/sekft/sft.py @@ -0,0 +1,209 @@ +"""sekft trainer: SFT a base model on kept shell-operation trajectories. + +Trains assistant turns ONLY -- the commands and the terminal ``exit`` / ``panic``. +The environment turns (system orientation, prompts, command output) are masked +to ``-100`` so the model learns to *produce* commands, not to predict the +environment's replies. Getting this mask wrong is the classic way to ruin a +shell-operator SFT (the model starts hallucinating output), so it is the part +worth testing hardest -- and it is framework-independent. + +Render uses the tokenizer's OWN chat template (``apply_chat_template``), so the +training render is identical to what the serving harness produces (ccpty sends +structured messages and the inference endpoint applies the model's default +template). Trajectories are canonicalised first (``normalize_for_template``): +a leading ``system`` turn is folded into the first ``user`` turn and consecutive +same-role turns are merged, because instruct templates such as Mistral's have no +system role and require strict user/assistant alternation. That same +canonicalisation must run on the serving side. Everything else is standard +causal-LM SFT with an assistant-only loss mask. + + python sft.py --data ./trajectories --base --out ./ckpt + python sft.py --data ./trajectories --base --inspect # mask stats, no training + +Training needs torch + transformers + peft (a GPU box). ``--inspect`` and the +normalize/mask helpers run anywhere a tokenizer with a chat template is +available. +""" +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +def normalize_for_template(messages: list[dict]) -> list[dict]: + """Canonicalise a trajectory for instruct chat templates that have no system + role and require strict user/assistant alternation (Mistral and friends): + treat ``system`` as ``user``, then merge consecutive same-role turns by + joining their content with a newline. + + This is loss-neutral for the assistant mask (only environment/user turns + ever merge; the assistant commands are never adjacent in this data) and it + is what lets ``apply_chat_template`` render the multi-turn shell dialogue. + The serving side MUST apply the same canonicalisation, or train and serve + diverge again. + """ + out: list[dict] = [] + for m in messages: + role = "user" if m["role"] == "system" else m["role"] + if out and out[-1]["role"] == role: + out[-1] = {"role": role, "content": out[-1]["content"] + "\n" + m["content"]} + else: + out.append({"role": role, "content": m["content"]}) + return out + + +def build_masked_example(messages: list[dict], tokenizer) -> dict: + """Tokenize a trajectory with the tokenizer's OWN chat template and build an + assistant-only loss mask. + + The render is ``tokenizer.apply_chat_template`` on the canonicalised turns, + so it is byte-identical to what the serving harness sends. The mask is + derived by token-prefix differencing: the tokens an assistant turn + contributes are exactly those that appear when it extends the rendered + prefix, which trains the commands plus the template's end-of-turn token (so + the model learns to stop) and masks every environment turn to ``-100``. This + assumes an additive template (each turn extends the previous render); a + non-additive one raises rather than silently mis-mask. + """ + msgs = normalize_for_template(messages) + ids = tokenizer.apply_chat_template(msgs, add_generation_prompt=False) + labels = [-100] * len(ids) + prev: list[int] = [] + for i, m in enumerate(msgs): + upto = tokenizer.apply_chat_template(msgs[:i + 1], add_generation_prompt=False) + if ids[:len(upto)] != upto or upto[:len(prev)] != prev: + raise ValueError("chat template is not additive; cannot derive an " + "assistant loss mask by token-prefix differencing") + if m["role"] == "assistant": + for j in range(len(prev), len(upto)): + labels[j] = ids[j] + prev = upto + return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels} + + +def iter_keepers(data_dir: Path): + """Yield ``turns`` (message lists) from trajectory JSONs marked keep.""" + for f in sorted(data_dir.glob("*.json")): + d = json.loads(f.read_text()) + if d.get("keep"): + yield d["turns"] + + +def mask_stats(example: dict) -> tuple[int, int]: + """(trained tokens, total tokens) for an example.""" + trained = sum(1 for x in example["labels"] if x != -100) + return trained, len(example["labels"]) + + +# -------------------------------------------------------------------------- +# Training (GPU box: torch + transformers + peft) +# -------------------------------------------------------------------------- + +def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float, + batch: int, accum: int, max_len: int, lora_r: int, + load_4bit: bool = False) -> None: + import torch + from datasets import Dataset + from peft import LoraConfig, get_peft_model + from transformers import (AutoModelForCausalLM, AutoTokenizer, + DataCollatorForSeq2Seq, Trainer, TrainingArguments) + + tok = AutoTokenizer.from_pretrained(base) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + + rows = [] + for turns in iter_keepers(data_dir): + ex = build_masked_example(turns, tok) + if len(ex["input_ids"]) <= max_len and any(l != -100 for l in ex["labels"]): + rows.append(ex) + if not rows: + raise SystemExit(f"no usable keeper trajectories in {data_dir}") + print(f"examples: {len(rows)}; " + f"trained/total tokens: {sum(mask_stats(r)[0] for r in rows)}" + f"/{sum(mask_stats(r)[1] for r in rows)}") + ds = Dataset.from_list(rows) + + # 4-bit (QLoRA) shrinks the base from ~14 GB to ~4 GB to move across the + # OcuLink/PCIe link and to hold in VRAM; nf4 + fp16 compute works on the + # V100 (sm_70). Without it, plain fp16 weights. + quant = None + if load_4bit: + from transformers import BitsAndBytesConfig + quant = BitsAndBytesConfig( + load_in_4bit=True, bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, + ) + model = AutoModelForCausalLM.from_pretrained( + base, dtype=torch.float16, quantization_config=quant) + if load_4bit: + from peft import prepare_model_for_kbit_training + model = prepare_model_for_kbit_training(model) # handles ckpt + input grads + else: + model.enable_input_require_grads() + model.gradient_checkpointing_enable() + model = get_peft_model(model, LoraConfig( + r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.05, task_type="CAUSAL_LM", + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + )) + model.print_trainable_parameters() + + args = TrainingArguments( + output_dir=str(out), per_device_train_batch_size=batch, + gradient_accumulation_steps=accum, num_train_epochs=epochs, + learning_rate=lr, fp16=True, logging_steps=1, save_strategy="epoch", + report_to=["tensorboard"], logging_dir=str(out / "runs"), + remove_unused_columns=False, warmup_ratio=0.03, + ) + trainer = Trainer( + model=model, args=args, train_dataset=ds, + data_collator=DataCollatorForSeq2Seq(tok, padding=True, label_pad_token_id=-100), + ) + trainer.train() + model.save_pretrained(str(out)) + tok.save_pretrained(str(out)) + # durable, greppable record of the curve (loss/lr/grad_norm per step). + (out / "log_history.jsonl").write_text( + "\n".join(json.dumps(r) for r in trainer.state.log_history)) + print(f"saved LoRA adapter + log_history.jsonl -> {out} " + f"(tensorboard: --logdir {out / 'runs'})") + + +def inspect(data_dir: Path, base: str) -> None: + from transformers import AutoTokenizer + tok = AutoTokenizer.from_pretrained(base) + n = tt = tr = 0 + for turns in iter_keepers(data_dir): + ex = build_masked_example(turns, tok) + t, total = mask_stats(ex) + tr += t; tt += total; n += 1 + if not n: + raise SystemExit(f"no keeper trajectories in {data_dir}") + print(f"{n} keeper trajectories; {tr}/{tt} tokens trained " + f"({100*tr/tt:.1f}% assistant, rest masked)") + + +def main() -> None: + ap = argparse.ArgumentParser(description="SFT a model on shell trajectories.") + ap.add_argument("--data", type=Path, default=Path("./trajectories")) + ap.add_argument("--base", required=True, help="HF model id or local dir") + ap.add_argument("--out", type=Path, default=Path("./ckpt")) + ap.add_argument("--inspect", action="store_true", help="mask stats only, no training") + ap.add_argument("--epochs", type=float, default=3.0) + ap.add_argument("--lr", type=float, default=2e-4) + ap.add_argument("--batch", type=int, default=1) + ap.add_argument("--accum", type=int, default=8) + ap.add_argument("--max-len", type=int, default=4096) + ap.add_argument("--lora-r", type=int, default=16) + ap.add_argument("--load-4bit", action="store_true", + help="QLoRA: load base in 4-bit (less to move over the link, less VRAM)") + ns = ap.parse_args() + if ns.inspect: + inspect(ns.data, ns.base) + else: + train(ns.data, ns.base, ns.out, ns.epochs, ns.lr, ns.batch, ns.accum, + ns.max_len, ns.lora_r, ns.load_4bit) + + +if __name__ == "__main__": + main()