#!/usr/bin/env python3 """ReflACT unified training entry point. Usage ----- python scripts/train.py --config configs/alfworld/default.yaml Any YAML key can be overridden from the command line:: python scripts/train.py --config configs/alfworld/default.yaml \\ --batch_size 40 --num_epochs 2 --seed 123 Run ``python scripts/train.py --help`` for a full list of options. """ from __future__ import annotations import argparse import datetime import os import sys # Ensure the project root is on sys.path so ``import skillopt`` works # regardless of where the script is invoked from. _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) _PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR) if _PROJECT_ROOT not in sys.path: sys.path.insert(0, _PROJECT_ROOT) from skillopt.model.common import default_model_for_backend, normalize_backend_name _OPENAI_DEFAULT_MODEL_SENTINELS = {"gpt-5.4", "gpt-5.5"} # ── Environment registry ──────────────────────────────────────────────────── _ENV_REGISTRY: dict[str, type] = {} def _register_builtins() -> None: """Lazy-import built-in adapters so we don't pull heavy deps at CLI parse time.""" try: from skillopt.envs.alfworld.adapter import ALFWorldAdapter _ENV_REGISTRY["alfworld"] = ALFWorldAdapter except ImportError: pass # ALFWorld deps not installed — skip try: from skillopt.envs.searchqa.adapter import SearchQAAdapter _ENV_REGISTRY["searchqa"] = SearchQAAdapter except ImportError: pass try: from skillopt.envs.livemathematicianbench.adapter import LiveMathematicianBenchAdapter _ENV_REGISTRY["livemathematicianbench"] = LiveMathematicianBenchAdapter except ImportError: pass try: from skillopt.envs.babyvision.adapter import BabyVisionAdapter _ENV_REGISTRY["babyvision"] = BabyVisionAdapter except ImportError: pass try: from skillopt.envs.spreadsheetbench.adapter import SpreadsheetBenchAdapter _ENV_REGISTRY["spreadsheetbench"] = SpreadsheetBenchAdapter except ImportError: pass try: from skillopt.envs.mmrb.adapter import MMRBAdapter _ENV_REGISTRY["mmrb"] = MMRBAdapter except ImportError: pass try: from skillopt.envs.docvqa.adapter import DocVQAAdapter _ENV_REGISTRY["docvqa"] = DocVQAAdapter except ImportError: pass try: from skillopt.envs.mathverse.adapter import MathVerseAdapter _ENV_REGISTRY["mathverse"] = MathVerseAdapter except ImportError: pass try: from skillopt.envs.officeqa.adapter import OfficeQAAdapter _ENV_REGISTRY["officeqa"] = OfficeQAAdapter except ImportError: pass try: from skillopt.envs.sealqa.adapter import SealQAAdapter _ENV_REGISTRY["sealqa"] = SealQAAdapter except ImportError: pass try: from skillopt.envs.swebench.adapter import SWEBenchAdapter _ENV_REGISTRY["swebench"] = SWEBenchAdapter except ImportError: pass def get_adapter(cfg: dict): """Instantiate the environment adapter specified in ``cfg["env"]``.""" _register_builtins() env_name = cfg.get("env", "alfworld") if env_name not in _ENV_REGISTRY: raise ValueError( f"Unknown environment '{env_name}'. " f"Available: {list(_ENV_REGISTRY.keys())}" ) adapter_cls = _ENV_REGISTRY[env_name] # Inspect adapter __init__ signature and only pass accepted kwargs import inspect sig = inspect.signature(adapter_cls.__init__) accepted = set(sig.parameters.keys()) - {"self"} adapter_kwargs: dict = {} for key in accepted: if key in cfg: adapter_kwargs[key] = cfg[key] return adapter_cls(**adapter_kwargs) # ── CLI ────────────────────────────────────────────────────────────────────── _BOOL = lambda x: x.lower() in ("true", "1", "yes") # noqa: E731 def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="ReflACT: Reflective Agent Tuning", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) p.add_argument("--config", type=str, required=True, help="Path to YAML config file") p.add_argument("--cfg-options", nargs="+", default=[], help="Override config: section.key=value (e.g. train.batch_size=40)") # Legacy flat CLI overrides (still work, prefer --cfg-options for new usage) p.add_argument("--env", type=str) p.add_argument("--backend", type=str, choices=["azure_openai", "codex", "codex_exec", "claude", "claude_chat", "claude_code_exec"]) p.add_argument("--teacher_model", type=str) p.add_argument("--student_model", type=str) p.add_argument("--teacher_backend", type=str) p.add_argument("--student_backend", type=str) p.add_argument("--reasoning_effort", type=str, choices=["", "low", "medium", "high", "xhigh", "max"]) p.add_argument("--rewrite_reasoning_effort", type=str) p.add_argument("--rewrite_max_completion_tokens", type=int) p.add_argument("--azure_endpoint", type=str) p.add_argument("--azure_api_version", type=str) p.add_argument("--azure_api_key", type=str) p.add_argument("--azure_openai_endpoint", type=str) p.add_argument("--azure_openai_api_version", type=str) p.add_argument("--azure_openai_api_key", type=str) p.add_argument("--azure_openai_auth_mode", type=str) p.add_argument("--azure_openai_ad_scope", type=str) p.add_argument("--azure_openai_managed_identity_client_id", type=str) p.add_argument("--teacher_azure_openai_endpoint", type=str) p.add_argument("--teacher_azure_openai_api_version", type=str) p.add_argument("--teacher_azure_openai_api_key", type=str) p.add_argument("--teacher_azure_openai_auth_mode", type=str) p.add_argument("--teacher_azure_openai_ad_scope", type=str) p.add_argument("--teacher_azure_openai_managed_identity_client_id", type=str) p.add_argument("--student_azure_openai_endpoint", type=str) p.add_argument("--student_azure_openai_api_version", type=str) p.add_argument("--student_azure_openai_api_key", type=str) p.add_argument("--student_azure_openai_auth_mode", type=str) p.add_argument("--student_azure_openai_ad_scope", type=str) p.add_argument("--student_azure_openai_managed_identity_client_id", type=str) p.add_argument("--codex_exec_path", type=str) p.add_argument("--codex_exec_sandbox", type=str) p.add_argument("--codex_exec_profile", type=str) p.add_argument("--codex_exec_full_auto", type=_BOOL) p.add_argument("--codex_exec_reasoning_effort", type=str) p.add_argument("--codex_exec_use_sdk", type=str) p.add_argument("--codex_exec_network_access", type=_BOOL) p.add_argument("--codex_exec_web_search", type=_BOOL) p.add_argument("--codex_exec_approval_policy", type=str) p.add_argument("--claude_code_exec_path", type=str) p.add_argument("--claude_code_exec_profile", type=str) p.add_argument("--claude_code_exec_use_sdk", type=str) p.add_argument("--claude_code_exec_effort", type=str) p.add_argument("--claude_code_exec_max_thinking_tokens", type=int) p.add_argument("--codex_trace_to_teacher", type=_BOOL) p.add_argument("--skill_init", type=str) p.add_argument("--num_epochs", type=int) p.add_argument("--train_size", type=int) p.add_argument("--steps_per_epoch", type=int) p.add_argument("--batch_size", type=int) p.add_argument("--accumulation", type=int) p.add_argument("--seed", type=int) p.add_argument("--edit_budget", type=int) p.add_argument("--min_edit_budget", type=int) p.add_argument("--lr_scheduler", type=str, choices=["constant", "linear", "cosine", "autonomous"]) p.add_argument("--lr_control_mode", type=str, choices=["fixed", "autonomous", "none"]) p.add_argument("--merge_batch_size", type=int) p.add_argument("--max_analyst_rounds", type=int) p.add_argument("--sel_env_num", type=int) p.add_argument("--test_env_num", type=int) p.add_argument("--eval_test", type=_BOOL) p.add_argument("--use_gate", type=_BOOL) p.add_argument("--max_steps", type=int) p.add_argument("--max_api_workers", type=int) p.add_argument("--analyst_workers", type=int) p.add_argument("--failure_only", type=_BOOL) p.add_argument("--minibatch_size", type=int) p.add_argument("--use_meta_reflect", type=_BOOL) p.add_argument("--meta_edit_budget", type=int) p.add_argument("--skill_update_mode", type=str, choices=[ "patch", "rewrite_from_suggestions", "rewrite", "suggestions", "full_rewrite", "full_rewrite_minibatch", "minibatch_full_rewrite", ]) p.add_argument("--use_deep_reflect", type=_BOOL) p.add_argument("--deep_reflect_failures", type=int) p.add_argument("--deep_reflect_successes", type=int) p.add_argument("--use_slow_update", type=_BOOL) p.add_argument("--slow_update_samples", type=int) p.add_argument("--longitudinal_pair_policy", type=str, choices=["mixed", "changed", "unchanged"]) p.add_argument("--use_meta_skill", type=_BOOL) p.add_argument("--data_path", type=str) p.add_argument("--split_mode", type=str, choices=["ratio", "split_dir"]) p.add_argument("--split_ratio", type=str) p.add_argument("--split_seed", type=int) p.add_argument("--split_dir", type=str) p.add_argument("--split_output_dir", type=str) p.add_argument("--data_root", type=str) p.add_argument("--max_turns", type=int) p.add_argument("--workers", type=int) p.add_argument("--limit", type=int) p.add_argument("--shuffle_choices", type=_BOOL) p.add_argument("--use_theorem", type=_BOOL) p.add_argument("--use_sketch", type=_BOOL) p.add_argument("--image_detail", type=str) p.add_argument("--judge_model", type=str) p.add_argument("--judge_max_completion_tokens", type=int) p.add_argument("--judge_retries", type=int) p.add_argument("--out_root", type=str) p.add_argument("--mode", type=str) return p.parse_args() # ── Flat key → structured path mapping (for legacy CLI → structured config) ── _LEGACY_TO_STRUCTURED: dict[str, str] = { "backend": "model.backend", "teacher_model": "model.teacher", "student_model": "model.student", "teacher_backend": "model.teacher_backend", "student_backend": "model.student_backend", "reasoning_effort": "model.reasoning_effort", "rewrite_reasoning_effort": "model.rewrite_reasoning_effort", "rewrite_max_completion_tokens": "model.rewrite_max_completion_tokens", "azure_endpoint": "model.azure_endpoint", "azure_api_version": "model.azure_api_version", "azure_api_key": "model.azure_api_key", "azure_openai_endpoint": "model.azure_openai_endpoint", "azure_openai_api_version": "model.azure_openai_api_version", "azure_openai_api_key": "model.azure_openai_api_key", "azure_openai_auth_mode": "model.azure_openai_auth_mode", "azure_openai_ad_scope": "model.azure_openai_ad_scope", "azure_openai_managed_identity_client_id": "model.azure_openai_managed_identity_client_id", "teacher_azure_openai_endpoint": "model.teacher_azure_openai_endpoint", "teacher_azure_openai_api_version": "model.teacher_azure_openai_api_version", "teacher_azure_openai_api_key": "model.teacher_azure_openai_api_key", "teacher_azure_openai_auth_mode": "model.teacher_azure_openai_auth_mode", "teacher_azure_openai_ad_scope": "model.teacher_azure_openai_ad_scope", "teacher_azure_openai_managed_identity_client_id": "model.teacher_azure_openai_managed_identity_client_id", "student_azure_openai_endpoint": "model.student_azure_openai_endpoint", "student_azure_openai_api_version": "model.student_azure_openai_api_version", "student_azure_openai_api_key": "model.student_azure_openai_api_key", "student_azure_openai_auth_mode": "model.student_azure_openai_auth_mode", "student_azure_openai_ad_scope": "model.student_azure_openai_ad_scope", "student_azure_openai_managed_identity_client_id": "model.student_azure_openai_managed_identity_client_id", "codex_exec_path": "model.codex_exec_path", "codex_exec_sandbox": "model.codex_exec_sandbox", "codex_exec_profile": "model.codex_exec_profile", "codex_exec_full_auto": "model.codex_exec_full_auto", "codex_exec_reasoning_effort": "model.codex_exec_reasoning_effort", "codex_exec_use_sdk": "model.codex_exec_use_sdk", "codex_exec_network_access": "model.codex_exec_network_access", "codex_exec_web_search": "model.codex_exec_web_search", "codex_exec_approval_policy": "model.codex_exec_approval_policy", "claude_code_exec_path": "model.claude_code_exec_path", "claude_code_exec_profile": "model.claude_code_exec_profile", "claude_code_exec_use_sdk": "model.claude_code_exec_use_sdk", "claude_code_exec_effort": "model.claude_code_exec_effort", "claude_code_exec_max_thinking_tokens": "model.claude_code_exec_max_thinking_tokens", "codex_trace_to_teacher": "model.codex_trace_to_teacher", "num_epochs": "train.num_epochs", "train_size": "train.train_size", "steps_per_epoch": "train.steps_per_epoch", "batch_size": "train.batch_size", "accumulation": "train.accumulation", "seed": "train.seed", "minibatch_size": "gradient.minibatch_size", "merge_batch_size": "gradient.merge_batch_size", "analyst_workers": "gradient.analyst_workers", "max_analyst_rounds": "gradient.max_analyst_rounds", "failure_only": "gradient.failure_only", "use_deep_reflect": "gradient.use_deep_reflect", "deep_reflect_failures": "gradient.deep_reflect_failures", "deep_reflect_successes": "gradient.deep_reflect_successes", "edit_budget": "optimizer.learning_rate", "min_edit_budget": "optimizer.min_learning_rate", "lr_scheduler": "optimizer.lr_scheduler", "lr_control_mode": "optimizer.lr_control_mode", "skill_update_mode": "optimizer.skill_update_mode", "use_meta_reflect": "optimizer.use_meta_reflect", "meta_edit_budget": "optimizer.meta_learning_rate", "use_slow_update": "optimizer.use_slow_update", "slow_update_samples": "optimizer.slow_update_samples", "longitudinal_pair_policy": "optimizer.longitudinal_pair_policy", "use_meta_skill": "optimizer.use_meta_skill", "use_gate": "evaluation.use_gate", "sel_env_num": "evaluation.sel_env_num", "test_env_num": "evaluation.test_env_num", "eval_test": "evaluation.eval_test", "env": "env.name", "skill_init": "env.skill_init", "out_root": "env.out_root", } def load_config(args: argparse.Namespace) -> dict: """Load config with _base_ inheritance, then apply CLI overrides.""" from skillopt.config import load_config as _load, flatten_config, is_structured cfg = _load(args.config, overrides=args.cfg_options) structured = is_structured(cfg) # Apply legacy --key value overrides cli = {k: v for k, v in vars(args).items() if v is not None and k not in ("config", "cfg_options")} if cli: if structured: from skillopt.config import apply_overrides mapped = [] for k, v in cli.items(): dotted = _LEGACY_TO_STRUCTURED.get(k) if dotted: mapped.append(f"{dotted}={v}") else: mapped.append(f"env.{k}={v}") apply_overrides(cfg, mapped) else: cfg.update(cli) # Flatten structured config → flat dict for trainer/adapter flat = flatten_config(cfg) if structured else cfg for new_key, old_key in ( ("azure_openai_endpoint", "azure_endpoint"), ("azure_openai_api_version", "azure_api_version"), ("azure_openai_api_key", "azure_api_key"), ): if flat.get(new_key) in (None, "") and flat.get(old_key) not in (None, ""): flat[new_key] = flat[old_key] explicit_backend = getattr(args, "backend", None) if explicit_backend is None: for option in args.cfg_options or []: key = str(option).split("=", 1)[0].strip() if key == "model.backend": explicit_backend = str(option).split("=", 1)[1].strip() break backend = normalize_backend_name(flat.get("model_backend") or flat.get("student_backend") or "azure_openai") def _has_model_override(dotted_key: str, legacy_key: str) -> bool: if getattr(args, legacy_key, None) is not None: return True for option in args.cfg_options or []: key = str(option).split("=", 1)[0].strip() if key == dotted_key: return True return False if explicit_backend is not None: backend = normalize_backend_name(explicit_backend) flat["model_backend"] = backend if backend in {"claude", "claude_chat"}: flat.setdefault("teacher_backend", "claude_chat") flat.setdefault("student_backend", "claude_chat") elif backend in {"codex", "codex_exec"}: flat.setdefault("teacher_backend", "openai_chat") flat.setdefault("student_backend", "codex_exec") elif backend == "claude_code_exec": flat.setdefault("teacher_backend", "openai_chat") flat.setdefault("student_backend", "claude_code_exec") else: flat.setdefault("teacher_backend", "openai_chat") flat.setdefault("student_backend", "openai_chat") else: flat.setdefault("teacher_backend", "openai_chat") flat.setdefault("student_backend", "openai_chat") if flat.get("teacher_backend") == "claude_chat": if ( str(flat.get("teacher_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS and not _has_model_override("model.teacher", "teacher_model") ): flat["teacher_model"] = default_model_for_backend("claude_chat") if flat.get("student_backend") == "claude_chat": if ( str(flat.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS and not _has_model_override("model.student", "student_model") ): flat["student_model"] = default_model_for_backend("claude_chat") if flat.get("student_backend") == "claude_code_exec": if ( str(flat.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS and not _has_model_override("model.student", "student_model") ): flat["student_model"] = default_model_for_backend("claude_chat") # Auto-generate output root if not flat.get("out_root"): env = flat.get("env", "unknown") model = flat.get("teacher_model", "unknown").replace("/", "-") ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") flat["out_root"] = os.path.join("outputs", f"skillopt_{env}_{model}_{ts}") flat["out_root"] = os.path.abspath(flat["out_root"]) return flat # ── Main ───────────────────────────────────────────────────────────────────── def main() -> None: args = parse_args() cfg = load_config(args) print(f"\n{'='*60}") print(f" ReflACT — Reflective Agent Tuning") print(f"{'='*60}") print(f" env: {cfg.get('env')}") print(f" teacher_model: {cfg.get('teacher_model')}") print(f" student_model: {cfg.get('student_model')}") print(f" teacher_backend:{cfg.get('teacher_backend', 'openai_chat')}") print(f" student_backend:{cfg.get('student_backend', 'openai_chat')}") print(f" reasoning: {cfg.get('reasoning_effort') or 'off'}") print(f" rewrite_effort: {cfg.get('rewrite_reasoning_effort') or 'off'}") print(f" epochs: {cfg.get('num_epochs')}") print(f" train_size: {cfg.get('train_size') or 'from dataset'}") print(f" steps/epoch: auto") print(f" batch_size: {cfg.get('batch_size')}") print(f" edit_budget: {cfg.get('edit_budget')}") print(f" lr_scheduler: {cfg.get('lr_scheduler', 'constant')}") print(f" update_mode: {cfg.get('skill_update_mode', 'patch')}") print(f" min_edit_budget:{cfg.get('min_edit_budget', 2)}") print(f" minibatch_size: {cfg.get('minibatch_size')}") print(f" seed: {cfg.get('seed')}") print(f" meta_reflect: {cfg.get('use_meta_reflect', False)}") print(f" meta_skill: {cfg.get('use_meta_skill', False)}") print(f" out_root: {cfg.get('out_root')}") print(f"{'='*60}\n") # Build adapter adapter = get_adapter(cfg) # Build trainer and run from skillopt.engine.trainer import ReflACTTrainer trainer = ReflACTTrainer(cfg, adapter) summary = trainer.train() print(f"\n Output saved to: {cfg['out_root']}") if summary.get("test_hard") is not None: print(f" Final test: {summary['test_hard']:.4f}") if __name__ == "__main__": main()