mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
484 lines
21 KiB
Python
484 lines
21 KiB
Python
#!/usr/bin/env python3
|
|
"""ReflACT unified training entry point.
|
|
|
|
Usage
|
|
-----
|
|
python scripts/train.py --config configs/alfworld/default.yaml
|
|
|
|
Any YAML key can be overridden from the command line::
|
|
|
|
python scripts/train.py --config configs/alfworld/default.yaml \\
|
|
--batch_size 40 --num_epochs 2 --seed 123
|
|
|
|
Run ``python scripts/train.py --help`` for a full list of options.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import datetime
|
|
import os
|
|
import sys
|
|
|
|
# Ensure the project root is on sys.path so ``import skillopt`` works
|
|
# regardless of where the script is invoked from.
|
|
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR)
|
|
if _PROJECT_ROOT not in sys.path:
|
|
sys.path.insert(0, _PROJECT_ROOT)
|
|
|
|
from skillopt.model.common import default_model_for_backend, normalize_backend_name
|
|
|
|
_OPENAI_DEFAULT_MODEL_SENTINELS = {"gpt-5.4", "gpt-5.5"}
|
|
|
|
|
|
# ── Environment registry ────────────────────────────────────────────────────
|
|
|
|
_ENV_REGISTRY: dict[str, type] = {}
|
|
|
|
|
|
def _register_builtins() -> None:
|
|
"""Lazy-import built-in adapters so we don't pull heavy deps at CLI parse time."""
|
|
try:
|
|
from skillopt.envs.alfworld.adapter import ALFWorldAdapter
|
|
_ENV_REGISTRY["alfworld"] = ALFWorldAdapter
|
|
except ImportError:
|
|
pass # ALFWorld deps not installed — skip
|
|
try:
|
|
from skillopt.envs.searchqa.adapter import SearchQAAdapter
|
|
_ENV_REGISTRY["searchqa"] = SearchQAAdapter
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from skillopt.envs.livemathematicianbench.adapter import LiveMathematicianBenchAdapter
|
|
_ENV_REGISTRY["livemathematicianbench"] = LiveMathematicianBenchAdapter
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from skillopt.envs.babyvision.adapter import BabyVisionAdapter
|
|
_ENV_REGISTRY["babyvision"] = BabyVisionAdapter
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from skillopt.envs.spreadsheetbench.adapter import SpreadsheetBenchAdapter
|
|
_ENV_REGISTRY["spreadsheetbench"] = SpreadsheetBenchAdapter
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from skillopt.envs.mmrb.adapter import MMRBAdapter
|
|
_ENV_REGISTRY["mmrb"] = MMRBAdapter
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from skillopt.envs.docvqa.adapter import DocVQAAdapter
|
|
_ENV_REGISTRY["docvqa"] = DocVQAAdapter
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from skillopt.envs.mathverse.adapter import MathVerseAdapter
|
|
_ENV_REGISTRY["mathverse"] = MathVerseAdapter
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from skillopt.envs.officeqa.adapter import OfficeQAAdapter
|
|
_ENV_REGISTRY["officeqa"] = OfficeQAAdapter
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from skillopt.envs.sealqa.adapter import SealQAAdapter
|
|
_ENV_REGISTRY["sealqa"] = SealQAAdapter
|
|
except ImportError:
|
|
pass
|
|
try:
|
|
from skillopt.envs.swebench.adapter import SWEBenchAdapter
|
|
_ENV_REGISTRY["swebench"] = SWEBenchAdapter
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def get_adapter(cfg: dict):
|
|
"""Instantiate the environment adapter specified in ``cfg["env"]``."""
|
|
_register_builtins()
|
|
env_name = cfg.get("env", "alfworld")
|
|
if env_name not in _ENV_REGISTRY:
|
|
raise ValueError(
|
|
f"Unknown environment '{env_name}'. "
|
|
f"Available: {list(_ENV_REGISTRY.keys())}"
|
|
)
|
|
adapter_cls = _ENV_REGISTRY[env_name]
|
|
|
|
# Inspect adapter __init__ signature and only pass accepted kwargs
|
|
import inspect
|
|
sig = inspect.signature(adapter_cls.__init__)
|
|
accepted = set(sig.parameters.keys()) - {"self"}
|
|
adapter_kwargs: dict = {}
|
|
for key in accepted:
|
|
if key in cfg:
|
|
adapter_kwargs[key] = cfg[key]
|
|
|
|
return adapter_cls(**adapter_kwargs)
|
|
|
|
|
|
# ── CLI ──────────────────────────────────────────────────────────────────────
|
|
|
|
_BOOL = lambda x: x.lower() in ("true", "1", "yes") # noqa: E731
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
p = argparse.ArgumentParser(
|
|
description="ReflACT: Reflective Agent Tuning",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=__doc__,
|
|
)
|
|
p.add_argument("--config", type=str, required=True,
|
|
help="Path to YAML config file")
|
|
p.add_argument("--cfg-options", nargs="+", default=[],
|
|
help="Override config: section.key=value (e.g. train.batch_size=40)")
|
|
|
|
# Legacy flat CLI overrides (still work, prefer --cfg-options for new usage)
|
|
p.add_argument("--env", type=str)
|
|
p.add_argument("--backend", type=str,
|
|
choices=["azure_openai", "codex", "codex_exec", "claude", "claude_chat", "claude_code_exec"])
|
|
p.add_argument("--teacher_model", type=str)
|
|
p.add_argument("--student_model", type=str)
|
|
p.add_argument("--teacher_backend", type=str)
|
|
p.add_argument("--student_backend", type=str)
|
|
p.add_argument("--reasoning_effort", type=str,
|
|
choices=["", "low", "medium", "high", "xhigh", "max"])
|
|
p.add_argument("--rewrite_reasoning_effort", type=str)
|
|
p.add_argument("--rewrite_max_completion_tokens", type=int)
|
|
p.add_argument("--azure_endpoint", type=str)
|
|
p.add_argument("--azure_api_version", type=str)
|
|
p.add_argument("--azure_api_key", type=str)
|
|
p.add_argument("--azure_openai_endpoint", type=str)
|
|
p.add_argument("--azure_openai_api_version", type=str)
|
|
p.add_argument("--azure_openai_api_key", type=str)
|
|
p.add_argument("--azure_openai_auth_mode", type=str)
|
|
p.add_argument("--azure_openai_ad_scope", type=str)
|
|
p.add_argument("--azure_openai_managed_identity_client_id", type=str)
|
|
p.add_argument("--teacher_azure_openai_endpoint", type=str)
|
|
p.add_argument("--teacher_azure_openai_api_version", type=str)
|
|
p.add_argument("--teacher_azure_openai_api_key", type=str)
|
|
p.add_argument("--teacher_azure_openai_auth_mode", type=str)
|
|
p.add_argument("--teacher_azure_openai_ad_scope", type=str)
|
|
p.add_argument("--teacher_azure_openai_managed_identity_client_id", type=str)
|
|
p.add_argument("--student_azure_openai_endpoint", type=str)
|
|
p.add_argument("--student_azure_openai_api_version", type=str)
|
|
p.add_argument("--student_azure_openai_api_key", type=str)
|
|
p.add_argument("--student_azure_openai_auth_mode", type=str)
|
|
p.add_argument("--student_azure_openai_ad_scope", type=str)
|
|
p.add_argument("--student_azure_openai_managed_identity_client_id", type=str)
|
|
p.add_argument("--codex_exec_path", type=str)
|
|
p.add_argument("--codex_exec_sandbox", type=str)
|
|
p.add_argument("--codex_exec_profile", type=str)
|
|
p.add_argument("--codex_exec_full_auto", type=_BOOL)
|
|
p.add_argument("--codex_exec_reasoning_effort", type=str)
|
|
p.add_argument("--codex_exec_use_sdk", type=str)
|
|
p.add_argument("--codex_exec_network_access", type=_BOOL)
|
|
p.add_argument("--codex_exec_web_search", type=_BOOL)
|
|
p.add_argument("--codex_exec_approval_policy", type=str)
|
|
p.add_argument("--claude_code_exec_path", type=str)
|
|
p.add_argument("--claude_code_exec_profile", type=str)
|
|
p.add_argument("--claude_code_exec_use_sdk", type=str)
|
|
p.add_argument("--claude_code_exec_effort", type=str)
|
|
p.add_argument("--claude_code_exec_max_thinking_tokens", type=int)
|
|
p.add_argument("--codex_trace_to_teacher", type=_BOOL)
|
|
p.add_argument("--skill_init", type=str)
|
|
p.add_argument("--num_epochs", type=int)
|
|
p.add_argument("--train_size", type=int)
|
|
p.add_argument("--steps_per_epoch", type=int)
|
|
p.add_argument("--batch_size", type=int)
|
|
p.add_argument("--accumulation", type=int)
|
|
p.add_argument("--seed", type=int)
|
|
p.add_argument("--edit_budget", type=int)
|
|
p.add_argument("--min_edit_budget", type=int)
|
|
p.add_argument("--lr_scheduler", type=str,
|
|
choices=["constant", "linear", "cosine", "autonomous"])
|
|
p.add_argument("--lr_control_mode", type=str,
|
|
choices=["fixed", "autonomous", "none"])
|
|
p.add_argument("--merge_batch_size", type=int)
|
|
p.add_argument("--max_analyst_rounds", type=int)
|
|
p.add_argument("--sel_env_num", type=int)
|
|
p.add_argument("--test_env_num", type=int)
|
|
p.add_argument("--eval_test", type=_BOOL)
|
|
p.add_argument("--use_gate", type=_BOOL)
|
|
p.add_argument("--max_steps", type=int)
|
|
p.add_argument("--max_api_workers", type=int)
|
|
p.add_argument("--analyst_workers", type=int)
|
|
p.add_argument("--failure_only", type=_BOOL)
|
|
p.add_argument("--minibatch_size", type=int)
|
|
p.add_argument("--use_meta_reflect", type=_BOOL)
|
|
p.add_argument("--meta_edit_budget", type=int)
|
|
p.add_argument("--skill_update_mode", type=str,
|
|
choices=[
|
|
"patch",
|
|
"rewrite_from_suggestions",
|
|
"rewrite",
|
|
"suggestions",
|
|
"full_rewrite",
|
|
"full_rewrite_minibatch",
|
|
"minibatch_full_rewrite",
|
|
])
|
|
p.add_argument("--use_deep_reflect", type=_BOOL)
|
|
p.add_argument("--deep_reflect_failures", type=int)
|
|
p.add_argument("--deep_reflect_successes", type=int)
|
|
p.add_argument("--use_slow_update", type=_BOOL)
|
|
p.add_argument("--slow_update_samples", type=int)
|
|
p.add_argument("--longitudinal_pair_policy", type=str,
|
|
choices=["mixed", "changed", "unchanged"])
|
|
p.add_argument("--use_meta_skill", type=_BOOL)
|
|
p.add_argument("--data_path", type=str)
|
|
p.add_argument("--split_mode", type=str,
|
|
choices=["ratio", "split_dir"])
|
|
p.add_argument("--split_ratio", type=str)
|
|
p.add_argument("--split_seed", type=int)
|
|
p.add_argument("--split_dir", type=str)
|
|
p.add_argument("--split_output_dir", type=str)
|
|
p.add_argument("--data_root", type=str)
|
|
p.add_argument("--max_turns", type=int)
|
|
p.add_argument("--workers", type=int)
|
|
p.add_argument("--limit", type=int)
|
|
p.add_argument("--shuffle_choices", type=_BOOL)
|
|
p.add_argument("--use_theorem", type=_BOOL)
|
|
p.add_argument("--use_sketch", type=_BOOL)
|
|
p.add_argument("--image_detail", type=str)
|
|
p.add_argument("--judge_model", type=str)
|
|
p.add_argument("--judge_max_completion_tokens", type=int)
|
|
p.add_argument("--judge_retries", type=int)
|
|
p.add_argument("--out_root", type=str)
|
|
p.add_argument("--mode", type=str)
|
|
|
|
return p.parse_args()
|
|
|
|
|
|
# ── Flat key → structured path mapping (for legacy CLI → structured config) ──
|
|
|
|
_LEGACY_TO_STRUCTURED: dict[str, str] = {
|
|
"backend": "model.backend",
|
|
"teacher_model": "model.teacher",
|
|
"student_model": "model.student",
|
|
"teacher_backend": "model.teacher_backend",
|
|
"student_backend": "model.student_backend",
|
|
"reasoning_effort": "model.reasoning_effort",
|
|
"rewrite_reasoning_effort": "model.rewrite_reasoning_effort",
|
|
"rewrite_max_completion_tokens": "model.rewrite_max_completion_tokens",
|
|
"azure_endpoint": "model.azure_endpoint",
|
|
"azure_api_version": "model.azure_api_version",
|
|
"azure_api_key": "model.azure_api_key",
|
|
"azure_openai_endpoint": "model.azure_openai_endpoint",
|
|
"azure_openai_api_version": "model.azure_openai_api_version",
|
|
"azure_openai_api_key": "model.azure_openai_api_key",
|
|
"azure_openai_auth_mode": "model.azure_openai_auth_mode",
|
|
"azure_openai_ad_scope": "model.azure_openai_ad_scope",
|
|
"azure_openai_managed_identity_client_id": "model.azure_openai_managed_identity_client_id",
|
|
"teacher_azure_openai_endpoint": "model.teacher_azure_openai_endpoint",
|
|
"teacher_azure_openai_api_version": "model.teacher_azure_openai_api_version",
|
|
"teacher_azure_openai_api_key": "model.teacher_azure_openai_api_key",
|
|
"teacher_azure_openai_auth_mode": "model.teacher_azure_openai_auth_mode",
|
|
"teacher_azure_openai_ad_scope": "model.teacher_azure_openai_ad_scope",
|
|
"teacher_azure_openai_managed_identity_client_id": "model.teacher_azure_openai_managed_identity_client_id",
|
|
"student_azure_openai_endpoint": "model.student_azure_openai_endpoint",
|
|
"student_azure_openai_api_version": "model.student_azure_openai_api_version",
|
|
"student_azure_openai_api_key": "model.student_azure_openai_api_key",
|
|
"student_azure_openai_auth_mode": "model.student_azure_openai_auth_mode",
|
|
"student_azure_openai_ad_scope": "model.student_azure_openai_ad_scope",
|
|
"student_azure_openai_managed_identity_client_id": "model.student_azure_openai_managed_identity_client_id",
|
|
"codex_exec_path": "model.codex_exec_path",
|
|
"codex_exec_sandbox": "model.codex_exec_sandbox",
|
|
"codex_exec_profile": "model.codex_exec_profile",
|
|
"codex_exec_full_auto": "model.codex_exec_full_auto",
|
|
"codex_exec_reasoning_effort": "model.codex_exec_reasoning_effort",
|
|
"codex_exec_use_sdk": "model.codex_exec_use_sdk",
|
|
"codex_exec_network_access": "model.codex_exec_network_access",
|
|
"codex_exec_web_search": "model.codex_exec_web_search",
|
|
"codex_exec_approval_policy": "model.codex_exec_approval_policy",
|
|
"claude_code_exec_path": "model.claude_code_exec_path",
|
|
"claude_code_exec_profile": "model.claude_code_exec_profile",
|
|
"claude_code_exec_use_sdk": "model.claude_code_exec_use_sdk",
|
|
"claude_code_exec_effort": "model.claude_code_exec_effort",
|
|
"claude_code_exec_max_thinking_tokens": "model.claude_code_exec_max_thinking_tokens",
|
|
"codex_trace_to_teacher": "model.codex_trace_to_teacher",
|
|
"num_epochs": "train.num_epochs",
|
|
"train_size": "train.train_size",
|
|
"steps_per_epoch": "train.steps_per_epoch",
|
|
"batch_size": "train.batch_size",
|
|
"accumulation": "train.accumulation",
|
|
"seed": "train.seed",
|
|
"minibatch_size": "gradient.minibatch_size",
|
|
"merge_batch_size": "gradient.merge_batch_size",
|
|
"analyst_workers": "gradient.analyst_workers",
|
|
"max_analyst_rounds": "gradient.max_analyst_rounds",
|
|
"failure_only": "gradient.failure_only",
|
|
"use_deep_reflect": "gradient.use_deep_reflect",
|
|
"deep_reflect_failures": "gradient.deep_reflect_failures",
|
|
"deep_reflect_successes": "gradient.deep_reflect_successes",
|
|
"edit_budget": "optimizer.learning_rate",
|
|
"min_edit_budget": "optimizer.min_learning_rate",
|
|
"lr_scheduler": "optimizer.lr_scheduler",
|
|
"lr_control_mode": "optimizer.lr_control_mode",
|
|
"skill_update_mode": "optimizer.skill_update_mode",
|
|
"use_meta_reflect": "optimizer.use_meta_reflect",
|
|
"meta_edit_budget": "optimizer.meta_learning_rate",
|
|
"use_slow_update": "optimizer.use_slow_update",
|
|
"slow_update_samples": "optimizer.slow_update_samples",
|
|
"longitudinal_pair_policy": "optimizer.longitudinal_pair_policy",
|
|
"use_meta_skill": "optimizer.use_meta_skill",
|
|
"use_gate": "evaluation.use_gate",
|
|
"sel_env_num": "evaluation.sel_env_num",
|
|
"test_env_num": "evaluation.test_env_num",
|
|
"eval_test": "evaluation.eval_test",
|
|
"env": "env.name",
|
|
"skill_init": "env.skill_init",
|
|
"out_root": "env.out_root",
|
|
}
|
|
|
|
|
|
def load_config(args: argparse.Namespace) -> dict:
|
|
"""Load config with _base_ inheritance, then apply CLI overrides."""
|
|
from skillopt.config import load_config as _load, flatten_config, is_structured
|
|
|
|
cfg = _load(args.config, overrides=args.cfg_options)
|
|
structured = is_structured(cfg)
|
|
|
|
# Apply legacy --key value overrides
|
|
cli = {k: v for k, v in vars(args).items()
|
|
if v is not None and k not in ("config", "cfg_options")}
|
|
if cli:
|
|
if structured:
|
|
from skillopt.config import apply_overrides
|
|
mapped = []
|
|
for k, v in cli.items():
|
|
dotted = _LEGACY_TO_STRUCTURED.get(k)
|
|
if dotted:
|
|
mapped.append(f"{dotted}={v}")
|
|
else:
|
|
mapped.append(f"env.{k}={v}")
|
|
apply_overrides(cfg, mapped)
|
|
else:
|
|
cfg.update(cli)
|
|
|
|
# Flatten structured config → flat dict for trainer/adapter
|
|
flat = flatten_config(cfg) if structured else cfg
|
|
|
|
for new_key, old_key in (
|
|
("azure_openai_endpoint", "azure_endpoint"),
|
|
("azure_openai_api_version", "azure_api_version"),
|
|
("azure_openai_api_key", "azure_api_key"),
|
|
):
|
|
if flat.get(new_key) in (None, "") and flat.get(old_key) not in (None, ""):
|
|
flat[new_key] = flat[old_key]
|
|
|
|
explicit_backend = getattr(args, "backend", None)
|
|
if explicit_backend is None:
|
|
for option in args.cfg_options or []:
|
|
key = str(option).split("=", 1)[0].strip()
|
|
if key == "model.backend":
|
|
explicit_backend = str(option).split("=", 1)[1].strip()
|
|
break
|
|
|
|
backend = normalize_backend_name(flat.get("model_backend") or flat.get("student_backend") or "azure_openai")
|
|
|
|
def _has_model_override(dotted_key: str, legacy_key: str) -> bool:
|
|
if getattr(args, legacy_key, None) is not None:
|
|
return True
|
|
for option in args.cfg_options or []:
|
|
key = str(option).split("=", 1)[0].strip()
|
|
if key == dotted_key:
|
|
return True
|
|
return False
|
|
|
|
if explicit_backend is not None:
|
|
backend = normalize_backend_name(explicit_backend)
|
|
flat["model_backend"] = backend
|
|
if backend in {"claude", "claude_chat"}:
|
|
flat.setdefault("teacher_backend", "claude_chat")
|
|
flat.setdefault("student_backend", "claude_chat")
|
|
elif backend in {"codex", "codex_exec"}:
|
|
flat.setdefault("teacher_backend", "openai_chat")
|
|
flat.setdefault("student_backend", "codex_exec")
|
|
elif backend == "claude_code_exec":
|
|
flat.setdefault("teacher_backend", "openai_chat")
|
|
flat.setdefault("student_backend", "claude_code_exec")
|
|
else:
|
|
flat.setdefault("teacher_backend", "openai_chat")
|
|
flat.setdefault("student_backend", "openai_chat")
|
|
else:
|
|
flat.setdefault("teacher_backend", "openai_chat")
|
|
flat.setdefault("student_backend", "openai_chat")
|
|
|
|
if flat.get("teacher_backend") == "claude_chat":
|
|
if (
|
|
str(flat.get("teacher_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS
|
|
and not _has_model_override("model.teacher", "teacher_model")
|
|
):
|
|
flat["teacher_model"] = default_model_for_backend("claude_chat")
|
|
if flat.get("student_backend") == "claude_chat":
|
|
if (
|
|
str(flat.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS
|
|
and not _has_model_override("model.student", "student_model")
|
|
):
|
|
flat["student_model"] = default_model_for_backend("claude_chat")
|
|
if flat.get("student_backend") == "claude_code_exec":
|
|
if (
|
|
str(flat.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS
|
|
and not _has_model_override("model.student", "student_model")
|
|
):
|
|
flat["student_model"] = default_model_for_backend("claude_chat")
|
|
|
|
# Auto-generate output root
|
|
if not flat.get("out_root"):
|
|
env = flat.get("env", "unknown")
|
|
model = flat.get("teacher_model", "unknown").replace("/", "-")
|
|
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
flat["out_root"] = os.path.join("outputs", f"skillopt_{env}_{model}_{ts}")
|
|
|
|
flat["out_root"] = os.path.abspath(flat["out_root"])
|
|
return flat
|
|
|
|
|
|
# ── Main ─────────────────────────────────────────────────────────────────────
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
cfg = load_config(args)
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f" ReflACT — Reflective Agent Tuning")
|
|
print(f"{'='*60}")
|
|
print(f" env: {cfg.get('env')}")
|
|
print(f" teacher_model: {cfg.get('teacher_model')}")
|
|
print(f" student_model: {cfg.get('student_model')}")
|
|
print(f" teacher_backend:{cfg.get('teacher_backend', 'openai_chat')}")
|
|
print(f" student_backend:{cfg.get('student_backend', 'openai_chat')}")
|
|
print(f" reasoning: {cfg.get('reasoning_effort') or 'off'}")
|
|
print(f" rewrite_effort: {cfg.get('rewrite_reasoning_effort') or 'off'}")
|
|
print(f" epochs: {cfg.get('num_epochs')}")
|
|
print(f" train_size: {cfg.get('train_size') or 'from dataset'}")
|
|
print(f" steps/epoch: auto")
|
|
print(f" batch_size: {cfg.get('batch_size')}")
|
|
print(f" edit_budget: {cfg.get('edit_budget')}")
|
|
print(f" lr_scheduler: {cfg.get('lr_scheduler', 'constant')}")
|
|
print(f" update_mode: {cfg.get('skill_update_mode', 'patch')}")
|
|
print(f" min_edit_budget:{cfg.get('min_edit_budget', 2)}")
|
|
print(f" minibatch_size: {cfg.get('minibatch_size')}")
|
|
print(f" seed: {cfg.get('seed')}")
|
|
print(f" meta_reflect: {cfg.get('use_meta_reflect', False)}")
|
|
print(f" meta_skill: {cfg.get('use_meta_skill', False)}")
|
|
print(f" out_root: {cfg.get('out_root')}")
|
|
print(f"{'='*60}\n")
|
|
|
|
# Build adapter
|
|
adapter = get_adapter(cfg)
|
|
|
|
# Build trainer and run
|
|
from skillopt.engine.trainer import ReflACTTrainer
|
|
trainer = ReflACTTrainer(cfg, adapter)
|
|
summary = trainer.train()
|
|
|
|
print(f"\n Output saved to: {cfg['out_root']}")
|
|
if summary.get("test_hard") is not None:
|
|
print(f" Final test: {summary['test_hard']:.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|