Merge branch 'feature/2'

This commit is contained in:
Tiara Rodney 2026-06-16 20:14:16 +02:00
commit 908f59f5a0
Signed by: tiara
GPG key ID: 5CD8EC1D46106723
2 changed files with 210 additions and 1 deletions

2
TODO
View file

@ -37,7 +37,7 @@ Content-Type: application/issue
ID: 2
Type: feature
Title: SFT trainer with chat-template render and assistant-only mask
Status: in-progress
Status: done
Priority: medium
Created: 2026-06-16
Module: sekft

View file

@ -0,0 +1,209 @@
"""sekft trainer: SFT a base model on kept shell-operation trajectories.
Trains assistant turns ONLY -- the commands and the terminal ``exit`` / ``panic``.
The environment turns (system orientation, prompts, command output) are masked
to ``-100`` so the model learns to *produce* commands, not to predict the
environment's replies. Getting this mask wrong is the classic way to ruin a
shell-operator SFT (the model starts hallucinating output), so it is the part
worth testing hardest -- and it is framework-independent.
Render uses the tokenizer's OWN chat template (``apply_chat_template``), so the
training render is identical to what the serving harness produces (ccpty sends
structured messages and the inference endpoint applies the model's default
template). Trajectories are canonicalised first (``normalize_for_template``):
a leading ``system`` turn is folded into the first ``user`` turn and consecutive
same-role turns are merged, because instruct templates such as Mistral's have no
system role and require strict user/assistant alternation. That same
canonicalisation must run on the serving side. Everything else is standard
causal-LM SFT with an assistant-only loss mask.
python sft.py --data ./trajectories --base <hf-model-dir> --out ./ckpt
python sft.py --data ./trajectories --base <dir> --inspect # mask stats, no training
Training needs torch + transformers + peft (a GPU box). ``--inspect`` and the
normalize/mask helpers run anywhere a tokenizer with a chat template is
available.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
def normalize_for_template(messages: list[dict]) -> list[dict]:
"""Canonicalise a trajectory for instruct chat templates that have no system
role and require strict user/assistant alternation (Mistral and friends):
treat ``system`` as ``user``, then merge consecutive same-role turns by
joining their content with a newline.
This is loss-neutral for the assistant mask (only environment/user turns
ever merge; the assistant commands are never adjacent in this data) and it
is what lets ``apply_chat_template`` render the multi-turn shell dialogue.
The serving side MUST apply the same canonicalisation, or train and serve
diverge again.
"""
out: list[dict] = []
for m in messages:
role = "user" if m["role"] == "system" else m["role"]
if out and out[-1]["role"] == role:
out[-1] = {"role": role, "content": out[-1]["content"] + "\n" + m["content"]}
else:
out.append({"role": role, "content": m["content"]})
return out
def build_masked_example(messages: list[dict], tokenizer) -> dict:
"""Tokenize a trajectory with the tokenizer's OWN chat template and build an
assistant-only loss mask.
The render is ``tokenizer.apply_chat_template`` on the canonicalised turns,
so it is byte-identical to what the serving harness sends. The mask is
derived by token-prefix differencing: the tokens an assistant turn
contributes are exactly those that appear when it extends the rendered
prefix, which trains the commands plus the template's end-of-turn token (so
the model learns to stop) and masks every environment turn to ``-100``. This
assumes an additive template (each turn extends the previous render); a
non-additive one raises rather than silently mis-mask.
"""
msgs = normalize_for_template(messages)
ids = tokenizer.apply_chat_template(msgs, add_generation_prompt=False)
labels = [-100] * len(ids)
prev: list[int] = []
for i, m in enumerate(msgs):
upto = tokenizer.apply_chat_template(msgs[:i + 1], add_generation_prompt=False)
if ids[:len(upto)] != upto or upto[:len(prev)] != prev:
raise ValueError("chat template is not additive; cannot derive an "
"assistant loss mask by token-prefix differencing")
if m["role"] == "assistant":
for j in range(len(prev), len(upto)):
labels[j] = ids[j]
prev = upto
return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels}
def iter_keepers(data_dir: Path):
"""Yield ``turns`` (message lists) from trajectory JSONs marked keep."""
for f in sorted(data_dir.glob("*.json")):
d = json.loads(f.read_text())
if d.get("keep"):
yield d["turns"]
def mask_stats(example: dict) -> tuple[int, int]:
"""(trained tokens, total tokens) for an example."""
trained = sum(1 for x in example["labels"] if x != -100)
return trained, len(example["labels"])
# --------------------------------------------------------------------------
# Training (GPU box: torch + transformers + peft)
# --------------------------------------------------------------------------
def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
batch: int, accum: int, max_len: int, lora_r: int,
load_4bit: bool = False) -> None:
import torch
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import (AutoModelForCausalLM, AutoTokenizer,
DataCollatorForSeq2Seq, Trainer, TrainingArguments)
tok = AutoTokenizer.from_pretrained(base)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
rows = []
for turns in iter_keepers(data_dir):
ex = build_masked_example(turns, tok)
if len(ex["input_ids"]) <= max_len and any(l != -100 for l in ex["labels"]):
rows.append(ex)
if not rows:
raise SystemExit(f"no usable keeper trajectories in {data_dir}")
print(f"examples: {len(rows)}; "
f"trained/total tokens: {sum(mask_stats(r)[0] for r in rows)}"
f"/{sum(mask_stats(r)[1] for r in rows)}")
ds = Dataset.from_list(rows)
# 4-bit (QLoRA) shrinks the base from ~14 GB to ~4 GB to move across the
# OcuLink/PCIe link and to hold in VRAM; nf4 + fp16 compute works on the
# V100 (sm_70). Without it, plain fp16 weights.
quant = None
if load_4bit:
from transformers import BitsAndBytesConfig
quant = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
base, dtype=torch.float16, quantization_config=quant)
if load_4bit:
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model) # handles ckpt + input grads
else:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model = get_peft_model(model, LoraConfig(
r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.05, task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
))
model.print_trainable_parameters()
args = TrainingArguments(
output_dir=str(out), per_device_train_batch_size=batch,
gradient_accumulation_steps=accum, num_train_epochs=epochs,
learning_rate=lr, fp16=True, logging_steps=1, save_strategy="epoch",
report_to=["tensorboard"], logging_dir=str(out / "runs"),
remove_unused_columns=False, warmup_ratio=0.03,
)
trainer = Trainer(
model=model, args=args, train_dataset=ds,
data_collator=DataCollatorForSeq2Seq(tok, padding=True, label_pad_token_id=-100),
)
trainer.train()
model.save_pretrained(str(out))
tok.save_pretrained(str(out))
# durable, greppable record of the curve (loss/lr/grad_norm per step).
(out / "log_history.jsonl").write_text(
"\n".join(json.dumps(r) for r in trainer.state.log_history))
print(f"saved LoRA adapter + log_history.jsonl -> {out} "
f"(tensorboard: --logdir {out / 'runs'})")
def inspect(data_dir: Path, base: str) -> None:
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(base)
n = tt = tr = 0
for turns in iter_keepers(data_dir):
ex = build_masked_example(turns, tok)
t, total = mask_stats(ex)
tr += t; tt += total; n += 1
if not n:
raise SystemExit(f"no keeper trajectories in {data_dir}")
print(f"{n} keeper trajectories; {tr}/{tt} tokens trained "
f"({100*tr/tt:.1f}% assistant, rest masked)")
def main() -> None:
ap = argparse.ArgumentParser(description="SFT a model on shell trajectories.")
ap.add_argument("--data", type=Path, default=Path("./trajectories"))
ap.add_argument("--base", required=True, help="HF model id or local dir")
ap.add_argument("--out", type=Path, default=Path("./ckpt"))
ap.add_argument("--inspect", action="store_true", help="mask stats only, no training")
ap.add_argument("--epochs", type=float, default=3.0)
ap.add_argument("--lr", type=float, default=2e-4)
ap.add_argument("--batch", type=int, default=1)
ap.add_argument("--accum", type=int, default=8)
ap.add_argument("--max-len", type=int, default=4096)
ap.add_argument("--lora-r", type=int, default=16)
ap.add_argument("--load-4bit", action="store_true",
help="QLoRA: load base in 4-bit (less to move over the link, less VRAM)")
ns = ap.parse_args()
if ns.inspect:
inspect(ns.data, ns.base)
else:
train(ns.data, ns.base, ns.out, ns.epochs, ns.lr, ns.batch, ns.accum,
ns.max_len, ns.lora_r, ns.load_4bit)
if __name__ == "__main__":
main()