diff --git a/Pipfile b/Pipfile index 08edb5c..fffe94e 100644 --- a/Pipfile +++ b/Pipfile @@ -12,6 +12,7 @@ name = "pypi" [dev-packages] tox = "*" pytest = "*" +mypy = "*" build = "*" twine = "*" setuptools-scm = "~=8.2.0" diff --git a/TODO b/TODO index 3ed93ce..acb41d8 100644 --- a/TODO +++ b/TODO @@ -140,7 +140,7 @@ Content-Type: application/issue ID: 9 Type: feature Title: Type-check the package under mypy strict -Status: in-progress +Status: done Priority: medium Created: 2026-06-17 Module: sekft diff --git a/pyproject.toml b/pyproject.toml index fa22ee5..ae51a61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,9 @@ Git = "https://git.code.tiararodney.com/tiararodney/sekft" where = ["src"] namespaces = true +[tool.setuptools.package-data] +"tiararodney.sekft" = ["py.typed"] + [tool.pytest.ini_options] pythonpath = ["src", "../posix-sdc/src"] testpaths = ["tests"] @@ -70,6 +73,16 @@ markers = [ [tool.mypy] strict = true +mypy_path = "src" +explicit_package_bases = true +namespace_packages = true + +[[tool.mypy.overrides]] +module = [ + "torch.*", "transformers.*", "peft.*", "datasets.*", "bitsandbytes.*", + "tiararodney.posix_sdc.*", +] +ignore_missing_imports = true [tool.autopep8] max_line_length = 80 diff --git a/src/tiararodney/sekft/eval.py b/src/tiararodney/sekft/eval.py index 4438134..59f5bfe 100644 --- a/src/tiararodney/sekft/eval.py +++ b/src/tiararodney/sekft/eval.py @@ -20,7 +20,9 @@ from __future__ import annotations import argparse import json +from collections.abc import Callable from pathlib import Path +from typing import Any from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available from tiararodney.posix_sdc.factory.rollout import rollout @@ -30,7 +32,7 @@ from .sft import normalize_for_template def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64, - temperature: float = 0.7): + temperature: float = 0.7) -> Callable[[list[dict[str, str]]], str]: """A ``messages -> command`` callable backed by base + LoRA adapter. Renders the conversation exactly as the model was trained, appends the @@ -46,7 +48,7 @@ def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64, model = PeftModel.from_pretrained(model, adapter) model.eval() - def operator(messages): + def operator(messages: list[dict[str, str]]) -> str: msgs = normalize_for_template(messages) ids = tok.apply_chat_template( msgs, add_generation_prompt=True, return_tensors="pt").to(model.device) @@ -55,13 +57,14 @@ def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64, ids, max_new_tokens=max_new_tokens, do_sample=temperature > 0, temperature=max(temperature, 1e-2), eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id) - return tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip() + text: str = tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip() + return text return operator def evaluate(base: str, adapter: str, scenarios_dir: Path, n: int, - max_steps: int, temperature: float) -> dict: + max_steps: int, temperature: float) -> dict[str, Any]: if not available(): raise SystemExit("sekft-dash image unavailable; `docker build -t sekft-dash .`") operator = make_local_operator(base, adapter, temperature=temperature) diff --git a/src/tiararodney/sekft/py.typed b/src/tiararodney/sekft/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/tiararodney/sekft/resident.py b/src/tiararodney/sekft/resident.py index c295d48..3b0b8e0 100644 --- a/src/tiararodney/sekft/resident.py +++ b/src/tiararodney/sekft/resident.py @@ -23,6 +23,7 @@ import argparse import gc import json from pathlib import Path +from typing import Any import torch from datasets import Dataset @@ -68,7 +69,7 @@ class Resident: # -- build masked rows from kept trajectories -------------------------- - def _rows(self, data_dir: Path, max_len: int) -> list[dict]: + def _rows(self, data_dir: Path, max_len: int) -> list[dict[str, list[Any]]]: rows = [] for turns in iter_keepers(data_dir): ex = build_masked_example(turns, self.tok) @@ -83,8 +84,8 @@ class Resident: def fit(self, data_dir: str, out: str, lora_r: int = 16, lr: float = 2e-4, epochs: float = 3.0, batch: int = 1, accum: int = 8, max_len: int = 4096) -> Path: - data_dir, out = Path(data_dir).expanduser(), Path(out).expanduser() - ds = Dataset.from_list(self._rows(data_dir, max_len)) + ddir, odir = Path(data_dir).expanduser(), Path(out).expanduser() + ds = Dataset.from_list(self._rows(ddir, max_len)) if not self.load_4bit: self.base.gradient_checkpointing_enable() model = get_peft_model(self.base, LoraConfig( @@ -92,31 +93,31 @@ class Resident: task_type="CAUSAL_LM", target_modules=LORA_TARGETS)) model.print_trainable_parameters() args = TrainingArguments( - output_dir=str(out), per_device_train_batch_size=batch, + output_dir=str(odir), per_device_train_batch_size=batch, gradient_accumulation_steps=accum, num_train_epochs=epochs, learning_rate=lr, fp16=True, logging_steps=1, save_strategy="no", - report_to=["tensorboard"], logging_dir=str(out / "runs"), + report_to=["tensorboard"], logging_dir=str(odir / "runs"), remove_unused_columns=False, warmup_ratio=0.03) tr = Trainer(model=model, args=args, train_dataset=ds, data_collator=DataCollatorForSeq2Seq( self.tok, padding=True, label_pad_token_id=-100)) tr.train() - out.mkdir(parents=True, exist_ok=True) - model.save_pretrained(str(out)) - self.tok.save_pretrained(str(out)) - (out / "log_history.jsonl").write_text( + odir.mkdir(parents=True, exist_ok=True) + model.save_pretrained(str(odir)) + self.tok.save_pretrained(str(odir)) + (odir / "log_history.jsonl").write_text( "\n".join(json.dumps(r) for r in tr.state.log_history)) losses = [h["loss"] for h in tr.state.log_history if "loss" in h] - print(f"[resident] fit -> {out} final loss {losses[-1] if losses else '?'}") + print(f"[resident] fit -> {odir} final loss {losses[-1] if losses else '?'}") self.base = model.unload() # strip LoRA, restore resident base del model, tr, ds _free() - return out + return odir # -- behavioural eval of a saved adapter ------------------------------- def evaluate(self, adapter: str, scenarios_dir: str, n: int = 10, - max_steps: int = 30, temperature: float = 0.7) -> dict: + max_steps: int = 30, temperature: float = 0.7) -> dict[str, Any]: from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available from tiararodney.posix_sdc.factory.rollout import rollout from tiararodney.posix_sdc.schema import Scenario @@ -130,7 +131,7 @@ class Resident: pm = self.base pm.eval() - def operator(messages): + def operator(messages: list[dict[str, str]]) -> str: msgs = normalize_for_template(messages) ids = self.tok.apply_chat_template( msgs, add_generation_prompt=True, return_tensors="pt").to(pm.device) @@ -139,7 +140,8 @@ class Resident: temperature=max(temperature, 1e-2), eos_token_id=self.tok.eos_token_id, pad_token_id=self.tok.eos_token_id) - return self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip() + text: str = self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip() + return text backend = DashDocker() rows = [] diff --git a/src/tiararodney/sekft/sft.py b/src/tiararodney/sekft/sft.py index 5ac8633..0716f69 100644 --- a/src/tiararodney/sekft/sft.py +++ b/src/tiararodney/sekft/sft.py @@ -28,9 +28,11 @@ from __future__ import annotations import argparse import json +from collections.abc import Iterator from pathlib import Path +from typing import Any -def normalize_for_template(messages: list[dict]) -> list[dict]: +def normalize_for_template(messages: list[dict[str, str]]) -> list[dict[str, str]]: """Canonicalise a trajectory for instruct chat templates that have no system role and require strict user/assistant alternation (Mistral and friends): treat ``system`` as ``user``, then merge consecutive same-role turns by @@ -42,7 +44,7 @@ def normalize_for_template(messages: list[dict]) -> list[dict]: The serving side MUST apply the same canonicalisation, or train and serve diverge again. """ - out: list[dict] = [] + out: list[dict[str, str]] = [] for m in messages: role = "user" if m["role"] == "system" else m["role"] if out and out[-1]["role"] == role: @@ -52,7 +54,7 @@ def normalize_for_template(messages: list[dict]) -> list[dict]: return out -def build_masked_example(messages: list[dict], tokenizer) -> dict: +def build_masked_example(messages: list[dict[str, str]], tokenizer: Any) -> dict[str, list[Any]]: """Tokenize a trajectory with the tokenizer's OWN chat template and build an assistant-only loss mask. @@ -81,7 +83,7 @@ def build_masked_example(messages: list[dict], tokenizer) -> dict: return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels} -def iter_keepers(data_dir: Path): +def iter_keepers(data_dir: Path) -> Iterator[list[dict[str, str]]]: """Yield ``turns`` (message lists) from trajectory JSONs marked keep.""" for f in sorted(data_dir.glob("*.json")): d = json.loads(f.read_text()) @@ -89,7 +91,7 @@ def iter_keepers(data_dir: Path): yield d["turns"] -def mask_stats(example: dict) -> tuple[int, int]: +def mask_stats(example: dict[str, list[Any]]) -> tuple[int, int]: """(trained tokens, total tokens) for an example.""" trained = sum(1 for x in example["labels"] if x != -100) return trained, len(example["labels"]) diff --git a/tests/unit/test_sft.py b/tests/unit/test_sft.py index f2e84b4..d24eef0 100644 --- a/tests/unit/test_sft.py +++ b/tests/unit/test_sft.py @@ -15,7 +15,7 @@ class FakeTok: """Additive chat template: each turn renders to `` tokens... ``; the generation prompt appends ````.""" - def apply_chat_template(self, msgs: list[dict], add_generation_prompt: bool = False, + def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False, return_tensors: Any = None) -> list[str]: toks: list[str] = [] for m in msgs: @@ -65,7 +65,7 @@ def test_mask_trains_assistant_turns_only() -> None: def test_mask_raises_on_non_additive_template() -> None: class BadTok: - def apply_chat_template(self, msgs: list[dict], add_generation_prompt: bool = False, + def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False, return_tensors: Any = None) -> list[int]: return list(range(len(msgs), 0, -1)) # reversed: prefixes do not nest