"""Unit tests for the config framework.""" from argparse import ArgumentParser, Namespace from dataclasses import dataclass from pathlib import Path import pytest from byteb4rb1e.utils.config import ( add_config_arguments, apply_cli_overrides, apply_overrides, ensure_ini, ensure_ini_multi, format_help, format_section, load_ini, resolve_hints, ) @dataclass class SampleConfig: name: str = "default" count: int = 10 ratio: float = 0.5 enabled: bool = True class TestLoadIni: def test_loads_values(self, tmp_path): ini = tmp_path / "test.ini" ini.write_text( "[sample]\n" "name = custom\n" "count = 42\n" "ratio = 0.75\n" ) config = load_ini(SampleConfig, ini) assert config.name == "custom" assert config.count == 42 assert config.ratio == 0.75 assert config.enabled is True # default def test_missing_section_uses_defaults(self, tmp_path): ini = tmp_path / "test.ini" ini.write_text("[other]\nfoo = bar\n") config = load_ini(SampleConfig, ini) assert config.name == "default" assert config.count == 10 def test_missing_file_uses_defaults(self, tmp_path): config = load_ini( SampleConfig, tmp_path / "missing.ini" ) assert config.name == "default" def test_unknown_key_raises(self, tmp_path): ini = tmp_path / "test.ini" ini.write_text("[sample]\nunknown_key = bad\n") with pytest.raises(ValueError, match="unknown_key"): load_ini(SampleConfig, ini) def test_custom_section_name(self, tmp_path): ini = tmp_path / "test.ini" ini.write_text("[mysection]\nname = custom\n") config = load_ini( SampleConfig, ini, section="mysection" ) assert config.name == "custom" def test_comments_ignored(self, tmp_path): ini = tmp_path / "test.ini" ini.write_text( "[sample]\n" "# this is a comment\n" "name = works # inline comment\n" ) config = load_ini(SampleConfig, ini) assert config.name == "works" class TestAddConfigArguments: def test_generates_flags(self): parser = ArgumentParser() add_config_arguments(SampleConfig, parser) args = parser.parse_args( ["--name", "cli", "--count", "99"] ) assert args.name == "cli" assert args.count == 99 def test_defaults_are_none(self): parser = ArgumentParser() add_config_arguments(SampleConfig, parser) args = parser.parse_args([]) assert args.name is None assert args.count is None def test_underscores_become_dashes(self): @dataclass class DashConfig: my_long_name: str = "x" parser = ArgumentParser() add_config_arguments(DashConfig, parser) args = parser.parse_args( ["--my-long-name", "val"] ) assert args.my_long_name == "val" class TestApplyCliOverrides: def test_overrides_set_values(self): config = SampleConfig() args = Namespace(name="override", count=None, ratio=None, enabled=None) result = apply_cli_overrides(config, args) assert result.name == "override" assert result.count == 10 # unchanged def test_no_overrides_returns_same(self): config = SampleConfig() args = Namespace(name=None, count=None, ratio=None, enabled=None) result = apply_cli_overrides(config, args) assert result.name == "default" assert result is config class TestEnsureIni: def test_creates_file_if_missing(self, tmp_path): ini = tmp_path / "new.ini" assert not ini.exists() config = ensure_ini(SampleConfig, ini) assert ini.exists() assert config.name == "default" assert config.count == 10 def test_created_file_has_all_fields(self, tmp_path): ini = tmp_path / "new.ini" ensure_ini(SampleConfig, ini) content = ini.read_text() assert "name" in content assert "count" in content assert "ratio" in content assert "enabled" in content def test_created_file_has_comments(self, tmp_path): ini = tmp_path / "new.ini" ensure_ini(SampleConfig, ini) content = ini.read_text() assert "# name (str)" in content assert "# count (int)" in content def test_reads_existing_file(self, tmp_path): ini = tmp_path / "existing.ini" ini.write_text("[sample]\ncount = 42\n") config = ensure_ini(SampleConfig, ini) assert config.count == 42 def test_does_not_overwrite_existing(self, tmp_path): ini = tmp_path / "existing.ini" ini.write_text("[sample]\ncount = 42\n") ensure_ini(SampleConfig, ini) content = ini.read_text() assert content == "[sample]\ncount = 42\n" def test_created_file_is_loadable(self, tmp_path): ini = tmp_path / "new.ini" ensure_ini(SampleConfig, ini) config = load_ini(SampleConfig, ini) assert config.name == "default" assert config.count == 10 assert config.ratio == 0.5 class TestIntegration: def test_ini_then_cli_override(self, tmp_path): ini = tmp_path / "test.ini" ini.write_text("[sample]\ncount = 42\n") config = load_ini(SampleConfig, ini) assert config.count == 42 args = Namespace(name=None, count=99, ratio=None, enabled=None) config = apply_cli_overrides(config, args) assert config.count == 99 assert config.name == "default" def test_ensure_then_cli_override(self, tmp_path): ini = tmp_path / "new.ini" config = ensure_ini(SampleConfig, ini) assert config.count == 10 args = Namespace(name=None, count=99, ratio=None, enabled=None) config = apply_cli_overrides(config, args) assert config.count == 99 assert config.name == "default" # Config file unchanged reloaded = load_ini(SampleConfig, ini) assert reloaded.count == 10 class TestResolveHints: def test_returns_type_dict(self): hints = resolve_hints(SampleConfig) assert hints["name"] is str assert hints["count"] is int assert hints["ratio"] is float assert hints["enabled"] is bool class TestFormatSection: def test_includes_section_header(self): text = format_section(SampleConfig) assert "[sample]" in text def test_custom_section_name(self): text = format_section(SampleConfig, "custom") assert "[custom]" in text def test_includes_all_fields(self): text = format_section(SampleConfig) assert "name = default" in text assert "count = 10" in text assert "ratio = 0.5" in text assert "enabled = True" in text def test_includes_type_comments(self): text = format_section(SampleConfig) assert "# name (str)" in text assert "# count (int)" in text def test_is_loadable(self, tmp_path): ini = tmp_path / "test.ini" ini.write_text(format_section(SampleConfig) + "\n") config = load_ini(SampleConfig, ini) assert config.name == "default" assert config.count == 10 class TestEnsureIniMulti: def test_creates_file_with_multiple_sections(self, tmp_path): @dataclass class OtherConfig: host: str = "localhost" port: int = 8080 ini = tmp_path / "multi.ini" ensure_ini_multi([ (SampleConfig, None), (OtherConfig, "server"), ], ini) content = ini.read_text() assert "[sample]" in content assert "[server]" in content assert "name = default" in content assert "host = localhost" in content def test_does_not_overwrite_existing(self, tmp_path): ini = tmp_path / "multi.ini" ini.write_text("[existing]\nfoo = bar\n") ensure_ini_multi([(SampleConfig, None)], ini) content = ini.read_text() assert content == "[existing]\nfoo = bar\n" def test_sections_are_loadable(self, tmp_path): @dataclass class DbConfig: url: str = "sqlite:///test.db" ini = tmp_path / "multi.ini" ensure_ini_multi([ (SampleConfig, None), (DbConfig, "database"), ], ini) sample = load_ini(SampleConfig, ini) db = load_ini(DbConfig, ini, section="database") assert sample.name == "default" assert db.url == "sqlite:///test.db" class TestApplyOverrides: def test_applies_dotted_path(self): config = SampleConfig() result = apply_overrides(config, { "provider.name": "custom", "provider.count": "99", }, prefix="provider") assert result.name == "custom" assert result.count == 99 def test_without_prefix(self): config = SampleConfig() result = apply_overrides(config, { "name": "direct", "count": "42", }) assert result.name == "direct" assert result.count == 42 def test_no_matching_keys_returns_same(self): config = SampleConfig() result = apply_overrides(config, {"other.key": "val"}, prefix="provider") assert result is config def test_bool_coercion(self): config = SampleConfig() result = apply_overrides(config, {"enabled": "false"}) assert result.enabled is False def test_preserves_unset_fields(self): config = SampleConfig() result = apply_overrides(config, {"name": "changed"}) assert result.name == "changed" assert result.count == 10 # unchanged assert result.ratio == 0.5 # unchanged class TestFormatHelp: def test_lists_all_fields(self): lines = format_help(SampleConfig) assert len(lines) == 4 assert any("name" in l for l in lines) assert any("count" in l for l in lines) def test_includes_types(self): lines = format_help(SampleConfig) text = "\n".join(lines) assert "str" in text assert "int" in text def test_includes_defaults(self): lines = format_help(SampleConfig) text = "\n".join(lines) assert "default" in text assert "10" in text def test_with_prefix(self): lines = format_help(SampleConfig, prefix="provider") assert any("provider.name" in l for l in lines) assert any("provider.count" in l for l in lines)