feat(trainer): final-skill val + best promotion; keep best unpolluted by slow_update

- slow_update force-inject now writes current_skill ONLY (best_skill stays a
  faithful val-best snapshot, never receives un-validated slow_update content)
- after training, run one val on the final skill; if its gate score beats the
  incumbent best, promote final to best (updates best_skill/best_step/best_origin)
- trainer now evaluates final skill on test itself (reuses best test result when
  final==best); records final_selection_* and final_test_* in summary.json
- spreadsheetbench: head+tail truncate the post-execution verification report at
  source to fix multi-MB conversation bloat

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Cuzyoung
2026-06-02 05:55:31 +00:00
parent 372fd56c1e
commit ffe581098b
2 changed files with 208 additions and 28 deletions

View File

@@ -1543,13 +1543,13 @@ class ReflACTTrainer:
elif action in {
"accept", "accept_new_best", "force_accept",
}:
# Force-accept mode: re-apply to both current & best.
# Force-accept mode: re-apply guidance to
# current_skill only. best_skill must remain a
# faithful snapshot of the val-best step and must
# NOT receive force-injected slow-update content.
current_skill = replace_slow_update_field(
current_skill, slow_saved["slow_update_content"],
)
best_skill = replace_slow_update_field(
best_skill, slow_saved["slow_update_content"],
)
elif epoch == 1:
# Epoch 1: inject empty placeholder
os.makedirs(slow_dir, exist_ok=True)
@@ -1557,7 +1557,7 @@ class ReflACTTrainer:
current_origin = f"slow_update_placeholder_epoch_{epoch:02d}"
_save_skill(out_root, global_step, current_skill)
with open(os.path.join(out_root, "best_skill.md"), "w") as f:
f.write(best_skill if best_score > current_score else current_skill)
f.write(best_skill)
with open(slow_done_path, "w") as f:
json.dump({"action": "inject_placeholder", "epoch": epoch}, f, indent=2)
_persist_runtime_state(global_step)
@@ -1778,16 +1778,15 @@ class ReflACTTrainer:
else:
# ── Force-accept mode (default) ──────────────────
# The epoch-level longitudinal guidance is injected
# into both current_skill and best_skill
# unconditionally — it must not be gated by
# step-level selection scores.
# into current_skill ONLY, so training continues
# with the accumulated slow memory. best_skill is
# left untouched: it must remain a faithful snapshot
# of the val-best step (which may be a pre-slow step
# such as S_0 carrying no slow_update field at all).
slow_content = slow_result["slow_update_content"]
current_skill = replace_slow_update_field(
current_skill, slow_content,
)
best_skill = replace_slow_update_field(
best_skill, slow_content,
)
# Update caches so downstream steps use the
# slow-update-injected skill for hashing.
slow_candidate_hash = skill_hash(current_skill)
@@ -1798,7 +1797,7 @@ class ReflACTTrainer:
print(
f" [slow update] force-injected into "
f"current & best "
f"current only "
f"({len(slow_content)} chars), "
f"{slow_time}s"
)
@@ -1951,10 +1950,70 @@ class ReflACTTrainer:
baseline_test_soft = None
test_hard = None
test_soft = None
final_test_hard = None
final_test_soft = None
final_selection_hard = None
final_selection_soft = None
if cfg["eval_test"]:
task_types = adapter.get_task_types()
# ── Final skill validation (valid_seen) + best promotion ─────
# The final (last) skill may carry an epoch-end slow_update that
# was force-injected WITHOUT a val pass (use_gate=false or
# slow_update_gate_with_selection=false), so it never competed for
# best. Run one real val on the final skill; if its gate score
# beats the incumbent best, PROMOTE it to best so that best is the
# true val-argmax over all skills (including the final slow_update).
# When final == best, reuse the existing val score (no rollout).
try:
if skill_hash(current_skill) == skill_hash(best_skill):
final_selection_hard, final_selection_soft = best_score, None
print(
"\n [final skill == best skill] "
f"final_selection_hard={best_score:.4f} (reused)"
)
else:
fval_env, fval_n = _build_eval_env(
split="valid_seen",
env_num=cfg["sel_env_num"],
seed=seed,
)
fval_dir = os.path.join(out_root, "final_selection_eval")
fval_results = adapter.rollout(fval_env, current_skill, fval_dir)
final_selection_hard, final_selection_soft = compute_score(fval_results)
final_gate_score = select_gate_score(
final_selection_hard, final_selection_soft,
gate_metric, gate_mixed_weight,
)
print(
f"\n [final skill val] items={fval_n} "
f"final_selection_hard={final_selection_hard:.4f} "
f"gate={final_gate_score:.4f} "
f"(best={best_score:.4f})"
)
if final_gate_score > best_score:
# Promote: the final (slow-updated) skill is val-better
# than the incumbent best. Make it the new best so the
# subsequent BEST-skill test rollout evaluates it and
# best/final test scores coincide.
print(
f" [promote] final {final_gate_score:.4f} > "
f"best {best_score:.4f} → final becomes new best "
f"(step {global_step}, origin {current_origin})"
)
best_skill = current_skill
best_score = final_gate_score
best_step = global_step
best_origin = current_origin
with open(os.path.join(out_root, "best_skill.md"), "w") as f:
f.write(best_skill)
_persist_runtime_state(global_step)
except Exception as _e: # noqa: BLE001
final_selection_hard = None
final_selection_soft = None
print(f"\n [final skill val FAILED: {_e!r}]")
# Baseline: S_0 on test set (valid_unseen)
print(f"\n{'='*60}")
print(" BASELINE TEST — evaluate initial skill on Test set (valid_unseen)")
@@ -2023,13 +2082,87 @@ class ReflACTTrainer:
f, indent=2, ensure_ascii=False,
)
# Final skill (last skill in trajectory) on test set.
# Distinct from best_skill: with use_gate=False every candidate is
# force-accepted so the final skill is whatever the last step
# produced; with use_gate=True it is the last accepted skill, which
# may differ from the best-on-val skill. We always evaluate it so
# every run reports baseline / best-on-val / final on test.
# Guarded so a failure here never prevents summary.json from being
# written (the orchestrator's post-hoc safety net fills it in).
try:
if skill_hash(current_skill) == skill_hash(best_skill):
# Final == best: reuse results, skip a redundant rollout.
final_test_hard, final_test_soft = test_hard, test_soft
final_test_dir = os.path.join(out_root, "test_eval_final")
os.makedirs(final_test_dir, exist_ok=True)
with open(os.path.join(final_test_dir, "summary.json"), "w") as f:
json.dump(
{
k: {
"total": b["total"],
"hard_acc": b["hard"] / max(b["total"], 1),
}
for k, b in best_buckets.items()
},
f, indent=2, ensure_ascii=False,
)
print(
"\n [final skill == best skill] "
f"final_test_hard={final_test_hard:.4f} (reused)"
)
else:
print(f"\n{'='*60}")
print(" FINAL SKILL TEST — evaluate last skill on Test set (valid_unseen)")
print(f"{'='*60}")
test_env3, test_n3 = _build_eval_env(
split="valid_unseen",
env_num=cfg["test_env_num"],
seed=seed,
)
print(f" Test items: {test_n3}")
final_test_dir = os.path.join(out_root, "test_eval_final")
final_test_results = adapter.rollout(test_env3, current_skill, final_test_dir)
final_test_hard, final_test_soft = compute_score(final_test_results)
final_buckets = _compute_task_type_buckets(final_test_results, task_types)
print("\n === Final Skill Test Results ===")
for task_type in task_types + ["overall"]:
b = final_buckets.get(task_type, {"total": 0, "hard": 0})
t = max(b["total"], 1)
print(
f" {task_type:<40s}: "
f"hard={b['hard']}/{b['total']}={b['hard']/t:.4f}"
)
with open(os.path.join(final_test_dir, "summary.json"), "w") as f:
json.dump(
{
k: {
"total": b["total"],
"hard_acc": b["hard"] / max(b["total"], 1),
}
for k, b in final_buckets.items()
},
f, indent=2, ensure_ascii=False,
)
except Exception as _e: # noqa: BLE001
final_test_hard = None
final_test_soft = None
print(f"\n [final skill test FAILED: {_e!r}] "
"— will be filled by post-hoc eval")
# Comparison
delta_hard = (test_hard or 0) - (baseline_test_hard or 0)
print(f"\n === Improvement (best vs baseline) ===")
print(f"\n === Improvement vs baseline (init S_0) ===")
print(
f" hard: {baseline_test_hard:.4f} -> {test_hard:.4f} "
f" [2] best-on-val hard: {baseline_test_hard:.4f} -> {test_hard:.4f} "
f"(delta={delta_hard:+.4f})"
)
if final_test_hard is not None:
final_delta_hard = (final_test_hard or 0) - (baseline_test_hard or 0)
print(
f" [3] final/last hard: {baseline_test_hard:.4f} -> {final_test_hard:.4f} "
f"(delta={final_delta_hard:+.4f})"
)
# ── Global summary ───────────────────────────────────────────────
total_wall = time.time() - t_loop_start
@@ -2061,6 +2194,8 @@ class ReflACTTrainer:
skill_hash(skill_init), (None, None),
)[0],
"best_selection_hard": best_score,
"final_selection_hard": final_selection_hard,
"final_selection_soft": final_selection_soft,
"best_step": best_step,
"current_origin": current_origin,
"best_origin": best_origin,
@@ -2073,11 +2208,18 @@ class ReflACTTrainer:
"baseline_test_soft": baseline_test_soft,
"test_hard": test_hard,
"test_soft": test_soft,
"final_test_hard": final_test_hard,
"final_test_soft": final_test_soft,
"test_delta_hard": (
(test_hard or 0) - (baseline_test_hard or 0)
if test_hard is not None
else None
),
"final_test_delta_hard": (
(final_test_hard or 0) - (baseline_test_hard or 0)
if final_test_hard is not None
else None
),
"total_wall_time_s": round(total_wall, 1),
"token_summary": token_summary,
}
@@ -2098,8 +2240,22 @@ class ReflACTTrainer:
f" epoch {es['epoch']}: accept={es['accepts']} reject={es['rejects']} "
f"best={es['best_score_at_epoch_end']:.4f}"
)
if baseline_test_hard is not None:
print("\n === TEST scores (3 skills, split=valid_unseen) ===")
print(
f" [1] init/baseline (S_0) : "
f"test_hard={baseline_test_hard:.4f}"
)
if test_hard is not None:
print(f" test_hard={test_hard:.4f} test_soft={test_soft:.4f}")
print(
f" [2] best-on-val (step {best_step})".ljust(37)
+ f": test_hard={test_hard:.4f} test_soft={test_soft:.4f}"
)
if final_test_hard is not None:
print(
f" [3] final/last skill : "
f"test_hard={final_test_hard:.4f} test_soft={final_test_soft:.4f}"
)
if token_summary.get("_total"):
t = token_summary["_total"]
print(

View File

@@ -89,6 +89,21 @@ def _find_test_cases(task_dir: str) -> list[tuple[str, str, str]]:
# ── Auto-verify helper ──────────────────────────────────────────────────────
# The official SpreadsheetBench evaluator never serialises cells to text — it
# compares in memory and returns only a pass/fail bool. The per-cell report
# below is a repo-local training aid (fed back to the model on retry and saved
# into the trajectory for reflection). On most tasks the answer range is a
# handful of cells, so the full report is tiny. But a few tasks have answer
# ranges spanning tens of thousands of cells (e.g. 80-42 =
# 'Consolidate_ALL'!A2:L8000 ≈ 96k cells); dumping every cell explodes the
# report to several MB, floods the model's context and bloats conversation
# files. We therefore apply the same head+tail character truncation the rest of
# the codebase uses for oversized trajectory text (cf. reflect.py / slow_update.py
# `text[:half] + "...[truncated]...\n" + text[-half:]`): keep the first and last
# `_MAX_REPORT_CHARS // 2` chars so both the leading and trailing wrong cells
# stay visible. Small reports are unchanged.
_MAX_REPORT_CHARS = 12000 # head+tail char budget (~6000 head + 6000 tail)
def _auto_verify_output(
pred_path: str,
@@ -99,7 +114,8 @@ def _auto_verify_output(
Returns a human-readable verification report that can be appended to the
trajectory so the error analyst can see exactly what went wrong (e.g.
``cell A1: got=None, expected=420``).
``cell A1: got=None, expected=420``). Oversized reports are head+tail
truncated to `_MAX_REPORT_CHARS` chars, matching the rest of the codebase.
"""
if not os.path.exists(pred_path):
return "Verification: output file does not exist."
@@ -131,7 +147,7 @@ def _auto_verify_output(
lines.append(f" Sheet '{sheet_name}' NOT FOUND in output.")
continue
n_correct_skipped = 0
n_empty_correct = 0 # empty-on-both correct cells collapsed to a count
for cn in cell_names:
gv = ws_gold[cn].value if ws_gold else "N/A"
pv = ws_pred[cn].value
@@ -140,20 +156,18 @@ def _auto_verify_output(
# flag e.g. 5 vs 5.0 or None vs "" as mismatches and mislead the
# model into "fixing" cells that already pass scoring.
ok_cell = ws_gold is not None and _compare_cell_value(gv, pv)
match = "" if ok_cell else ""
# Skip cells that are correct AND empty on both sides: for large
# answer ranges (e.g. C2:C5000) the vast majority are empty
# (got=None, expected=None ✓) and would otherwise flood the
# report with hundreds of thousands of noise chars, burying the
# few real ✗ lines. We only emit wrong cells and non-empty
# correct cells; empty-correct cells are collapsed into a count.
# Collapse only cells that are correct AND empty on both sides
# (got=None, expected=None ✓): pure noise. Every other cell —
# including non-empty correct cells — is listed in full; the
# final head+tail char cap keeps the report bounded.
if ok_cell and gv in (None, "") and pv in (None, ""):
n_correct_skipped += 1
n_empty_correct += 1
continue
match = "" if ok_cell else ""
lines.append(f" {sheet_name}!{cn}: got={pv!r}, expected={gv!r} {match}")
if n_correct_skipped:
if n_empty_correct:
lines.append(
f" (+{n_correct_skipped} empty cells correct, omitted)"
f" (+{n_empty_correct} empty cells correct, omitted)"
)
# Also check if any cells in the output contain formula strings
@@ -180,7 +194,17 @@ def _auto_verify_output(
wb_pred.close()
wb_gold.close()
return "\n".join(lines)
report = "\n".join(lines)
# Head+tail truncation, matching reflect.py / slow_update.py: keep the first
# and last half so both leading and trailing wrong cells remain visible.
if len(report) > _MAX_REPORT_CHARS:
half = _MAX_REPORT_CHARS // 2
report = (
report[:half]
+ f"\n ...[verification report truncated, {len(report)} chars total]...\n"
+ report[-half:]
)
return report
# ── Per-task worker ──────────────────────────────────────────────────────────