bugfix(15): normalise apply_chat_template's BatchEncoding (transformers 5.x)
apply_chat_template returns a BatchEncoding ({input_ids: [...]}) on transformers
>= 5 where 4.x returned a bare list[int]. build_masked_example treated the render
as a dict, so len/slicing were wrong and the prefix-differencing spuriously
raised "chat template is not additive" on every real model. Extract the id
sequence via a _render_ids helper; verified the assistant-only mask against
mistralai/Mistral-7B-Instruct-v0.2. The fake tokenizer returned a bare list and
missed this, so a BatchEncoding-returning variant now guards it.
This commit is contained in:
parent
7853224796
commit
4987d951ce
3 changed files with 51 additions and 2 deletions
|
|
@ -27,6 +27,15 @@ class FakeTok:
|
|||
return toks
|
||||
|
||||
|
||||
class FakeTokBatchEncoding(FakeTok):
|
||||
"""Like FakeTok, but returns a dict as transformers >= 5's
|
||||
``apply_chat_template`` does (a BatchEncoding), to exercise the id-extraction."""
|
||||
|
||||
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
||||
return_tensors: Any = None) -> dict[str, list[str]]:
|
||||
return {"input_ids": super().apply_chat_template(msgs, add_generation_prompt, return_tensors)}
|
||||
|
||||
|
||||
def test_normalize_folds_system_and_merges_consecutive() -> None:
|
||||
raw = [
|
||||
{"role": "system", "content": "orient"},
|
||||
|
|
@ -63,6 +72,20 @@ def test_mask_trains_assistant_turns_only() -> None:
|
|||
assert {"orient", "login", "out"} <= set(masked) # environment masked
|
||||
|
||||
|
||||
def test_mask_handles_batchencoding_return() -> None:
|
||||
# transformers >= 5 returns a BatchEncoding ({input_ids: [...]}) rather than a
|
||||
# bare list[int]; the mask must come out identical. Regression for the 5.x bug
|
||||
# that made every real template look "not additive".
|
||||
raw = [
|
||||
{"role": "user", "content": "login"},
|
||||
{"role": "assistant", "content": "cat f"},
|
||||
{"role": "user", "content": "out"},
|
||||
{"role": "assistant", "content": "exit"},
|
||||
]
|
||||
assert (sft.build_masked_example(raw, FakeTokBatchEncoding())
|
||||
== sft.build_masked_example(raw, FakeTok()))
|
||||
|
||||
|
||||
def test_mask_raises_on_non_additive_template() -> None:
|
||||
class BadTok:
|
||||
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue