Merge branch 'feature/10'

feat(10): structured logging for the trainer
This commit is contained in:
Tiara Rodney 2026-06-17 23:47:52 +02:00
commit 1c890c703f
Signed by: tiara
GPG key ID: 5CD8EC1D46106723
3 changed files with 71 additions and 12 deletions

2
TODO
View file

@ -155,7 +155,7 @@ Content-Type: application/issue
ID: 10 ID: 10
Type: feature Type: feature
Title: structured logging for the trainer (sft) Title: structured logging for the trainer (sft)
Status: in-progress Status: done
Priority: medium Priority: medium
Created: 2026-06-17 Created: 2026-06-17
Module: sekft Module: sekft

View file

@ -0,0 +1,20 @@
"""Console logging setup shared by the sekft entry points.
Logs go to stderr so stdout stays clean for a command's actual output (metrics
JSON, a path a caller might capture). Call :func:`setup` once at the top of a
``main()``; modules then log through ``logging.getLogger("sekft.<area>")``.
"""
from __future__ import annotations
import logging
def setup(verbose: bool = False, quiet: bool = False) -> None:
"""Configure root logging to stderr. ``quiet`` shows warnings and worse,
``verbose`` adds debug; the default is info."""
level = logging.WARNING if quiet else logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s %(levelname)-5s %(name)s %(message)s",
datefmt="%H:%M:%S",
)

View file

