refactor: annotate the trainer modules under mypy strict

This commit is contained in:
Tiara Rodney 2026-06-17 14:03:52 +02:00
parent e60495b2ce
commit 9397280e6f
Signed by: tiara
GPG key ID: 5CD8EC1D46106723
3 changed files with 30 additions and 23 deletions

View file

@ -20,7 +20,9 @@ from __future__ import annotations
import argparse
import json
from collections.abc import Callable
from pathlib import Path
from typing import Any
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
from tiararodney.posix_sdc.factory.rollout import rollout
@ -30,7 +32,7 @@ from .sft import normalize_for_template
def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
temperature: float = 0.7):
temperature: float = 0.7) -> Callable[[list[dict[str, str]]], str]:
"""A ``messages -> command`` callable backed by base + LoRA adapter.
Renders the conversation exactly as the model was trained, appends the
@ -46,7 +48,7 @@ def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
model = PeftModel.from_pretrained(model, adapter)
model.eval()
def operator(messages):
def operator(messages: list[dict[str, str]]) -> str:
msgs = normalize_for_template(messages)
ids = tok.apply_chat_template(
msgs, add_generation_prompt=True, return_tensors="pt").to(model.device)
@ -55,13 +57,14 @@ def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
ids, max_new_tokens=max_new_tokens,
do_sample=temperature > 0, temperature=max(temperature, 1e-2),
eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id)
return tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip()
text: str = tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip()
return text
return operator
def evaluate(base: str, adapter: str, scenarios_dir: Path, n: int,
max_steps: int, temperature: float) -> dict:
max_steps: int, temperature: float) -> dict[str, Any]:
if not available():
raise SystemExit("sekft-dash image unavailable; `docker build -t sekft-dash .`")
operator = make_local_operator(base, adapter, temperature=temperature)

View file

@ -23,6 +23,7 @@ import argparse
import gc
import json
from pathlib import Path
from typing import Any
import torch
from datasets import Dataset
@ -68,7 +69,7 @@ class Resident:
# -- build masked rows from kept trajectories --------------------------
def _rows(self, data_dir: Path, max_len: int) -> list[dict]:
def _rows(self, data_dir: Path, max_len: int) -> list[dict[str, list[Any]]]:
rows = []
for turns in iter_keepers(data_dir):
ex = build_masked_example(turns, self.tok)
@ -83,8 +84,8 @@ class Resident:
def fit(self, data_dir: str, out: str, lora_r: int = 16, lr: float = 2e-4,
epochs: float = 3.0, batch: int = 1, accum: int = 8,
max_len: int = 4096) -> Path:
data_dir, out = Path(data_dir).expanduser(), Path(out).expanduser()
ds = Dataset.from_list(self._rows(data_dir, max_len))
ddir, odir = Path(data_dir).expanduser(), Path(out).expanduser()
ds = Dataset.from_list(self._rows(ddir, max_len))
if not self.load_4bit:
self.base.gradient_checkpointing_enable()
model = get_peft_model(self.base, LoraConfig(
@ -92,31 +93,31 @@ class Resident:
task_type="CAUSAL_LM", target_modules=LORA_TARGETS))
model.print_trainable_parameters()
args = TrainingArguments(
output_dir=str(out), per_device_train_batch_size=batch,
output_dir=str(odir), per_device_train_batch_size=batch,
gradient_accumulation_steps=accum, num_train_epochs=epochs,
learning_rate=lr, fp16=True, logging_steps=1, save_strategy="no",
report_to=["tensorboard"], logging_dir=str(out / "runs"),
report_to=["tensorboard"], logging_dir=str(odir / "runs"),
remove_unused_columns=False, warmup_ratio=0.03)
tr = Trainer(model=model, args=args, train_dataset=ds,
data_collator=DataCollatorForSeq2Seq(
self.tok, padding=True, label_pad_token_id=-100))
tr.train()
out.mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(out))
self.tok.save_pretrained(str(out))
(out / "log_history.jsonl").write_text(
odir.mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(odir))
self.tok.save_pretrained(str(odir))
(odir / "log_history.jsonl").write_text(
"\n".join(json.dumps(r) for r in tr.state.log_history))
losses = [h["loss"] for h in tr.state.log_history if "loss" in h]
print(f"[resident] fit -> {out} final loss {losses[-1] if losses else '?'}")
print(f"[resident] fit -> {odir} final loss {losses[-1] if losses else '?'}")
self.base = model.unload() # strip LoRA, restore resident base
del model, tr, ds
_free()
return out
return odir
# -- behavioural eval of a saved adapter -------------------------------
def evaluate(self, adapter: str, scenarios_dir: str, n: int = 10,
max_steps: int = 30, temperature: float = 0.7) -> dict:
max_steps: int = 30, temperature: float = 0.7) -> dict[str, Any]:
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
from tiararodney.posix_sdc.factory.rollout import rollout
from tiararodney.posix_sdc.schema import Scenario
@ -130,7 +131,7 @@ class Resident:
pm = self.base
pm.eval()
def operator(messages):
def operator(messages: list[dict[str, str]]) -> str:
msgs = normalize_for_template(messages)
ids = self.tok.apply_chat_template(
msgs, add_generation_prompt=True, return_tensors="pt").to(pm.device)
@ -139,7 +140,8 @@ class Resident:
temperature=max(temperature, 1e-2),
eos_token_id=self.tok.eos_token_id,
pad_token_id=self.tok.eos_token_id)
return self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip()
text: str = self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip()
return text
backend = DashDocker()
rows = []

View file

@ -28,9 +28,11 @@ from __future__ import annotations
import argparse
import json
from collections.abc import Iterator
from pathlib import Path
from typing import Any
def normalize_for_template(messages: list[dict]) -> list[dict]:
def normalize_for_template(messages: list[dict[str, str]]) -> list[dict[str, str]]:
"""Canonicalise a trajectory for instruct chat templates that have no system
role and require strict user/assistant alternation (Mistral and friends):
treat ``system`` as ``user``, then merge consecutive same-role turns by
@ -42,7 +44,7 @@ def normalize_for_template(messages: list[dict]) -> list[dict]:
The serving side MUST apply the same canonicalisation, or train and serve
diverge again.
"""
out: list[dict] = []
out: list[dict[str, str]] = []
for m in messages:
role = "user" if m["role"] == "system" else m["role"]
if out and out[-1]["role"] == role:
@ -52,7 +54,7 @@ def normalize_for_template(messages: list[dict]) -> list[dict]:
return out
def build_masked_example(messages: list[dict], tokenizer) -> dict:
def build_masked_example(messages: list[dict[str, str]], tokenizer: Any) -> dict[str, list[Any]]:
"""Tokenize a trajectory with the tokenizer's OWN chat template and build an
assistant-only loss mask.
@ -81,7 +83,7 @@ def build_masked_example(messages: list[dict], tokenizer) -> dict:
return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels}
def iter_keepers(data_dir: Path):
def iter_keepers(data_dir: Path) -> Iterator[list[dict[str, str]]]:
"""Yield ``turns`` (message lists) from trajectory JSONs marked keep."""
for f in sorted(data_dir.glob("*.json")):
d = json.loads(f.read_text())
@ -89,7 +91,7 @@ def iter_keepers(data_dir: Path):
yield d["turns"]
def mask_stats(example: dict) -> tuple[int, int]:
def mask_stats(example: dict[str, list[Any]]) -> tuple[int, int]:
"""(trained tokens, total tokens) for an example."""
trained = sum(1 for x in example["labels"] if x != -100)
return trained, len(example["labels"])