mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
feat(trainer): final-skill val + best promotion; keep best unpolluted by slow_update
- slow_update force-inject now writes current_skill ONLY (best_skill stays a faithful val-best snapshot, never receives un-validated slow_update content) - after training, run one val on the final skill; if its gate score beats the incumbent best, promote final to best (updates best_skill/best_step/best_origin) - trainer now evaluates final skill on test itself (reuses best test result when final==best); records final_selection_* and final_test_* in summary.json - spreadsheetbench: head+tail truncate the post-execution verification report at source to fix multi-MB conversation bloat Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1543,13 +1543,13 @@ class ReflACTTrainer:
|
||||
elif action in {
|
||||
"accept", "accept_new_best", "force_accept",
|
||||
}:
|
||||
# Force-accept mode: re-apply to both current & best.
|
||||
# Force-accept mode: re-apply guidance to
|
||||
# current_skill only. best_skill must remain a
|
||||
# faithful snapshot of the val-best step and must
|
||||
# NOT receive force-injected slow-update content.
|
||||
current_skill = replace_slow_update_field(
|
||||
current_skill, slow_saved["slow_update_content"],
|
||||
)
|
||||
best_skill = replace_slow_update_field(
|
||||
best_skill, slow_saved["slow_update_content"],
|
||||
)
|
||||
elif epoch == 1:
|
||||
# Epoch 1: inject empty placeholder
|
||||
os.makedirs(slow_dir, exist_ok=True)
|
||||
@@ -1557,7 +1557,7 @@ class ReflACTTrainer:
|
||||
current_origin = f"slow_update_placeholder_epoch_{epoch:02d}"
|
||||
_save_skill(out_root, global_step, current_skill)
|
||||
with open(os.path.join(out_root, "best_skill.md"), "w") as f:
|
||||
f.write(best_skill if best_score > current_score else current_skill)
|
||||
f.write(best_skill)
|
||||
with open(slow_done_path, "w") as f:
|
||||
json.dump({"action": "inject_placeholder", "epoch": epoch}, f, indent=2)
|
||||
_persist_runtime_state(global_step)
|
||||
@@ -1778,16 +1778,15 @@ class ReflACTTrainer:
|
||||
else:
|
||||
# ── Force-accept mode (default) ──────────────────
|
||||
# The epoch-level longitudinal guidance is injected
|
||||
# into both current_skill and best_skill
|
||||
# unconditionally — it must not be gated by
|
||||
# step-level selection scores.
|
||||
# into current_skill ONLY, so training continues
|
||||
# with the accumulated slow memory. best_skill is
|
||||
# left untouched: it must remain a faithful snapshot
|
||||
# of the val-best step (which may be a pre-slow step
|
||||
# such as S_0 carrying no slow_update field at all).
|
||||
slow_content = slow_result["slow_update_content"]
|
||||
current_skill = replace_slow_update_field(
|
||||
current_skill, slow_content,
|
||||
)
|
||||
best_skill = replace_slow_update_field(
|
||||
best_skill, slow_content,
|
||||
)
|
||||
# Update caches so downstream steps use the
|
||||
# slow-update-injected skill for hashing.
|
||||
slow_candidate_hash = skill_hash(current_skill)
|
||||
@@ -1798,7 +1797,7 @@ class ReflACTTrainer:
|
||||
|
||||
print(
|
||||
f" [slow update] force-injected into "
|
||||
f"current & best "
|
||||
f"current only "
|
||||
f"({len(slow_content)} chars), "
|
||||
f"{slow_time}s"
|
||||
)
|
||||
@@ -1951,10 +1950,70 @@ class ReflACTTrainer:
|
||||
baseline_test_soft = None
|
||||
test_hard = None
|
||||
test_soft = None
|
||||
final_test_hard = None
|
||||
final_test_soft = None
|
||||
final_selection_hard = None
|
||||
final_selection_soft = None
|
||||
|
||||
if cfg["eval_test"]:
|
||||
task_types = adapter.get_task_types()
|
||||
|
||||
# ── Final skill validation (valid_seen) + best promotion ─────
|
||||
# The final (last) skill may carry an epoch-end slow_update that
|
||||
# was force-injected WITHOUT a val pass (use_gate=false or
|
||||
# slow_update_gate_with_selection=false), so it never competed for
|
||||
# best. Run one real val on the final skill; if its gate score
|
||||
# beats the incumbent best, PROMOTE it to best so that best is the
|
||||
# true val-argmax over all skills (including the final slow_update).
|
||||
# When final == best, reuse the existing val score (no rollout).
|
||||
try:
|
||||
if skill_hash(current_skill) == skill_hash(best_skill):
|
||||
final_selection_hard, final_selection_soft = best_score, None
|
||||
print(
|
||||
"\n [final skill == best skill] "
|
||||
f"final_selection_hard={best_score:.4f} (reused)"
|
||||
)
|
||||
else:
|
||||
fval_env, fval_n = _build_eval_env(
|
||||
split="valid_seen",
|
||||
env_num=cfg["sel_env_num"],
|
||||
seed=seed,
|
||||
)
|
||||
fval_dir = os.path.join(out_root, "final_selection_eval")
|
||||
fval_results = adapter.rollout(fval_env, current_skill, fval_dir)
|
||||
final_selection_hard, final_selection_soft = compute_score(fval_results)
|
||||
final_gate_score = select_gate_score(
|
||||
final_selection_hard, final_selection_soft,
|
||||
gate_metric, gate_mixed_weight,
|
||||
)
|
||||
print(
|
||||
f"\n [final skill val] items={fval_n} "
|
||||
f"final_selection_hard={final_selection_hard:.4f} "
|
||||
f"gate={final_gate_score:.4f} "
|
||||
f"(best={best_score:.4f})"
|
||||
)
|
||||
if final_gate_score > best_score:
|
||||
# Promote: the final (slow-updated) skill is val-better
|
||||
# than the incumbent best. Make it the new best so the
|
||||
# subsequent BEST-skill test rollout evaluates it and
|
||||
# best/final test scores coincide.
|
||||
print(
|
||||
f" [promote] final {final_gate_score:.4f} > "
|
||||
f"best {best_score:.4f} → final becomes new best "
|
||||
f"(step {global_step}, origin {current_origin})"
|
||||
)
|
||||
best_skill = current_skill
|
||||
best_score = final_gate_score
|
||||
best_step = global_step
|
||||
best_origin = current_origin
|
||||
with open(os.path.join(out_root, "best_skill.md"), "w") as f:
|
||||
f.write(best_skill)
|
||||
_persist_runtime_state(global_step)
|
||||
except Exception as _e: # noqa: BLE001
|
||||
final_selection_hard = None
|
||||
final_selection_soft = None
|
||||
print(f"\n [final skill val FAILED: {_e!r}]")
|
||||
|
||||
# Baseline: S_0 on test set (valid_unseen)
|
||||
print(f"\n{'='*60}")
|
||||
print(" BASELINE TEST — evaluate initial skill on Test set (valid_unseen)")
|
||||
@@ -2023,13 +2082,87 @@ class ReflACTTrainer:
|
||||
f, indent=2, ensure_ascii=False,
|
||||
)
|
||||
|
||||
# Final skill (last skill in trajectory) on test set.
|
||||
# Distinct from best_skill: with use_gate=False every candidate is
|
||||
# force-accepted so the final skill is whatever the last step
|
||||
# produced; with use_gate=True it is the last accepted skill, which
|
||||
# may differ from the best-on-val skill. We always evaluate it so
|
||||
# every run reports baseline / best-on-val / final on test.
|
||||
# Guarded so a failure here never prevents summary.json from being
|
||||
# written (the orchestrator's post-hoc safety net fills it in).
|
||||
try:
|
||||
if skill_hash(current_skill) == skill_hash(best_skill):
|
||||
# Final == best: reuse results, skip a redundant rollout.
|
||||
final_test_hard, final_test_soft = test_hard, test_soft
|
||||
final_test_dir = os.path.join(out_root, "test_eval_final")
|
||||
os.makedirs(final_test_dir, exist_ok=True)
|
||||
with open(os.path.join(final_test_dir, "summary.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
k: {
|
||||
"total": b["total"],
|
||||
"hard_acc": b["hard"] / max(b["total"], 1),
|
||||
}
|
||||
for k, b in best_buckets.items()
|
||||
},
|
||||
f, indent=2, ensure_ascii=False,
|
||||
)
|
||||
print(
|
||||
"\n [final skill == best skill] "
|
||||
f"final_test_hard={final_test_hard:.4f} (reused)"
|
||||
)
|
||||
else:
|
||||
print(f"\n{'='*60}")
|
||||
print(" FINAL SKILL TEST — evaluate last skill on Test set (valid_unseen)")
|
||||
print(f"{'='*60}")
|
||||
test_env3, test_n3 = _build_eval_env(
|
||||
split="valid_unseen",
|
||||
env_num=cfg["test_env_num"],
|
||||
seed=seed,
|
||||
)
|
||||
print(f" Test items: {test_n3}")
|
||||
final_test_dir = os.path.join(out_root, "test_eval_final")
|
||||
final_test_results = adapter.rollout(test_env3, current_skill, final_test_dir)
|
||||
final_test_hard, final_test_soft = compute_score(final_test_results)
|
||||
final_buckets = _compute_task_type_buckets(final_test_results, task_types)
|
||||
print("\n === Final Skill Test Results ===")
|
||||
for task_type in task_types + ["overall"]:
|
||||
b = final_buckets.get(task_type, {"total": 0, "hard": 0})
|
||||
t = max(b["total"], 1)
|
||||
print(
|
||||
f" {task_type:<40s}: "
|
||||
f"hard={b['hard']}/{b['total']}={b['hard']/t:.4f}"
|
||||
)
|
||||
with open(os.path.join(final_test_dir, "summary.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
k: {
|
||||
"total": b["total"],
|
||||
"hard_acc": b["hard"] / max(b["total"], 1),
|
||||
}
|
||||
for k, b in final_buckets.items()
|
||||
},
|
||||
f, indent=2, ensure_ascii=False,
|
||||
)
|
||||
except Exception as _e: # noqa: BLE001
|
||||
final_test_hard = None
|
||||
final_test_soft = None
|
||||
print(f"\n [final skill test FAILED: {_e!r}] "
|
||||
"— will be filled by post-hoc eval")
|
||||
|
||||
# Comparison
|
||||
delta_hard = (test_hard or 0) - (baseline_test_hard or 0)
|
||||
print(f"\n === Improvement (best vs baseline) ===")
|
||||
print(f"\n === Improvement vs baseline (init S_0) ===")
|
||||
print(
|
||||
f" hard: {baseline_test_hard:.4f} -> {test_hard:.4f} "
|
||||
f" [2] best-on-val hard: {baseline_test_hard:.4f} -> {test_hard:.4f} "
|
||||
f"(delta={delta_hard:+.4f})"
|
||||
)
|
||||
if final_test_hard is not None:
|
||||
final_delta_hard = (final_test_hard or 0) - (baseline_test_hard or 0)
|
||||
print(
|
||||
f" [3] final/last hard: {baseline_test_hard:.4f} -> {final_test_hard:.4f} "
|
||||
f"(delta={final_delta_hard:+.4f})"
|
||||
)
|
||||
|
||||
# ── Global summary ───────────────────────────────────────────────
|
||||
total_wall = time.time() - t_loop_start
|
||||
@@ -2061,6 +2194,8 @@ class ReflACTTrainer:
|
||||
skill_hash(skill_init), (None, None),
|
||||
)[0],
|
||||
"best_selection_hard": best_score,
|
||||
"final_selection_hard": final_selection_hard,
|
||||
"final_selection_soft": final_selection_soft,
|
||||
"best_step": best_step,
|
||||
"current_origin": current_origin,
|
||||
"best_origin": best_origin,
|
||||
@@ -2073,11 +2208,18 @@ class ReflACTTrainer:
|
||||
"baseline_test_soft": baseline_test_soft,
|
||||
"test_hard": test_hard,
|
||||
"test_soft": test_soft,
|
||||
"final_test_hard": final_test_hard,
|
||||
"final_test_soft": final_test_soft,
|
||||
"test_delta_hard": (
|
||||
(test_hard or 0) - (baseline_test_hard or 0)
|
||||
if test_hard is not None
|
||||
else None
|
||||
),
|
||||
"final_test_delta_hard": (
|
||||
(final_test_hard or 0) - (baseline_test_hard or 0)
|
||||
if final_test_hard is not None
|
||||
else None
|
||||
),
|
||||
"total_wall_time_s": round(total_wall, 1),
|
||||
"token_summary": token_summary,
|
||||
}
|
||||
@@ -2098,8 +2240,22 @@ class ReflACTTrainer:
|
||||
f" epoch {es['epoch']}: accept={es['accepts']} reject={es['rejects']} "
|
||||
f"best={es['best_score_at_epoch_end']:.4f}"
|
||||
)
|
||||
if baseline_test_hard is not None:
|
||||
print("\n === TEST scores (3 skills, split=valid_unseen) ===")
|
||||
print(
|
||||
f" [1] init/baseline (S_0) : "
|
||||
f"test_hard={baseline_test_hard:.4f}"
|
||||
)
|
||||
if test_hard is not None:
|
||||
print(f" test_hard={test_hard:.4f} test_soft={test_soft:.4f}")
|
||||
print(
|
||||
f" [2] best-on-val (step {best_step})".ljust(37)
|
||||
+ f": test_hard={test_hard:.4f} test_soft={test_soft:.4f}"
|
||||
)
|
||||
if final_test_hard is not None:
|
||||
print(
|
||||
f" [3] final/last skill : "
|
||||
f"test_hard={final_test_hard:.4f} test_soft={final_test_soft:.4f}"
|
||||
)
|
||||
if token_summary.get("_total"):
|
||||
t = token_summary["_total"]
|
||||
print(
|
||||
|
||||
@@ -89,6 +89,21 @@ def _find_test_cases(task_dir: str) -> list[tuple[str, str, str]]:
|
||||
|
||||
# ── Auto-verify helper ──────────────────────────────────────────────────────
|
||||
|
||||
# The official SpreadsheetBench evaluator never serialises cells to text — it
|
||||
# compares in memory and returns only a pass/fail bool. The per-cell report
|
||||
# below is a repo-local training aid (fed back to the model on retry and saved
|
||||
# into the trajectory for reflection). On most tasks the answer range is a
|
||||
# handful of cells, so the full report is tiny. But a few tasks have answer
|
||||
# ranges spanning tens of thousands of cells (e.g. 80-42 =
|
||||
# 'Consolidate_ALL'!A2:L8000 ≈ 96k cells); dumping every cell explodes the
|
||||
# report to several MB, floods the model's context and bloats conversation
|
||||
# files. We therefore apply the same head+tail character truncation the rest of
|
||||
# the codebase uses for oversized trajectory text (cf. reflect.py / slow_update.py
|
||||
# `text[:half] + "...[truncated]...\n" + text[-half:]`): keep the first and last
|
||||
# `_MAX_REPORT_CHARS // 2` chars so both the leading and trailing wrong cells
|
||||
# stay visible. Small reports are unchanged.
|
||||
_MAX_REPORT_CHARS = 12000 # head+tail char budget (~6000 head + 6000 tail)
|
||||
|
||||
|
||||
def _auto_verify_output(
|
||||
pred_path: str,
|
||||
@@ -99,7 +114,8 @@ def _auto_verify_output(
|
||||
|
||||
Returns a human-readable verification report that can be appended to the
|
||||
trajectory so the error analyst can see exactly what went wrong (e.g.
|
||||
``cell A1: got=None, expected=420``).
|
||||
``cell A1: got=None, expected=420``). Oversized reports are head+tail
|
||||
truncated to `_MAX_REPORT_CHARS` chars, matching the rest of the codebase.
|
||||
"""
|
||||
if not os.path.exists(pred_path):
|
||||
return "Verification: output file does not exist."
|
||||
@@ -131,7 +147,7 @@ def _auto_verify_output(
|
||||
lines.append(f" Sheet '{sheet_name}' NOT FOUND in output.")
|
||||
continue
|
||||
|
||||
n_correct_skipped = 0
|
||||
n_empty_correct = 0 # empty-on-both correct cells collapsed to a count
|
||||
for cn in cell_names:
|
||||
gv = ws_gold[cn].value if ws_gold else "N/A"
|
||||
pv = ws_pred[cn].value
|
||||
@@ -140,20 +156,18 @@ def _auto_verify_output(
|
||||
# flag e.g. 5 vs 5.0 or None vs "" as mismatches and mislead the
|
||||
# model into "fixing" cells that already pass scoring.
|
||||
ok_cell = ws_gold is not None and _compare_cell_value(gv, pv)
|
||||
match = "✓" if ok_cell else "✗"
|
||||
# Skip cells that are correct AND empty on both sides: for large
|
||||
# answer ranges (e.g. C2:C5000) the vast majority are empty
|
||||
# (got=None, expected=None ✓) and would otherwise flood the
|
||||
# report with hundreds of thousands of noise chars, burying the
|
||||
# few real ✗ lines. We only emit wrong cells and non-empty
|
||||
# correct cells; empty-correct cells are collapsed into a count.
|
||||
# Collapse only cells that are correct AND empty on both sides
|
||||
# (got=None, expected=None ✓): pure noise. Every other cell —
|
||||
# including non-empty correct cells — is listed in full; the
|
||||
# final head+tail char cap keeps the report bounded.
|
||||
if ok_cell and gv in (None, "") and pv in (None, ""):
|
||||
n_correct_skipped += 1
|
||||
n_empty_correct += 1
|
||||
continue
|
||||
match = "✓" if ok_cell else "✗"
|
||||
lines.append(f" {sheet_name}!{cn}: got={pv!r}, expected={gv!r} {match}")
|
||||
if n_correct_skipped:
|
||||
if n_empty_correct:
|
||||
lines.append(
|
||||
f" (+{n_correct_skipped} empty cells correct, omitted)"
|
||||
f" (+{n_empty_correct} empty cells correct, omitted)"
|
||||
)
|
||||
|
||||
# Also check if any cells in the output contain formula strings
|
||||
@@ -180,7 +194,17 @@ def _auto_verify_output(
|
||||
wb_pred.close()
|
||||
wb_gold.close()
|
||||
|
||||
return "\n".join(lines)
|
||||
report = "\n".join(lines)
|
||||
# Head+tail truncation, matching reflect.py / slow_update.py: keep the first
|
||||
# and last half so both leading and trailing wrong cells remain visible.
|
||||
if len(report) > _MAX_REPORT_CHARS:
|
||||
half = _MAX_REPORT_CHARS // 2
|
||||
report = (
|
||||
report[:half]
|
||||
+ f"\n ...[verification report truncated, {len(report)} chars total]...\n"
|
||||
+ report[-half:]
|
||||
)
|
||||
return report
|
||||
|
||||
|
||||
# ── Per-task worker ──────────────────────────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user