@ -28,10 +28,16 @@ from __future__ import annotations
import argparse import argparse
import json import json
import logging
from collections.abc import Iterator from collections.abc import Iterator
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from ._log import setup as _setup_logging
log = logging.getLogger("sekft.train")
def normalize_for_template(messages: list[dict[str, str]]) -> list[dict[str, str]]: def normalize_for_template(messages: list[dict[str, str]]) -> list[dict[str, str]]:
"""Canonicalise a trajectory for instruct chat templates that have no system """Canonicalise a trajectory for instruct chat templates that have no system
role and require strict user/assistant alternation (Mistral and friends): role and require strict user/assistant alternation (Mistral and friends):
@ -109,21 +115,44 @@ def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
DataCollatorForSeq2Seq, Trainer, TrainingArguments) DataCollatorForSeq2Seq, Trainer, TrainingArguments)
from transformers.utils import logging as hf_logging
# Surface the Trainer's own per-step curve (loss/lr/grad_norm); it is at
# WARNING by default, which is most of why training looks silent.
hf_logging.set_verbosity_info()
log.info("base=%s data=%s out=%s", base, data_dir, out)
log.info("loading tokenizer: %s", base)
tok = AutoTokenizer.from_pretrained(base) tok = AutoTokenizer.from_pretrained(base)
if tok.pad_token is None: if tok.pad_token is None:
tok.pad_token = tok.eos_token tok.pad_token = tok.eos_token
rows = [] log.info("building masked examples from %s ...", data_dir)
rows: list[dict[str, list[Any]]] = []
n_seen = n_long = n_empty = 0
for turns in iter_keepers(data_dir): for turns in iter_keepers(data_dir):
n_seen += 1
ex = build_masked_example(turns, tok) ex = build_masked_example(turns, tok)
if len(ex["input_ids"]) <= max_len and any(l != -100 for l in ex["labels"]): log.debug(" trajectory %d: %d turns -> %d tokens, %d trained",
rows.append(ex) n_seen, len(turns), len(ex["input_ids"]), mask_stats(ex)[0])
if n_seen % 100 == 0:
log.info(" ... %d trajectories processed, %d usable", n_seen, len(rows))
if len(ex["input_ids"]) > max_len:
n_long += 1
continue
if not any(l != -100 for l in ex["labels"]):
n_empty += 1
continue
rows.append(ex)
if not rows: if not rows:
raise SystemExit(f"no usable keeper trajectories in {data_dir}") raise SystemExit(f"no usable keeper trajectories in {data_dir}")
print(f"examples: {len(rows)}; " trained = sum(mask_stats(r)[0] for r in rows)
f"trained/total tokens: {sum(mask_stats(r)[0] for r in rows)}" total = sum(mask_stats(r)[1] for r in rows)
f"/{sum(mask_stats(r)[1] for r in rows)}") log.info("dataset: %d keepers -> %d usable; %d trained / %d tokens (%.1f%% assistant)",
n_seen, len(rows), trained, total, 100 * trained / total)
if n_long or n_empty:
log.warning("dropped %d trajectories: %d over --max-len %d, %d empty-mask",
n_long + n_empty, n_long, max_len, n_empty)
ds = Dataset.from_list(rows) ds = Dataset.from_list(rows)
# 4-bit (QLoRA) shrinks the base from ~14 GB to ~4 GB to move across the # 4-bit (QLoRA) shrinks the base from ~14 GB to ~4 GB to move across the
@ -136,6 +165,8 @@ def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
load_in_4bit=True, bnb_4bit_quant_type="nf4", load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True,
) )
log.info("loading base model: %s (%s)", base,
"4-bit QLoRA" if load_4bit else "fp16")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base, dtype=torch.float16, quantization_config=quant) base, dtype=torch.float16, quantization_config=quant)
if load_4bit: if load_4bit:
@ -148,7 +179,9 @@ def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.05, task_type="CAUSAL_LM", r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.05, task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)) ))
model.print_trainable_parameters() n_train, n_all = model.get_nb_trainable_parameters()
log.info("LoRA r=%d: %d trainable / %d params (%.3f%%)",
lora_r, n_train, n_all, 100 * n_train / n_all)
args = TrainingArguments( args = TrainingArguments(
output_dir=str(out), per_device_train_batch_size=batch, output_dir=str(out), per_device_train_batch_size=batch,
@ -161,18 +194,21 @@ def train(data_dir: Path, base: str, out: Path, epochs: float, lr: float,
model=model, args=args, train_dataset=ds, model=model, args=args, train_dataset=ds,
data_collator=DataCollatorForSeq2Seq(tok, padding=True, label_pad_token_id=-100), data_collator=DataCollatorForSeq2Seq(tok, padding=True, label_pad_token_id=-100),
) )
log.info("training: %g epochs, lr=%g, batch=%d x accum=%d (effective %d), max_len=%d",
epochs, lr, batch, accum, batch * accum, max_len)
trainer.train() trainer.train()
model.save_pretrained(str(out)) model.save_pretrained(str(out))
tok.save_pretrained(str(out)) tok.save_pretrained(str(out))
# durable, greppable record of the curve (loss/lr/grad_norm per step). # durable, greppable record of the curve (loss/lr/grad_norm per step).
(out / "log_history.jsonl").write_text( (out / "log_history.jsonl").write_text(
"\n".join(json.dumps(r) for r in trainer.state.log_history)) "\n".join(json.dumps(r) for r in trainer.state.log_history))
print(f"saved LoRA adapter + log_history.jsonl -> {out} " log.info("saved LoRA adapter + log_history.jsonl -> %s (tensorboard: --logdir %s)",
f"(tensorboard: --logdir {out / 'runs'})") out, out / "runs")
def inspect(data_dir: Path, base: str) -> None: def inspect(data_dir: Path, base: str) -> None:
from transformers import AutoTokenizer from transformers import AutoTokenizer
log.info("loading tokenizer: %s", base)
tok = AutoTokenizer.from_pretrained(base) tok = AutoTokenizer.from_pretrained(base)
n = tt = tr = 0 n = tt = tr = 0
for turns in iter_keepers(data_dir): for turns in iter_keepers(data_dir):
@ -181,8 +217,8 @@ def inspect(data_dir: Path, base: str) -> None:
tr += t; tt += total; n += 1 tr += t; tt += total; n += 1
if not n: if not n:
raise SystemExit(f"no keeper trajectories in {data_dir}") raise SystemExit(f"no keeper trajectories in {data_dir}")
print(f"{n} keeper trajectories; {tr}/{tt} tokens trained " log.info("%d keeper trajectories; %d/%d tokens trained (%.1f%% assistant, rest masked)",
f"({100*tr/tt:.1f}% assistant, rest masked)") n, tr, tt, 100 * tr / tt)
def main() -> None: def main() -> None:
@ -199,7 +235,10 @@ def main() -> None:
ap.add_argument("--lora-r", type=int, default=16) ap.add_argument("--lora-r", type=int, default=16)
ap.add_argument("--load-4bit", action="store_true", ap.add_argument("--load-4bit", action="store_true",
help="QLoRA: load base in 4-bit (less to move over the link, less VRAM)") help="QLoRA: load base in 4-bit (less to move over the link, less VRAM)")
ap.add_argument("-v", "--verbose", action="store_true", help="debug-level logging")
ap.add_argument("-q", "--quiet", action="store_true", help="warnings and errors only")
ns = ap.parse_args() ns = ap.parse_args()
_setup_logging(verbose=ns.verbose, quiet=ns.quiet)
if ns.inspect: if ns.inspect:
inspect(ns.data, ns.base) inspect(ns.data, ns.base)
else: else: