test: annotate the sft test helpers
This commit is contained in:
parent
9397280e6f
commit
64020c321d
1 changed files with 2 additions and 2 deletions
|
|
@ -15,7 +15,7 @@ class FakeTok:
|
||||||
"""Additive chat template: each turn renders to ``<role> tokens... </e>``;
|
"""Additive chat template: each turn renders to ``<role> tokens... </e>``;
|
||||||
the generation prompt appends ``<assistant>``."""
|
the generation prompt appends ``<assistant>``."""
|
||||||
|
|
||||||
def apply_chat_template(self, msgs: list[dict], add_generation_prompt: bool = False,
|
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
||||||
return_tensors: Any = None) -> list[str]:
|
return_tensors: Any = None) -> list[str]:
|
||||||
toks: list[str] = []
|
toks: list[str] = []
|
||||||
for m in msgs:
|
for m in msgs:
|
||||||
|
|
@ -65,7 +65,7 @@ def test_mask_trains_assistant_turns_only() -> None:
|
||||||
|
|
||||||
def test_mask_raises_on_non_additive_template() -> None:
|
def test_mask_raises_on_non_additive_template() -> None:
|
||||||
class BadTok:
|
class BadTok:
|
||||||
def apply_chat_template(self, msgs: list[dict], add_generation_prompt: bool = False,
|
def apply_chat_template(self, msgs: list[dict[str, str]], add_generation_prompt: bool = False,
|
||||||
return_tensors: Any = None) -> list[int]:
|
return_tensors: Any = None) -> list[int]:
|
||||||
return list(range(len(msgs), 0, -1)) # reversed: prefixes do not nest
|
return list(range(len(msgs), 0, -1)) # reversed: prefixes do not nest
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue