"""Unit tests for the SFT render canonicalisation and assistant-only mask. These run anywhere: a fake additive tokenizer stands in for a real chat template, so no torch/transformers is needed.""" from __future__ import annotations from typing import Any import pytest from tiararodney.sekft import sft class FakeTok: """Additive chat template: each turn renders to `` tokens... ``; the generation prompt appends ````.""" def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False, return_tensors: Any = None) -> list[str]: toks: list[str] = [] for m in msgs: toks.append(f"<{m['role']}>") toks += m["content"].split() toks.append("") if add_generation_prompt: toks.append("") return toks class FakeTokBatchEncoding(FakeTok): """Like FakeTok, but returns a dict as transformers >= 5's ``apply_chat_template`` does (a BatchEncoding), to exercise the id-extraction.""" def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False, return_tensors: Any = None) -> dict[str, list[str]]: return {"input_ids": super().apply_chat_template(msgs, add_generation_prompt, return_tensors)} def test_normalize_folds_system_and_merges_consecutive() -> None: raw = [ {"role": "system", "content": "orient"}, {"role": "user", "content": "login"}, {"role": "user", "content": "prompt"}, {"role": "assistant", "content": "cat f"}, {"role": "user", "content": "out"}, {"role": "user", "content": "prompt"}, {"role": "assistant", "content": "exit"}, ] norm = sft.normalize_for_template(raw) assert [m["role"] for m in norm] == ["user", "assistant", "user", "assistant"] assert norm[0]["content"] == "orient\nlogin\nprompt" def test_normalize_leaves_clean_alternation_untouched() -> None: raw = [{"role": "user", "content": "a"}, {"role": "assistant", "content": "b"}] assert sft.normalize_for_template(raw) == raw def test_mask_trains_assistant_turns_only() -> None: raw = [ {"role": "system", "content": "orient"}, {"role": "user", "content": "login"}, {"role": "assistant", "content": "cat f"}, {"role": "user", "content": "out"}, {"role": "assistant", "content": "exit"}, ] ex = sft.build_masked_example(raw, FakeTok()) trained = [t for t, lab in zip(ex["input_ids"], ex["labels"]) if lab != -100] masked = [t for t, lab in zip(ex["input_ids"], ex["labels"]) if lab == -100] assert set(trained) <= {"", "cat", "f", "exit", ""} assert "cat" in trained and "exit" in trained # both commands present assert {"orient", "login", "out"} <= set(masked) # environment masked def test_mask_handles_batchencoding_return() -> None: # transformers >= 5 returns a BatchEncoding ({input_ids: [...]}) rather than a # bare list[int]; the mask must come out identical. Regression for the 5.x bug # that made every real template look "not additive". raw = [ {"role": "user", "content": "login"}, {"role": "assistant", "content": "cat f"}, {"role": "user", "content": "out"}, {"role": "assistant", "content": "exit"}, ] assert (sft.build_masked_example(raw, FakeTokBatchEncoding()) == sft.build_masked_example(raw, FakeTok())) def test_mask_raises_on_non_additive_template() -> None: class BadTok: def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False, return_tensors: Any = None) -> list[int]: return list(range(len(msgs), 0, -1)) # reversed: prefixes do not nest with pytest.raises(ValueError): sft.build_masked_example( [{"role": "user", "content": "a"}, {"role": "assistant", "content": "b"}], BadTok())