sekft/tests/unit/test_sft.py
Tiara Rodney 4987d951ce
bugfix(15): normalise apply_chat_template's BatchEncoding (transformers 5.x)
apply_chat_template returns a BatchEncoding ({input_ids: [...]}) on transformers
>= 5 where 4.x returned a bare list[int]. build_masked_example treated the render
as a dict, so len/slicing were wrong and the prefix-differencing spuriously
raised "chat template is not additive" on every real model. Extract the id
sequence via a _render_ids helper; verified the assistant-only mask against
mistralai/Mistral-7B-Instruct-v0.2. The fake tokenizer returned a bare list and
missed this, so a BatchEncoding-returning variant now guards it.
2026-06-18 12:37:01 +02:00

98 lines
4 KiB
Python

"""Unit tests for the SFT render canonicalisation and assistant-only mask.
These run anywhere: a fake additive tokenizer stands in for a real chat
template, so no torch/transformers is needed."""
from __future__ import annotations
from typing import Any
import pytest
from tiararodney.sekft import sft
class FakeTok:
"""Additive chat template: each turn renders to ``<role> tokens... </e>``;
the generation prompt appends ``<assistant>``."""
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
return_tensors: Any = None) -> list[str]:
toks: list[str] = []
for m in msgs:
toks.append(f"<{m['role']}>")
toks += m["content"].split()
toks.append("</e>")
if add_generation_prompt:
toks.append("<assistant>")
return toks
class FakeTokBatchEncoding(FakeTok):
"""Like FakeTok, but returns a dict as transformers >= 5's
``apply_chat_template`` does (a BatchEncoding), to exercise the id-extraction."""
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
return_tensors: Any = None) -> dict[str, list[str]]:
return {"input_ids": super().apply_chat_template(msgs, add_generation_prompt, return_tensors)}
def test_normalize_folds_system_and_merges_consecutive() -> None:
raw = [
{"role": "system", "content": "orient"},
{"role": "user", "content": "login"},
{"role": "user", "content": "prompt"},
{"role": "assistant", "content": "cat f"},
{"role": "user", "content": "out"},
{"role": "user", "content": "prompt"},
{"role": "assistant", "content": "exit"},
]
norm = sft.normalize_for_template(raw)
assert [m["role"] for m in norm] == ["user", "assistant", "user", "assistant"]
assert norm[0]["content"] == "orient\nlogin\nprompt"
def test_normalize_leaves_clean_alternation_untouched() -> None:
raw = [{"role": "user", "content": "a"}, {"role": "assistant", "content": "b"}]
assert sft.normalize_for_template(raw) == raw
def test_mask_trains_assistant_turns_only() -> None:
raw = [
{"role": "system", "content": "orient"},
{"role": "user", "content": "login"},
{"role": "assistant", "content": "cat f"},
{"role": "user", "content": "out"},
{"role": "assistant", "content": "exit"},
]
ex = sft.build_masked_example(raw, FakeTok())
trained = [t for t, lab in zip(ex["input_ids"], ex["labels"]) if lab != -100]
masked = [t for t, lab in zip(ex["input_ids"], ex["labels"]) if lab == -100]
assert set(trained) <= {"<assistant>", "cat", "f", "exit", "</e>"}
assert "cat" in trained and "exit" in trained # both commands present
assert {"orient", "login", "out"} <= set(masked) # environment masked
def test_mask_handles_batchencoding_return() -> None:
# transformers >= 5 returns a BatchEncoding ({input_ids: [...]}) rather than a
# bare list[int]; the mask must come out identical. Regression for the 5.x bug
# that made every real template look "not additive".
raw = [
{"role": "user", "content": "login"},
{"role": "assistant", "content": "cat f"},
{"role": "user", "content": "out"},
{"role": "assistant", "content": "exit"},
]
assert (sft.build_masked_example(raw, FakeTokBatchEncoding())
== sft.build_masked_example(raw, FakeTok()))
def test_mask_raises_on_non_additive_template() -> None:
class BadTok:
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
return_tensors: Any = None) -> list[int]:
return list(range(len(msgs), 0, -1)) # reversed: prefixes do not nest
with pytest.raises(ValueError):
sft.build_masked_example(
[{"role": "user", "content": "a"}, {"role": "assistant", "content": "b"}],
BadTok())