Files
microsoft-SkillOpt/skillopt/config.py
Cuzyoung 00602df9e9 feat(slow-update): add config-controlled gated / force-injected modes
Add optimizer.slow_update_gate_with_selection to control how epoch-boundary
slow-update guidance is applied:
- false (default): force-injected - inject guidance into current & best
  unconditionally (unchanged behavior).
- true: gated - evaluate the slow-update candidate on the selection set and
  accept/reject via the same validation gate as step-level updates
  (logic follows the SkillReflection ablation).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-31 02:02:23 +00:00

269 lines
11 KiB
Python

"""ReflACT config loading engine — structured YAML with inheritance.
Supports two config formats:
1. **Structured** (new): sections like ``model``, ``train``, ``gradient``,
``optimizer``, ``evaluation``, ``env`` — with ``_base_`` inheritance.
2. **Flat** (legacy): all keys at top level — fully backward compatible.
Usage::
from skillopt.config import load_config, flatten_config
cfg = load_config("configs/searchqa_default.yaml")
flat = flatten_config(cfg) # always returns flat dict for trainer
"""
from __future__ import annotations
import copy
import os
from typing import Any
import yaml
# ── Section names that indicate a structured config ──────────────────────
_STRUCTURED_SECTIONS = frozenset({
"model", "train", "gradient", "optimizer", "evaluation", "env",
})
# ── Structured → flat key mapping ────────────────────────────────────────
_FLATTEN_MAP: dict[str, str] = {
"model.backend": "model_backend",
"model.optimizer": "optimizer_model",
"model.target": "target_model",
"model.optimizer_backend": "optimizer_backend",
"model.target_backend": "target_backend",
"model.reasoning_effort": "reasoning_effort",
"model.rewrite_reasoning_effort": "rewrite_reasoning_effort",
"model.rewrite_max_completion_tokens": "rewrite_max_completion_tokens",
"model.codex_exec_path": "codex_exec_path",
"model.codex_exec_sandbox": "codex_exec_sandbox",
"model.codex_exec_profile": "codex_exec_profile",
"model.codex_exec_full_auto": "codex_exec_full_auto",
"model.codex_exec_reasoning_effort": "codex_exec_reasoning_effort",
"model.codex_exec_use_sdk": "codex_exec_use_sdk",
"model.codex_exec_network_access": "codex_exec_network_access",
"model.codex_exec_web_search": "codex_exec_web_search",
"model.codex_exec_approval_policy": "codex_exec_approval_policy",
"model.claude_code_exec_path": "claude_code_exec_path",
"model.claude_code_exec_profile": "claude_code_exec_profile",
"model.claude_code_exec_use_sdk": "claude_code_exec_use_sdk",
"model.claude_code_exec_effort": "claude_code_exec_effort",
"model.claude_code_exec_max_thinking_tokens": "claude_code_exec_max_thinking_tokens",
"model.codex_trace_to_optimizer": "codex_trace_to_optimizer",
"model.azure_endpoint": "azure_endpoint",
"model.azure_api_version": "azure_api_version",
"model.azure_api_key": "azure_api_key",
"model.azure_openai_endpoint": "azure_openai_endpoint",
"model.azure_openai_api_version": "azure_openai_api_version",
"model.azure_openai_api_key": "azure_openai_api_key",
"model.azure_openai_auth_mode": "azure_openai_auth_mode",
"model.azure_openai_ad_scope": "azure_openai_ad_scope",
"model.azure_openai_managed_identity_client_id": "azure_openai_managed_identity_client_id",
"model.optimizer_azure_openai_endpoint": "optimizer_azure_openai_endpoint",
"model.optimizer_azure_openai_api_version": "optimizer_azure_openai_api_version",
"model.optimizer_azure_openai_api_key": "optimizer_azure_openai_api_key",
"model.optimizer_azure_openai_auth_mode": "optimizer_azure_openai_auth_mode",
"model.optimizer_azure_openai_ad_scope": "optimizer_azure_openai_ad_scope",
"model.optimizer_azure_openai_managed_identity_client_id": "optimizer_azure_openai_managed_identity_client_id",
"model.target_azure_openai_endpoint": "target_azure_openai_endpoint",
"model.target_azure_openai_api_version": "target_azure_openai_api_version",
"model.target_azure_openai_api_key": "target_azure_openai_api_key",
"model.target_azure_openai_auth_mode": "target_azure_openai_auth_mode",
"model.target_azure_openai_ad_scope": "target_azure_openai_ad_scope",
"model.target_azure_openai_managed_identity_client_id": "target_azure_openai_managed_identity_client_id",
"model.qwen_chat_base_url": "qwen_chat_base_url",
"model.qwen_chat_api_key": "qwen_chat_api_key",
"model.qwen_chat_temperature": "qwen_chat_temperature",
"model.qwen_chat_timeout_seconds": "qwen_chat_timeout_seconds",
"model.qwen_chat_max_tokens": "qwen_chat_max_tokens",
"model.qwen_chat_enable_thinking": "qwen_chat_enable_thinking",
"train.num_epochs": "num_epochs",
"train.train_size": "train_size",
"train.steps_per_epoch": "steps_per_epoch",
"train.batch_size": "batch_size",
"train.accumulation": "accumulation",
"train.seed": "seed",
"gradient.minibatch_size": "minibatch_size",
"gradient.merge_batch_size": "merge_batch_size",
"gradient.analyst_workers": "analyst_workers",
"gradient.failure_only": "failure_only",
"gradient.max_analyst_rounds": "max_analyst_rounds",
"optimizer.learning_rate": "edit_budget",
"optimizer.min_learning_rate": "min_edit_budget",
"optimizer.lr_scheduler": "lr_scheduler",
"optimizer.lr_control_mode": "lr_control_mode",
"optimizer.skill_update_mode": "skill_update_mode",
"optimizer.meta_learning_rate": "meta_edit_budget",
"optimizer.use_slow_update": "use_slow_update",
"optimizer.slow_update_samples": "slow_update_samples",
"optimizer.slow_update_gate_with_selection": "slow_update_gate_with_selection",
"optimizer.longitudinal_pair_policy": "longitudinal_pair_policy",
"optimizer.use_meta_skill": "use_meta_skill",
"evaluation.use_gate": "use_gate",
"evaluation.gate_metric": "gate_metric",
"evaluation.gate_mixed_weight": "gate_mixed_weight",
"evaluation.sel_env_num": "sel_env_num",
"evaluation.test_env_num": "test_env_num",
"evaluation.eval_test": "eval_test",
"env.name": "env",
"env.skill_init": "skill_init",
"env.out_root": "out_root",
}
# ── Deep merge ───────────────────────────────────────────────────────────
def _deep_merge(base: dict, override: dict) -> dict:
"""Recursively merge *override* into *base* (returns new dict)."""
result = copy.deepcopy(base)
for key, val in override.items():
if key in result and isinstance(result[key], dict) and isinstance(val, dict):
result[key] = _deep_merge(result[key], val)
else:
result[key] = copy.deepcopy(val)
return result
# ── YAML loading with _base_ inheritance ─────────────────────────────────
def _load_yaml(path: str, _visited: set[str] | None = None) -> dict:
"""Load a YAML file, resolving ``_base_`` inheritance recursively."""
abs_path = os.path.abspath(path)
if _visited is None:
_visited = set()
if abs_path in _visited:
raise ValueError(f"Circular _base_ inheritance: {abs_path}")
_visited.add(abs_path)
with open(abs_path) as f:
cfg = yaml.safe_load(f) or {}
base_ref = cfg.pop("_base_", None)
if base_ref:
base_path = os.path.join(os.path.dirname(abs_path), base_ref)
base_cfg = _load_yaml(base_path, _visited)
cfg = _deep_merge(base_cfg, cfg)
return cfg
# ── Format detection ─────────────────────────────────────────────────────
def is_structured(cfg: dict) -> bool:
"""Return True if *cfg* uses the new structured section format."""
return any(
key in _STRUCTURED_SECTIONS and isinstance(cfg.get(key), dict)
for key in cfg
)
# ── Flatten ──────────────────────────────────────────────────────────────
def flatten_config(cfg: dict) -> dict:
"""Convert a structured config to the flat dict expected by the trainer.
If *cfg* is already flat, returns a shallow copy unchanged.
"""
if not is_structured(cfg):
return dict(cfg)
flat: dict[str, Any] = {}
evaluation_section = cfg.get("evaluation", {})
if isinstance(evaluation_section, dict) and evaluation_section.get("use_gate") is False:
raise ValueError(
"Gate validation is mandatory in this branch. Remove "
"`evaluation.use_gate: false` from the config."
)
# Apply the explicit mapping
for dotted, flat_key in _FLATTEN_MAP.items():
section, key = dotted.split(".", 1)
section_dict = cfg.get(section, {})
if isinstance(section_dict, dict) and key in section_dict:
flat[flat_key] = section_dict[key]
# Pass through env-specific keys not in the explicit mapping
env_section = cfg.get("env", {})
if isinstance(env_section, dict):
mapped_env_keys = {
k.split(".", 1)[1]
for k in _FLATTEN_MAP
if k.startswith("env.")
}
for key, val in env_section.items():
if key not in mapped_env_keys:
flat[key] = val
return flat
# ── Override application ─────────────────────────────────────────────────
def _cast_value(val_str: str) -> Any:
"""Auto-cast a CLI string value to int / float / bool / str."""
if val_str.lower() in ("true", "yes"):
return True
if val_str.lower() in ("false", "no"):
return False
try:
return int(val_str)
except ValueError:
pass
try:
return float(val_str)
except ValueError:
pass
return val_str
def apply_overrides(cfg: dict, overrides: list[str]) -> None:
"""Apply ``key=value`` overrides to a structured config (in place).
Supports both ``section.key=value`` (for structured configs) and
``key=value`` (for flat configs or flat keys in env section).
"""
for item in overrides:
if "=" not in item:
raise ValueError(f"Invalid override (expected key=value): {item!r}")
key, val_str = item.split("=", 1)
val = _cast_value(val_str)
if "." in key:
section, subkey = key.split(".", 1)
if section in cfg and isinstance(cfg[section], dict):
cfg[section][subkey] = val
else:
cfg.setdefault(section, {})[subkey] = val
else:
# Flat key — apply to top level (for legacy compat)
cfg[key] = val
# ── Public API ───────────────────────────────────────────────────────────
def load_config(
path: str,
overrides: list[str] | None = None,
) -> dict:
"""Load a config file with ``_base_`` inheritance and optional overrides.
Parameters
----------
path : str
Path to the YAML config file.
overrides : list[str] | None
``key=value`` strings from ``--cfg-options``.
Returns
-------
dict
The merged config (structured or flat depending on the YAML).
"""
cfg = _load_yaml(path)
if overrides:
apply_overrides(cfg, overrides)
return cfg