diff --git a/configs/_base_/default.yaml b/configs/_base_/default.yaml index 9ec2270..b1c9068 100644 --- a/configs/_base_/default.yaml +++ b/configs/_base_/default.yaml @@ -71,6 +71,12 @@ optimizer: evaluation: use_gate: true + # gate_metric: 'hard' (default, backward-compatible), + # 'soft' (use soft/F1 score), + # 'mixed' ((1 - w) * hard + w * soft). + # See skillopt/evaluation/gate.py for details. + gate_metric: hard + gate_mixed_weight: 0.5 sel_env_num: 0 test_env_num: 0 eval_test: true diff --git a/skillopt/config.py b/skillopt/config.py index cc61a27..bf56bda 100644 --- a/skillopt/config.py +++ b/skillopt/config.py @@ -101,6 +101,8 @@ _FLATTEN_MAP: dict[str, str] = { "optimizer.longitudinal_pair_policy": "longitudinal_pair_policy", "optimizer.use_meta_skill": "use_meta_skill", "evaluation.use_gate": "use_gate", + "evaluation.gate_metric": "gate_metric", + "evaluation.gate_mixed_weight": "gate_mixed_weight", "evaluation.sel_env_num": "sel_env_num", "evaluation.test_env_num": "test_env_num", "evaluation.eval_test": "eval_test", diff --git a/skillopt/engine/trainer.py b/skillopt/engine/trainer.py index d12982c..d3a2508 100644 --- a/skillopt/engine/trainer.py +++ b/skillopt/engine/trainer.py @@ -24,7 +24,7 @@ from collections import defaultdict from skillopt.datasets.base import BatchSpec from skillopt.envs.base import EnvAdapter -from skillopt.evaluation.gate import evaluate_gate +from skillopt.evaluation.gate import evaluate_gate, select_gate_score from skillopt.gradient.aggregate import merge_patches from skillopt.optimizer.meta_skill import run_meta_skill from skillopt.optimizer.clip import rank_and_select @@ -845,6 +845,26 @@ class ReflACTTrainer: "Gate validation is mandatory in this branch. Remove " "`evaluation.use_gate=false` from the config." ) + gate_metric = str(cfg.get("gate_metric", "hard")).strip().lower() + if gate_metric not in {"hard", "soft", "mixed"}: + raise ValueError( + f"evaluation.gate_metric must be 'hard' | 'soft' | 'mixed', " + f"got {gate_metric!r}" + ) + gate_mixed_weight = float(cfg.get("gate_mixed_weight", 0.5)) + if not 0.0 <= gate_mixed_weight <= 1.0: + raise ValueError( + f"evaluation.gate_mixed_weight must be in [0, 1], " + f"got {gate_mixed_weight}" + ) + print( + f" [gate] metric={gate_metric}" + + ( + f" mixed_weight={gate_mixed_weight}" + if gate_metric == "mixed" + else "" + ) + ) if current_score < 0: print(f"\n{'='*60}") print(" BASELINE — evaluate initial skill on Selection set (valid_seen)") @@ -857,16 +877,20 @@ class ReflACTTrainer: print(f" Selection items: {sel_n}") baseline_dir = os.path.join(out_root, "selection_eval_baseline") baseline_results = adapter.rollout(sel_env, skill_init, baseline_dir) - current_score, baseline_soft = compute_score(baseline_results) + baseline_hard, baseline_soft = compute_score(baseline_results) + current_score = select_gate_score( + baseline_hard, baseline_soft, gate_metric, gate_mixed_weight, + ) best_score = current_score sh = skill_hash(skill_init) - sel_cache[sh] = (current_score, baseline_soft) + sel_cache[sh] = (baseline_hard, baseline_soft) current_origin = "initial_skill" best_origin = "initial_skill" _persist_runtime_state(0) print( - f" [baseline result] selection hard={current_score:.4f} " - f"soft={baseline_soft:.4f}" + f" [baseline result] selection hard={baseline_hard:.4f} " + f"soft={baseline_soft:.4f} " + f"gate[{gate_metric}]={current_score:.4f}" ) # ── Training loop ──────────────────────────────────────────────── @@ -1287,7 +1311,15 @@ class ReflACTTrainer: best_score=best_score, best_step=best_step, global_step=global_step, + cand_soft=cand_soft, + metric=gate_metric, + mixed_weight=gate_mixed_weight, ) + cand_gate_score = select_gate_score( + cand_hard, cand_soft, gate_metric, gate_mixed_weight, + ) + step_rec["gate_metric"] = gate_metric + step_rec["candidate_gate_score"] = cand_gate_score step_rec["action"] = gate.action prev_current = current_score prev_best = best_score @@ -1301,20 +1333,29 @@ class ReflACTTrainer: if gate.action == "accept_new_best": best_origin = current_origin + if gate_metric == "hard": + score_label = f"hard={cand_hard:.4f}" + elif gate_metric == "soft": + score_label = f"soft={cand_soft:.4f}" + else: + score_label = ( + f"mixed[w={gate_mixed_weight}]={cand_gate_score:.4f} " + f"(hard={cand_hard:.4f} soft={cand_soft:.4f})" + ) if gate.action == "accept_new_best": print( f" [6/6 EVALUATE] ACCEPT (new best) " - f"hard={cand_hard:.4f} > prev best {prev_best:.4f}" + f"{score_label} > prev best {prev_best:.4f}" ) elif gate.action == "accept": print( f" [6/6 EVALUATE] ACCEPT " - f"hard={cand_hard:.4f} > current={prev_current:.4f}" + f"{score_label} > current={prev_current:.4f}" ) else: print( f" [6/6 EVALUATE] REJECT " - f"hard={cand_hard:.4f} <= current={current_score:.4f}" + f"{score_label} <= current={current_score:.4f}" ) step_rec["timing"]["evaluate_s"] = round(time.time() - t_phase, 1) @@ -1343,7 +1384,7 @@ class ReflACTTrainer: if isinstance(item, dict) ] buf_entry["score_before"] = current_score - buf_entry["score_after"] = cand_hard + buf_entry["score_after"] = cand_gate_score buf_entry["rejected_edits"] = rejected_edits step_buffer.append(buf_entry) diff --git a/skillopt/evaluation/__init__.py b/skillopt/evaluation/__init__.py index 87e0e1f..bb89670 100644 --- a/skillopt/evaluation/__init__.py +++ b/skillopt/evaluation/__init__.py @@ -4,4 +4,10 @@ Analogous to validation-based early stopping and model selection in neural network training: evaluates candidate skills on held-out selection sets and decides whether to accept or reject proposed updates. """ -from skillopt.evaluation.gate import evaluate_gate, GateAction, GateResult # noqa: F401 +from skillopt.evaluation.gate import ( # noqa: F401 + GateAction, + GateMetric, + GateResult, + evaluate_gate, + select_gate_score, +) diff --git a/skillopt/evaluation/gate.py b/skillopt/evaluation/gate.py index f4f2c40..18564b0 100644 --- a/skillopt/evaluation/gate.py +++ b/skillopt/evaluation/gate.py @@ -6,6 +6,20 @@ best scores, then returns an accept/reject decision. The trainer owns side-effects (cache lookup, rollout, printing, state mutation). This module is the pure decision function. + +Metric selection +---------------- +Three gate metrics are supported: + +* ``"hard"`` (default, backward-compatible): + Compare candidate vs current/best using *hard* exact-match accuracy. +* ``"soft"``: + Compare using *soft* per-item score (F1 / partial credit / etc.). + Use this when a small held-out selection set has too few items for + hard accuracy to be sensitive to incremental skill improvements. +* ``"mixed"``: + Compare using a weighted average ``(1 - w) * hard + w * soft``. + ``w`` is configurable via ``mixed_weight`` (default ``0.5``). """ from __future__ import annotations @@ -14,6 +28,7 @@ from typing import Literal GateAction = Literal["accept_new_best", "accept", "reject"] +GateMetric = Literal["hard", "soft", "mixed"] @dataclass(frozen=True) @@ -28,6 +43,36 @@ class GateResult: best_step: int +def select_gate_score( + hard: float, + soft: float, + metric: GateMetric = "hard", + mixed_weight: float = 0.5, +) -> float: + """Project (hard, soft) onto a single comparison metric. + + Parameters + ---------- + hard, soft + Aggregate hard / soft scores from a rollout batch (both 0..1). + metric + Which metric to compare on. + mixed_weight + For ``"mixed"``: weight given to ``soft``. Must be in ``[0, 1]``. + Ignored for ``"hard"`` / ``"soft"``. + """ + if metric == "hard": + return float(hard) + if metric == "soft": + return float(soft) + if metric == "mixed": + w = max(0.0, min(1.0, float(mixed_weight))) + return (1.0 - w) * float(hard) + w * float(soft) + raise ValueError( + f"unknown gate metric {metric!r}; expected 'hard', 'soft', or 'mixed'" + ) + + def evaluate_gate( candidate_skill: str, cand_hard: float, @@ -37,28 +82,58 @@ def evaluate_gate( best_score: float, best_step: int, global_step: int, + *, + cand_soft: float = 0.0, + metric: GateMetric = "hard", + mixed_weight: float = 0.5, ) -> GateResult: """Pure gate decision: compare candidate score to current/best. - Returns a *GateResult* with updated state; the caller decides what - to do with it (print, mutate trainer state, log, etc.). + Parameters + ---------- + candidate_skill + The candidate skill content being evaluated. + cand_hard, cand_soft + Aggregate hard / soft scores of the candidate on the selection set. + current_skill, current_score + The currently-active skill and its *metric-space* score. + best_skill, best_score, best_step + The best-so-far skill, its *metric-space* score, and the step + at which it was accepted. + global_step + Current global training step (recorded if a new best is accepted). + cand_soft + Soft score of the candidate; only consulted when ``metric != "hard"``. + Defaults to ``0.0`` for backward compatibility with callers that + previously passed only ``cand_hard``. + metric + Which metric to compare on. Defaults to ``"hard"`` to preserve + the original gate behavior. + mixed_weight + Weight on ``soft`` when ``metric == "mixed"``. + + Returns + ------- + GateResult + Updated state; the caller decides what to do with it (print, + mutate trainer state, log, etc.). """ - if cand_hard > current_score: - new_current_skill = candidate_skill - new_current_score = cand_hard - if cand_hard > best_score: + cand_score = select_gate_score(cand_hard, cand_soft, metric, mixed_weight) + + if cand_score > current_score: + if cand_score > best_score: return GateResult( action="accept_new_best", - current_skill=new_current_skill, - current_score=new_current_score, + current_skill=candidate_skill, + current_score=cand_score, best_skill=candidate_skill, - best_score=cand_hard, + best_score=cand_score, best_step=global_step, ) return GateResult( action="accept", - current_skill=new_current_skill, - current_score=new_current_score, + current_skill=candidate_skill, + current_score=cand_score, best_skill=best_skill, best_score=best_score, best_step=best_step,