The transformers 5.x return-type change behind #15 also breaks generation: apply_chat_template(add_generation_prompt=True, return_tensors="pt") returns a BatchEncoding, and eval.py + resident.py passed it to model.generate, which does inputs.shape[0] -> AttributeError (the holdout eval crashed on scenario 1). #15 fixed only the trainer. Factor a shared _input_ids helper and a render_prompt_ids function; both operators use it. Tests cover _input_ids for both shapes and render_prompt_ids.
113 lines
4.8 KiB
Python
113 lines
4.8 KiB
Python
"""Unit tests for the SFT render canonicalisation and assistant-only mask.
|
|
|
|
These run anywhere: a fake additive tokenizer stands in for a real chat
|
|
template, so no torch/transformers is needed."""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from tiararodney.sekft import sft
|
|
|
|
|
|
class FakeTok:
|
|
"""Additive chat template: each turn renders to ``<role> tokens... </e>``;
|
|
the generation prompt appends ``<assistant>``."""
|
|
|
|
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
|
return_tensors: Any = None) -> list[str]:
|
|
toks: list[str] = []
|
|
for m in msgs:
|
|
toks.append(f"<{m['role']}>")
|
|
toks += m["content"].split()
|
|
toks.append("</e>")
|
|
if add_generation_prompt:
|
|
toks.append("<assistant>")
|
|
return toks
|
|
|
|
|
|
class FakeTokBatchEncoding(FakeTok):
|
|
"""Like FakeTok, but returns a dict as transformers >= 5's
|
|
``apply_chat_template`` does (a BatchEncoding), to exercise the id-extraction."""
|
|
|
|
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
|
return_tensors: Any = None) -> dict[str, list[str]]:
|
|
return {"input_ids": super().apply_chat_template(msgs, add_generation_prompt, return_tensors)}
|
|
|
|
|
|
def test_normalize_folds_system_and_merges_consecutive() -> None:
|
|
raw = [
|
|
{"role": "system", "content": "orient"},
|
|
{"role": "user", "content": "login"},
|
|
{"role": "user", "content": "prompt"},
|
|
{"role": "assistant", "content": "cat f"},
|
|
{"role": "user", "content": "out"},
|
|
{"role": "user", "content": "prompt"},
|
|
{"role": "assistant", "content": "exit"},
|
|
]
|
|
norm = sft.normalize_for_template(raw)
|
|
assert [m["role"] for m in norm] == ["user", "assistant", "user", "assistant"]
|
|
assert norm[0]["content"] == "orient\nlogin\nprompt"
|
|
|
|
|
|
def test_normalize_leaves_clean_alternation_untouched() -> None:
|
|
raw = [{"role": "user", "content": "a"}, {"role": "assistant", "content": "b"}]
|
|
assert sft.normalize_for_template(raw) == raw
|
|
|
|
|
|
def test_mask_trains_assistant_turns_only() -> None:
|
|
raw = [
|
|
{"role": "system", "content": "orient"},
|
|
{"role": "user", "content": "login"},
|
|
{"role": "assistant", "content": "cat f"},
|
|
{"role": "user", "content": "out"},
|
|
{"role": "assistant", "content": "exit"},
|
|
]
|
|
ex = sft.build_masked_example(raw, FakeTok())
|
|
trained = [t for t, lab in zip(ex["input_ids"], ex["labels"]) if lab != -100]
|
|
masked = [t for t, lab in zip(ex["input_ids"], ex["labels"]) if lab == -100]
|
|
assert set(trained) <= {"<assistant>", "cat", "f", "exit", "</e>"}
|
|
assert "cat" in trained and "exit" in trained # both commands present
|
|
assert {"orient", "login", "out"} <= set(masked) # environment masked
|
|
|
|
|
|
def test_mask_handles_batchencoding_return() -> None:
|
|
# transformers >= 5 returns a BatchEncoding ({input_ids: [...]}) rather than a
|
|
# bare list[int]; the mask must come out identical. Regression for the 5.x bug
|
|
# that made every real template look "not additive".
|
|
raw = [
|
|
{"role": "user", "content": "login"},
|
|
{"role": "assistant", "content": "cat f"},
|
|
{"role": "user", "content": "out"},
|
|
{"role": "assistant", "content": "exit"},
|
|
]
|
|
assert (sft.build_masked_example(raw, FakeTokBatchEncoding())
|
|
== sft.build_masked_example(raw, FakeTok()))
|
|
|
|
|
|
def test_input_ids_extracts_from_batchencoding_or_passthrough() -> None:
|
|
# BatchEncoding (transformers 5.x) -> its input_ids; bare list/tensor (4.x) -> itself
|
|
assert sft._input_ids({"input_ids": [1, 2, 3], "attention_mask": [1, 1, 1]}) == [1, 2, 3]
|
|
assert sft._input_ids([4, 5, 6]) == [4, 5, 6]
|
|
|
|
|
|
def test_render_prompt_ids_normalises_and_appends_generation_prompt() -> None:
|
|
# the generation operators rely on this: fold + append <assistant>, return ids
|
|
# (not a BatchEncoding) so model.generate doesn't choke on a dict.
|
|
raw = [{"role": "system", "content": "orient"}, {"role": "user", "content": "go"}]
|
|
ids = sft.render_prompt_ids(FakeTok(), raw)
|
|
assert ids[-1] == "<assistant>" # generation prompt appended
|
|
assert {"orient", "go"} <= set(ids) # system folded into the user turn
|
|
|
|
|
|
def test_mask_raises_on_non_additive_template() -> None:
|
|
class BadTok:
|
|
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
|
return_tensors: Any = None) -> list[int]:
|
|
return list(range(len(msgs), 0, -1)) # reversed: prefixes do not nest
|
|
|
|
with pytest.raises(ValueError):
|
|
sft.build_masked_example(
|
|
[{"role": "user", "content": "a"}, {"role": "assistant", "content": "b"}],
|
|
BadTok())
|