diff --git a/TODO b/TODO index 6fb9efd..1039c22 100644 --- a/TODO +++ b/TODO @@ -248,7 +248,7 @@ Content-Type: application/issue ID: 19 Type: feature Title: config framework with CLI integration -Status: in-progress +Status: done Priority: medium Created: 2026-06-06 Relationships: diff --git a/src/byteb4rb1e/utils/argparse/__init__.py b/src/byteb4rb1e/utils/argparse/__init__.py index 84ae3ed..5bf1156 100644 --- a/src/byteb4rb1e/utils/argparse/__init__.py +++ b/src/byteb4rb1e/utils/argparse/__init__.py @@ -1,6 +1,7 @@ """Utilities for building composable CLIs from command dataclasses.""" +from byteb4rb1e.utils.argparse.actions import KeyValueAction from byteb4rb1e.utils.argparse.command import CLICommand from byteb4rb1e.utils.argparse.dispatcher import CLI -__all__ = ["CLI", "CLICommand"] +__all__ = ["CLI", "CLICommand", "KeyValueAction"] diff --git a/src/byteb4rb1e/utils/argparse/actions.py b/src/byteb4rb1e/utils/argparse/actions.py new file mode 100644 index 0000000..79a5e8f --- /dev/null +++ b/src/byteb4rb1e/utils/argparse/actions.py @@ -0,0 +1,33 @@ +"""Custom argparse actions.""" + +from __future__ import annotations + +import argparse +from typing import Any + + +class KeyValueAction(argparse.Action): + """Argparse action that accumulates ``KEY=VALUE`` pairs into a dict. + + Usage:: + + parser.add_argument("--config", action=KeyValueAction, + default={}, metavar="KEY=VALUE", + help="Set a config option (can be repeated)") + + Then ``args.config`` is a ``dict[str, str]``. + """ + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Any, + option_string: str | None = None, + ) -> None: + d = getattr(namespace, self.dest, None) or {} + if "=" not in values: + parser.error(f"Invalid format: {values!r} (expected KEY=VALUE)") + key, _, value = values.partition("=") + d[key.strip()] = value.strip() + setattr(namespace, self.dest, d) diff --git a/src/byteb4rb1e/utils/config.py b/src/byteb4rb1e/utils/config.py new file mode 100644 index 0000000..8bece72 --- /dev/null +++ b/src/byteb4rb1e/utils/config.py @@ -0,0 +1,369 @@ +"""Config framework — INI-backed dataclasses with CLI integration. + +A config dataclass is the single source of truth for settings. Values +come from three layers (later wins): + + 1. Dataclass field defaults + 2. INI file sections + 3. CLI overrides (via argparse flags or ``--config KEY=VALUE``) + +Two CLI integration styles: + +- ``add_config_arguments`` — generates one ``--flag`` per field. +- ``apply_overrides`` — accepts a ``dict[str, str]`` of dotted-path + overrides from a unified ``--config KEY=VALUE`` flag. +""" + +import configparser +from argparse import ArgumentParser, Namespace +from dataclasses import MISSING, fields +from pathlib import Path +from typing import Any, Type, TypeVar, get_type_hints + +T = TypeVar("T") + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _parse_bool(value: str) -> bool: + """Parse a boolean from INI/CLI string.""" + return value.lower() in ("true", "yes", "1", "on") + + +_TYPE_MAP = { + int: int, + float: float, + str: str, + bool: _parse_bool, +} + + +def resolve_hints(cls: Type) -> dict[str, type]: + """Resolve type hints for a dataclass, handling both evaluated + and string annotations. + + :param cls: a dataclass class. + :returns: dict mapping field names to resolved types. + """ + try: + return get_type_hints(cls) + except Exception: + return { + f.name: f.type if isinstance(f.type, type) + else str + for f in fields(cls) + } + + +def _section_name(cls: Type, section: str | None = None) -> str: + """Derive INI section name from class name if not provided.""" + if section is not None: + return section + name = cls.__name__ + if name.endswith("Config"): + name = name[: -len("Config")] + return name.lower() + + +# --------------------------------------------------------------------------- +# INI loading +# --------------------------------------------------------------------------- + +def load_ini( + cls: Type[T], + path: Path, + section: str | None = None, +) -> T: + """Load a config dataclass from an INI file. + + If *section* is not given, the dataclass name (lowercased, + without trailing "Config") is used. + + Unknown keys in the INI file raise ValueError. Missing keys + use the dataclass default. + """ + section = _section_name(cls, section) + + parser = configparser.ConfigParser( + comment_prefixes=("#", ";"), + inline_comment_prefixes=("#", ";"), + ) + parser.read(path) + + if not parser.has_section(section): + return cls() # type: ignore[call-arg] + + hints = resolve_hints(cls) + field_names = {f.name for f in fields(cls) if f.init} + kwargs: dict[str, Any] = {} + + for key, raw_value in parser.items(section): + if key not in field_names: + raise ValueError( + f"Unknown config key '{key}' in" + f" [{section}]. Valid keys:" + f" {sorted(field_names)}" + ) + + field_type = hints.get(key, str) + coerce = _TYPE_MAP.get(field_type, field_type) + kwargs[key] = coerce(raw_value) + + return cls(**kwargs) # type: ignore[call-arg] + + +# --------------------------------------------------------------------------- +# INI writing +# --------------------------------------------------------------------------- + +def format_section(cls: Type, section: str | None = None) -> str: + """Format a config dataclass as an INI section string. + + Returns the section header and all fields with their defaults + as commented key-value pairs. + + :param cls: a dataclass class. + :param section: section name (derived from class name if None). + :returns: INI section string. + """ + section = _section_name(cls, section) + hints = resolve_hints(cls) + lines = [f"[{section}]"] + + for f in fields(cls): + if not f.init: + continue + field_type = hints.get(f.name, str) + type_name = getattr(field_type, "__name__", str(field_type)) + + if f.default is not MISSING: + default = f.default + elif f.default_factory is not MISSING: # type: ignore[arg-type] + default = f.default_factory() # type: ignore[misc] + else: + continue + + lines.append(f"# {f.name} ({type_name})") + lines.append(f"{f.name} = {default}") + lines.append("") + + return "\n".join(lines) + + +def ensure_ini( + cls: Type[T], + path: Path, + section: str | None = None, +) -> T: + """Load config from INI, creating the file with defaults if + it does not exist. + + On first run, writes a commented INI file with all fields and + their default values. On subsequent runs, reads the existing + file. Never writes back CLI overrides. + """ + section = _section_name(cls, section) + + if not path.exists(): + _write_default_ini(cls, path, section) + + return load_ini(cls, path, section) + + +def ensure_ini_multi( + configs: list[tuple[Type, str | None]], + path: Path, +) -> None: + """Create an INI file with multiple sections if it does not exist. + + Each entry is a (dataclass_cls, section_name) tuple. If + section_name is None, it is derived from the class name. + + Does not overwrite an existing file. + + :param configs: list of (cls, section) tuples. + :param path: path to the INI file. + """ + if path.exists(): + return + + sections = [format_section(cls, section) for cls, section in configs] + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text("\n".join(sections) + "\n") + + +def _write_default_ini( + cls: Type, + path: Path, + section: str, +) -> None: + """Write an INI file with all fields as commented defaults.""" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(format_section(cls, section) + "\n") + + +# --------------------------------------------------------------------------- +# CLI: per-flag style (add_config_arguments / apply_cli_overrides) +# --------------------------------------------------------------------------- + +def add_config_arguments( + cls: Type[T], + parser: ArgumentParser, + prefix: str = "", +) -> None: + """Add CLI arguments for each field in a config dataclass. + + Field names are converted to CLI flags: ``heart_rate_resolution`` + becomes ``--heart-rate-resolution`` (or ``---heart-rate-resolution`` + if a prefix is given). + """ + hints = resolve_hints(cls) + + for f in fields(cls): + if not f.init: + continue + flag_name = f.name.replace("_", "-") + if prefix: + flag_name = f"{prefix}-{flag_name}" + + field_type = hints.get(f.name, str) + + kwargs: dict[str, Any] = { + "dest": f.name, + } + + if field_type is bool: + kwargs["action"] = ( + "store_false" + if f.default is True + else "store_true" + ) + kwargs["default"] = None + else: + kwargs["type"] = _TYPE_MAP.get( + field_type, field_type + ) + kwargs["default"] = None + kwargs["metavar"] = field_type.__name__.upper() + + parser.add_argument(f"--{flag_name}", **kwargs) + + +def apply_cli_overrides( + config: T, + args: Namespace, +) -> T: + """Apply CLI argument values to a config instance. + + Only overrides fields that were explicitly set on the command + line (not None). Returns a new instance. + """ + overrides = {} + for f in fields(config): # type: ignore[arg-type] + if not f.init: + continue + cli_value = getattr(args, f.name, None) + if cli_value is not None: + overrides[f.name] = cli_value + + if not overrides: + return config + + from dataclasses import asdict + merged = asdict(config) # type: ignore[arg-type] + merged.update(overrides) + return type(config)(**merged) # type: ignore[return-value] + + +# --------------------------------------------------------------------------- +# CLI: dotted-path style (apply_overrides) +# --------------------------------------------------------------------------- + +def apply_overrides( + config: T, + overrides: dict[str, str], + prefix: str = "", +) -> T: + """Apply dotted-path string overrides to a config dataclass. + + Used with a unified ``--config KEY=VALUE`` CLI flag. Each key + is a dotted path relative to the prefix. + + Example:: + + overrides = { + "provider.base_url": "http://localhost:4000", + "provider.model": "qwen2.5:7b", + } + config = apply_overrides(config, overrides, prefix="provider") + # config.base_url == "http://localhost:4000" + # config.model == "qwen2.5:7b" + + :param config: a dataclass instance. + :param overrides: dict of dotted keys to string values. + :param prefix: only apply keys starting with this prefix. + :returns: new config instance with overrides applied. + """ + hints = resolve_hints(type(config)) + kwargs: dict[str, Any] = {} + changed = False + + for f in fields(config): + if not f.init: + continue + full_key = f"{prefix}.{f.name}" if prefix else f.name + if full_key in overrides: + raw = overrides[full_key] + field_type = hints.get(f.name, str) + coerce = _TYPE_MAP.get(field_type, field_type) + kwargs[f.name] = coerce(raw) + changed = True + else: + kwargs[f.name] = getattr(config, f.name) + + if not changed: + return config + + return type(config)(**kwargs) # type: ignore[return-value] + + +def format_help(cls: Type, prefix: str = "") -> list[str]: + """Generate help lines for a config dataclass. + + Each line shows the dotted key path, type, and default value. + Suitable for CLI epilog text. + + :param cls: a dataclass class. + :param prefix: prepended to each key path. + :returns: list of formatted help strings. + """ + hints = resolve_hints(cls) + lines = [] + + for f in fields(cls): + if not f.init: + continue + field_type = hints.get(f.name, str) + type_name = getattr(field_type, "__name__", str(field_type)) + key = f"{prefix}.{f.name}" if prefix else f.name + + if f.default is not MISSING: + default = f.default + elif f.default_factory is not MISSING: # type: ignore[arg-type] + default = repr(f.default_factory()) # type: ignore[misc] + else: + default = "(required)" + + lines.append(f" {key} ({type_name}, default: {default})") + + return lines + + +# --------------------------------------------------------------------------- +# Backwards compat +# --------------------------------------------------------------------------- + +# keep the old private name working for existing callers +_resolve_hints = resolve_hints diff --git a/tests/unit/byteb4rb1e/utils/argparse/test_actions.py b/tests/unit/byteb4rb1e/utils/argparse/test_actions.py new file mode 100644 index 0000000..3d0c150 --- /dev/null +++ b/tests/unit/byteb4rb1e/utils/argparse/test_actions.py @@ -0,0 +1,52 @@ +"""Tests for custom argparse actions.""" + +from argparse import ArgumentParser + +import pytest + +from byteb4rb1e.utils.argparse.actions import KeyValueAction + + +def _parse(*args): + parser = ArgumentParser() + parser.add_argument("--config", action=KeyValueAction, default={}, metavar="KEY=VALUE") + return parser.parse_args(list(args)) + + +class TestKeyValueAction: + + def test_single_pair(self): + args = _parse("--config", "key=value") + assert args.config == {"key": "value"} + + def test_multiple_pairs(self): + args = _parse("--config", "a=1", "--config", "b=2") + assert args.config == {"a": "1", "b": "2"} + + def test_dotted_key(self): + args = _parse("--config", "provider.base_url=http://localhost") + assert args.config == {"provider.base_url": "http://localhost"} + + def test_value_with_equals(self): + args = _parse("--config", "url=http://host?a=1&b=2") + assert args.config == {"url": "http://host?a=1&b=2"} + + def test_empty_value(self): + args = _parse("--config", "key=") + assert args.config == {"key": ""} + + def test_strips_whitespace(self): + args = _parse("--config", " key = value ") + assert args.config == {"key": "value"} + + def test_overwrites_duplicate_key(self): + args = _parse("--config", "key=first", "--config", "key=second") + assert args.config == {"key": "second"} + + def test_default_empty_dict(self): + args = _parse() + assert args.config == {} + + def test_no_equals_raises(self): + with pytest.raises(SystemExit): + _parse("--config", "no_equals_here") diff --git a/tests/unit/byteb4rb1e/utils/config/test_config.py b/tests/unit/byteb4rb1e/utils/config/test_config.py new file mode 100644 index 0000000..6b25306 --- /dev/null +++ b/tests/unit/byteb4rb1e/utils/config/test_config.py @@ -0,0 +1,347 @@ +"""Unit tests for the config framework.""" + +from argparse import ArgumentParser, Namespace +from dataclasses import dataclass +from pathlib import Path + +import pytest + +from byteb4rb1e.utils.config import ( + add_config_arguments, + apply_cli_overrides, + apply_overrides, + ensure_ini, + ensure_ini_multi, + format_help, + format_section, + load_ini, + resolve_hints, +) + + +@dataclass +class SampleConfig: + name: str = "default" + count: int = 10 + ratio: float = 0.5 + enabled: bool = True + + +class TestLoadIni: + def test_loads_values(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text( + "[sample]\n" + "name = custom\n" + "count = 42\n" + "ratio = 0.75\n" + ) + config = load_ini(SampleConfig, ini) + assert config.name == "custom" + assert config.count == 42 + assert config.ratio == 0.75 + assert config.enabled is True # default + + def test_missing_section_uses_defaults(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text("[other]\nfoo = bar\n") + config = load_ini(SampleConfig, ini) + assert config.name == "default" + assert config.count == 10 + + def test_missing_file_uses_defaults(self, tmp_path): + config = load_ini( + SampleConfig, tmp_path / "missing.ini" + ) + assert config.name == "default" + + def test_unknown_key_raises(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text("[sample]\nunknown_key = bad\n") + with pytest.raises(ValueError, match="unknown_key"): + load_ini(SampleConfig, ini) + + def test_custom_section_name(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text("[mysection]\nname = custom\n") + config = load_ini( + SampleConfig, ini, section="mysection" + ) + assert config.name == "custom" + + def test_comments_ignored(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text( + "[sample]\n" + "# this is a comment\n" + "name = works # inline comment\n" + ) + config = load_ini(SampleConfig, ini) + assert config.name == "works" + + +class TestAddConfigArguments: + def test_generates_flags(self): + parser = ArgumentParser() + add_config_arguments(SampleConfig, parser) + args = parser.parse_args( + ["--name", "cli", "--count", "99"] + ) + assert args.name == "cli" + assert args.count == 99 + + def test_defaults_are_none(self): + parser = ArgumentParser() + add_config_arguments(SampleConfig, parser) + args = parser.parse_args([]) + assert args.name is None + assert args.count is None + + def test_underscores_become_dashes(self): + @dataclass + class DashConfig: + my_long_name: str = "x" + + parser = ArgumentParser() + add_config_arguments(DashConfig, parser) + args = parser.parse_args( + ["--my-long-name", "val"] + ) + assert args.my_long_name == "val" + + +class TestApplyCliOverrides: + def test_overrides_set_values(self): + config = SampleConfig() + args = Namespace(name="override", count=None, + ratio=None, enabled=None) + result = apply_cli_overrides(config, args) + assert result.name == "override" + assert result.count == 10 # unchanged + + def test_no_overrides_returns_same(self): + config = SampleConfig() + args = Namespace(name=None, count=None, + ratio=None, enabled=None) + result = apply_cli_overrides(config, args) + assert result.name == "default" + assert result is config + + +class TestEnsureIni: + def test_creates_file_if_missing(self, tmp_path): + ini = tmp_path / "new.ini" + assert not ini.exists() + config = ensure_ini(SampleConfig, ini) + assert ini.exists() + assert config.name == "default" + assert config.count == 10 + + def test_created_file_has_all_fields(self, tmp_path): + ini = tmp_path / "new.ini" + ensure_ini(SampleConfig, ini) + content = ini.read_text() + assert "name" in content + assert "count" in content + assert "ratio" in content + assert "enabled" in content + + def test_created_file_has_comments(self, tmp_path): + ini = tmp_path / "new.ini" + ensure_ini(SampleConfig, ini) + content = ini.read_text() + assert "# name (str)" in content + assert "# count (int)" in content + + def test_reads_existing_file(self, tmp_path): + ini = tmp_path / "existing.ini" + ini.write_text("[sample]\ncount = 42\n") + config = ensure_ini(SampleConfig, ini) + assert config.count == 42 + + def test_does_not_overwrite_existing(self, tmp_path): + ini = tmp_path / "existing.ini" + ini.write_text("[sample]\ncount = 42\n") + ensure_ini(SampleConfig, ini) + content = ini.read_text() + assert content == "[sample]\ncount = 42\n" + + def test_created_file_is_loadable(self, tmp_path): + ini = tmp_path / "new.ini" + ensure_ini(SampleConfig, ini) + config = load_ini(SampleConfig, ini) + assert config.name == "default" + assert config.count == 10 + assert config.ratio == 0.5 + + +class TestIntegration: + def test_ini_then_cli_override(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text("[sample]\ncount = 42\n") + config = load_ini(SampleConfig, ini) + assert config.count == 42 + + args = Namespace(name=None, count=99, + ratio=None, enabled=None) + config = apply_cli_overrides(config, args) + assert config.count == 99 + assert config.name == "default" + + def test_ensure_then_cli_override(self, tmp_path): + ini = tmp_path / "new.ini" + config = ensure_ini(SampleConfig, ini) + assert config.count == 10 + + args = Namespace(name=None, count=99, + ratio=None, enabled=None) + config = apply_cli_overrides(config, args) + assert config.count == 99 + assert config.name == "default" + + # Config file unchanged + reloaded = load_ini(SampleConfig, ini) + assert reloaded.count == 10 + + +class TestResolveHints: + def test_returns_type_dict(self): + hints = resolve_hints(SampleConfig) + assert hints["name"] is str + assert hints["count"] is int + assert hints["ratio"] is float + assert hints["enabled"] is bool + + +class TestFormatSection: + def test_includes_section_header(self): + text = format_section(SampleConfig) + assert "[sample]" in text + + def test_custom_section_name(self): + text = format_section(SampleConfig, "custom") + assert "[custom]" in text + + def test_includes_all_fields(self): + text = format_section(SampleConfig) + assert "name = default" in text + assert "count = 10" in text + assert "ratio = 0.5" in text + assert "enabled = True" in text + + def test_includes_type_comments(self): + text = format_section(SampleConfig) + assert "# name (str)" in text + assert "# count (int)" in text + + def test_is_loadable(self, tmp_path): + ini = tmp_path / "test.ini" + ini.write_text(format_section(SampleConfig) + "\n") + config = load_ini(SampleConfig, ini) + assert config.name == "default" + assert config.count == 10 + + +class TestEnsureIniMulti: + def test_creates_file_with_multiple_sections(self, tmp_path): + @dataclass + class OtherConfig: + host: str = "localhost" + port: int = 8080 + + ini = tmp_path / "multi.ini" + ensure_ini_multi([ + (SampleConfig, None), + (OtherConfig, "server"), + ], ini) + + content = ini.read_text() + assert "[sample]" in content + assert "[server]" in content + assert "name = default" in content + assert "host = localhost" in content + + def test_does_not_overwrite_existing(self, tmp_path): + ini = tmp_path / "multi.ini" + ini.write_text("[existing]\nfoo = bar\n") + ensure_ini_multi([(SampleConfig, None)], ini) + content = ini.read_text() + assert content == "[existing]\nfoo = bar\n" + + def test_sections_are_loadable(self, tmp_path): + @dataclass + class DbConfig: + url: str = "sqlite:///test.db" + + ini = tmp_path / "multi.ini" + ensure_ini_multi([ + (SampleConfig, None), + (DbConfig, "database"), + ], ini) + + sample = load_ini(SampleConfig, ini) + db = load_ini(DbConfig, ini, section="database") + assert sample.name == "default" + assert db.url == "sqlite:///test.db" + + +class TestApplyOverrides: + def test_applies_dotted_path(self): + config = SampleConfig() + result = apply_overrides(config, { + "provider.name": "custom", + "provider.count": "99", + }, prefix="provider") + assert result.name == "custom" + assert result.count == 99 + + def test_without_prefix(self): + config = SampleConfig() + result = apply_overrides(config, { + "name": "direct", + "count": "42", + }) + assert result.name == "direct" + assert result.count == 42 + + def test_no_matching_keys_returns_same(self): + config = SampleConfig() + result = apply_overrides(config, {"other.key": "val"}, prefix="provider") + assert result is config + + def test_bool_coercion(self): + config = SampleConfig() + result = apply_overrides(config, {"enabled": "false"}) + assert result.enabled is False + + def test_preserves_unset_fields(self): + config = SampleConfig() + result = apply_overrides(config, {"name": "changed"}) + assert result.name == "changed" + assert result.count == 10 # unchanged + assert result.ratio == 0.5 # unchanged + + +class TestFormatHelp: + def test_lists_all_fields(self): + lines = format_help(SampleConfig) + assert len(lines) == 4 + assert any("name" in l for l in lines) + assert any("count" in l for l in lines) + + def test_includes_types(self): + lines = format_help(SampleConfig) + text = "\n".join(lines) + assert "str" in text + assert "int" in text + + def test_includes_defaults(self): + lines = format_help(SampleConfig) + text = "\n".join(lines) + assert "default" in text + assert "10" in text + + def test_with_prefix(self): + lines = format_help(SampleConfig, prefix="provider") + assert any("provider.name" in l for l in lines) + assert any("provider.count" in l for l in lines)