mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
Split failure reflections into SKILL_DEFECT (body edit) vs EXECUTION_LAPSE (protected appendix note that re-emphasizes an existing rule, never edited by step-level analysts). Toggle: optimizer.use_skill_aware_reflection (default false; baseline byte-identical when off). - optimizer/appendix.py: protected APPENDIX region (inject/extract/append with dedup), mirrors the slow_update protected-field pattern - optimizer/skill_aware.py: analyst prompt augmentation, appendix_notes parsing, threshold-gated LLM consolidation, and a process-wide runtime switch (configure_skill_aware_reflection) set once by the trainer - gradient/reflect.py: augment error/success analyst prompts at runtime; None-sentinel kwargs resolve from the global switch, so env adapters need no per-benchmark wiring (works for all envs, present and future) - optimizer/skill.py: generalize the protected-region check to (slow_update, appendix); edits inside any protected region are skipped - engine/trainer.py: inject appendix at init, flush per-step EXECUTION_LAPSE notes after the gate settles, optional consolidation - tests: regression suite incl. toggle-off byte-identical guarantee and env-independent global-switch resolution (6/6 passing + live smoke) Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
283 lines
12 KiB
Python
283 lines
12 KiB
Python
"""ReflACT config loading engine — structured YAML with inheritance.
|
|
|
|
Supports two config formats:
|
|
1. **Structured** (new): sections like ``model``, ``train``, ``gradient``,
|
|
``optimizer``, ``evaluation``, ``env`` — with ``_base_`` inheritance.
|
|
2. **Flat** (legacy): all keys at top level — fully backward compatible.
|
|
|
|
Usage::
|
|
|
|
from skillopt.config import load_config, flatten_config
|
|
|
|
cfg = load_config("configs/searchqa_default.yaml")
|
|
flat = flatten_config(cfg) # always returns flat dict for trainer
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import os
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
# ── Section names that indicate a structured config ──────────────────────
|
|
|
|
_STRUCTURED_SECTIONS = frozenset({
|
|
"model", "train", "gradient", "optimizer", "evaluation", "env",
|
|
})
|
|
|
|
# ── Structured → flat key mapping ────────────────────────────────────────
|
|
|
|
_FLATTEN_MAP: dict[str, str] = {
|
|
"model.backend": "model_backend",
|
|
"model.optimizer": "optimizer_model",
|
|
"model.target": "target_model",
|
|
"model.optimizer_backend": "optimizer_backend",
|
|
"model.target_backend": "target_backend",
|
|
"model.reasoning_effort": "reasoning_effort",
|
|
"model.rewrite_reasoning_effort": "rewrite_reasoning_effort",
|
|
"model.rewrite_max_completion_tokens": "rewrite_max_completion_tokens",
|
|
"model.codex_exec_path": "codex_exec_path",
|
|
"model.codex_exec_sandbox": "codex_exec_sandbox",
|
|
"model.codex_exec_profile": "codex_exec_profile",
|
|
"model.codex_exec_full_auto": "codex_exec_full_auto",
|
|
"model.codex_exec_reasoning_effort": "codex_exec_reasoning_effort",
|
|
"model.codex_exec_use_sdk": "codex_exec_use_sdk",
|
|
"model.codex_exec_network_access": "codex_exec_network_access",
|
|
"model.codex_exec_web_search": "codex_exec_web_search",
|
|
"model.codex_exec_approval_policy": "codex_exec_approval_policy",
|
|
"model.claude_code_exec_path": "claude_code_exec_path",
|
|
"model.claude_code_exec_profile": "claude_code_exec_profile",
|
|
"model.claude_code_exec_use_sdk": "claude_code_exec_use_sdk",
|
|
"model.claude_code_exec_effort": "claude_code_exec_effort",
|
|
"model.claude_code_exec_max_thinking_tokens": "claude_code_exec_max_thinking_tokens",
|
|
"model.codex_trace_to_optimizer": "codex_trace_to_optimizer",
|
|
"model.azure_endpoint": "azure_endpoint",
|
|
"model.azure_api_version": "azure_api_version",
|
|
"model.azure_api_key": "azure_api_key",
|
|
"model.azure_openai_endpoint": "azure_openai_endpoint",
|
|
"model.azure_openai_api_version": "azure_openai_api_version",
|
|
"model.azure_openai_api_key": "azure_openai_api_key",
|
|
"model.azure_openai_auth_mode": "azure_openai_auth_mode",
|
|
"model.azure_openai_ad_scope": "azure_openai_ad_scope",
|
|
"model.azure_openai_managed_identity_client_id": "azure_openai_managed_identity_client_id",
|
|
"model.optimizer_azure_openai_endpoint": "optimizer_azure_openai_endpoint",
|
|
"model.optimizer_azure_openai_api_version": "optimizer_azure_openai_api_version",
|
|
"model.optimizer_azure_openai_api_key": "optimizer_azure_openai_api_key",
|
|
"model.optimizer_azure_openai_auth_mode": "optimizer_azure_openai_auth_mode",
|
|
"model.optimizer_azure_openai_ad_scope": "optimizer_azure_openai_ad_scope",
|
|
"model.optimizer_azure_openai_managed_identity_client_id": "optimizer_azure_openai_managed_identity_client_id",
|
|
"model.target_azure_openai_endpoint": "target_azure_openai_endpoint",
|
|
"model.target_azure_openai_api_version": "target_azure_openai_api_version",
|
|
"model.target_azure_openai_api_key": "target_azure_openai_api_key",
|
|
"model.target_azure_openai_auth_mode": "target_azure_openai_auth_mode",
|
|
"model.target_azure_openai_ad_scope": "target_azure_openai_ad_scope",
|
|
"model.target_azure_openai_managed_identity_client_id": "target_azure_openai_managed_identity_client_id",
|
|
"model.qwen_chat_base_url": "qwen_chat_base_url",
|
|
"model.qwen_chat_api_key": "qwen_chat_api_key",
|
|
"model.qwen_chat_temperature": "qwen_chat_temperature",
|
|
"model.qwen_chat_timeout_seconds": "qwen_chat_timeout_seconds",
|
|
"model.qwen_chat_max_tokens": "qwen_chat_max_tokens",
|
|
"model.qwen_chat_enable_thinking": "qwen_chat_enable_thinking",
|
|
"model.optimizer_qwen_chat_base_url": "optimizer_qwen_chat_base_url",
|
|
"model.optimizer_qwen_chat_api_key": "optimizer_qwen_chat_api_key",
|
|
"model.optimizer_qwen_chat_temperature": "optimizer_qwen_chat_temperature",
|
|
"model.optimizer_qwen_chat_timeout_seconds": "optimizer_qwen_chat_timeout_seconds",
|
|
"model.optimizer_qwen_chat_max_tokens": "optimizer_qwen_chat_max_tokens",
|
|
"model.optimizer_qwen_chat_enable_thinking": "optimizer_qwen_chat_enable_thinking",
|
|
"model.target_qwen_chat_base_url": "target_qwen_chat_base_url",
|
|
"model.target_qwen_chat_api_key": "target_qwen_chat_api_key",
|
|
"model.target_qwen_chat_temperature": "target_qwen_chat_temperature",
|
|
"model.target_qwen_chat_timeout_seconds": "target_qwen_chat_timeout_seconds",
|
|
"model.target_qwen_chat_max_tokens": "target_qwen_chat_max_tokens",
|
|
"model.target_qwen_chat_enable_thinking": "target_qwen_chat_enable_thinking",
|
|
"model.minimax_base_url": "minimax_base_url",
|
|
"model.minimax_api_key": "minimax_api_key",
|
|
"model.minimax_model": "minimax_model",
|
|
"model.minimax_temperature": "minimax_temperature",
|
|
"model.minimax_max_tokens": "minimax_max_tokens",
|
|
"model.minimax_enable_thinking": "minimax_enable_thinking",
|
|
"train.num_epochs": "num_epochs",
|
|
"train.train_size": "train_size",
|
|
"train.steps_per_epoch": "steps_per_epoch",
|
|
"train.batch_size": "batch_size",
|
|
"train.accumulation": "accumulation",
|
|
"train.seed": "seed",
|
|
"gradient.minibatch_size": "minibatch_size",
|
|
"gradient.merge_batch_size": "merge_batch_size",
|
|
"gradient.analyst_workers": "analyst_workers",
|
|
"gradient.failure_only": "failure_only",
|
|
"gradient.max_analyst_rounds": "max_analyst_rounds",
|
|
"optimizer.learning_rate": "edit_budget",
|
|
"optimizer.min_learning_rate": "min_edit_budget",
|
|
"optimizer.lr_scheduler": "lr_scheduler",
|
|
"optimizer.lr_control_mode": "lr_control_mode",
|
|
"optimizer.skill_update_mode": "skill_update_mode",
|
|
"optimizer.meta_learning_rate": "meta_edit_budget",
|
|
"optimizer.use_slow_update": "use_slow_update",
|
|
"optimizer.slow_update_samples": "slow_update_samples",
|
|
"optimizer.slow_update_gate_with_selection": "slow_update_gate_with_selection",
|
|
"optimizer.longitudinal_pair_policy": "longitudinal_pair_policy",
|
|
"optimizer.use_meta_skill": "use_meta_skill",
|
|
"optimizer.use_skill_aware_reflection": "use_skill_aware_reflection",
|
|
"optimizer.skill_aware_appendix_source": "skill_aware_appendix_source",
|
|
"optimizer.skill_aware_consolidate_threshold": "skill_aware_consolidate_threshold",
|
|
"evaluation.use_gate": "use_gate",
|
|
"evaluation.gate_metric": "gate_metric",
|
|
"evaluation.gate_mixed_weight": "gate_mixed_weight",
|
|
"evaluation.sel_env_num": "sel_env_num",
|
|
"evaluation.test_env_num": "test_env_num",
|
|
"evaluation.eval_test": "eval_test",
|
|
"env.name": "env",
|
|
"env.skill_init": "skill_init",
|
|
"env.out_root": "out_root",
|
|
}
|
|
|
|
|
|
# ── Deep merge ───────────────────────────────────────────────────────────
|
|
|
|
def _deep_merge(base: dict, override: dict) -> dict:
|
|
"""Recursively merge *override* into *base* (returns new dict)."""
|
|
result = copy.deepcopy(base)
|
|
for key, val in override.items():
|
|
if key in result and isinstance(result[key], dict) and isinstance(val, dict):
|
|
result[key] = _deep_merge(result[key], val)
|
|
else:
|
|
result[key] = copy.deepcopy(val)
|
|
return result
|
|
|
|
|
|
# ── YAML loading with _base_ inheritance ─────────────────────────────────
|
|
|
|
def _load_yaml(path: str, _visited: set[str] | None = None) -> dict:
|
|
"""Load a YAML file, resolving ``_base_`` inheritance recursively."""
|
|
abs_path = os.path.abspath(path)
|
|
if _visited is None:
|
|
_visited = set()
|
|
if abs_path in _visited:
|
|
raise ValueError(f"Circular _base_ inheritance: {abs_path}")
|
|
_visited.add(abs_path)
|
|
|
|
with open(abs_path) as f:
|
|
cfg = yaml.safe_load(f) or {}
|
|
|
|
base_ref = cfg.pop("_base_", None)
|
|
if base_ref:
|
|
base_path = os.path.join(os.path.dirname(abs_path), base_ref)
|
|
base_cfg = _load_yaml(base_path, _visited)
|
|
cfg = _deep_merge(base_cfg, cfg)
|
|
|
|
return cfg
|
|
|
|
|
|
# ── Format detection ─────────────────────────────────────────────────────
|
|
|
|
def is_structured(cfg: dict) -> bool:
|
|
"""Return True if *cfg* uses the new structured section format."""
|
|
return any(
|
|
key in _STRUCTURED_SECTIONS and isinstance(cfg.get(key), dict)
|
|
for key in cfg
|
|
)
|
|
|
|
|
|
# ── Flatten ──────────────────────────────────────────────────────────────
|
|
|
|
def flatten_config(cfg: dict) -> dict:
|
|
"""Convert a structured config to the flat dict expected by the trainer.
|
|
|
|
If *cfg* is already flat, returns a shallow copy unchanged.
|
|
"""
|
|
if not is_structured(cfg):
|
|
return dict(cfg)
|
|
|
|
flat: dict[str, Any] = {}
|
|
|
|
# Apply the explicit mapping
|
|
for dotted, flat_key in _FLATTEN_MAP.items():
|
|
section, key = dotted.split(".", 1)
|
|
section_dict = cfg.get(section, {})
|
|
if isinstance(section_dict, dict) and key in section_dict:
|
|
flat[flat_key] = section_dict[key]
|
|
|
|
# Pass through env-specific keys not in the explicit mapping
|
|
env_section = cfg.get("env", {})
|
|
if isinstance(env_section, dict):
|
|
mapped_env_keys = {
|
|
k.split(".", 1)[1]
|
|
for k in _FLATTEN_MAP
|
|
if k.startswith("env.")
|
|
}
|
|
for key, val in env_section.items():
|
|
if key not in mapped_env_keys:
|
|
flat[key] = val
|
|
|
|
return flat
|
|
|
|
|
|
# ── Override application ─────────────────────────────────────────────────
|
|
|
|
def _cast_value(val_str: str) -> Any:
|
|
"""Auto-cast a CLI string value to int / float / bool / str."""
|
|
if val_str.lower() in ("true", "yes"):
|
|
return True
|
|
if val_str.lower() in ("false", "no"):
|
|
return False
|
|
try:
|
|
return int(val_str)
|
|
except ValueError:
|
|
pass
|
|
try:
|
|
return float(val_str)
|
|
except ValueError:
|
|
pass
|
|
return val_str
|
|
|
|
|
|
def apply_overrides(cfg: dict, overrides: list[str]) -> None:
|
|
"""Apply ``key=value`` overrides to a structured config (in place).
|
|
|
|
Supports both ``section.key=value`` (for structured configs) and
|
|
``key=value`` (for flat configs or flat keys in env section).
|
|
"""
|
|
for item in overrides:
|
|
if "=" not in item:
|
|
raise ValueError(f"Invalid override (expected key=value): {item!r}")
|
|
key, val_str = item.split("=", 1)
|
|
val = _cast_value(val_str)
|
|
|
|
if "." in key:
|
|
section, subkey = key.split(".", 1)
|
|
if section in cfg and isinstance(cfg[section], dict):
|
|
cfg[section][subkey] = val
|
|
else:
|
|
cfg.setdefault(section, {})[subkey] = val
|
|
else:
|
|
# Flat key — apply to top level (for legacy compat)
|
|
cfg[key] = val
|
|
|
|
|
|
# ── Public API ───────────────────────────────────────────────────────────
|
|
|
|
def load_config(
|
|
path: str,
|
|
overrides: list[str] | None = None,
|
|
) -> dict:
|
|
"""Load a config file with ``_base_`` inheritance and optional overrides.
|
|
|
|
Parameters
|
|
----------
|
|
path : str
|
|
Path to the YAML config file.
|
|
overrides : list[str] | None
|
|
``key=value`` strings from ``--cfg-options``.
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
The merged config (structured or flat depending on the YAML).
|
|
"""
|
|
cfg = _load_yaml(path)
|
|
if overrides:
|
|
apply_overrides(cfg, overrides)
|
|
return cfg
|