feat: add SFT trainer with chat-template render and assistant-only mask
This commit is contained in:
parent
4533a04021
commit
0bb9c1983b
1 changed files with 209 additions and 0 deletions
209
src/tiararodney/sekft/sft.py
Normal file
209
src/tiararodney/sekft/sft.py
Normal file
|
|
@ -0,0 +1,209 @@
|
||||||
|
"""sekft trainer: SFT a base model on kept shell-operation trajectories.
|
||||||
|
|
||||||
|
Trains assistant turns ONLY -- the commands and the terminal ``exit`` / ``panic``.
|
||||||
|
The environment turns (system orientation, prompts, command output) are masked
|
||||||
|
to ``-100`` so the model learns to *produce* commands, not to predict the
|
||||||
|
environment's replies. Getting this mask wrong is the classic way to ruin a
|
||||||
|
shell-operator SFT (the model starts hallucinating output), so it is the part
|
||||||
|
worth testing hardest -- and it is framework-independent.
|
||||||
|
|
||||||
|
Render uses the tokenizer's OWN chat template (``apply_chat_template``), so the
|
||||||
|
training render is identical to what the serving harness produces (ccpty sends
|
||||||
|
structured messages and the inference endpoint applies the model's default
|
||||||
|
template). Trajectories are canonicalised first (``normalize_for_template``):
|
||||||
|
a leading ``system`` turn is folded into the first ``user`` turn and consecutive
|
||||||
|
same-role turns are merged, because instruct templates such as Mistral's have no
|
||||||
|
system role and require strict user/assistant alternation. That same
|
||||||
|
canonicalisation must run on the serving side. Everything else is standard
|
||||||
|
causal-LM SFT with an assistant-only loss mask.
|
||||||
|
|
||||||
|
python sft.py --data ./trajectories --base <hf-model-dir> --out ./ckpt
|
||||||
|
python sft.py --data ./trajectories --base <dir> --inspect # mask stats, no training
|
||||||
|
|
||||||
|
Training needs torch + transformers + peft (a GPU box). ``--inspect`` and the
|
||||||
|
normalize/mask helpers run anywhere a tokenizer with a chat template is
|
||||||
|
available.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def normalize_for_template(messages: list[dict]) -> list[dict]:
|
||||||
|
"""Canonicalise a trajectory for instruct chat templates that have no system
|
||||||
|
role and require strict user/assistant alternation (Mistral and friends):
|
||||||
|
treat ``system`` as ``user``, then merge consecutive same-role turns by
|
||||||
|
joining their content with a newline.
|
||||||
|
|
||||||
|
This is loss-neutral for the assistant mask (only environment/user turns
|
||||||
|
ever merge; the assistant commands are never adjacent in this data) and it
|
||||||
|
is what lets ``apply_chat_template`` render the multi-turn shell dialogue.
|
||||||
|
The serving side MUST apply the same canonicalisation, or train and serve
|
||||||
|
diverge again.
|
||||||
|
"""
|
||||||
|
out: list[dict] = []
|
||||||
|
for m in messages:
|
||||||
|
role = "user" if m["role"] == "system" else m["role"]
|
||||||
|
if out and out[-1]["role"] == role:
|
||||||
|
out[-1] = {"role": role, "content": out[-1]["content"] + "\n" + m["content"]}
|
||||||
|
else:
|
||||||
|
out.append({"role": role, "content": m["content"]})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def build_masked_example(messages: list[dict], tokenizer) -> dict:
|
||||||
|
"""Tokenize a trajectory with the tokenizer's OWN chat template and build an
|
||||||
|
assistant-only loss mask.
|
||||||
|
|
||||||
|
The render is ``tokenizer.apply_chat_template`` on the canonicalised turns,
|
||||||
|
so it is byte-identical to what the serving harness sends. The mask is
|
||||||
|
derived by token-prefix differencing: the tokens an assistant turn
|
||||||
|
contributes are exactly those that appear when it extends the rendered
|
||||||
|
prefix, which trains the commands plus the template's end-of-turn token (so
|
||||||
|
the model learns to stop) and masks every environment turn to ``-100``. This
|
||||||
|
assumes an additive template (each turn extends the previous render); a
|
||||||
|
non-additive one raises rather than silently mis-mask.
|
||||||
|
"""
|
||||||
|
msgs = normalize_for_template(messages)
|
||||||
|
ids = tokenizer.apply_chat_template(msgs, add_generation_prompt=False)
|
||||||
|
labels = [-100] * len(ids)
|
||||||
|
prev: list[int] = []
|
||||||
|
for i, m in enumerate(msgs):
|
||||||
|
upto = tokenizer.apply_chat_template(msgs[:i + 1], add_generation_prompt=False)
|
||||||
|
if ids[:len(upto)] != upto or upto[:len(prev)] != prev:
|
||||||
|
raise ValueError("chat template is not additive; cannot derive an "
|
||||||
|
"assistant loss mask by token-prefix differencing")
|
||||||
|
if m["role"] == "assistant":
|
||||||
|
for j in range(len(prev), len(upto)):
|
||||||
|
labels[j] = ids[j]
|
||||||
|
prev = upto
|
||||||
|
return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels}
|
||||||
|
|
||||||
|
|
||||||
|
def iter_keepers(data_dir: Path):
|
||||||
|
"""Yield ``turns`` (message lists) from trajectory JSONs marked keep."""
|
||||||
|
for f in sorted(data_dir.glob("*.json")):
|
||||||
|
d = json.loads(f.read_text())
|
||||||
|
if d.get("keep"):
|
||||||
|
yield d["turns"]
|
||||||
|
|
||||||
|
|
||||||
|
def mask_stats(example: dict) -> tuple[int, int]:
|
||||||
|
"""(trained tokens, total tokens) for an example."""
|
||||||
|
trained = sum(1 for x in example["labels"] if x != -100)
|
||||||
|
return trained, len(example["labels"])
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
# Training (GPU box: torch + transformers + peft)
|
||||||
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
|
||||||
|
batch: int, accum: int, max_len: int, lora_r: int,
|
||||||
|
load_4bit: bool = False) -> None:
|
||||||
|
import torch
|
||||||
|
from datasets import Dataset
|
||||||
|
from peft import LoraConfig, get_peft_model
|
||||||
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||||
|
DataCollatorForSeq2Seq, Trainer, TrainingArguments)
|
||||||
|
|
||||||
|
tok = AutoTokenizer.from_pretrained(base)
|
||||||
|
if tok.pad_token is None:
|
||||||
|
tok.pad_token = tok.eos_token
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for turns in iter_keepers(data_dir):
|
||||||
|
ex = build_masked_example(turns, tok)
|
||||||
|
if len(ex["input_ids"]) <= max_len and any(l != -100 for l in ex["labels"]):
|
||||||
|
rows.append(ex)
|
||||||
|
if not rows:
|
||||||
|
raise SystemExit(f"no usable keeper trajectories in {data_dir}")
|
||||||
|
print(f"examples: {len(rows)}; "
|
||||||
|
f"trained/total tokens: {sum(mask_stats(r)[0] for r in rows)}"
|
||||||
|
f"/{sum(mask_stats(r)[1] for r in rows)}")
|
||||||
|
ds = Dataset.from_list(rows)
|
||||||
|
|
||||||
|
# 4-bit (QLoRA) shrinks the base from ~14 GB to ~4 GB to move across the
|
||||||
|
# OcuLink/PCIe link and to hold in VRAM; nf4 + fp16 compute works on the
|
||||||
|
# V100 (sm_70). Without it, plain fp16 weights.
|
||||||
|
quant = None
|
||||||
|
if load_4bit:
|
||||||
|
from transformers import BitsAndBytesConfig
|
||||||
|
quant = BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True,
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
base, dtype=torch.float16, quantization_config=quant)
|
||||||
|
if load_4bit:
|
||||||
|
from peft import prepare_model_for_kbit_training
|
||||||
|
model = prepare_model_for_kbit_training(model) # handles ckpt + input grads
|
||||||
|
else:
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
model = get_peft_model(model, LoraConfig(
|
||||||
|
r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.05, task_type="CAUSAL_LM",
|
||||||
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
|
))
|
||||||
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
output_dir=str(out), per_device_train_batch_size=batch,
|
||||||
|
gradient_accumulation_steps=accum, num_train_epochs=epochs,
|
||||||
|
learning_rate=lr, fp16=True, logging_steps=1, save_strategy="epoch",
|
||||||
|
report_to=["tensorboard"], logging_dir=str(out / "runs"),
|
||||||
|
remove_unused_columns=False, warmup_ratio=0.03,
|
||||||
|
)
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model, args=args, train_dataset=ds,
|
||||||
|
data_collator=DataCollatorForSeq2Seq(tok, padding=True, label_pad_token_id=-100),
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
model.save_pretrained(str(out))
|
||||||
|
tok.save_pretrained(str(out))
|
||||||
|
# durable, greppable record of the curve (loss/lr/grad_norm per step).
|
||||||
|
(out / "log_history.jsonl").write_text(
|
||||||
|
"\n".join(json.dumps(r) for r in trainer.state.log_history))
|
||||||
|
print(f"saved LoRA adapter + log_history.jsonl -> {out} "
|
||||||
|
f"(tensorboard: --logdir {out / 'runs'})")
|
||||||
|
|
||||||
|
|
||||||
|
def inspect(data_dir: Path, base: str) -> None:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
tok = AutoTokenizer.from_pretrained(base)
|
||||||
|
n = tt = tr = 0
|
||||||
|
for turns in iter_keepers(data_dir):
|
||||||
|
ex = build_masked_example(turns, tok)
|
||||||
|
t, total = mask_stats(ex)
|
||||||
|
tr += t; tt += total; n += 1
|
||||||
|
if not n:
|
||||||
|
raise SystemExit(f"no keeper trajectories in {data_dir}")
|
||||||
|
print(f"{n} keeper trajectories; {tr}/{tt} tokens trained "
|
||||||
|
f"({100*tr/tt:.1f}% assistant, rest masked)")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(description="SFT a model on shell trajectories.")
|
||||||
|
ap.add_argument("--data", type=Path, default=Path("./trajectories"))
|
||||||
|
ap.add_argument("--base", required=True, help="HF model id or local dir")
|
||||||
|
ap.add_argument("--out", type=Path, default=Path("./ckpt"))
|
||||||
|
ap.add_argument("--inspect", action="store_true", help="mask stats only, no training")
|
||||||
|
ap.add_argument("--epochs", type=float, default=3.0)
|
||||||
|
ap.add_argument("--lr", type=float, default=2e-4)
|
||||||
|
ap.add_argument("--batch", type=int, default=1)
|
||||||
|
ap.add_argument("--accum", type=int, default=8)
|
||||||
|
ap.add_argument("--max-len", type=int, default=4096)
|
||||||
|
ap.add_argument("--lora-r", type=int, default=16)
|
||||||
|
ap.add_argument("--load-4bit", action="store_true",
|
||||||
|
help="QLoRA: load base in 4-bit (less to move over the link, less VRAM)")
|
||||||
|
ns = ap.parse_args()
|
||||||
|
if ns.inspect:
|
||||||
|
inspect(ns.data, ns.base)
|
||||||
|
else:
|
||||||
|
train(ns.data, ns.base, ns.out, ns.epochs, ns.lr, ns.batch, ns.accum,
|
||||||
|
ns.max_len, ns.lora_r, ns.load_4bit)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue