refactor: annotate the trainer modules under mypy strict

This commit is contained in:
Tiara Rodney 2026-06-17 14:03:52 +02:00
parent e60495b2ce
commit 9397280e6f
Signed by: tiara
GPG key ID: 5CD8EC1D46106723
3 changed files with 30 additions and 23 deletions

View file

@ -20,7 +20,9 @@ from __future__ import annotations
import argparse import argparse
import json import json
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
from tiararodney.posix_sdc.factory.rollout import rollout from tiararodney.posix_sdc.factory.rollout import rollout
@ -30,7 +32,7 @@ from .sft import normalize_for_template
def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64, def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
temperature: float = 0.7): temperature: float = 0.7) -> Callable[[list[dict[str, str]]], str]:
"""A ``messages -> command`` callable backed by base + LoRA adapter. """A ``messages -> command`` callable backed by base + LoRA adapter.
Renders the conversation exactly as the model was trained, appends the Renders the conversation exactly as the model was trained, appends the
@ -46,7 +48,7 @@ def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
model = PeftModel.from_pretrained(model, adapter) model = PeftModel.from_pretrained(model, adapter)
model.eval() model.eval()
def operator(messages): def operator(messages: list[dict[str, str]]) -> str:
msgs = normalize_for_template(messages) msgs = normalize_for_template(messages)
ids = tok.apply_chat_template( ids = tok.apply_chat_template(
msgs, add_generation_prompt=True, return_tensors="pt").to(model.device) msgs, add_generation_prompt=True, return_tensors="pt").to(model.device)
@ -55,13 +57,14 @@ def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
ids, max_new_tokens=max_new_tokens, ids, max_new_tokens=max_new_tokens,
do_sample=temperature > 0, temperature=max(temperature, 1e-2), do_sample=temperature > 0, temperature=max(temperature, 1e-2),
eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id) eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id)
return tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip() text: str = tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip()
return text
return operator return operator
def evaluate(base: str, adapter: str, scenarios_dir: Path, n: int, def evaluate(base: str, adapter: str, scenarios_dir: Path, n: int,
max_steps: int, temperature: float) -> dict: max_steps: int, temperature: float) -> dict[str, Any]:
if not available(): if not available():
raise SystemExit("sekft-dash image unavailable; `docker build -t sekft-dash .`") raise SystemExit("sekft-dash image unavailable; `docker build -t sekft-dash .`")
operator = make_local_operator(base, adapter, temperature=temperature) operator = make_local_operator(base, adapter, temperature=temperature)

View file

@ -23,6 +23,7 @@ import argparse
import gc import gc
import json import json
from pathlib import Path from pathlib import Path
from typing import Any
import torch import torch
from datasets import Dataset from datasets import Dataset
@ -68,7 +69,7 @@ class Resident:
# -- build masked rows from kept trajectories -------------------------- # -- build masked rows from kept trajectories --------------------------
def _rows(self, data_dir: Path, max_len: int) -> list[dict]: def _rows(self, data_dir: Path, max_len: int) -> list[dict[str, list[Any]]]:
rows = [] rows = []
for turns in iter_keepers(data_dir): for turns in iter_keepers(data_dir):
ex = build_masked_example(turns, self.tok) ex = build_masked_example(turns, self.tok)
@ -83,8 +84,8 @@ class Resident:
def fit(self, data_dir: str, out: str, lora_r: int = 16, lr: float = 2e-4, def fit(self, data_dir: str, out: str, lora_r: int = 16, lr: float = 2e-4,
epochs: float = 3.0, batch: int = 1, accum: int = 8, epochs: float = 3.0, batch: int = 1, accum: int = 8,
max_len: int = 4096) -> Path: max_len: int = 4096) -> Path:
data_dir, out = Path(data_dir).expanduser(), Path(out).expanduser() ddir, odir = Path(data_dir).expanduser(), Path(out).expanduser()
ds = Dataset.from_list(self._rows(data_dir, max_len)) ds = Dataset.from_list(self._rows(ddir, max_len))
if not self.load_4bit: if not self.load_4bit:
self.base.gradient_checkpointing_enable() self.base.gradient_checkpointing_enable()
model = get_peft_model(self.base, LoraConfig( model = get_peft_model(self.base, LoraConfig(
@ -92,31 +93,31 @@ class Resident:
task_type="CAUSAL_LM", target_modules=LORA_TARGETS)) task_type="CAUSAL_LM", target_modules=LORA_TARGETS))
model.print_trainable_parameters() model.print_trainable_parameters()
args = TrainingArguments( args = TrainingArguments(
output_dir=str(out), per_device_train_batch_size=batch, output_dir=str(odir), per_device_train_batch_size=batch,
gradient_accumulation_steps=accum, num_train_epochs=epochs, gradient_accumulation_steps=accum, num_train_epochs=epochs,
learning_rate=lr, fp16=True, logging_steps=1, save_strategy="no", learning_rate=lr, fp16=True, logging_steps=1, save_strategy="no",
report_to=["tensorboard"], logging_dir=str(out / "runs"), report_to=["tensorboard"], logging_dir=str(odir / "runs"),
remove_unused_columns=False, warmup_ratio=0.03) remove_unused_columns=False, warmup_ratio=0.03)
tr = Trainer(model=model, args=args, train_dataset=ds, tr = Trainer(model=model, args=args, train_dataset=ds,
data_collator=DataCollatorForSeq2Seq( data_collator=DataCollatorForSeq2Seq(
self.tok, padding=True, label_pad_token_id=-100)) self.tok, padding=True, label_pad_token_id=-100))
tr.train() tr.train()
out.mkdir(parents=True, exist_ok=True) odir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(out)) model.save_pretrained(str(odir))
self.tok.save_pretrained(str(out)) self.tok.save_pretrained(str(odir))
(out / "log_history.jsonl").write_text( (odir / "log_history.jsonl").write_text(
"\n".join(json.dumps(r) for r in tr.state.log_history)) "\n".join(json.dumps(r) for r in tr.state.log_history))
losses = [h["loss"] for h in tr.state.log_history if "loss" in h] losses = [h["loss"] for h in tr.state.log_history if "loss" in h]
print(f"[resident] fit -> {out} final loss {losses[-1] if losses else '?'}") print(f"[resident] fit -> {odir} final loss {losses[-1] if losses else '?'}")
self.base = model.unload() # strip LoRA, restore resident base self.base = model.unload() # strip LoRA, restore resident base
del model, tr, ds del model, tr, ds
_free() _free()
return out return odir
# -- behavioural eval of a saved adapter ------------------------------- # -- behavioural eval of a saved adapter -------------------------------
def evaluate(self, adapter: str, scenarios_dir: str, n: int = 10, def evaluate(self, adapter: str, scenarios_dir: str, n: int = 10,
max_steps: int = 30, temperature: float = 0.7) -> dict: max_steps: int = 30, temperature: float = 0.7) -> dict[str, Any]:
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
from tiararodney.posix_sdc.factory.rollout import rollout from tiararodney.posix_sdc.factory.rollout import rollout
from tiararodney.posix_sdc.schema import Scenario from tiararodney.posix_sdc.schema import Scenario
@ -130,7 +131,7 @@ class Resident:
pm = self.base pm = self.base
pm.eval() pm.eval()
def operator(messages): def operator(messages: list[dict[str, str]]) -> str:
msgs = normalize_for_template(messages) msgs = normalize_for_template(messages)
ids = self.tok.apply_chat_template( ids = self.tok.apply_chat_template(
msgs, add_generation_prompt=True, return_tensors="pt").to(pm.device) msgs, add_generation_prompt=True, return_tensors="pt").to(pm.device)
@ -139,7 +140,8 @@ class Resident:
temperature=max(temperature, 1e-2), temperature=max(temperature, 1e-2),
eos_token_id=self.tok.eos_token_id, eos_token_id=self.tok.eos_token_id,
pad_token_id=self.tok.eos_token_id) pad_token_id=self.tok.eos_token_id)
return self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip() text: str = self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip()
return text
backend = DashDocker() backend = DashDocker()
rows = [] rows = []

View file

@ -28,9 +28,11 @@ from __future__ import annotations
import argparse import argparse
import json import json
from collections.abc import Iterator
from pathlib import Path from pathlib import Path
from typing import Any
def normalize_for_template(messages: list[dict]) -> list[dict]: def normalize_for_template(messages: list[dict[str, str]]) -> list[dict[str, str]]:
"""Canonicalise a trajectory for instruct chat templates that have no system """Canonicalise a trajectory for instruct chat templates that have no system
role and require strict user/assistant alternation (Mistral and friends): role and require strict user/assistant alternation (Mistral and friends):
treat ``system`` as ``user``, then merge consecutive same-role turns by treat ``system`` as ``user``, then merge consecutive same-role turns by
@ -42,7 +44,7 @@ def normalize_for_template(messages: list[dict]) -> list[dict]:
The serving side MUST apply the same canonicalisation, or train and serve The serving side MUST apply the same canonicalisation, or train and serve
diverge again. diverge again.
""" """
out: list[dict] = [] out: list[dict[str, str]] = []
for m in messages: for m in messages:
role = "user" if m["role"] == "system" else m["role"] role = "user" if m["role"] == "system" else m["role"]
if out and out[-1]["role"] == role: if out and out[-1]["role"] == role:
@ -52,7 +54,7 @@ def normalize_for_template(messages: list[dict]) -> list[dict]:
return out return out
def build_masked_example(messages: list[dict], tokenizer) -> dict: def build_masked_example(messages: list[dict[str, str]], tokenizer: Any) -> dict[str, list[Any]]:
"""Tokenize a trajectory with the tokenizer's OWN chat template and build an """Tokenize a trajectory with the tokenizer's OWN chat template and build an
assistant-only loss mask. assistant-only loss mask.
@ -81,7 +83,7 @@ def build_masked_example(messages: list[dict], tokenizer) -> dict:
return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels} return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels}
def iter_keepers(data_dir: Path): def iter_keepers(data_dir: Path) -> Iterator[list[dict[str, str]]]:
"""Yield ``turns`` (message lists) from trajectory JSONs marked keep.""" """Yield ``turns`` (message lists) from trajectory JSONs marked keep."""
for f in sorted(data_dir.glob("*.json")): for f in sorted(data_dir.glob("*.json")):
d = json.loads(f.read_text()) d = json.loads(f.read_text())
@ -89,7 +91,7 @@ def iter_keepers(data_dir: Path):
yield d["turns"] yield d["turns"]
def mask_stats(example: dict) -> tuple[int, int]: def mask_stats(example: dict[str, list[Any]]) -> tuple[int, int]:
"""(trained tokens, total tokens) for an example.""" """(trained tokens, total tokens) for an example."""
trained = sum(1 for x in example["labels"] if x != -100) trained = sum(1 for x in example["labels"] if x != -100)
return trained, len(example["labels"]) return trained, len(example["labels"])