Merge branch 'feature/9'
This commit is contained in:
commit
86df915524
8 changed files with 47 additions and 26 deletions
1
Pipfile
1
Pipfile
|
|
@ -12,6 +12,7 @@ name = "pypi"
|
||||||
[dev-packages]
|
[dev-packages]
|
||||||
tox = "*"
|
tox = "*"
|
||||||
pytest = "*"
|
pytest = "*"
|
||||||
|
mypy = "*"
|
||||||
build = "*"
|
build = "*"
|
||||||
twine = "*"
|
twine = "*"
|
||||||
setuptools-scm = "~=8.2.0"
|
setuptools-scm = "~=8.2.0"
|
||||||
|
|
|
||||||
2
TODO
2
TODO
|
|
@ -140,7 +140,7 @@ Content-Type: application/issue
|
||||||
ID: 9
|
ID: 9
|
||||||
Type: feature
|
Type: feature
|
||||||
Title: Type-check the package under mypy strict
|
Title: Type-check the package under mypy strict
|
||||||
Status: in-progress
|
Status: done
|
||||||
Priority: medium
|
Priority: medium
|
||||||
Created: 2026-06-17
|
Created: 2026-06-17
|
||||||
Module: sekft
|
Module: sekft
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,9 @@ Git = "https://git.code.tiararodney.com/tiararodney/sekft"
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
namespaces = true
|
namespaces = true
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
"tiararodney.sekft" = ["py.typed"]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
pythonpath = ["src", "../posix-sdc/src"]
|
pythonpath = ["src", "../posix-sdc/src"]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
|
@ -70,6 +73,16 @@ markers = [
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
strict = true
|
strict = true
|
||||||
|
mypy_path = "src"
|
||||||
|
explicit_package_bases = true
|
||||||
|
namespace_packages = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = [
|
||||||
|
"torch.*", "transformers.*", "peft.*", "datasets.*", "bitsandbytes.*",
|
||||||
|
"tiararodney.posix_sdc.*",
|
||||||
|
]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
[tool.autopep8]
|
[tool.autopep8]
|
||||||
max_line_length = 80
|
max_line_length = 80
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,9 @@ from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
|
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
|
||||||
from tiararodney.posix_sdc.factory.rollout import rollout
|
from tiararodney.posix_sdc.factory.rollout import rollout
|
||||||
|
|
@ -30,7 +32,7 @@ from .sft import normalize_for_template
|
||||||
|
|
||||||
|
|
||||||
def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
|
def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
|
||||||
temperature: float = 0.7):
|
temperature: float = 0.7) -> Callable[[list[dict[str, str]]], str]:
|
||||||
"""A ``messages -> command`` callable backed by base + LoRA adapter.
|
"""A ``messages -> command`` callable backed by base + LoRA adapter.
|
||||||
|
|
||||||
Renders the conversation exactly as the model was trained, appends the
|
Renders the conversation exactly as the model was trained, appends the
|
||||||
|
|
@ -46,7 +48,7 @@ def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
|
||||||
model = PeftModel.from_pretrained(model, adapter)
|
model = PeftModel.from_pretrained(model, adapter)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
def operator(messages):
|
def operator(messages: list[dict[str, str]]) -> str:
|
||||||
msgs = normalize_for_template(messages)
|
msgs = normalize_for_template(messages)
|
||||||
ids = tok.apply_chat_template(
|
ids = tok.apply_chat_template(
|
||||||
msgs, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
msgs, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||||||
|
|
@ -55,13 +57,14 @@ def make_local_operator(base: str, adapter: str, max_new_tokens: int = 64,
|
||||||
ids, max_new_tokens=max_new_tokens,
|
ids, max_new_tokens=max_new_tokens,
|
||||||
do_sample=temperature > 0, temperature=max(temperature, 1e-2),
|
do_sample=temperature > 0, temperature=max(temperature, 1e-2),
|
||||||
eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id)
|
eos_token_id=tok.eos_token_id, pad_token_id=tok.eos_token_id)
|
||||||
return tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip()
|
text: str = tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
return operator
|
return operator
|
||||||
|
|
||||||
|
|
||||||
def evaluate(base: str, adapter: str, scenarios_dir: Path, n: int,
|
def evaluate(base: str, adapter: str, scenarios_dir: Path, n: int,
|
||||||
max_steps: int, temperature: float) -> dict:
|
max_steps: int, temperature: float) -> dict[str, Any]:
|
||||||
if not available():
|
if not available():
|
||||||
raise SystemExit("sekft-dash image unavailable; `docker build -t sekft-dash .`")
|
raise SystemExit("sekft-dash image unavailable; `docker build -t sekft-dash .`")
|
||||||
operator = make_local_operator(base, adapter, temperature=temperature)
|
operator = make_local_operator(base, adapter, temperature=temperature)
|
||||||
|
|
|
||||||
0
src/tiararodney/sekft/py.typed
Normal file
0
src/tiararodney/sekft/py.typed
Normal file
|
|
@ -23,6 +23,7 @@ import argparse
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
@ -68,7 +69,7 @@ class Resident:
|
||||||
|
|
||||||
# -- build masked rows from kept trajectories --------------------------
|
# -- build masked rows from kept trajectories --------------------------
|
||||||
|
|
||||||
def _rows(self, data_dir: Path, max_len: int) -> list[dict]:
|
def _rows(self, data_dir: Path, max_len: int) -> list[dict[str, list[Any]]]:
|
||||||
rows = []
|
rows = []
|
||||||
for turns in iter_keepers(data_dir):
|
for turns in iter_keepers(data_dir):
|
||||||
ex = build_masked_example(turns, self.tok)
|
ex = build_masked_example(turns, self.tok)
|
||||||
|
|
@ -83,8 +84,8 @@ class Resident:
|
||||||
def fit(self, data_dir: str, out: str, lora_r: int = 16, lr: float = 2e-4,
|
def fit(self, data_dir: str, out: str, lora_r: int = 16, lr: float = 2e-4,
|
||||||
epochs: float = 3.0, batch: int = 1, accum: int = 8,
|
epochs: float = 3.0, batch: int = 1, accum: int = 8,
|
||||||
max_len: int = 4096) -> Path:
|
max_len: int = 4096) -> Path:
|
||||||
data_dir, out = Path(data_dir).expanduser(), Path(out).expanduser()
|
ddir, odir = Path(data_dir).expanduser(), Path(out).expanduser()
|
||||||
ds = Dataset.from_list(self._rows(data_dir, max_len))
|
ds = Dataset.from_list(self._rows(ddir, max_len))
|
||||||
if not self.load_4bit:
|
if not self.load_4bit:
|
||||||
self.base.gradient_checkpointing_enable()
|
self.base.gradient_checkpointing_enable()
|
||||||
model = get_peft_model(self.base, LoraConfig(
|
model = get_peft_model(self.base, LoraConfig(
|
||||||
|
|
@ -92,31 +93,31 @@ class Resident:
|
||||||
task_type="CAUSAL_LM", target_modules=LORA_TARGETS))
|
task_type="CAUSAL_LM", target_modules=LORA_TARGETS))
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
args = TrainingArguments(
|
args = TrainingArguments(
|
||||||
output_dir=str(out), per_device_train_batch_size=batch,
|
output_dir=str(odir), per_device_train_batch_size=batch,
|
||||||
gradient_accumulation_steps=accum, num_train_epochs=epochs,
|
gradient_accumulation_steps=accum, num_train_epochs=epochs,
|
||||||
learning_rate=lr, fp16=True, logging_steps=1, save_strategy="no",
|
learning_rate=lr, fp16=True, logging_steps=1, save_strategy="no",
|
||||||
report_to=["tensorboard"], logging_dir=str(out / "runs"),
|
report_to=["tensorboard"], logging_dir=str(odir / "runs"),
|
||||||
remove_unused_columns=False, warmup_ratio=0.03)
|
remove_unused_columns=False, warmup_ratio=0.03)
|
||||||
tr = Trainer(model=model, args=args, train_dataset=ds,
|
tr = Trainer(model=model, args=args, train_dataset=ds,
|
||||||
data_collator=DataCollatorForSeq2Seq(
|
data_collator=DataCollatorForSeq2Seq(
|
||||||
self.tok, padding=True, label_pad_token_id=-100))
|
self.tok, padding=True, label_pad_token_id=-100))
|
||||||
tr.train()
|
tr.train()
|
||||||
out.mkdir(parents=True, exist_ok=True)
|
odir.mkdir(parents=True, exist_ok=True)
|
||||||
model.save_pretrained(str(out))
|
model.save_pretrained(str(odir))
|
||||||
self.tok.save_pretrained(str(out))
|
self.tok.save_pretrained(str(odir))
|
||||||
(out / "log_history.jsonl").write_text(
|
(odir / "log_history.jsonl").write_text(
|
||||||
"\n".join(json.dumps(r) for r in tr.state.log_history))
|
"\n".join(json.dumps(r) for r in tr.state.log_history))
|
||||||
losses = [h["loss"] for h in tr.state.log_history if "loss" in h]
|
losses = [h["loss"] for h in tr.state.log_history if "loss" in h]
|
||||||
print(f"[resident] fit -> {out} final loss {losses[-1] if losses else '?'}")
|
print(f"[resident] fit -> {odir} final loss {losses[-1] if losses else '?'}")
|
||||||
self.base = model.unload() # strip LoRA, restore resident base
|
self.base = model.unload() # strip LoRA, restore resident base
|
||||||
del model, tr, ds
|
del model, tr, ds
|
||||||
_free()
|
_free()
|
||||||
return out
|
return odir
|
||||||
|
|
||||||
# -- behavioural eval of a saved adapter -------------------------------
|
# -- behavioural eval of a saved adapter -------------------------------
|
||||||
|
|
||||||
def evaluate(self, adapter: str, scenarios_dir: str, n: int = 10,
|
def evaluate(self, adapter: str, scenarios_dir: str, n: int = 10,
|
||||||
max_steps: int = 30, temperature: float = 0.7) -> dict:
|
max_steps: int = 30, temperature: float = 0.7) -> dict[str, Any]:
|
||||||
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
|
from tiararodney.posix_sdc.factory.dashdocker import DashDocker, available
|
||||||
from tiararodney.posix_sdc.factory.rollout import rollout
|
from tiararodney.posix_sdc.factory.rollout import rollout
|
||||||
from tiararodney.posix_sdc.schema import Scenario
|
from tiararodney.posix_sdc.schema import Scenario
|
||||||
|
|
@ -130,7 +131,7 @@ class Resident:
|
||||||
pm = self.base
|
pm = self.base
|
||||||
pm.eval()
|
pm.eval()
|
||||||
|
|
||||||
def operator(messages):
|
def operator(messages: list[dict[str, str]]) -> str:
|
||||||
msgs = normalize_for_template(messages)
|
msgs = normalize_for_template(messages)
|
||||||
ids = self.tok.apply_chat_template(
|
ids = self.tok.apply_chat_template(
|
||||||
msgs, add_generation_prompt=True, return_tensors="pt").to(pm.device)
|
msgs, add_generation_prompt=True, return_tensors="pt").to(pm.device)
|
||||||
|
|
@ -139,7 +140,8 @@ class Resident:
|
||||||
temperature=max(temperature, 1e-2),
|
temperature=max(temperature, 1e-2),
|
||||||
eos_token_id=self.tok.eos_token_id,
|
eos_token_id=self.tok.eos_token_id,
|
||||||
pad_token_id=self.tok.eos_token_id)
|
pad_token_id=self.tok.eos_token_id)
|
||||||
return self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip()
|
text: str = self.tok.decode(o[0][ids.shape[1]:], skip_special_tokens=True).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
backend = DashDocker()
|
backend = DashDocker()
|
||||||
rows = []
|
rows = []
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,11 @@ from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
from collections.abc import Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
def normalize_for_template(messages: list[dict]) -> list[dict]:
|
def normalize_for_template(messages: list[dict[str, str]]) -> list[dict[str, str]]:
|
||||||
"""Canonicalise a trajectory for instruct chat templates that have no system
|
"""Canonicalise a trajectory for instruct chat templates that have no system
|
||||||
role and require strict user/assistant alternation (Mistral and friends):
|
role and require strict user/assistant alternation (Mistral and friends):
|
||||||
treat ``system`` as ``user``, then merge consecutive same-role turns by
|
treat ``system`` as ``user``, then merge consecutive same-role turns by
|
||||||
|
|
@ -42,7 +44,7 @@ def normalize_for_template(messages: list[dict]) -> list[dict]:
|
||||||
The serving side MUST apply the same canonicalisation, or train and serve
|
The serving side MUST apply the same canonicalisation, or train and serve
|
||||||
diverge again.
|
diverge again.
|
||||||
"""
|
"""
|
||||||
out: list[dict] = []
|
out: list[dict[str, str]] = []
|
||||||
for m in messages:
|
for m in messages:
|
||||||
role = "user" if m["role"] == "system" else m["role"]
|
role = "user" if m["role"] == "system" else m["role"]
|
||||||
if out and out[-1]["role"] == role:
|
if out and out[-1]["role"] == role:
|
||||||
|
|
@ -52,7 +54,7 @@ def normalize_for_template(messages: list[dict]) -> list[dict]:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def build_masked_example(messages: list[dict], tokenizer) -> dict:
|
def build_masked_example(messages: list[dict[str, str]], tokenizer: Any) -> dict[str, list[Any]]:
|
||||||
"""Tokenize a trajectory with the tokenizer's OWN chat template and build an
|
"""Tokenize a trajectory with the tokenizer's OWN chat template and build an
|
||||||
assistant-only loss mask.
|
assistant-only loss mask.
|
||||||
|
|
||||||
|
|
@ -81,7 +83,7 @@ def build_masked_example(messages: list[dict], tokenizer) -> dict:
|
||||||
return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels}
|
return {"input_ids": ids, "attention_mask": [1] * len(ids), "labels": labels}
|
||||||
|
|
||||||
|
|
||||||
def iter_keepers(data_dir: Path):
|
def iter_keepers(data_dir: Path) -> Iterator[list[dict[str, str]]]:
|
||||||
"""Yield ``turns`` (message lists) from trajectory JSONs marked keep."""
|
"""Yield ``turns`` (message lists) from trajectory JSONs marked keep."""
|
||||||
for f in sorted(data_dir.glob("*.json")):
|
for f in sorted(data_dir.glob("*.json")):
|
||||||
d = json.loads(f.read_text())
|
d = json.loads(f.read_text())
|
||||||
|
|
@ -89,7 +91,7 @@ def iter_keepers(data_dir: Path):
|
||||||
yield d["turns"]
|
yield d["turns"]
|
||||||
|
|
||||||
|
|
||||||
def mask_stats(example: dict) -> tuple[int, int]:
|
def mask_stats(example: dict[str, list[Any]]) -> tuple[int, int]:
|
||||||
"""(trained tokens, total tokens) for an example."""
|
"""(trained tokens, total tokens) for an example."""
|
||||||
trained = sum(1 for x in example["labels"] if x != -100)
|
trained = sum(1 for x in example["labels"] if x != -100)
|
||||||
return trained, len(example["labels"])
|
return trained, len(example["labels"])
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ class FakeTok:
|
||||||
"""Additive chat template: each turn renders to ``<role> tokens... </e>``;
|
"""Additive chat template: each turn renders to ``<role> tokens... </e>``;
|
||||||
the generation prompt appends ``<assistant>``."""
|
the generation prompt appends ``<assistant>``."""
|
||||||
|
|
||||||
def apply_chat_template(self, msgs: list[dict], add_generation_prompt: bool = False,
|
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
||||||
return_tensors: Any = None) -> list[str]:
|
return_tensors: Any = None) -> list[str]:
|
||||||
toks: list[str] = []
|
toks: list[str] = []
|
||||||
for m in msgs:
|
for m in msgs:
|
||||||
|
|
@ -65,7 +65,7 @@ def test_mask_trains_assistant_turns_only() -> None:
|
||||||
|
|
||||||
def test_mask_raises_on_non_additive_template() -> None:
|
def test_mask_raises_on_non_additive_template() -> None:
|
||||||
class BadTok:
|
class BadTok:
|
||||||
def apply_chat_template(self, msgs: list[dict], add_generation_prompt: bool = False,
|
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
||||||
return_tensors: Any = None) -> list[int]:
|
return_tensors: Any = None) -> list[int]:
|
||||||
return list(range(len(msgs), 0, -1)) # reversed: prefixes do not nest
|
return list(range(len(msgs), 0, -1)) # reversed: prefixes do not nest
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue