Merge branch 'feature/2'
This commit is contained in:
commit
908f59f5a0
2 changed files with 210 additions and 1 deletions
2
TODO
2
TODO
|
|
@ -37,7 +37,7 @@ Content-Type: application/issue
|
|||
ID: 2
|
||||
Type: feature
|
||||
Title: SFT trainer with chat-template render and assistant-only mask
|
||||
Status: in-progress
|
||||
Status: done
|
||||
Priority: medium
|
||||
Created: 2026-06-16
|
||||
Module: sekft
|
||||
|
|
|
|||
209
src/tiararodney/sekft/sft.py
Normal file
209
src/tiararodney/sekft/sft.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
"""sekft trainer: SFT a base model on kept shell-operation trajectories.
|
||||
|
||||
Trains assistant turns ONLY -- the commands and the terminal ``exit`` / ``panic``.
|
||||
The environment turns (system orientation, prompts, command output) are masked
|
||||
to ``-100`` so the model learns to *produce* commands, not to predict the
|
||||
environment's replies. Getting this mask wrong is the classic way to ruin a
|
||||
shell-operator SFT (the model starts hallucinating output), so it is the part
|
||||
worth testing hardest -- and it is framework-independent.
|
||||
|
||||
Render uses the tokenizer's OWN chat template (``apply_chat_template``), so the
|
||||
training render is identical to what the serving harness produces (ccpty sends
|
||||
structured messages and the inference endpoint applies the model's default
|
||||
template). Trajectories are canonicalised first (``normalize_for_template``):
|
||||
a leading ``system`` turn is folded into the first ``user`` turn and consecutive
|
||||
same-role turns are merged, because instruct templates such as Mistral's have no
|
||||
system role and require strict user/assistant alternation. That same
|
||||
canonicalisation must run on the serving side. Everything else is standard
|
||||
causal-LM SFT with an assistant-only loss mask.
|
||||
|
||||
python sft.py --data ./trajectories --base <hf-model-dir> --out ./ckpt
|
||||
python sft.py --data ./trajectories --base <dir> --inspect # mask stats, no training
|
||||
|
||||
Training needs torch + transformers + peft (a GPU box). ``--inspect`` and the
|
||||
normalize/mask helpers run anywhere a tokenizer with a chat template is
|
||||
available.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
def normalize_for_template(messages: list[dict]) -> list[dict]:
|
||||
"""Canonicalise a trajectory for instruct chat templates that have no system
|
||||
role and require strict user/assistant alternation (Mistral and friends):
|
||||
treat ``system`` as ``user``, then merge consecutive same-role turns by
|
||||
joining their content with a newline.
|
||||
|
||||
This is loss-neutral for the assistant mask (only environment/user turns
|
||||
ever merge; the assistant commands are never adjacent in this data) and it
|
||||
is what lets ``apply_chat_template`` render the multi-turn shell dialogue.
|
||||
The serving side MUST apply the same canonicalisation, or train and serve
|
||||
diverge again.
|
||||
"""
|
||||
out: list[dict] = []
|
||||
for m in messages:
|
||||
role = "user" if m["role"] == "system" else m["role"]
|
||||
if out and out[-1]["role"] == role:
|
||||
out[-1] = {"role": role, "content": out[-1]["content"] + "\n" + m["content"]}
|
||||
else:
|
||||
out.append({"role": role, "content": m["content"]})
|
||||
return out
|
||||
|
||||
|
||||
def build_masked_example(messages: list[dict], tokenizer) -> dict:
|
||||
"""Tokenize a trajectory with the tokenizer's OWN chat template and build an
|
||||
assistant-only loss mask.
|
||||
|
||||
The render is ``tokenizer.apply_chat_template`` on the canonicalised turns,
|
||||
so it is byte-identical to what the serving harness sends. The mask is
|
||||
derived by token-prefix differencing: the tokens an assistant turn
|
||||
contributes are exactly those that appear when it extends the rendered
|
||||
prefix, which trains the commands plus the template's end-of-turn token (so
|
||||
the model learns to stop) and masks every environment turn to ``-100``. This
|
||||
assumes an additive template (each turn extends the previous render); a
|
||||
non-additive one raises rather than silently mis-mask.
|
||||
"""
|
||||
msgs = normalize_for_template(messages)
|
||||
ids = tokenizer.apply_chat_template(msgs, add_generation_prompt=False)
|
||||
labels = [-100] * len(ids)
|
||||
prev: list[int] = []
|
||||
for i, m in enumerate(msgs):
|
||||
upto = tokenizer.apply_chat_template(msgs[:i + 1], add_generation_prompt=False)
|
||||
if ids[:len(upto)] != upto or upto[:len(prev)] != prev:
|
||||
raise ValueError("chat template is not additive; cannot derive an "
|
||||
"assistant loss mask by token-prefix differencing")
|
||||
if m["role"] == "assistant":
|
||||
for j in range(len(prev), len(upto)):
|
||||
labels[j] = ids[j]
|
||||
prev = upto
|
||||
return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels}
|
||||
|
||||
|
||||
def iter_keepers(data_dir: Path):
|
||||
"""Yield ``turns`` (message lists) from trajectory JSONs marked keep."""
|
||||
for f in sorted(data_dir.glob("*.json")):
|
||||
d = json.loads(f.read_text())
|
||||
if d.get("keep"):
|
||||
yield d["turns"]
|
||||
|
||||
|
||||
def mask_stats(example: dict) -> tuple[int, int]:
|
||||
"""(trained tokens, total tokens) for an example."""
|
||||
trained = sum(1 for x in example["labels"] if x != -100)
|
||||
return trained, len(example["labels"])
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# Training (GPU box: torch + transformers + peft)
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
|
||||
batch: int, accum: int, max_len: int, lora_r: int,
|
||||
load_4bit: bool = False) -> None:
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
DataCollatorForSeq2Seq, Trainer, TrainingArguments)
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(base)
|
||||
if tok.pad_token is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
|
||||
rows = []
|
||||
for turns in iter_keepers(data_dir):
|
||||
ex = build_masked_example(turns, tok)
|
||||
if len(ex["input_ids"]) <= max_len and any(l != -100 for l in ex["labels"]):
|
||||
rows.append(ex)
|
||||
if not rows:
|
||||
raise SystemExit(f"no usable keeper trajectories in {data_dir}")
|
||||
print(f"examples: {len(rows)}; "
|
||||
f"trained/total tokens: {sum(mask_stats(r)[0] for r in rows)}"
|
||||
f"/{sum(mask_stats(r)[1] for r in rows)}")
|
||||
ds = Dataset.from_list(rows)
|
||||
|
||||
# 4-bit (QLoRA) shrinks the base from ~14 GB to ~4 GB to move across the
|
||||
# OcuLink/PCIe link and to hold in VRAM; nf4 + fp16 compute works on the
|
||||
# V100 (sm_70). Without it, plain fp16 weights.
|
||||
quant = None
|
||||
if load_4bit:
|
||||
from transformers import BitsAndBytesConfig
|
||||
quant = BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base, dtype=torch.float16, quantization_config=quant)
|
||||
if load_4bit:
|
||||
from peft import prepare_model_for_kbit_training
|
||||
model = prepare_model_for_kbit_training(model) # handles ckpt + input grads
|
||||
else:
|
||||
model.enable_input_require_grads()
|
||||
model.gradient_checkpointing_enable()
|
||||
model = get_peft_model(model, LoraConfig(
|
||||
r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.05, task_type="CAUSAL_LM",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
))
|
||||
model.print_trainable_parameters()
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=str(out), per_device_train_batch_size=batch,
|
||||
gradient_accumulation_steps=accum, num_train_epochs=epochs,
|
||||
learning_rate=lr, fp16=True, logging_steps=1, save_strategy="epoch",
|
||||
report_to=["tensorboard"], logging_dir=str(out / "runs"),
|
||||
remove_unused_columns=False, warmup_ratio=0.03,
|
||||
)
|
||||
trainer = Trainer(
|
||||
model=model, args=args, train_dataset=ds,
|
||||
data_collator=DataCollatorForSeq2Seq(tok, padding=True, label_pad_token_id=-100),
|
||||
)
|
||||
trainer.train()
|
||||
model.save_pretrained(str(out))
|
||||
tok.save_pretrained(str(out))
|
||||
# durable, greppable record of the curve (loss/lr/grad_norm per step).
|
||||
(out / "log_history.jsonl").write_text(
|
||||
"\n".join(json.dumps(r) for r in trainer.state.log_history))
|
||||
print(f"saved LoRA adapter + log_history.jsonl -> {out} "
|
||||
f"(tensorboard: --logdir {out / 'runs'})")
|
||||
|
||||
|
||||
def inspect(data_dir: Path, base: str) -> None:
|
||||
from transformers import AutoTokenizer
|
||||
tok = AutoTokenizer.from_pretrained(base)
|
||||
n = tt = tr = 0
|
||||
for turns in iter_keepers(data_dir):
|
||||
ex = build_masked_example(turns, tok)
|
||||
t, total = mask_stats(ex)
|
||||
tr += t; tt += total; n += 1
|
||||
if not n:
|
||||
raise SystemExit(f"no keeper trajectories in {data_dir}")
|
||||
print(f"{n} keeper trajectories; {tr}/{tt} tokens trained "
|
||||
f"({100*tr/tt:.1f}% assistant, rest masked)")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description="SFT a model on shell trajectories.")
|
||||
ap.add_argument("--data", type=Path, default=Path("./trajectories"))
|
||||
ap.add_argument("--base", required=True, help="HF model id or local dir")
|
||||
ap.add_argument("--out", type=Path, default=Path("./ckpt"))
|
||||
ap.add_argument("--inspect", action="store_true", help="mask stats only, no training")
|
||||
ap.add_argument("--epochs", type=float, default=3.0)
|
||||
ap.add_argument("--lr", type=float, default=2e-4)
|
||||
ap.add_argument("--batch", type=int, default=1)
|
||||
ap.add_argument("--accum", type=int, default=8)
|
||||
ap.add_argument("--max-len", type=int, default=4096)
|
||||
ap.add_argument("--lora-r", type=int, default=16)
|
||||
ap.add_argument("--load-4bit", action="store_true",
|
||||
help="QLoRA: load base in 4-bit (less to move over the link, less VRAM)")
|
||||
ns = ap.parse_args()
|
||||
if ns.inspect:
|
||||
inspect(ns.data, ns.base)
|
||||
else:
|
||||
train(ns.data, ns.base, ns.out, ns.epochs, ns.lr, ns.batch, ns.accum,
|
||||
ns.max_len, ns.lora_r, ns.load_4bit)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue