Merge branch 'feature/9'
This commit is contained in:
commit
86df915524
8 changed files with 47 additions and 26 deletions
1
Pipfile
1
Pipfile
|
|
@ -12,6 +12,7 @@ name = "pypi"
|
|||
[dev-packages]
|
||||
tox = "*"
|
||||
pytest = "*"
|
||||
mypy = "*"
|
||||
build = "*"
|
||||
twine = "*"
|
||||
setuptools-scm = "~=8.2.0"
|
||||
|
|
|
|||
2
TODO
2
TODO
|
|
@ -140,7 +140,7 @@ Content-Type: application/issue
|
|||
ID: 9
|
||||
Type: feature
|
||||
Title: Type-check the package under mypy strict
|
||||
Status: in-progress
|
||||
Status: done
|
||||
Priority: medium
|
||||
Created: 2026-06-17
|
||||
Module: sekft
|
||||
|
|
|
|||
|
|
@ -59,6 +59,9 @@ Git = "https://git.code.tiararodney.com/tiararodney/sekft"
|
|||
where = ["src"]
|
||||
namespaces = true
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"tiararodney.sekft" = ["py.typed"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["src", "../posix-sdc/src"]
|
||||
testpaths = ["tests"]
|
||||
|
|
@ -70,6 +73,16 @@ markers = [
|
|||
|
||||
[tool.mypy]
|
||||
strict = true
|
||||
mypy_path = "src"
|
||||
explicit_package_bases = true
|
||||
namespace_packages = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"torch.*", "transformers.*", "peft.*", "datasets.*", "bitsandbytes.*",
|
||||
"tiararodney.posix_sdc.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.autopep8]
|
||||
max_line_length = 80
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ from __future__ import annotations
|
|||
|
||||
import argparse
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
|
||||
from tiararodney.posix_sdc.factory.rollout import rollout
|
||||
|
|
@ -30,7 +32,7 @@ from .sft import normalize_for_template
|
|||
|
||||
|
||||
def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
|
||||
temperature: float = 0.7):
|
||||
temperature: float = 0.7) -> Callable[[list[dict[str, str]]], str]:
|
||||
"""A ``messages -> command`` callable backed by base + LoRA adapter.
|
||||
|
||||
Renders the conversation exactly as the model was trained, appends the
|
||||
|
|
@ -46,7 +48,7 @@ def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
|
|||
model = PeftModel.from_pretrained(model, adapter)
|
||||
model.eval()
|
||||
|
||||
def operator(messages):
|
||||
def operator(messages: list[dict[str, str]]) -> str:
|
||||
msgs = normalize_for_template(messages)
|
||||
ids = tok.apply_chat_template(
|
||||
msgs, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||||
|
|
@ -55,13 +57,14 @@ def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
|
|||
ids, max_new_tokens=max_new_tokens,
|
||||
do_sample=temperature > 0, temperature=max(temperature, 1e-2),
|
||||
eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id)
|
||||
return tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip()
|
||||
text: str = tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip()
|
||||
return text
|
||||
|
||||
return operator
|
||||
|
||||
|
||||
def evaluate(base: str, adapter: str, scenarios_dir: Path, n: int,
|
||||
max_steps: int, temperature: float) -> dict:
|
||||
max_steps: int, temperature: float) -> dict[str, Any]:
|
||||
if not available():
|
||||
raise SystemExit("sekft-dash image unavailable; `docker build -t sekft-dash .`")
|
||||
operator = make_local_operator(base, adapter, temperature=temperature)
|
||||
|
|
|
|||
0
src/tiararodney/sekft/py.typed
Normal file
0
src/tiararodney/sekft/py.typed
Normal file
|
|
@ -23,6 +23,7 @@ import argparse
|
|||
import gc
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
|
@ -68,7 +69,7 @@ class Resident:
|
|||
|
||||
# -- build masked rows from kept trajectories --------------------------
|
||||
|
||||
def _rows(self, data_dir: Path, max_len: int) -> list[dict]:
|
||||
def _rows(self, data_dir: Path, max_len: int) -> list[dict[str, list[Any]]]:
|
||||
rows = []
|
||||
for turns in iter_keepers(data_dir):
|
||||
ex = build_masked_example(turns, self.tok)
|
||||
|
|
@ -83,8 +84,8 @@ class Resident:
|
|||
def fit(self, data_dir: str, out: str, lora_r: int = 16, lr: float = 2e-4,
|
||||
epochs: float = 3.0, batch: int = 1, accum: int = 8,
|
||||
max_len: int = 4096) -> Path:
|
||||
data_dir, out = Path(data_dir).expanduser(), Path(out).expanduser()
|
||||
ds = Dataset.from_list(self._rows(data_dir, max_len))
|
||||
ddir, odir = Path(data_dir).expanduser(), Path(out).expanduser()
|
||||
ds = Dataset.from_list(self._rows(ddir, max_len))
|
||||
if not self.load_4bit:
|
||||
self.base.gradient_checkpointing_enable()
|
||||
model = get_peft_model(self.base, LoraConfig(
|
||||
|
|
@ -92,31 +93,31 @@ class Resident:
|
|||
task_type="CAUSAL_LM", target_modules=LORA_TARGETS))
|
||||
model.print_trainable_parameters()
|
||||
args = TrainingArguments(
|
||||
output_dir=str(out), per_device_train_batch_size=batch,
|
||||
output_dir=str(odir), per_device_train_batch_size=batch,
|
||||
gradient_accumulation_steps=accum, num_train_epochs=epochs,
|
||||
learning_rate=lr, fp16=True, logging_steps=1, save_strategy="no",
|
||||
report_to=["tensorboard"], logging_dir=str(out / "runs"),
|
||||
report_to=["tensorboard"], logging_dir=str(odir / "runs"),
|
||||
remove_unused_columns=False, warmup_ratio=0.03)
|
||||
tr = Trainer(model=model, args=args, train_dataset=ds,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
self.tok, padding=True, label_pad_token_id=-100))
|
||||
tr.train()
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(str(out))
|
||||
self.tok.save_pretrained(str(out))
|
||||
(out / "log_history.jsonl").write_text(
|
||||
odir.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(str(odir))
|
||||
self.tok.save_pretrained(str(odir))
|
||||
(odir / "log_history.jsonl").write_text(
|
||||
"\n".join(json.dumps(r) for r in tr.state.log_history))
|
||||
losses = [h["loss"] for h in tr.state.log_history if "loss" in h]
|
||||
print(f"[resident] fit -> {out} final loss {losses[-1] if losses else '?'}")
|
||||
print(f"[resident] fit -> {odir} final loss {losses[-1] if losses else '?'}")
|
||||
self.base = model.unload() # strip LoRA, restore resident base
|
||||
del model, tr, ds
|
||||
_free()
|
||||
return out
|
||||
return odir
|
||||
|
||||
# -- behavioural eval of a saved adapter -------------------------------
|
||||
|
||||
def evaluate(self, adapter: str, scenarios_dir: str, n: int = 10,
|
||||
max_steps: int = 30, temperature: float = 0.7) -> dict:
|
||||
max_steps: int = 30, temperature: float = 0.7) -> dict[str, Any]:
|
||||
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
|
||||
from tiararodney.posix_sdc.factory.rollout import rollout
|
||||
from tiararodney.posix_sdc.schema import Scenario
|
||||
|
|
@ -130,7 +131,7 @@ class Resident:
|
|||
pm = self.base
|
||||
pm.eval()
|
||||
|
||||
def operator(messages):
|
||||
def operator(messages: list[dict[str, str]]) -> str:
|
||||
msgs = normalize_for_template(messages)
|
||||
ids = self.tok.apply_chat_template(
|
||||
msgs, add_generation_prompt=True, return_tensors="pt").to(pm.device)
|
||||
|
|
@ -139,7 +140,8 @@ class Resident:
|
|||
temperature=max(temperature, 1e-2),
|
||||
eos_token_id=self.tok.eos_token_id,
|
||||
pad_token_id=self.tok.eos_token_id)
|
||||
return self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip()
|
||||
text: str = self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip()
|
||||
return text
|
||||
|
||||
backend = DashDocker()
|
||||
rows = []
|
||||
|
|
|
|||
|
|
@ -28,9 +28,11 @@ from __future__ import annotations
|
|||
|
||||
import argparse
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
def normalize_for_template(messages: list[dict]) -> list[dict]:
|
||||
def normalize_for_template(messages: list[dict[str, str]]) -> list[dict[str, str]]:
|
||||
"""Canonicalise a trajectory for instruct chat templates that have no system
|
||||
role and require strict user/assistant alternation (Mistral and friends):
|
||||
treat ``system`` as ``user``, then merge consecutive same-role turns by
|
||||
|
|
@ -42,7 +44,7 @@ def normalize_for_template(messages: list[dict]) -> list[dict]:
|
|||
The serving side MUST apply the same canonicalisation, or train and serve
|
||||
diverge again.
|
||||
"""
|
||||
out: list[dict] = []
|
||||
out: list[dict[str, str]] = []
|
||||
for m in messages:
|
||||
role = "user" if m["role"] == "system" else m["role"]
|
||||
if out and out[-1]["role"] == role:
|
||||
|
|
@ -52,7 +54,7 @@ def normalize_for_template(messages: list[dict]) -> list[dict]:
|
|||
return out
|
||||
|
||||
|
||||
def build_masked_example(messages: list[dict], tokenizer) -> dict:
|
||||
def build_masked_example(messages: list[dict[str, str]], tokenizer: Any) -> dict[str, list[Any]]:
|
||||
"""Tokenize a trajectory with the tokenizer's OWN chat template and build an
|
||||
assistant-only loss mask.
|
||||
|
||||
|
|
@ -81,7 +83,7 @@ def build_masked_example(messages: list[dict], tokenizer) -> dict:
|
|||
return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels}
|
||||
|
||||
|
||||
def iter_keepers(data_dir: Path):
|
||||
def iter_keepers(data_dir: Path) -> Iterator[list[dict[str, str]]]:
|
||||
"""Yield ``turns`` (message lists) from trajectory JSONs marked keep."""
|
||||
for f in sorted(data_dir.glob("*.json")):
|
||||
d = json.loads(f.read_text())
|
||||
|
|
@ -89,7 +91,7 @@ def iter_keepers(data_dir: Path):
|
|||
yield d["turns"]
|
||||
|
||||
|
||||
def mask_stats(example: dict) -> tuple[int, int]:
|
||||
def mask_stats(example: dict[str, list[Any]]) -> tuple[int, int]:
|
||||
"""(trained tokens, total tokens) for an example."""
|
||||
trained = sum(1 for x in example["labels"] if x != -100)
|
||||
return trained, len(example["labels"])
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class FakeTok:
|
|||
"""Additive chat template: each turn renders to ``<role> tokens... </e>``;
|
||||
the generation prompt appends ``<assistant>``."""
|
||||
|
||||
def apply_chat_template(self, msgs: list[dict], add_generation_prompt: bool = False,
|
||||
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
||||
return_tensors: Any = None) -> list[str]:
|
||||
toks: list[str] = []
|
||||
for m in msgs:
|
||||
|
|
@ -65,7 +65,7 @@ def test_mask_trains_assistant_turns_only() -> None:
|
|||
|
||||
def test_mask_raises_on_non_additive_template() -> None:
|
||||
class BadTok:
|
||||
def apply_chat_template(self, msgs: list[dict], add_generation_prompt: bool = False,
|
||||
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
||||
return_tensors: Any = None) -> list[int]:
|
||||
return list(range(len(msgs), 0, -1)) # reversed: prefixes do not nest
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue