From 96c8e4685e6ca7f4211ac1b1db9e83e6dcc9141e Mon Sep 17 00:00:00 2001 From: Tiara Rodney Date: Sat, 6 Jun 2026 14:35:02 +0200 Subject: [PATCH] test: add config framework unit tests --- .../byteb4rb1e/utils/config/test_config.py | 347 ++++++++++++++++++ 1 file changed, 347 insertions(+) create mode 100644 tests/unit/byteb4rb1e/utils/config/test_config.py diff --git a/tests/unit/byteb4rb1e/utils/config/test_config.py b/tests/unit/byteb4rb1e/utils/config/test_config.py new file mode 100644 index 0000000..6b25306 --- /dev/null +++ b/tests/unit/byteb4rb1e/utils/config/test_config.py @@ -0,0 +1,347 @@ +"""Unit tests for the config framework.""" + +from argparse import ArgumentParser, Namespace +from dataclasses import dataclass +from pathlib import Path + +import pytest + +from byteb4rb1e.utils.config import ( + add_config_arguments, + apply_cli_overrides, + apply_overrides, + ensure_ini, + ensure_ini_multi, + format_help, + format_section, + load_ini, + resolve_hints, +) + + +@dataclass +class SampleConfig: + name: str = "default" + count: int = 10 + ratio: float = 0.5 + enabled: bool = True + + +class TestLoadIni: + def test_loads_values(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text( + "[sample]\n" + "name = custom\n" + "count = 42\n" + "ratio = 0.75\n" + ) + config = load_ini(SampleConfig, ini) + assert config.name == "custom" + assert config.count == 42 + assert config.ratio == 0.75 + assert config.enabled is True # default + + def test_missing_section_uses_defaults(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text("[other]\nfoo = bar\n") + config = load_ini(SampleConfig, ini) + assert config.name == "default" + assert config.count == 10 + + def test_missing_file_uses_defaults(self, tmp_path): + config = load_ini( + SampleConfig, tmp_path / "missing.ini" + ) + assert config.name == "default" + + def test_unknown_key_raises(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text("[sample]\nunknown_key = bad\n") + with pytest.raises(ValueError, match="unknown_key"): + load_ini(SampleConfig, ini) + + def test_custom_section_name(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text("[mysection]\nname = custom\n") + config = load_ini( + SampleConfig, ini, section="mysection" + ) + assert config.name == "custom" + + def test_comments_ignored(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text( + "[sample]\n" + "# this is a comment\n" + "name = works # inline comment\n" + ) + config = load_ini(SampleConfig, ini) + assert config.name == "works" + + +class TestAddConfigArguments: + def test_generates_flags(self): + parser = ArgumentParser() + add_config_arguments(SampleConfig, parser) + args = parser.parse_args( + ["--name", "cli", "--count", "99"] + ) + assert args.name == "cli" + assert args.count == 99 + + def test_defaults_are_none(self): + parser = ArgumentParser() + add_config_arguments(SampleConfig, parser) + args = parser.parse_args([]) + assert args.name is None + assert args.count is None + + def test_underscores_become_dashes(self): + @dataclass + class DashConfig: + my_long_name: str = "x" + + parser = ArgumentParser() + add_config_arguments(DashConfig, parser) + args = parser.parse_args( + ["--my-long-name", "val"] + ) + assert args.my_long_name == "val" + + +class TestApplyCliOverrides: + def test_overrides_set_values(self): + config = SampleConfig() + args = Namespace(name="override", count=None, + ratio=None, enabled=None) + result = apply_cli_overrides(config, args) + assert result.name == "override" + assert result.count == 10 # unchanged + + def test_no_overrides_returns_same(self): + config = SampleConfig() + args = Namespace(name=None, count=None, + ratio=None, enabled=None) + result = apply_cli_overrides(config, args) + assert result.name == "default" + assert result is config + + +class TestEnsureIni: + def test_creates_file_if_missing(self, tmp_path): + ini = tmp_path / "new.ini" + assert not ini.exists() + config = ensure_ini(SampleConfig, ini) + assert ini.exists() + assert config.name == "default" + assert config.count == 10 + + def test_created_file_has_all_fields(self, tmp_path): + ini = tmp_path / "new.ini" + ensure_ini(SampleConfig, ini) + content = ini.read_text() + assert "name" in content + assert "count" in content + assert "ratio" in content + assert "enabled" in content + + def test_created_file_has_comments(self, tmp_path): + ini = tmp_path / "new.ini" + ensure_ini(SampleConfig, ini) + content = ini.read_text() + assert "# name (str)" in content + assert "# count (int)" in content + + def test_reads_existing_file(self, tmp_path): + ini = tmp_path / "existing.ini" + ini.write_text("[sample]\ncount = 42\n") + config = ensure_ini(SampleConfig, ini) + assert config.count == 42 + + def test_does_not_overwrite_existing(self, tmp_path): + ini = tmp_path / "existing.ini" + ini.write_text("[sample]\ncount = 42\n") + ensure_ini(SampleConfig, ini) + content = ini.read_text() + assert content == "[sample]\ncount = 42\n" + + def test_created_file_is_loadable(self, tmp_path): + ini = tmp_path / "new.ini" + ensure_ini(SampleConfig, ini) + config = load_ini(SampleConfig, ini) + assert config.name == "default" + assert config.count == 10 + assert config.ratio == 0.5 + + +class TestIntegration: + def test_ini_then_cli_override(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text("[sample]\ncount = 42\n") + config = load_ini(SampleConfig, ini) + assert config.count == 42 + + args = Namespace(name=None, count=99, + ratio=None, enabled=None) + config = apply_cli_overrides(config, args) + assert config.count == 99 + assert config.name == "default" + + def test_ensure_then_cli_override(self, tmp_path): + ini = tmp_path / "new.ini" + config = ensure_ini(SampleConfig, ini) + assert config.count == 10 + + args = Namespace(name=None, count=99, + ratio=None, enabled=None) + config = apply_cli_overrides(config, args) + assert config.count == 99 + assert config.name == "default" + + # Config file unchanged + reloaded = load_ini(SampleConfig, ini) + assert reloaded.count == 10 + + +class TestResolveHints: + def test_returns_type_dict(self): + hints = resolve_hints(SampleConfig) + assert hints["name"] is str + assert hints["count"] is int + assert hints["ratio"] is float + assert hints["enabled"] is bool + + +class TestFormatSection: + def test_includes_section_header(self): + text = format_section(SampleConfig) + assert "[sample]" in text + + def test_custom_section_name(self): + text = format_section(SampleConfig, "custom") + assert "[custom]" in text + + def test_includes_all_fields(self): + text = format_section(SampleConfig) + assert "name = default" in text + assert "count = 10" in text + assert "ratio = 0.5" in text + assert "enabled = True" in text + + def test_includes_type_comments(self): + text = format_section(SampleConfig) + assert "# name (str)" in text + assert "# count (int)" in text + + def test_is_loadable(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text(format_section(SampleConfig) + "\n") + config = load_ini(SampleConfig, ini) + assert config.name == "default" + assert config.count == 10 + + +class TestEnsureIniMulti: + def test_creates_file_with_multiple_sections(self, tmp_path): + @dataclass + class OtherConfig: + host: str = "localhost" + port: int = 8080 + + ini = tmp_path / "multi.ini" + ensure_ini_multi([ + (SampleConfig, None), + (OtherConfig, "server"), + ], ini) + + content = ini.read_text() + assert "[sample]" in content + assert "[server]" in content + assert "name = default" in content + assert "host = localhost" in content + + def test_does_not_overwrite_existing(self, tmp_path): + ini = tmp_path / "multi.ini" + ini.write_text("[existing]\nfoo = bar\n") + ensure_ini_multi([(SampleConfig, None)], ini) + content = ini.read_text() + assert content == "[existing]\nfoo = bar\n" + + def test_sections_are_loadable(self, tmp_path): + @dataclass + class DbConfig: + url: str = "sqlite:///test.db" + + ini = tmp_path / "multi.ini" + ensure_ini_multi([ + (SampleConfig, None), + (DbConfig, "database"), + ], ini) + + sample = load_ini(SampleConfig, ini) + db = load_ini(DbConfig, ini, section="database") + assert sample.name == "default" + assert db.url == "sqlite:///test.db" + + +class TestApplyOverrides: + def test_applies_dotted_path(self): + config = SampleConfig() + result = apply_overrides(config, { + "provider.name": "custom", + "provider.count": "99", + }, prefix="provider") + assert result.name == "custom" + assert result.count == 99 + + def test_without_prefix(self): + config = SampleConfig() + result = apply_overrides(config, { + "name": "direct", + "count": "42", + }) + assert result.name == "direct" + assert result.count == 42 + + def test_no_matching_keys_returns_same(self): + config = SampleConfig() + result = apply_overrides(config, {"other.key": "val"}, prefix="provider") + assert result is config + + def test_bool_coercion(self): + config = SampleConfig() + result = apply_overrides(config, {"enabled": "false"}) + assert result.enabled is False + + def test_preserves_unset_fields(self): + config = SampleConfig() + result = apply_overrides(config, {"name": "changed"}) + assert result.name == "changed" + assert result.count == 10 # unchanged + assert result.ratio == 0.5 # unchanged + + +class TestFormatHelp: + def test_lists_all_fields(self): + lines = format_help(SampleConfig) + assert len(lines) == 4 + assert any("name" in l for l in lines) + assert any("count" in l for l in lines) + + def test_includes_types(self): + lines = format_help(SampleConfig) + text = "\n".join(lines) + assert "str" in text + assert "int" in text + + def test_includes_defaults(self): + lines = format_help(SampleConfig) + text = "\n".join(lines) + assert "default" in text + assert "10" in text + + def test_with_prefix(self): + lines = format_help(SampleConfig, prefix="provider") + assert any("provider.name" in l for l in lines) + assert any("provider.count" in l for l in lines)