347 lines
11 KiB
Python
347 lines
11 KiB
Python
"""Unit tests for the config framework."""
|
|
|
|
from argparse import ArgumentParser, Namespace
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from byteb4rb1e.utils.config import (
|
|
add_config_arguments,
|
|
apply_cli_overrides,
|
|
apply_overrides,
|
|
ensure_ini,
|
|
ensure_ini_multi,
|
|
format_help,
|
|
format_section,
|
|
load_ini,
|
|
resolve_hints,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SampleConfig:
|
|
name: str = "default"
|
|
count: int = 10
|
|
ratio: float = 0.5
|
|
enabled: bool = True
|
|
|
|
|
|
class TestLoadIni:
|
|
def test_loads_values(self, tmp_path):
|
|
ini = tmp_path / "test.ini"
|
|
ini.write_text(
|
|
"[sample]\n"
|
|
"name = custom\n"
|
|
"count = 42\n"
|
|
"ratio = 0.75\n"
|
|
)
|
|
config = load_ini(SampleConfig, ini)
|
|
assert config.name == "custom"
|
|
assert config.count == 42
|
|
assert config.ratio == 0.75
|
|
assert config.enabled is True # default
|
|
|
|
def test_missing_section_uses_defaults(self, tmp_path):
|
|
ini = tmp_path / "test.ini"
|
|
ini.write_text("[other]\nfoo = bar\n")
|
|
config = load_ini(SampleConfig, ini)
|
|
assert config.name == "default"
|
|
assert config.count == 10
|
|
|
|
def test_missing_file_uses_defaults(self, tmp_path):
|
|
config = load_ini(
|
|
SampleConfig, tmp_path / "missing.ini"
|
|
)
|
|
assert config.name == "default"
|
|
|
|
def test_unknown_key_raises(self, tmp_path):
|
|
ini = tmp_path / "test.ini"
|
|
ini.write_text("[sample]\nunknown_key = bad\n")
|
|
with pytest.raises(ValueError, match="unknown_key"):
|
|
load_ini(SampleConfig, ini)
|
|
|
|
def test_custom_section_name(self, tmp_path):
|
|
ini = tmp_path / "test.ini"
|
|
ini.write_text("[mysection]\nname = custom\n")
|
|
config = load_ini(
|
|
SampleConfig, ini, section="mysection"
|
|
)
|
|
assert config.name == "custom"
|
|
|
|
def test_comments_ignored(self, tmp_path):
|
|
ini = tmp_path / "test.ini"
|
|
ini.write_text(
|
|
"[sample]\n"
|
|
"# this is a comment\n"
|
|
"name = works # inline comment\n"
|
|
)
|
|
config = load_ini(SampleConfig, ini)
|
|
assert config.name == "works"
|
|
|
|
|
|
class TestAddConfigArguments:
|
|
def test_generates_flags(self):
|
|
parser = ArgumentParser()
|
|
add_config_arguments(SampleConfig, parser)
|
|
args = parser.parse_args(
|
|
["--name", "cli", "--count", "99"]
|
|
)
|
|
assert args.name == "cli"
|
|
assert args.count == 99
|
|
|
|
def test_defaults_are_none(self):
|
|
parser = ArgumentParser()
|
|
add_config_arguments(SampleConfig, parser)
|
|
args = parser.parse_args([])
|
|
assert args.name is None
|
|
assert args.count is None
|
|
|
|
def test_underscores_become_dashes(self):
|
|
@dataclass
|
|
class DashConfig:
|
|
my_long_name: str = "x"
|
|
|
|
parser = ArgumentParser()
|
|
add_config_arguments(DashConfig, parser)
|
|
args = parser.parse_args(
|
|
["--my-long-name", "val"]
|
|
)
|
|
assert args.my_long_name == "val"
|
|
|
|
|
|
class TestApplyCliOverrides:
|
|
def test_overrides_set_values(self):
|
|
config = SampleConfig()
|
|
args = Namespace(name="override", count=None,
|
|
ratio=None, enabled=None)
|
|
result = apply_cli_overrides(config, args)
|
|
assert result.name == "override"
|
|
assert result.count == 10 # unchanged
|
|
|
|
def test_no_overrides_returns_same(self):
|
|
config = SampleConfig()
|
|
args = Namespace(name=None, count=None,
|
|
ratio=None, enabled=None)
|
|
result = apply_cli_overrides(config, args)
|
|
assert result.name == "default"
|
|
assert result is config
|
|
|
|
|
|
class TestEnsureIni:
|
|
def test_creates_file_if_missing(self, tmp_path):
|
|
ini = tmp_path / "new.ini"
|
|
assert not ini.exists()
|
|
config = ensure_ini(SampleConfig, ini)
|
|
assert ini.exists()
|
|
assert config.name == "default"
|
|
assert config.count == 10
|
|
|
|
def test_created_file_has_all_fields(self, tmp_path):
|
|
ini = tmp_path / "new.ini"
|
|
ensure_ini(SampleConfig, ini)
|
|
content = ini.read_text()
|
|
assert "name" in content
|
|
assert "count" in content
|
|
assert "ratio" in content
|
|
assert "enabled" in content
|
|
|
|
def test_created_file_has_comments(self, tmp_path):
|
|
ini = tmp_path / "new.ini"
|
|
ensure_ini(SampleConfig, ini)
|
|
content = ini.read_text()
|
|
assert "# name (str)" in content
|
|
assert "# count (int)" in content
|
|
|
|
def test_reads_existing_file(self, tmp_path):
|
|
ini = tmp_path / "existing.ini"
|
|
ini.write_text("[sample]\ncount = 42\n")
|
|
config = ensure_ini(SampleConfig, ini)
|
|
assert config.count == 42
|
|
|
|
def test_does_not_overwrite_existing(self, tmp_path):
|
|
ini = tmp_path / "existing.ini"
|
|
ini.write_text("[sample]\ncount = 42\n")
|
|
ensure_ini(SampleConfig, ini)
|
|
content = ini.read_text()
|
|
assert content == "[sample]\ncount = 42\n"
|
|
|
|
def test_created_file_is_loadable(self, tmp_path):
|
|
ini = tmp_path / "new.ini"
|
|
ensure_ini(SampleConfig, ini)
|
|
config = load_ini(SampleConfig, ini)
|
|
assert config.name == "default"
|
|
assert config.count == 10
|
|
assert config.ratio == 0.5
|
|
|
|
|
|
class TestIntegration:
|
|
def test_ini_then_cli_override(self, tmp_path):
|
|
ini = tmp_path / "test.ini"
|
|
ini.write_text("[sample]\ncount = 42\n")
|
|
config = load_ini(SampleConfig, ini)
|
|
assert config.count == 42
|
|
|
|
args = Namespace(name=None, count=99,
|
|
ratio=None, enabled=None)
|
|
config = apply_cli_overrides(config, args)
|
|
assert config.count == 99
|
|
assert config.name == "default"
|
|
|
|
def test_ensure_then_cli_override(self, tmp_path):
|
|
ini = tmp_path / "new.ini"
|
|
config = ensure_ini(SampleConfig, ini)
|
|
assert config.count == 10
|
|
|
|
args = Namespace(name=None, count=99,
|
|
ratio=None, enabled=None)
|
|
config = apply_cli_overrides(config, args)
|
|
assert config.count == 99
|
|
assert config.name == "default"
|
|
|
|
# Config file unchanged
|
|
reloaded = load_ini(SampleConfig, ini)
|
|
assert reloaded.count == 10
|
|
|
|
|
|
class TestResolveHints:
|
|
def test_returns_type_dict(self):
|
|
hints = resolve_hints(SampleConfig)
|
|
assert hints["name"] is str
|
|
assert hints["count"] is int
|
|
assert hints["ratio"] is float
|
|
assert hints["enabled"] is bool
|
|
|
|
|
|
class TestFormatSection:
|
|
def test_includes_section_header(self):
|
|
text = format_section(SampleConfig)
|
|
assert "[sample]" in text
|
|
|
|
def test_custom_section_name(self):
|
|
text = format_section(SampleConfig, "custom")
|
|
assert "[custom]" in text
|
|
|
|
def test_includes_all_fields(self):
|
|
text = format_section(SampleConfig)
|
|
assert "name = default" in text
|
|
assert "count = 10" in text
|
|
assert "ratio = 0.5" in text
|
|
assert "enabled = True" in text
|
|
|
|
def test_includes_type_comments(self):
|
|
text = format_section(SampleConfig)
|
|
assert "# name (str)" in text
|
|
assert "# count (int)" in text
|
|
|
|
def test_is_loadable(self, tmp_path):
|
|
ini = tmp_path / "test.ini"
|
|
ini.write_text(format_section(SampleConfig) + "\n")
|
|
config = load_ini(SampleConfig, ini)
|
|
assert config.name == "default"
|
|
assert config.count == 10
|
|
|
|
|
|
class TestEnsureIniMulti:
|
|
def test_creates_file_with_multiple_sections(self, tmp_path):
|
|
@dataclass
|
|
class OtherConfig:
|
|
host: str = "localhost"
|
|
port: int = 8080
|
|
|
|
ini = tmp_path / "multi.ini"
|
|
ensure_ini_multi([
|
|
(SampleConfig, None),
|
|
(OtherConfig, "server"),
|
|
], ini)
|
|
|
|
content = ini.read_text()
|
|
assert "[sample]" in content
|
|
assert "[server]" in content
|
|
assert "name = default" in content
|
|
assert "host = localhost" in content
|
|
|
|
def test_does_not_overwrite_existing(self, tmp_path):
|
|
ini = tmp_path / "multi.ini"
|
|
ini.write_text("[existing]\nfoo = bar\n")
|
|
ensure_ini_multi([(SampleConfig, None)], ini)
|
|
content = ini.read_text()
|
|
assert content == "[existing]\nfoo = bar\n"
|
|
|
|
def test_sections_are_loadable(self, tmp_path):
|
|
@dataclass
|
|
class DbConfig:
|
|
url: str = "sqlite:///test.db"
|
|
|
|
ini = tmp_path / "multi.ini"
|
|
ensure_ini_multi([
|
|
(SampleConfig, None),
|
|
(DbConfig, "database"),
|
|
], ini)
|
|
|
|
sample = load_ini(SampleConfig, ini)
|
|
db = load_ini(DbConfig, ini, section="database")
|
|
assert sample.name == "default"
|
|
assert db.url == "sqlite:///test.db"
|
|
|
|
|
|
class TestApplyOverrides:
|
|
def test_applies_dotted_path(self):
|
|
config = SampleConfig()
|
|
result = apply_overrides(config, {
|
|
"provider.name": "custom",
|
|
"provider.count": "99",
|
|
}, prefix="provider")
|
|
assert result.name == "custom"
|
|
assert result.count == 99
|
|
|
|
def test_without_prefix(self):
|
|
config = SampleConfig()
|
|
result = apply_overrides(config, {
|
|
"name": "direct",
|
|
"count": "42",
|
|
})
|
|
assert result.name == "direct"
|
|
assert result.count == 42
|
|
|
|
def test_no_matching_keys_returns_same(self):
|
|
config = SampleConfig()
|
|
result = apply_overrides(config, {"other.key": "val"}, prefix="provider")
|
|
assert result is config
|
|
|
|
def test_bool_coercion(self):
|
|
config = SampleConfig()
|
|
result = apply_overrides(config, {"enabled": "false"})
|
|
assert result.enabled is False
|
|
|
|
def test_preserves_unset_fields(self):
|
|
config = SampleConfig()
|
|
result = apply_overrides(config, {"name": "changed"})
|
|
assert result.name == "changed"
|
|
assert result.count == 10 # unchanged
|
|
assert result.ratio == 0.5 # unchanged
|
|
|
|
|
|
class TestFormatHelp:
|
|
def test_lists_all_fields(self):
|
|
lines = format_help(SampleConfig)
|
|
assert len(lines) == 4
|
|
assert any("name" in l for l in lines)
|
|
assert any("count" in l for l in lines)
|
|
|
|
def test_includes_types(self):
|
|
lines = format_help(SampleConfig)
|
|
text = "\n".join(lines)
|
|
assert "str" in text
|
|
assert "int" in text
|
|
|
|
def test_includes_defaults(self):
|
|
lines = format_help(SampleConfig)
|
|
text = "\n".join(lines)
|
|
assert "default" in text
|
|
assert "10" in text
|
|
|
|
def test_with_prefix(self):
|
|
lines = format_help(SampleConfig, prefix="provider")
|
|
assert any("provider.name" in l for l in lines)
|
|
assert any("provider.count" in l for l in lines)
|