Merge branch 'feature/12'
feat(12): load training data from a raw dir, a curated jsonl, or the Hub
This commit is contained in:
commit
b8843557d7
3 changed files with 85 additions and 11 deletions
2
TODO
2
TODO
|
|
@ -197,7 +197,7 @@ Content-Type: application/issue
|
||||||
ID: 12
|
ID: 12
|
||||||
Type: feature
|
Type: feature
|
||||||
Title: load training data from a raw dir, a curated jsonl, or the Hub
|
Title: load training data from a raw dir, a curated jsonl, or the Hub
|
||||||
Status: in-progress
|
Status: done
|
||||||
Priority: medium
|
Priority: medium
|
||||||
Created: 2026-06-17
|
Created: 2026-06-17
|
||||||
Module: sekft
|
Module: sekft
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ canonicalisation must run on the serving side. Everything else is standard
|
||||||
causal-LM SFT with an assistant-only loss mask.
|
causal-LM SFT with an assistant-only loss mask.
|
||||||
|
|
||||||
sekft-train --data ./trajectories --base <hf-model-dir> --out ./ckpt
|
sekft-train --data ./trajectories --base <hf-model-dir> --out ./ckpt
|
||||||
|
sekft-train --data corpus.jsonl --base <dir> # a curated .jsonl corpus
|
||||||
|
sekft-train --hub --base <dir> # the published corpus (Hub)
|
||||||
sekft-train --data ./trajectories --base <dir> --inspect # mask stats, no training
|
sekft-train --data ./trajectories --base <dir> --inspect # mask stats, no training
|
||||||
|
|
||||||
Training needs torch + transformers + peft (a GPU box). ``--inspect`` and the
|
Training needs torch + transformers + peft (a GPU box). ``--inspect`` and the
|
||||||
|
|
@ -90,13 +92,42 @@ def build_masked_example(messages: list[dict[str, str]], tokenizer: Any) -> dict
|
||||||
|
|
||||||
|
|
||||||
def iter_keepers(data_dir: Path) -> Iterator[list[dict[str, str]]]:
|
def iter_keepers(data_dir: Path) -> Iterator[list[dict[str, str]]]:
|
||||||
"""Yield ``turns`` (message lists) from trajectory JSONs marked keep."""
|
"""Yield ``turns`` (message lists) from raw rollout JSONs marked keep."""
|
||||||
for f in sorted(data_dir.glob("*.json")):
|
for f in sorted(data_dir.glob("*.json")):
|
||||||
d = json.loads(f.read_text())
|
d = json.loads(f.read_text())
|
||||||
if d.get("keep"):
|
if d.get("keep"):
|
||||||
yield d["turns"]
|
yield d["turns"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_turns(data: Path, hub: bool = False,
|
||||||
|
revision: str | None = None) -> Iterator[list[dict[str, str]]]:
|
||||||
|
"""Yield assistant-bearing ``turns`` from one of three sources:
|
||||||
|
|
||||||
|
- ``--hub``: the published corpus via posix-sdc's ``load_trajectories`` (the
|
||||||
|
in-repo ``data/`` of a posix-sdc checkout, else the Hugging Face Hub);
|
||||||
|
- ``data`` a ``.jsonl`` file: a curated corpus, already keep-filtered, one
|
||||||
|
record per line;
|
||||||
|
- ``data`` a directory: raw rollout ``.json`` (keep-filtered here).
|
||||||
|
|
||||||
|
posix-sdc is imported lazily, so the raw-dir and ``.jsonl`` paths need
|
||||||
|
neither posix-sdc nor huggingface_hub installed.
|
||||||
|
"""
|
||||||
|
if hub:
|
||||||
|
from tiararodney.posix_sdc import load_trajectories
|
||||||
|
for r in load_trajectories(revision=revision):
|
||||||
|
yield r["turns"]
|
||||||
|
elif data.is_dir():
|
||||||
|
yield from iter_keepers(data)
|
||||||
|
elif data.suffix == ".jsonl":
|
||||||
|
with open(data) as fh:
|
||||||
|
for line in fh:
|
||||||
|
if line.strip():
|
||||||
|
yield json.loads(line)["turns"]
|
||||||
|
else:
|
||||||
|
raise SystemExit(
|
||||||
|
f"--data must be a rollout directory or a .jsonl corpus (got {data})")
|
||||||
|
|
||||||
|
|
||||||
def mask_stats(example: dict[str, list[Any]]) -> tuple[int, int]:
|
def mask_stats(example: dict[str, list[Any]]) -> tuple[int, int]:
|
||||||
"""(trained tokens, total tokens) for an example."""
|
"""(trained tokens, total tokens) for an example."""
|
||||||
trained = sum(1 for x in example["labels"] if x != -100)
|
trained = sum(1 for x in example["labels"] if x != -100)
|
||||||
|
|
@ -109,7 +140,8 @@ def mask_stats(example: dict[str, list[Any]]) -> tuple[int, int]:
|
||||||
|
|
||||||
def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
|
def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
|
||||||
batch: int, accum: int, max_len: int, lora_r: int,
|
batch: int, accum: int, max_len: int, lora_r: int,
|
||||||
load_4bit: bool = False) -> None:
|
load_4bit: bool = False, hub: bool = False,
|
||||||
|
revision: str | None = None) -> None:
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft import LoraConfig, get_peft_model
|
from peft import LoraConfig, get_peft_model
|
||||||
|
|
@ -121,16 +153,17 @@ def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
|
||||||
# WARNING by default, which is most of why training looks silent.
|
# WARNING by default, which is most of why training looks silent.
|
||||||
hf_logging.set_verbosity_info()
|
hf_logging.set_verbosity_info()
|
||||||
|
|
||||||
log.info("base=%s data=%s out=%s", base, data_dir, out)
|
source = "hub" if hub else data_dir
|
||||||
|
log.info("base=%s data=%s out=%s", base, source, out)
|
||||||
log.info("loading tokenizer: %s", base)
|
log.info("loading tokenizer: %s", base)
|
||||||
tok = AutoTokenizer.from_pretrained(base)
|
tok = AutoTokenizer.from_pretrained(base)
|
||||||
if tok.pad_token is None:
|
if tok.pad_token is None:
|
||||||
tok.pad_token = tok.eos_token
|
tok.pad_token = tok.eos_token
|
||||||
|
|
||||||
log.info("building masked examples from %s ...", data_dir)
|
log.info("building masked examples from %s ...", source)
|
||||||
rows: list[dict[str, list[Any]]] = []
|
rows: list[dict[str, list[Any]]] = []
|
||||||
n_seen = n_long = n_empty = 0
|
n_seen = n_long = n_empty = 0
|
||||||
for turns in iter_keepers(data_dir):
|
for turns in load_turns(data_dir, hub=hub, revision=revision):
|
||||||
n_seen += 1
|
n_seen += 1
|
||||||
ex = build_masked_example(turns, tok)
|
ex = build_masked_example(turns, tok)
|
||||||
log.debug(" trajectory %d: %d turns -> %d tokens, %d trained",
|
log.debug(" trajectory %d: %d turns -> %d tokens, %d trained",
|
||||||
|
|
@ -206,12 +239,13 @@ def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
|
||||||
out, out / "runs")
|
out, out / "runs")
|
||||||
|
|
||||||
|
|
||||||
def inspect(data_dir: Path, base: str) -> None:
|
def inspect(data_dir: Path, base: str, hub: bool = False,
|
||||||
|
revision: str | None = None) -> None:
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
log.info("loading tokenizer: %s", base)
|
log.info("loading tokenizer: %s", base)
|
||||||
tok = AutoTokenizer.from_pretrained(base)
|
tok = AutoTokenizer.from_pretrained(base)
|
||||||
n = tt = tr = 0
|
n = tt = tr = 0
|
||||||
for turns in iter_keepers(data_dir):
|
for turns in load_turns(data_dir, hub=hub, revision=revision):
|
||||||
ex = build_masked_example(turns, tok)
|
ex = build_masked_example(turns, tok)
|
||||||
t, total = mask_stats(ex)
|
t, total = mask_stats(ex)
|
||||||
tr += t; tt += total; n += 1
|
tr += t; tt += total; n += 1
|
||||||
|
|
@ -223,7 +257,12 @@ def inspect(data_dir: Path, base: str) -> None:
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
ap = argparse.ArgumentParser(description="SFT a model on shell trajectories.")
|
ap = argparse.ArgumentParser(description="SFT a model on shell trajectories.")
|
||||||
ap.add_argument("--data", type=Path, default=Path("./trajectories"))
|
ap.add_argument("--data", type=Path, default=Path("./trajectories"),
|
||||||
|
help="a raw rollout dir or a curated .jsonl corpus")
|
||||||
|
ap.add_argument("--hub", action="store_true",
|
||||||
|
help="load the published corpus via posix-sdc (Hub); ignores --data")
|
||||||
|
ap.add_argument("--revision", default=None,
|
||||||
|
help="dataset revision/tag to pin when using --hub")
|
||||||
ap.add_argument("--base", required=True, help="HF model id or local dir")
|
ap.add_argument("--base", required=True, help="HF model id or local dir")
|
||||||
ap.add_argument("--out", type=Path, default=Path("./ckpt"))
|
ap.add_argument("--out", type=Path, default=Path("./ckpt"))
|
||||||
ap.add_argument("--inspect", action="store_true", help="mask stats only, no training")
|
ap.add_argument("--inspect", action="store_true", help="mask stats only, no training")
|
||||||
|
|
@ -240,10 +279,10 @@ def main() -> None:
|
||||||
ns = ap.parse_args()
|
ns = ap.parse_args()
|
||||||
_setup_logging(verbose=ns.verbose, quiet=ns.quiet)
|
_setup_logging(verbose=ns.verbose, quiet=ns.quiet)
|
||||||
if ns.inspect:
|
if ns.inspect:
|
||||||
inspect(ns.data, ns.base)
|
inspect(ns.data, ns.base, hub=ns.hub, revision=ns.revision)
|
||||||
else:
|
else:
|
||||||
train(ns.data, ns.base, ns.out, ns.epochs, ns.lr, ns.batch, ns.accum,
|
train(ns.data, ns.base, ns.out, ns.epochs, ns.lr, ns.batch, ns.accum,
|
||||||
ns.max_len, ns.lora_r, ns.load_4bit)
|
ns.max_len, ns.lora_r, ns.load_4bit, hub=ns.hub, revision=ns.revision)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
35
tests/unit/test_load.py
Normal file
35
tests/unit/test_load.py
Normal file
|
|
@ -0,0 +1,35 @@
|
||||||
|
"""Unit tests for the trainer's three-source data loader (raw dir / curated
|
||||||
|
jsonl). The Hub path delegates to posix-sdc and is covered there."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tiararodney.sekft import sft
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_turns_from_raw_dir(tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "a.json").write_text(json.dumps(
|
||||||
|
{"keep": True, "turns": [{"role": "assistant", "content": "ls"}]}))
|
||||||
|
(tmp_path / "b.json").write_text(json.dumps( # not kept -> excluded
|
||||||
|
{"keep": False, "turns": [{"role": "assistant", "content": "rm -rf /"}]}))
|
||||||
|
got = list(sft.load_turns(tmp_path))
|
||||||
|
assert len(got) == 1
|
||||||
|
assert got[0][0]["content"] == "ls"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_turns_from_jsonl(tmp_path: Path) -> None:
|
||||||
|
f = tmp_path / "corpus.jsonl"
|
||||||
|
f.write_text("\n".join(json.dumps({"turns": [{"role": "assistant", "content": c}]})
|
||||||
|
for c in ("ls", "cat x")) + "\n")
|
||||||
|
got = list(sft.load_turns(f))
|
||||||
|
assert [t[0]["content"] for t in got] == ["ls", "cat x"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_turns_rejects_other_paths(tmp_path: Path) -> None:
|
||||||
|
bad = tmp_path / "notes.txt"
|
||||||
|
bad.write_text("hi")
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
list(sft.load_turns(bad))
|
||||||
Loading…
Add table
Add a link
Reference in a new issue