From 9fda7311ea2dd45aff68efab0a0b6eaf6906ade7 Mon Sep 17 00:00:00 2001 From: Cuzyoung Date: Fri, 8 May 2026 18:16:18 +0000 Subject: [PATCH] Initial commit --- .gitignore | 15 + README.md | 557 +++++ configs/_base_/default.yaml | 93 + configs/ablation_study/README.md | 305 +++ configs/ablation_study/launch_commands.sh | 81 + configs/ablation_study/matrix.yaml | 257 ++ configs/ablation_study/validation.md | 141 ++ configs/alfworld/default.yaml | 30 + configs/alfworld/meta_reflect.yaml | 4 + configs/babyvision/default.yaml | 21 + configs/docvqa/default.yaml | 28 + configs/livemathematicianbench/default.yaml | 22 + configs/mathverse/default.yaml | 23 + configs/mmrb/default.yaml | 18 + configs/officeqa/default.yaml | 25 + configs/sealqa/default.yaml | 23 + configs/searchqa/default.yaml | 32 + configs/spreadsheetbench/default.yaml | 34 + configs/swebench/default.yaml | 36 + scripts/codex_azure_mi.sh | 20 + scripts/download_babyvision.py | 53 + .../eval_livemathematicianbench_baseline.py | 333 +++ scripts/eval_only.py | 451 ++++ scripts/eval_prompt_custom.py | 361 +++ scripts/eval_prompt_official.py | 352 +++ scripts/eval_searchqa_val500.sh | 37 + scripts/eval_verified400.sh | 41 + scripts/eval_verified400_multi.sh | 42 + scripts/eval_verified400_single.sh | 42 + ...launch_harness_bestsetting_from_scratch.sh | 120 + scripts/launch_harness_canonical_claude18.sh | 178 ++ .../launch_harness_canonical_claude4_smoke.sh | 130 + scripts/launch_harness_canonical_wave1.sh | 168 ++ scripts/launch_harness_initial_claude4.sh | 128 + ...aunch_harness_initial_spreadsheet_clean.sh | 103 + scripts/launch_lrctrl_fullrewrite_neutral3.sh | 116 + ...ullrewrite_neutral3_spreadsheet_repeats.sh | 111 + ...ctrl_fullrewrite_neutral3_sq_lm_repeats.sh | 115 + .../launch_spreadsheet_full_replacements.sh | 175 ++ scripts/monitor_harness_claude18.sh | 61 + scripts/prepare_ablation_splits.py | 130 + scripts/run_ablation_matrix.py | 680 +++++ scripts/run_alfworld.sh | 68 + scripts/run_meta_skill_ablation.sh | 94 + scripts/run_missing_meta_parallel.sh | 43 + scripts/run_searchqa.sh | 43 + scripts/run_spreadsheetbench.sh | 48 + scripts/train.py | 483 ++++ scripts/train_searchqa.sh | 43 + scripts/train_spreadsheet_multi.sh | 48 + scripts/train_spreadsheet_single.sh | 48 + scripts/watch_ablation.py | 218 ++ skillopt/__init__.py | 29 + skillopt/config.py | 263 ++ skillopt/datasets/__init__.py | 7 + skillopt/datasets/base.py | 512 ++++ skillopt/engine/__init__.py | 9 + skillopt/engine/trainer.py | 2195 +++++++++++++++++ skillopt/envs/__init__.py | 1 + skillopt/envs/alfworld/__init__.py | 5 + skillopt/envs/alfworld/adapter.py | 585 +++++ skillopt/envs/alfworld/dataloader.py | 123 + .../envs/alfworld/prompts/analyst_error.md | 55 + .../envs/alfworld/prompts/analyst_success.md | 33 + skillopt/envs/alfworld/prompts/deep_probe.md | 35 + .../alfworld/prompts/rollout_no_history.md | 8 + .../alfworld/prompts/rollout_with_history.md | 9 + .../alfworld/prompts/rollout_with_memory.md | 16 + skillopt/envs/alfworld/reflect.py | 4 + skillopt/envs/alfworld/rollout.py | 359 +++ skillopt/envs/alfworld/skills/initial.md | 45 + skillopt/envs/alfworld/vendor/__init__.py | 9 + .../envs/alfworld/vendor/alfworld_envs.py | 221 ++ .../alfworld/vendor/alfworld_projection.py | 60 + .../envs/alfworld/vendor/alfworld_prompts.py | 8 + skillopt/envs/alfworld/vendor/config_tw.yaml | 145 ++ skillopt/envs/alfworld/vendor/env_base.py | 84 + skillopt/envs/alfworld/vendor/env_manager.py | 139 ++ skillopt/envs/alfworld/vendor/memory.py | 87 + skillopt/envs/babyvision/__init__.py | 1 + skillopt/envs/babyvision/adapter.py | 267 ++ skillopt/envs/babyvision/dataloader.py | 214 ++ skillopt/envs/babyvision/evaluator.py | 160 ++ .../envs/babyvision/prompts/analyst_error.md | 36 + .../babyvision/prompts/analyst_success.md | 25 + .../envs/babyvision/prompts/deep_probe.md | 25 + skillopt/envs/babyvision/prompts/judge.md | 35 + .../envs/babyvision/prompts/rollout_system.md | 13 + skillopt/envs/babyvision/reflect.py | 4 + skillopt/envs/babyvision/rollout.py | 467 ++++ skillopt/envs/babyvision/skills/initial.md | 18 + skillopt/envs/base.py | 396 +++ skillopt/envs/deep_reflect.py | 114 + skillopt/envs/docvqa/__init__.py | 1 + skillopt/envs/docvqa/adapter.py | 153 ++ skillopt/envs/docvqa/dataloader.py | 61 + skillopt/envs/docvqa/evaluator.py | 113 + skillopt/envs/docvqa/prompts/analyst_error.md | 35 + .../envs/docvqa/prompts/analyst_success.md | 24 + .../envs/docvqa/prompts/rollout_system.md | 12 + skillopt/envs/docvqa/rollout.py | 365 +++ skillopt/envs/docvqa/skills/initial.md | 11 + .../envs/livemathematicianbench/__init__.py | 1 + .../envs/livemathematicianbench/adapter.py | 284 +++ .../envs/livemathematicianbench/dataloader.py | 308 +++ .../envs/livemathematicianbench/evaluator.py | 62 + .../prompts/analyst_error.md | 37 + .../prompts/analyst_success.md | 25 + .../prompts/deep_probe.md | 23 + .../prompts/deep_probe_codex.md | 26 + .../prompts/rollout_system.md | 12 + .../envs/livemathematicianbench/reflect.py | 4 + .../envs/livemathematicianbench/rollout.py | 401 +++ .../livemathematicianbench/skills/initial.md | 16 + skillopt/envs/mathverse/__init__.py | 5 + skillopt/envs/mathverse/adapter.py | 280 +++ skillopt/envs/mathverse/dataloader.py | 228 ++ skillopt/envs/mathverse/evaluator.py | 180 ++ .../envs/mathverse/prompts/analyst_error.md | 37 + .../envs/mathverse/prompts/analyst_success.md | 26 + skillopt/envs/mathverse/prompts/deep_probe.md | 25 + skillopt/envs/mathverse/prompts/judge.md | 25 + .../envs/mathverse/prompts/rollout_system.md | 11 + skillopt/envs/mathverse/reflect.py | 4 + skillopt/envs/mathverse/rollout.py | 415 ++++ skillopt/envs/mathverse/skills/initial.md | 15 + skillopt/envs/mmrb/__init__.py | 2 + skillopt/envs/mmrb/adapter.py | 283 +++ skillopt/envs/mmrb/dataloader.py | 146 ++ skillopt/envs/mmrb/evaluator.py | 102 + skillopt/envs/mmrb/prompts/rollout_system.md | 10 + skillopt/envs/mmrb/rollout.py | 439 ++++ skillopt/envs/mmrb/skills/initial.md | 17 + skillopt/envs/officeqa/__init__.py | 1 + skillopt/envs/officeqa/adapter.py | 133 + skillopt/envs/officeqa/dataloader.py | 71 + skillopt/envs/officeqa/evaluator.py | 46 + .../envs/officeqa/prompts/analyst_error.md | 37 + .../envs/officeqa/prompts/analyst_success.md | 25 + .../envs/officeqa/prompts/rollout_system.md | 15 + skillopt/envs/officeqa/rollout.py | 363 +++ skillopt/envs/officeqa/skills/initial.md | 15 + skillopt/envs/officeqa/tool_runtime.py | 134 + skillopt/envs/sealqa/__init__.py | 1 + skillopt/envs/sealqa/adapter.py | 130 + skillopt/envs/sealqa/dataloader.py | 37 + skillopt/envs/sealqa/evaluator.py | 126 + skillopt/envs/sealqa/prompts/analyst_error.md | 30 + .../envs/sealqa/prompts/analyst_success.md | 19 + .../envs/sealqa/prompts/rollout_system.md | 3 + skillopt/envs/sealqa/rollout.py | 268 ++ skillopt/envs/sealqa/skills/initial.md | 11 + skillopt/envs/sealqa/tool_runtime.py | 30 + skillopt/envs/searchqa/__init__.py | 1 + skillopt/envs/searchqa/adapter.py | 250 ++ skillopt/envs/searchqa/dataloader.py | 42 + skillopt/envs/searchqa/evaluator.py | 100 + .../envs/searchqa/prompts/analyst_error.md | 46 + .../envs/searchqa/prompts/analyst_success.md | 32 + skillopt/envs/searchqa/prompts/deep_probe.md | 27 + .../envs/searchqa/prompts/rollout_system.md | 13 + skillopt/envs/searchqa/reflect.py | 4 + skillopt/envs/searchqa/rollout.py | 455 ++++ skillopt/envs/searchqa/skills/initial.md | 3 + skillopt/envs/spreadsheetbench/__init__.py | 5 + skillopt/envs/spreadsheetbench/adapter.py | 309 +++ .../envs/spreadsheetbench/codegen_agent.py | 704 ++++++ skillopt/envs/spreadsheetbench/dataloader.py | 37 + skillopt/envs/spreadsheetbench/evaluator.py | 158 ++ skillopt/envs/spreadsheetbench/executor.py | 67 + .../spreadsheetbench/prompts/analyst_error.md | 46 + .../prompts/analyst_success.md | 32 + .../prompts/codegen_system.md | 1 + .../prompts/critical_rules.md | 9 + .../spreadsheetbench/prompts/deep_probe.md | 35 + .../spreadsheetbench/prompts/react_system.md | 21 + skillopt/envs/spreadsheetbench/react_agent.py | 395 +++ skillopt/envs/spreadsheetbench/reflect.py | 4 + skillopt/envs/spreadsheetbench/rollout.py | 921 +++++++ .../envs/spreadsheetbench/skills/initial.md | 56 + .../spreadsheetbench/skills/xlsx_initial.md | 56 + .../spreadsheetbench/skills/xlsx_skill0.md | 4 + skillopt/envs/swebench/__init__.py | 1 + skillopt/envs/swebench/adapter.py | 137 + skillopt/envs/swebench/dataloader.py | 151 ++ skillopt/envs/swebench/rollout.py | 346 +++ skillopt/envs/swebench/skills/initial.md | 23 + skillopt/evaluation/__init__.py | 7 + skillopt/evaluation/gate.py | 73 + skillopt/gradient/__init__.py | 17 + skillopt/gradient/aggregate.py | 253 ++ skillopt/gradient/deep_probe.py | 77 + skillopt/gradient/reflect.py | 588 +++++ skillopt/model/__init__.py | 343 +++ skillopt/model/azure_openai.py | 871 +++++++ skillopt/model/backend_config.py | 185 ++ skillopt/model/claude_backend.py | 359 +++ skillopt/model/codex_backend.py | 664 +++++ skillopt/model/codex_harness.py | 1057 ++++++++ skillopt/model/common.py | 222 ++ skillopt/model/router.py | 236 ++ skillopt/optimizer/__init__.py | 15 + skillopt/optimizer/clip.py | 109 + skillopt/optimizer/lr_autonomous.py | 108 + skillopt/optimizer/meta_reflect.py | 198 ++ skillopt/optimizer/meta_skill.py | 87 + skillopt/optimizer/rewrite.py | 59 + skillopt/optimizer/scheduler.py | 127 + skillopt/optimizer/select.py | 4 + skillopt/optimizer/skill.py | 154 ++ skillopt/optimizer/slow_update.py | 374 +++ skillopt/optimizer/update_modes.py | 136 + skillopt/prompts/__init__.py | 63 + skillopt/prompts/analyst_error.md | 41 + .../prompts/analyst_error_full_rewrite.md | 32 + skillopt/prompts/analyst_error_rewrite.md | 44 + skillopt/prompts/analyst_success.md | 36 + .../prompts/analyst_success_full_rewrite.md | 30 + skillopt/prompts/analyst_success_rewrite.md | 33 + skillopt/prompts/deep_probe.md | 34 + skillopt/prompts/deep_probe_codex.md | 35 + skillopt/prompts/lr_autonomous.md | 20 + skillopt/prompts/merge_failure.md | 30 + .../prompts/merge_failure_full_rewrite.md | 28 + skillopt/prompts/merge_failure_rewrite.md | 26 + skillopt/prompts/merge_final.md | 33 + skillopt/prompts/merge_final_full_rewrite.md | 28 + skillopt/prompts/merge_final_rewrite.md | 25 + skillopt/prompts/merge_success.md | 28 + .../prompts/merge_success_full_rewrite.md | 28 + skillopt/prompts/merge_success_rewrite.md | 25 + skillopt/prompts/meta_reflect.md | 63 + skillopt/prompts/meta_reflect_rewrite.md | 28 + skillopt/prompts/meta_skill.md | 40 + skillopt/prompts/ranking.md | 20 + skillopt/prompts/ranking_rewrite.md | 15 + skillopt/prompts/rewrite_skill.md | 25 + skillopt/prompts/slow_update.md | 60 + skillopt/scheduler/__init__.py | 8 + skillopt/types.py | 357 +++ skillopt/utils/__init__.py | 4 + skillopt/utils/json_utils.py | 42 + skillopt/utils/scoring.py | 29 + 243 files changed, 31492 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 configs/_base_/default.yaml create mode 100644 configs/ablation_study/README.md create mode 100755 configs/ablation_study/launch_commands.sh create mode 100644 configs/ablation_study/matrix.yaml create mode 100644 configs/ablation_study/validation.md create mode 100644 configs/alfworld/default.yaml create mode 100644 configs/alfworld/meta_reflect.yaml create mode 100644 configs/babyvision/default.yaml create mode 100644 configs/docvqa/default.yaml create mode 100644 configs/livemathematicianbench/default.yaml create mode 100644 configs/mathverse/default.yaml create mode 100644 configs/mmrb/default.yaml create mode 100644 configs/officeqa/default.yaml create mode 100644 configs/sealqa/default.yaml create mode 100644 configs/searchqa/default.yaml create mode 100644 configs/spreadsheetbench/default.yaml create mode 100644 configs/swebench/default.yaml create mode 100755 scripts/codex_azure_mi.sh create mode 100644 scripts/download_babyvision.py create mode 100644 scripts/eval_livemathematicianbench_baseline.py create mode 100644 scripts/eval_only.py create mode 100644 scripts/eval_prompt_custom.py create mode 100644 scripts/eval_prompt_official.py create mode 100755 scripts/eval_searchqa_val500.sh create mode 100755 scripts/eval_verified400.sh create mode 100755 scripts/eval_verified400_multi.sh create mode 100755 scripts/eval_verified400_single.sh create mode 100755 scripts/launch_harness_bestsetting_from_scratch.sh create mode 100755 scripts/launch_harness_canonical_claude18.sh create mode 100755 scripts/launch_harness_canonical_claude4_smoke.sh create mode 100755 scripts/launch_harness_canonical_wave1.sh create mode 100755 scripts/launch_harness_initial_claude4.sh create mode 100755 scripts/launch_harness_initial_spreadsheet_clean.sh create mode 100755 scripts/launch_lrctrl_fullrewrite_neutral3.sh create mode 100755 scripts/launch_lrctrl_fullrewrite_neutral3_spreadsheet_repeats.sh create mode 100755 scripts/launch_lrctrl_fullrewrite_neutral3_sq_lm_repeats.sh create mode 100755 scripts/launch_spreadsheet_full_replacements.sh create mode 100755 scripts/monitor_harness_claude18.sh create mode 100644 scripts/prepare_ablation_splits.py create mode 100755 scripts/run_ablation_matrix.py create mode 100755 scripts/run_alfworld.sh create mode 100755 scripts/run_meta_skill_ablation.sh create mode 100644 scripts/run_missing_meta_parallel.sh create mode 100755 scripts/run_searchqa.sh create mode 100755 scripts/run_spreadsheetbench.sh create mode 100644 scripts/train.py create mode 100755 scripts/train_searchqa.sh create mode 100755 scripts/train_spreadsheet_multi.sh create mode 100755 scripts/train_spreadsheet_single.sh create mode 100644 scripts/watch_ablation.py create mode 100644 skillopt/__init__.py create mode 100644 skillopt/config.py create mode 100644 skillopt/datasets/__init__.py create mode 100644 skillopt/datasets/base.py create mode 100644 skillopt/engine/__init__.py create mode 100644 skillopt/engine/trainer.py create mode 100644 skillopt/envs/__init__.py create mode 100644 skillopt/envs/alfworld/__init__.py create mode 100644 skillopt/envs/alfworld/adapter.py create mode 100644 skillopt/envs/alfworld/dataloader.py create mode 100644 skillopt/envs/alfworld/prompts/analyst_error.md create mode 100644 skillopt/envs/alfworld/prompts/analyst_success.md create mode 100644 skillopt/envs/alfworld/prompts/deep_probe.md create mode 100644 skillopt/envs/alfworld/prompts/rollout_no_history.md create mode 100644 skillopt/envs/alfworld/prompts/rollout_with_history.md create mode 100644 skillopt/envs/alfworld/prompts/rollout_with_memory.md create mode 100644 skillopt/envs/alfworld/reflect.py create mode 100644 skillopt/envs/alfworld/rollout.py create mode 100644 skillopt/envs/alfworld/skills/initial.md create mode 100644 skillopt/envs/alfworld/vendor/__init__.py create mode 100644 skillopt/envs/alfworld/vendor/alfworld_envs.py create mode 100644 skillopt/envs/alfworld/vendor/alfworld_projection.py create mode 100644 skillopt/envs/alfworld/vendor/alfworld_prompts.py create mode 100644 skillopt/envs/alfworld/vendor/config_tw.yaml create mode 100644 skillopt/envs/alfworld/vendor/env_base.py create mode 100644 skillopt/envs/alfworld/vendor/env_manager.py create mode 100644 skillopt/envs/alfworld/vendor/memory.py create mode 100644 skillopt/envs/babyvision/__init__.py create mode 100644 skillopt/envs/babyvision/adapter.py create mode 100644 skillopt/envs/babyvision/dataloader.py create mode 100644 skillopt/envs/babyvision/evaluator.py create mode 100644 skillopt/envs/babyvision/prompts/analyst_error.md create mode 100644 skillopt/envs/babyvision/prompts/analyst_success.md create mode 100644 skillopt/envs/babyvision/prompts/deep_probe.md create mode 100644 skillopt/envs/babyvision/prompts/judge.md create mode 100644 skillopt/envs/babyvision/prompts/rollout_system.md create mode 100644 skillopt/envs/babyvision/reflect.py create mode 100644 skillopt/envs/babyvision/rollout.py create mode 100644 skillopt/envs/babyvision/skills/initial.md create mode 100644 skillopt/envs/base.py create mode 100644 skillopt/envs/deep_reflect.py create mode 100644 skillopt/envs/docvqa/__init__.py create mode 100644 skillopt/envs/docvqa/adapter.py create mode 100644 skillopt/envs/docvqa/dataloader.py create mode 100644 skillopt/envs/docvqa/evaluator.py create mode 100644 skillopt/envs/docvqa/prompts/analyst_error.md create mode 100644 skillopt/envs/docvqa/prompts/analyst_success.md create mode 100644 skillopt/envs/docvqa/prompts/rollout_system.md create mode 100644 skillopt/envs/docvqa/rollout.py create mode 100644 skillopt/envs/docvqa/skills/initial.md create mode 100644 skillopt/envs/livemathematicianbench/__init__.py create mode 100644 skillopt/envs/livemathematicianbench/adapter.py create mode 100644 skillopt/envs/livemathematicianbench/dataloader.py create mode 100644 skillopt/envs/livemathematicianbench/evaluator.py create mode 100644 skillopt/envs/livemathematicianbench/prompts/analyst_error.md create mode 100644 skillopt/envs/livemathematicianbench/prompts/analyst_success.md create mode 100644 skillopt/envs/livemathematicianbench/prompts/deep_probe.md create mode 100644 skillopt/envs/livemathematicianbench/prompts/deep_probe_codex.md create mode 100644 skillopt/envs/livemathematicianbench/prompts/rollout_system.md create mode 100644 skillopt/envs/livemathematicianbench/reflect.py create mode 100644 skillopt/envs/livemathematicianbench/rollout.py create mode 100644 skillopt/envs/livemathematicianbench/skills/initial.md create mode 100644 skillopt/envs/mathverse/__init__.py create mode 100644 skillopt/envs/mathverse/adapter.py create mode 100644 skillopt/envs/mathverse/dataloader.py create mode 100644 skillopt/envs/mathverse/evaluator.py create mode 100644 skillopt/envs/mathverse/prompts/analyst_error.md create mode 100644 skillopt/envs/mathverse/prompts/analyst_success.md create mode 100644 skillopt/envs/mathverse/prompts/deep_probe.md create mode 100644 skillopt/envs/mathverse/prompts/judge.md create mode 100644 skillopt/envs/mathverse/prompts/rollout_system.md create mode 100644 skillopt/envs/mathverse/reflect.py create mode 100644 skillopt/envs/mathverse/rollout.py create mode 100644 skillopt/envs/mathverse/skills/initial.md create mode 100644 skillopt/envs/mmrb/__init__.py create mode 100644 skillopt/envs/mmrb/adapter.py create mode 100644 skillopt/envs/mmrb/dataloader.py create mode 100644 skillopt/envs/mmrb/evaluator.py create mode 100644 skillopt/envs/mmrb/prompts/rollout_system.md create mode 100644 skillopt/envs/mmrb/rollout.py create mode 100644 skillopt/envs/mmrb/skills/initial.md create mode 100644 skillopt/envs/officeqa/__init__.py create mode 100644 skillopt/envs/officeqa/adapter.py create mode 100644 skillopt/envs/officeqa/dataloader.py create mode 100644 skillopt/envs/officeqa/evaluator.py create mode 100644 skillopt/envs/officeqa/prompts/analyst_error.md create mode 100644 skillopt/envs/officeqa/prompts/analyst_success.md create mode 100644 skillopt/envs/officeqa/prompts/rollout_system.md create mode 100644 skillopt/envs/officeqa/rollout.py create mode 100644 skillopt/envs/officeqa/skills/initial.md create mode 100644 skillopt/envs/officeqa/tool_runtime.py create mode 100644 skillopt/envs/sealqa/__init__.py create mode 100644 skillopt/envs/sealqa/adapter.py create mode 100644 skillopt/envs/sealqa/dataloader.py create mode 100644 skillopt/envs/sealqa/evaluator.py create mode 100644 skillopt/envs/sealqa/prompts/analyst_error.md create mode 100644 skillopt/envs/sealqa/prompts/analyst_success.md create mode 100644 skillopt/envs/sealqa/prompts/rollout_system.md create mode 100644 skillopt/envs/sealqa/rollout.py create mode 100644 skillopt/envs/sealqa/skills/initial.md create mode 100644 skillopt/envs/sealqa/tool_runtime.py create mode 100644 skillopt/envs/searchqa/__init__.py create mode 100644 skillopt/envs/searchqa/adapter.py create mode 100644 skillopt/envs/searchqa/dataloader.py create mode 100644 skillopt/envs/searchqa/evaluator.py create mode 100644 skillopt/envs/searchqa/prompts/analyst_error.md create mode 100644 skillopt/envs/searchqa/prompts/analyst_success.md create mode 100644 skillopt/envs/searchqa/prompts/deep_probe.md create mode 100644 skillopt/envs/searchqa/prompts/rollout_system.md create mode 100644 skillopt/envs/searchqa/reflect.py create mode 100644 skillopt/envs/searchqa/rollout.py create mode 100644 skillopt/envs/searchqa/skills/initial.md create mode 100644 skillopt/envs/spreadsheetbench/__init__.py create mode 100644 skillopt/envs/spreadsheetbench/adapter.py create mode 100644 skillopt/envs/spreadsheetbench/codegen_agent.py create mode 100644 skillopt/envs/spreadsheetbench/dataloader.py create mode 100644 skillopt/envs/spreadsheetbench/evaluator.py create mode 100644 skillopt/envs/spreadsheetbench/executor.py create mode 100644 skillopt/envs/spreadsheetbench/prompts/analyst_error.md create mode 100644 skillopt/envs/spreadsheetbench/prompts/analyst_success.md create mode 100644 skillopt/envs/spreadsheetbench/prompts/codegen_system.md create mode 100644 skillopt/envs/spreadsheetbench/prompts/critical_rules.md create mode 100644 skillopt/envs/spreadsheetbench/prompts/deep_probe.md create mode 100644 skillopt/envs/spreadsheetbench/prompts/react_system.md create mode 100644 skillopt/envs/spreadsheetbench/react_agent.py create mode 100644 skillopt/envs/spreadsheetbench/reflect.py create mode 100644 skillopt/envs/spreadsheetbench/rollout.py create mode 100644 skillopt/envs/spreadsheetbench/skills/initial.md create mode 100644 skillopt/envs/spreadsheetbench/skills/xlsx_initial.md create mode 100644 skillopt/envs/spreadsheetbench/skills/xlsx_skill0.md create mode 100644 skillopt/envs/swebench/__init__.py create mode 100644 skillopt/envs/swebench/adapter.py create mode 100644 skillopt/envs/swebench/dataloader.py create mode 100644 skillopt/envs/swebench/rollout.py create mode 100644 skillopt/envs/swebench/skills/initial.md create mode 100644 skillopt/evaluation/__init__.py create mode 100644 skillopt/evaluation/gate.py create mode 100644 skillopt/gradient/__init__.py create mode 100644 skillopt/gradient/aggregate.py create mode 100644 skillopt/gradient/deep_probe.py create mode 100644 skillopt/gradient/reflect.py create mode 100644 skillopt/model/__init__.py create mode 100644 skillopt/model/azure_openai.py create mode 100644 skillopt/model/backend_config.py create mode 100644 skillopt/model/claude_backend.py create mode 100644 skillopt/model/codex_backend.py create mode 100644 skillopt/model/codex_harness.py create mode 100644 skillopt/model/common.py create mode 100644 skillopt/model/router.py create mode 100644 skillopt/optimizer/__init__.py create mode 100644 skillopt/optimizer/clip.py create mode 100644 skillopt/optimizer/lr_autonomous.py create mode 100644 skillopt/optimizer/meta_reflect.py create mode 100644 skillopt/optimizer/meta_skill.py create mode 100644 skillopt/optimizer/rewrite.py create mode 100644 skillopt/optimizer/scheduler.py create mode 100644 skillopt/optimizer/select.py create mode 100644 skillopt/optimizer/skill.py create mode 100644 skillopt/optimizer/slow_update.py create mode 100644 skillopt/optimizer/update_modes.py create mode 100644 skillopt/prompts/__init__.py create mode 100644 skillopt/prompts/analyst_error.md create mode 100644 skillopt/prompts/analyst_error_full_rewrite.md create mode 100644 skillopt/prompts/analyst_error_rewrite.md create mode 100644 skillopt/prompts/analyst_success.md create mode 100644 skillopt/prompts/analyst_success_full_rewrite.md create mode 100644 skillopt/prompts/analyst_success_rewrite.md create mode 100644 skillopt/prompts/deep_probe.md create mode 100644 skillopt/prompts/deep_probe_codex.md create mode 100644 skillopt/prompts/lr_autonomous.md create mode 100644 skillopt/prompts/merge_failure.md create mode 100644 skillopt/prompts/merge_failure_full_rewrite.md create mode 100644 skillopt/prompts/merge_failure_rewrite.md create mode 100644 skillopt/prompts/merge_final.md create mode 100644 skillopt/prompts/merge_final_full_rewrite.md create mode 100644 skillopt/prompts/merge_final_rewrite.md create mode 100644 skillopt/prompts/merge_success.md create mode 100644 skillopt/prompts/merge_success_full_rewrite.md create mode 100644 skillopt/prompts/merge_success_rewrite.md create mode 100644 skillopt/prompts/meta_reflect.md create mode 100644 skillopt/prompts/meta_reflect_rewrite.md create mode 100644 skillopt/prompts/meta_skill.md create mode 100644 skillopt/prompts/ranking.md create mode 100644 skillopt/prompts/ranking_rewrite.md create mode 100644 skillopt/prompts/rewrite_skill.md create mode 100644 skillopt/prompts/slow_update.md create mode 100644 skillopt/scheduler/__init__.py create mode 100644 skillopt/types.py create mode 100644 skillopt/utils/__init__.py create mode 100644 skillopt/utils/json_utils.py create mode 100644 skillopt/utils/scoring.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ac777f8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +__pycache__/ +*.pyc +data/ +outputs/ +logs/ +external/ +/BabyVision/ +/MMRB/ +/SpreadsheetBench/ +/dl4ir-searchQA/ +configs/local/ +configs/**/*.local.yaml +*.local.md +.secrets/ +.codex_azure*/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..c7e23ed --- /dev/null +++ b/README.md @@ -0,0 +1,557 @@ +# ReflACT: Reflective Agent Tuning + +ReflACT is a framework for optimizing an external skill document through iterative rollout, reflection, editing, and gated validation. + +It does **not** fine-tune model weights. Instead, it treats the skill document as the optimization target: + +- the **student** model executes tasks with the current skill +- the **teacher** model analyzes trajectories and proposes edits +- the framework merges, ranks, applies, and validates those edits +- only validated skill updates are kept + +This branch implements a full training loop with step-level skill optimization and optional epoch-level memory mechanisms (`slow_update`, `meta_skill`, `meta_reflect`). + +## Method Overview + +### Optimization Target + +Each run maintains a mutable markdown skill document. The framework repeatedly improves that document instead of changing model parameters. + +This gives a training-style loop for prompt / policy optimization: + +1. Roll out the current skill on a batch of tasks. +2. Reflect on failures and successes. +3. Merge patch proposals into a coherent candidate update. +4. Rank and select a bounded number of edits. +5. Apply those edits to produce a candidate skill. +6. Validate the candidate skill on a held-out selection split. +7. Keep the update only if the gate accepts it. + +### Per-Step Pipeline + +Every training step executes the following pipeline in `skillopt/engine/trainer.py`: + +1. **Rollout** + The student model runs a batch of tasks using the current skill. + +2. **Reflect** + The teacher analyzes minibatches of trajectories and emits raw patches. + Failure-driven and success-driven patches are tracked separately. + +3. **Aggregate** + Raw patches are merged hierarchically. Metadata such as `support_count` and `source_type` is carried into the merged patch so later ranking can use it. + +4. **Select** + The teacher ranks the merged edit pool and keeps up to `edit_budget` edits. + +5. **Update** + The selected edits are applied to the skill document. The framework records an `edit_apply_report.json` so you can see which edits actually landed, which were skipped, and why. + +6. **Evaluate / Gate** + The candidate skill is evaluated on the selection split. Gate validation is mandatory in this branch. A candidate update is accepted only if it improves over the current selection score; a new global best is tracked separately. + +### Within-Epoch Memory + +Inside an epoch, the trainer maintains a step buffer containing: + +- compact failure-pattern summaries from previous steps +- rejected edits and their score deltas + +That context is fed back into later reflection calls so the teacher can avoid repeating ineffective edits and can focus on unsolved error patterns. + +### Epoch-Level Mechanisms + +This branch supports three optional epoch-level mechanisms. + +#### Slow Update + +At the end of each epoch, `slow_update` compares the previous epoch’s terminal skill and current epoch’s terminal skill on a sampled train subset. It then writes longitudinal guidance into a protected slow-update region inside the skill document. + +Importantly, this guidance is **not** blindly written through. It is converted into a candidate skill and sent through the same selection gate as step-level updates. + +#### Meta Skill + +`meta_skill` is teacher-side cross-epoch memory. It does not directly edit the current skill. Instead, it writes a compact memory artifact describing longer-term patterns across adjacent epochs. That memory is loaded into later reflection / merge / ranking calls as extra context. + +#### Meta Reflect + +`meta_reflect` runs at epoch end over the step history of the current epoch. It looks at accepted and rejected directions from the whole epoch, proposes higher-level patch edits, applies them to a meta candidate, and then sends that candidate through the same selection gate. + +## What This Branch Guarantees + +The current implementation assumes the following as the mainline method contract: + +- gate validation is always on +- the current skill, current score, best skill, and best score stay aligned +- `slow_update` is gated before being committed +- patch provenance (`source_type`, `support_count`) reaches selection +- patch application is observable through per-edit reports +- resume state is restored from `runtime_state.json` rather than inferred only from history +- all benchmark model calls go through the unified backend router + +## Model Backends + +All model access now goes through the split teacher/student model layer in `skillopt.model`. + +Supported teacher backends: + +- `openai_chat` +- `claude_chat` + +Supported student backends: + +- `openai_chat` +- `claude_chat` +- `codex_exec` +- `claude_code_exec` + +Recommended config shape: + +```yaml +model: + teacher_backend: openai_chat + student_backend: codex_exec + teacher: gpt-5.4 + student: gpt-5.4-codex + reasoning_effort: medium +``` + +Legacy `model.backend` and CLI flags like `--backend codex` still work. They are mapped onto the split backend model for backward compatibility. + +The same routing is used by: + +- training (`scripts/train.py`) +- eval-only runs (`scripts/eval_only.py`) +- SpreadsheetBench standalone prompt eval scripts +- LiveMathematicianBench baseline eval script +- benchmark rollout code inside the main framework + +### Azure OpenAI + +If you use `openai_chat`, configure either environment variables or config values: + +```bash +export AZURE_OPENAI_ENDPOINT="https://your-endpoint.openai.azure.com/" +export AZURE_OPENAI_API_KEY="your-api-key" +export AZURE_OPENAI_API_VERSION="2025-04-01-preview" +``` + +The config supports both the old keys and the new explicit names: + +```yaml +model: + azure_openai_endpoint: "..." + azure_openai_api_version: "..." + azure_openai_api_key: "" + azure_openai_auth_mode: api_key + azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default" + azure_openai_managed_identity_client_id: "" +``` + +`azure_openai_auth_mode` can be used for API-key auth or Azure AD / managed identity flows. + +### Exec Harness + +`codex_exec` and `claude_code_exec` run the student inside a workspace harness instead of a plain chat call. The harness writes task files, renders a dynamic `SKILL.md`, runs the student CLI, and saves raw execution artifacts such as: + +- `codex_raw.txt` +- `codex_trace_summary.txt` +- workspace-local task / skill files + +This branch keeps `meta_skill` and `apply_patch_with_report`, while upgrading the student path to the more realistic workspace-exec setup. + +### Trace-Aware Deep Reflect + +When `student_backend=codex_exec` and `gradient.use_deep_reflect=true`, deep reflection can probe a specific earlier Codex attempt: + +- the teacher sees a compact Codex trace summary +- deep probe can target `probe_target_id` +- the follow-up rollout can resume from `probe_after_step` + +This is wired for the dataset-backed environments in this branch. + +### Rewrite Mode + +Skill updates support two modes: + +- `optimizer.skill_update_mode=patch` +- `optimizer.skill_update_mode=rewrite_from_suggestions` + +`patch` keeps the existing fine-grained edit application path and still records `edit_apply_report.json`. + +`rewrite_from_suggestions` asks the teacher to emit higher-level rewrite suggestions, then rewrites the whole skill in one pass. This is useful when patch edits become too fragmented. + +## Repository Layout + +```text +skillopt/ + engine/ + trainer.py main training loop + gradient/ + reflect.py minibatch reflection + aggregate.py hierarchical patch merge + deep_probe.py diagnostic probing for deep reflect + optimizer/ + clip.py edit ranking / selection + skill.py patch application + apply report + slow_update.py epoch-level longitudinal guidance + meta_skill.py teacher-side cross-epoch memory + meta_reflect.py epoch-level macro editing + evaluation/ + gate.py pure gate decision logic + model/ + backend_config.py teacher/student backend routing + azure_openai.py Azure backend + codex_harness.py workspace exec harness + Codex trace parsing + claude_backend.py Claude backend + envs/ + ... environment adapters and rollout logic +scripts/ + train.py unified training entry + eval_only.py evaluate one skill without training +configs/ + _base_/default.yaml shared defaults + /default.yaml environment-specific configs +``` + +## Configuration + +Configs use structured YAML with `_base_` inheritance. + +The base config is `configs/_base_/default.yaml`. Key defaults in this branch are: + +- `model.teacher_backend = openai_chat` +- `model.student_backend = openai_chat` +- `model.reasoning_effort = medium` +- `optimizer.use_slow_update = true` +- `optimizer.use_meta_skill = true` +- `optimizer.use_meta_reflect = false` +- `gradient.use_deep_reflect = false` +- `optimizer.skill_update_mode = patch` + +Default setting snapshot: + +```yaml +model: + backend: azure_openai + teacher: gpt-5.4 + student: gpt-5.4 + teacher_backend: openai_chat + student_backend: openai_chat + reasoning_effort: medium + rewrite_reasoning_effort: "" + rewrite_max_completion_tokens: 64000 + codex_exec_path: codex + codex_exec_sandbox: workspace-write + codex_exec_profile: "" + codex_exec_full_auto: false + codex_exec_reasoning_effort: none + claude_code_exec_path: claude + claude_code_exec_profile: "" + codex_trace_to_teacher: true + +train: + num_epochs: 4 + train_size: 0 + batch_size: 80 + accumulation: 1 + seed: 42 + +gradient: + minibatch_size: 16 + merge_batch_size: 16 + analyst_workers: 16 + max_analyst_rounds: 3 + failure_only: false + use_deep_reflect: false + deep_reflect_failures: 4 + deep_reflect_successes: 2 + +optimizer: + learning_rate: 8 + min_learning_rate: 2 + lr_scheduler: cosine + skill_update_mode: patch + use_meta_reflect: false + meta_learning_rate: 8 + use_slow_update: true + slow_update_samples: 20 + use_meta_skill: true + +evaluation: + use_gate: true + sel_env_num: 0 + test_env_num: 0 + eval_test: true + +env: + split_mode: ratio + split_ratio: "2:1:7" + split_seed: 42 +``` + +For the full source of truth, see [configs/_base_/default.yaml](/home/azureuser/workspace-yqh/skillopt_final/configs/_base_/default.yaml). + +Selected fields: + +| Section | Key | Meaning | +|---|---|---| +| `model` | `teacher_backend` | teacher backend: `openai_chat` or `claude_chat` | +| `model` | `student_backend` | student backend: chat backend or exec backend | +| `model` | `teacher` | teacher model / deployment | +| `model` | `student` | student model / deployment | +| `model` | `reasoning_effort` | reasoning budget passed to the backend when supported | +| `model` | `codex_trace_to_teacher` | include Codex trace summaries in teacher reflection context | +| `train` | `num_epochs` | number of epochs | +| `train` | `train_size` | expected train split size, or `0` to infer | +| `train` | `batch_size` | tasks per rollout batch | +| `train` | `accumulation` | number of rollout/reflect minibatches per step | +| `gradient` | `minibatch_size` | trajectories per analyst minibatch | +| `gradient` | `merge_batch_size` | patches per aggregate batch | +| `gradient` | `use_deep_reflect` | enable diagnostic probe rollouts | +| `gradient` | `max_analyst_rounds` | teacher reflection retries / refinement budget | +| `optimizer` | `learning_rate` | max edits kept after selection | +| `optimizer` | `lr_scheduler` | edit-budget scheduler | +| `optimizer` | `use_slow_update` | epoch-level longitudinal guidance | +| `optimizer` | `use_meta_skill` | teacher-side epoch memory | +| `optimizer` | `use_meta_reflect` | epoch-level macro editing | +| `optimizer` | `skill_update_mode` | `patch` or `rewrite_from_suggestions` | +| `evaluation` | `sel_env_num` | selection set size (`0` means full split) | +| `evaluation` | `test_env_num` | test set size (`0` means full split) | + +### Important Branch Rule + +`use_gate=false` is intentionally not supported in this branch. Gate validation is part of the method contract here. + +If an old config still contains `evaluation.use_gate: false`, the loader / trainer will raise instead of silently continuing. + +## Supported Environments + +The main training entry and eval-only entry now register 11 environments: + +| Env | Default rollout shape | Current default split / data setting | Branch alignment | +|---|---|---|---| +| `alfworld` | environment-backed episodic rollout | native ALFWorld train/eval splits | in `skillopt_new_zzw` | +| `babyvision` | single-round multimodal QA | `split_mode=ratio` from raw metadata/images, or prepared `split_dir` | in `skillopt_new_zzw` | +| `docvqa` | single-round multimodal QA | `split_dir: data/docvqa_split` | in `skillopt_new_zzw` | +| `livemathematicianbench` | single-round QA | `split_mode=ratio` or prepared `split_dir` | in `skillopt_new_zzw` | +| `mathverse` | single-round multimodal math QA | `data_root: data/MathVerse`, split files loaded from `split_dir` when provided | in `skillopt_new_zzw` | +| `mmrb` | single-round multimodal reasoning QA | `split_mode=ratio` or prepared `split_dir` | in `skillopt_new_zzw` | +| `officeqa` | multi-turn tool loop | `split_dir: data/officeqa_split` plus `data_dirs: [data/officeqa_docs_official]` | in `skillopt_new_zzw` | +| `sealqa` | multi-turn tool loop | `split_dir: data/sealqa_split` | in `skillopt_new_zzw` | +| `searchqa` | single-round QA (`max_turns=1`) | `split_dir: data/searchqa_split` | in `skillopt_new_zzw` | +| `spreadsheetbench` | codegen loop, default `mode=multi`, `max_turns=30` | `split_dir: data/spreadsheetbench_split`, `data_root: data/spreadsheetbench_verified_400` | in `skillopt_new_zzw`, default adjusted here to multi-round | +| `swebench` | mini-swe-agent multi-step bug-fixing rollout | `split_mode=ratio`, `dataset_name=lite`, repo-stratified `2:1:7` split materialized under `out_root/_generated_splits/...` unless `split_dir` is provided | added here, aligned to `swe-bench-old` | + +## Data Expectations + +The standard two-mode dataset entry path is: + +- `split_mode: ratio` + - load raw data from `env.data_path` + - build a deterministic `train/`, `val/`, `test/` split under `env.split_output_dir` (or under `out_root/_generated_splits/` if unset) + - default ratio is explicitly `2:1:7` +- `split_mode: split_dir` + - load an existing `env.split_dir` with `train/`, `val/`, `test/` subdirectories + +This currently applies to: + +- `searchqa` +- `spreadsheetbench` +- `babyvision` +- `livemathematicianbench` +- `mmrb` +- `swebench` + +`ALFWorld` is the exception: it is environment-backed rather than JSON split-backed. + +The following environments currently expect prepared split directories or extra rooted assets rather than the generic ratio-split path: + +- `docvqa` +- `mathverse` +- `officeqa` +- `sealqa` + +At a high level: + +- `SearchQA`: raw QA json / jsonl or pre-split QA json files +- `SpreadsheetBench`: raw task manifest json plus spreadsheet task directory, or a pre-split task manifest +- `ALFWorld`: installed game environment and configured eval/train splits +- `BabyVision`: raw `meta_data.jsonl` plus images, or a pre-split directory +- `DocVQA`: pre-split CSV / JSON data under `split_dir` +- `LiveMathematicianBench`: raw monthly QA json files, or a pre-split directory +- `MathVerse`: split files plus `data_root` image assets +- `MMRB`: raw extracted dataset json files, or a pre-split directory +- `OfficeQA`: pre-split metadata plus resolved office document directories +- `SealQA`: pre-split metadata for tool-augmented QA tasks +- `SWEBench`: HuggingFace SWE-bench dataset alias (`lite` / `verified` / `full`) or a prepared split directory + +### Split References Across Branches + +The split-related defaults are not identical across `skillopt-final`, `skillopt_new_zzw`, `gepa`, and `swe-bench-old`. The practical reference points are: + +| Source branch | Explicit split settings / dirs | +|---|---| +| `skillopt-final` | `searchqa -> data/searchqa_split`; `spreadsheetbench -> data/spreadsheetbench_split`; `docvqa -> data/docvqa_split`; `officeqa -> data/officeqa_split`; `sealqa -> data/sealqa_split`; `swebench -> ratio split 2:1:7 over the default lite dataset, materialized under out_root/_generated_splits/...` | +| `skillopt_new_zzw` | Same 10-benchmark env set as above except no `swebench`; explicit split dirs are `data/searchqa_split`, `data/spreadsheetbench_split`, `data/docvqa_split`, `data/officeqa_split`, `data/sealqa_split`; `spreadsheetbench` there defaults to `mode=single`; `officeqa` uses `max_tool_turns=24`; `sealqa` uses `max_tool_turns=12` | +| `gepa` | `configs/spreadsheetbench.yaml` uses `data.splits_dir = data/spreadsheetbench/splits`, `eval.mode = react`, `eval.max_turns = 20`; `configs/swebench.yaml` uses `dataset = SWE-bench/SWE-bench_Verified` with `train_size = 100`, `val_size = 50`, `test_size = 350` | +| `swe-bench-old` | Repo-stratified `2:1:7` split over `SWE-Bench_Lite`, persisted as `outputs/.../split/train.json`, `selection.json`, `test.json`; the example split in that branch is `train=60`, `selection=33`, `test=207` | + +For the 10 benches shared with `skillopt_new_zzw`, the current branch is now aligned on env coverage. The main intentional delta is `spreadsheetbench`: this branch defaults to multi-round codegen, while `skillopt_new_zzw` kept `mode=single` by default. + +## Running Training + +Example: + +```bash +python scripts/train.py --config configs/searchqa/default.yaml +``` + +Explicit 2:1:7 split from raw data: + +```bash +python scripts/train.py \ + --config configs/searchqa/default.yaml \ + --split_mode ratio \ + --data_path /path/to/searchqa_train_2000.json +``` + +Directly consume a prepared split directory: + +```bash +python scripts/train.py \ + --config configs/searchqa/default.yaml \ + --split_mode split_dir \ + --split_dir /path/to/searchqa_split +``` + +You can override structured config keys from the CLI: + +```bash +python scripts/train.py \ + --config configs/spreadsheetbench/default.yaml \ + --cfg-options model.teacher_backend=openai_chat model.student_backend=codex_exec train.batch_size=40 optimizer.learning_rate=4 +``` + +Legacy flat overrides still work for common keys: + +```bash +python scripts/train.py \ + --config configs/searchqa/default.yaml \ + --backend azure_openai \ + --teacher_model gpt-5.4 \ + --student_model gpt-5.4 \ + --reasoning_effort medium +``` + +Exec harness example: + +```bash +python scripts/train.py \ + --config configs/searchqa/default.yaml \ + --teacher_backend openai_chat \ + --student_backend codex_exec \ + --teacher_model gpt-5.4 \ + --student_model gpt-5.4-codex \ + --use_deep_reflect true \ + --skill_update_mode rewrite_from_suggestions +``` + +SWEBench example: + +```bash +python scripts/train.py \ + --config configs/swebench/default.yaml \ + --cfg-options env.dataset_name=lite env.split_ratio=2:1:7 +``` + +## Eval-Only and Standalone Evaluation + +Evaluate a specific skill without training: + +```bash +python scripts/eval_only.py \ + --config configs/searchqa/default.yaml \ + --skill skillopt/envs/searchqa/skills/initial.md +``` + +The same dataset entry modes apply in eval-only runs: + +- `--split_mode ratio --data_path ...` +- `--split_mode split_dir --split_dir ...` + +Standalone scripts also exist for benchmark-specific comparisons, including: + +- `scripts/eval_prompt_custom.py` +- `scripts/eval_prompt_official.py` +- `scripts/eval_livemathematicianbench_baseline.py` + +These scripts now also support backend selection through the unified model layer. + +## Output Structure + +Each run writes a structured output directory under `out_root`. + +Important top-level artifacts: + +- `config.json` — flattened runtime config +- `history.json` — per-step history records +- `runtime_state.json` — resume state for current/best skill tracking +- `best_skill.md` — current best validated skill +- `skills/skill_vXXXX.md` — persisted skill snapshot per step + +Per-step artifacts live under `steps/step_XXXX/`, including: + +- `merged_patch.json` +- `ranked_edits.json` +- `candidate_skill.md` +- `edit_apply_report.json` +- `rewrite_result.json` when rewrite mode is enabled +- `selection_eval/` +- `trajectory_digest.json` +- rollout and patch subdirectories + +Epoch-level artifacts live under: + +- `slow_update/epoch_XX/` +- `meta_skill/epoch_XX/` +- `meta_reflect/epoch_XX/` + +## Resume Behavior + +The trainer resumes from `runtime_state.json` when present. That state tracks: + +- last completed step +- current skill path +- current score +- best skill path +- best score +- origin tags for current and best skill + +This is important because skill state can change at both step level and epoch level; resuming only from `history.json` is not sufficient for this branch’s method logic. + +## Notes + +- This repository focuses on skill optimization logic; datasets are not included. +- Patch application is intentionally observable. Inspect `edit_apply_report.json` when candidate skills do not behave as expected. +- `SpreadsheetBench` now defaults to `mode=multi`. If you run an exec student backend there, override back to `env.mode=single` because exec backends are still only wired for SpreadsheetBench single-mode rollout. +- `SWEBench` follows the older mini-swe-agent + `swebench.harness.run_evaluation` path, so it requires the SWE-bench / Docker toolchain rather than the generic chat-only stack. +- `slow_update` writes into a protected skill region and normal edits are prevented from overwriting that region directly. +- `meta_skill` is context memory, not a direct skill edit. +- `meta_reflect` is a gated skill edit stage, not just logging. + +## Minimal Setup + +```bash +conda create -n skillopt python=3.11 +conda activate skillopt +pip install openai pyyaml openpyxl +``` + +Depending on the environment, you may also need: + +```bash +pip install datasets gymnasium numpy ray regex +``` + +For `SWEBench`, you also need a working Docker environment plus the SWE-bench / mini-swe-agent dependencies used in `swe-bench-old`. diff --git a/configs/_base_/default.yaml b/configs/_base_/default.yaml new file mode 100644 index 0000000..3a20c0d --- /dev/null +++ b/configs/_base_/default.yaml @@ -0,0 +1,93 @@ +# ReflACT default configuration — base for all environments. +# Environment configs should inherit via: _base_: default.yaml + +model: + backend: azure_openai + teacher: gpt-5.5 + student: gpt-5.5 + teacher_backend: openai_chat + student_backend: openai_chat + reasoning_effort: medium + rewrite_reasoning_effort: "" + rewrite_max_completion_tokens: 64000 + codex_exec_path: codex + codex_exec_sandbox: workspace-write + codex_exec_profile: "" + codex_exec_full_auto: false + codex_exec_reasoning_effort: none + codex_exec_use_sdk: auto + codex_exec_network_access: false + codex_exec_web_search: false + codex_exec_approval_policy: never + claude_code_exec_path: claude + claude_code_exec_profile: "" + claude_code_exec_use_sdk: auto + claude_code_exec_effort: medium + claude_code_exec_max_thinking_tokens: 16384 + codex_trace_to_teacher: true + azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/" + azure_openai_api_version: "2024-12-01-preview" + azure_openai_api_key: "" # Fill locally if you do not export AZURE_OPENAI_API_KEY + azure_openai_auth_mode: azure_cli + azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default" + azure_openai_managed_identity_client_id: "" + teacher_azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/" + teacher_azure_openai_api_version: "2024-12-01-preview" + teacher_azure_openai_api_key: "" + teacher_azure_openai_auth_mode: azure_cli + teacher_azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default" + teacher_azure_openai_managed_identity_client_id: "" + student_azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/" + student_azure_openai_api_version: "2024-12-01-preview" + student_azure_openai_api_key: "" + student_azure_openai_auth_mode: azure_cli + student_azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default" + student_azure_openai_managed_identity_client_id: "" + +train: + num_epochs: 4 + train_size: 0 # 0 = derive from dataset split when available + batch_size: 40 + accumulation: 1 + seed: 42 + +gradient: + minibatch_size: 8 + merge_batch_size: 8 + analyst_workers: 16 + max_analyst_rounds: 3 + failure_only: false + use_deep_reflect: false + deep_reflect_failures: 4 + deep_reflect_successes: 2 + +optimizer: + learning_rate: 4 # max edits per step (edit_budget) + min_learning_rate: 2 # min edits for decay schedulers + lr_scheduler: cosine # constant / linear / cosine / autonomous + lr_control_mode: fixed # fixed / autonomous / none + skill_update_mode: patch # patch / rewrite_from_suggestions / full_rewrite_minibatch + use_meta_reflect: false + meta_learning_rate: 4 # max edits per epoch-level meta-reflect + use_slow_update: true + slow_update_samples: 20 + longitudinal_pair_policy: mixed # mixed / changed / unchanged + use_meta_skill: true + +evaluation: + use_gate: true + sel_env_num: 0 + test_env_num: 0 + eval_test: true + +env: + name: "" + skill_init: "" + split_mode: ratio # ratio = build deterministic split from data_path; split_dir = use pre-split train/val/test + split_ratio: "2:1:7" # explicit default for dataset-backed benchmarks: train:val:test + split_seed: 42 + split_dir: "" + data_path: "" + split_output_dir: "" + exec_timeout: 120 # per student model/code-agent call timeout in seconds + out_root: "" diff --git a/configs/ablation_study/README.md b/configs/ablation_study/README.md new file mode 100644 index 0000000..c0a4618 --- /dev/null +++ b/configs/ablation_study/README.md @@ -0,0 +1,305 @@ +# Ablation Study Configuration Manifest + +This folder records the final, reproducible settings for the ablation runs used +in `docs/ablation_paper_tables.md`. + +It is intentionally separate from the benchmark default configs. The benchmark +configs under `configs//default.yaml` remain the source task configs; +this folder records the exact matrix-level overrides, run roots, launch commands, +and validation rules used for the paper ablations. + +## Files + +- `matrix.yaml`: canonical ablation matrix, common overrides, benchmark splits, + token/output caps, and invalid-run rules. +- `launch_commands.sh`: exact launcher commands for the valid run roots. +- `validation.md`: monitoring, result extraction, and invalidation checklist. + +## Source Of Truth + +Use the matrix launcher: + +```bash +/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python scripts/run_ablation_matrix.py +``` + +The launcher builds runs from the same defaults and values recorded in +`matrix.yaml`. It skips completed runs by checking `summary.json` and skips +active runs by checking `env.out_root` in active `scripts/train.py` processes. + +Do not manually rerun a completed run into the same `env.out_root`. If a run is +invalid, archive or remove its output directory first, then let the launcher +start it cleanly. + +## Current Correct Run Roots + +- SearchQA / SpreadsheetBench original ablations: + `outputs/ablation_20260502_040604_unique48` +- SearchQA / SpreadsheetBench batch-size ablations: + `outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run` +- LiveMathBench / ALFWorld clean ablations: + `outputs/ablation_livemath_alfworld_clean_20260503_155155_run` +- DocVQA ablations: + `outputs/ablation_docvqa_20260503_160225_run` + +Archived, superseded, misaligned, dry-run, or pre-fix directories must not be +used for paper tables. + +## End-To-End Runbook + +### Environment + +Run from the repository root: + +```bash +cd /home/azureuser/workspace-gzy/SkillReflection +``` + +Always use: + +```bash +PY=/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python +export ALFWORLD_DATA=/home/azureuser/.cache/alfworld +``` + +Default model/auth settings are generated by `scripts/run_ablation_matrix.py`: + +```text +teacher=gpt-5.5 +student=gpt-5.5 +teacher_backend=openai_chat +student_backend=openai_chat +reasoning_effort=medium +teacher/student endpoint=https://t2vgoaigpt4o3.openai.azure.com/ +teacher/student api_version=2024-12-01-preview +teacher/student auth_mode=azure_cli +``` + +Core training settings: + +```text +train.num_epochs=4 +train.train_size=0 +train.batch_size=40 +train.accumulation=1 +train.seed=42 +gradient.minibatch_size=8 +gradient.merge_batch_size=8 +gradient.analyst_workers=16 +gradient.use_deep_reflect=false +optimizer.learning_rate=4 +optimizer.min_learning_rate=2 +optimizer.lr_scheduler=cosine +optimizer.lr_control_mode=fixed +optimizer.use_slow_update=true +optimizer.slow_update_samples=20 +optimizer.use_meta_skill=true +optimizer.use_meta_reflect=false +optimizer.longitudinal_pair_policy=mixed +evaluation.use_gate=true +evaluation.eval_test=true +env.split_mode=split_dir +``` + +`train.train_size=0` is intentional. The dataloader derives the train size from +the fixed split. Batch-size ablations rely on the default `ceil(train_size / +batch_size)` behavior; the last batch can be smaller than `train.batch_size`. + +### Fixed Splits + +Default split directories: + +```text +searchqa: data/ablation_splits/searchqa/2-1-7_seed42 +spreadsheetbench: data/ablation_splits/spreadsheetbench/2-1-7_seed42 +livemathematicianbench: data/ablation_splits/livemathematicianbench/2-1-7_seed42 +alfworld: data/ablation_splits/alfworld/2-1-7_seed42 +docvqa: /home/azureuser/zisu/SkillReflection/data/docvqa/splits +``` + +Default train/val/test sizes: + +| Benchmark | Train | Val | Test | +| --- | ---: | ---: | ---: | +| SearchQA | 400 | 200 | 1400 | +| SpreadsheetBench | 80 | 40 | 280 | +| LiveMathBench | 35 | 18 | 124 | +| ALFWorld | 39 | 18 | 134 | +| DocVQA | 1070 | 535 | 3744 | + +DocVQA images are not copied. The valid setup uses: + +```text +data/docvqa_images -> /home/azureuser/zisu/SkillReflection/data/docvqa_images +``` + +2026-05-05 DocVQA data correction: all DocVQA final reruns should use the zisu +10% split above and a fresh output root such as +`outputs/ablation_docvqa_zisu10pct_20260505_run`. The older local +`data/ablation_splits/docvqa/2-1-7_seed42` contains the same 5349 questionId +pool but a different train/val/test assignment, so its completed summaries are +historical only. + +### Matrix Groups + +Use these group names with `scripts/run_ablation_matrix.py`: + +```text +default split batch mbs lr sched slown mod smodel longpair lrctrl +``` + +`longpair` is the slow-update/meta-skill comparison-example ablation. It keeps +all prompts and training settings unchanged and only overrides: + +```text +optimizer.longitudinal_pair_policy=changed +optimizer.longitudinal_pair_policy=unchanged +``` + +The default paper setting remains `mixed`. + +`lrctrl` contains the two learning-rate-control baselines: + +```text +optimizer.lr_control_mode=autonomous +optimizer.lr_control_mode=none + optimizer.skill_update_mode=full_rewrite_minibatch +``` + +The autonomous run logs the chosen integer per step in `lr_decision.json` and +`lr_history.jsonl`. The full-rewrite run removes the LR/edit-selection concept: +each minibatch analyst produces a complete skill candidate, and aggregate/merge +produces the candidate skill directly. + +Batch-size values are: + +```text +8 / 24 / 40 / 56 / full +``` + +`40` is the default point. `full` expands to the benchmark train size. + +### Launch Commands Used In This Session + +The exact commands are recorded in `launch_commands.sh` and in +`docs/ablation_plan.md`. The important current policy is: + +- SearchQA / SpreadsheetBench batch-only matrix can run at `--max-parallel 8`. +- DocVQA matrix can run with its launcher at `--max-parallel 8`; later top-up used `--max-parallel 16` only because completed runs were skipped and active roots were checked. +- LiveMathBench is safe as API-only benchmark after the token cap fix. +- ALFWorld must not be mixed into a 24-way run on this shared machine. Use `--bench alfworld --max-parallel 1` only after memory is available. + +### Token And Timeout Fixes + +LiveMathBench must use a large student completion cap: + +```text +max_completion_tokens=16384 +timeout=300 +``` + +The old 768/512 cap produced many empty visible responses because hidden +reasoning consumed the budget. + +ALFWorld must use: + +```text +max_completion_tokens=2048 +empty response fallback -> look +missing action fallback -> look +``` + +### Invalid Runs + +Never fill paper tables from these archive directories: + +```text +outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_livemath_token768_20260504_022258/ +outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_token512_20260504_021417/ +outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_empty_action_20260504_025311/ +outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_prefallback_20260504_025402/ +outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_oom_partial_20260504_050517/ +outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_serial_lowmem_20260504_1300/ +``` + +### Current Resource Notes + +ALFWorld model calls are API calls. The ablation branch now creates local +ALFWorld/TextWorld environments through multiprocessing workers ported from +`skillopt_final_zzw`, not through Ray actors. Old Ray-based archived runs are +not valid for table fill. The observed historical failure mode in this session +was system RAM pressure and Ray OOM prevention, not model GPU memory. + +GPU memory currently shown by `nvidia-smi` came from unrelated Ray Serve visual +models under: + +```text +/home/azureuser/workspace-gzy/zyf/gca-skill +``` + +Those processes are `GroundingDINOModel` / `DA3Model`, not the +SkillReflection ablation ALFWorld run. + +There are also unrelated ALFWorld jobs under: + +```text +/home/azureuser/zisu/skill_distill +``` + +Do not confuse those with this repository's ablation outputs. + +### Monitoring + +Active run and duplicate output-root check: + +```bash +$PY - <<'PY' +import subprocess, re, collections, time +try: + raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True) +except subprocess.CalledProcessError: + raw = "" +roots = [] +for line in raw.splitlines(): + m = re.search(r"env\.out_root=([^\s]+)", line) + if m: + roots.append(m.group(1)) +ctr = collections.Counter(roots) +print("time", time.strftime("%F %T")) +print("active_count", len(roots)) +print("duplicates", [r.rsplit("/", 1)[-1] for r, c in ctr.items() if c > 1]) +for root in sorted(roots): + print(root.rsplit("/", 1)[-1]) +PY +``` + +Error scan: + +```bash +rg -n "Traceback|ERROR|Error code|AuthenticationError|BadRequest|RateLimit|content_filter|Killed|OutOfMemory|CUDA out of memory|\\[FAIL\\]|LLM call failed" \ + outputs/ablation_docvqa_20260503_160225_run/logs \ + outputs/ablation_livemath_alfworld_clean_20260503_155155_run/logs \ + outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/logs \ + -g '*.log' | tail -100 || true +``` + +Resource checks: + +```bash +free -h | sed -n '1,3p' +df -h /tmp +du -sh /tmp/ray 2>/dev/null || true +nvidia-smi +``` + +### Filling Tables + +Only use top-level `summary.json` from valid run roots. Fill +`docs/ablation_paper_tables.md` from: + +```text +best_selection_hard +baseline_test_hard +test_hard +test_delta_hard +token_summary._total.total_tokens +``` diff --git a/configs/ablation_study/launch_commands.sh b/configs/ablation_study/launch_commands.sh new file mode 100755 index 0000000..5313203 --- /dev/null +++ b/configs/ablation_study/launch_commands.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd /home/azureuser/workspace-gzy/SkillReflection + +PY=/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python +export ALFWORLD_DATA=/home/azureuser/.cache/alfworld + +# Original SearchQA / SpreadsheetBench full matrix reproduction command. +# Do not run this into the existing root unless intentionally reproducing from +# scratch; the current valid root is already populated: +# outputs/ablation_20260502_040604_unique48 +# +# setsid "$PY" scripts/run_ablation_matrix.py \ +# --groups default split mbs lr sched slown mod smodel \ +# --bench searchqa spreadsheetbench \ +# --run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_20260502_040604_unique48 \ +# --max-parallel 24 \ +# --execute \ +# > /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_20260502_040604_unique48/launcher_reproduce_full_matrix.log 2>&1 < /dev/null & +# +# SearchQA / SpreadsheetBench batch-size ablations only. +# Original non-batch SearchQA/SpreadsheetBench ablations live in: +# outputs/ablation_20260502_040604_unique48 +setsid "$PY" scripts/run_ablation_matrix.py \ + --groups batch \ + --bench searchqa spreadsheetbench \ + --run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run \ + --max-parallel 8 \ + --execute \ + > /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/launcher_parallel8.log 2>&1 < /dev/null & + +# DocVQA full matrix. +setsid "$PY" scripts/run_ablation_matrix.py \ + --groups default split batch mbs lr sched slown mod smodel \ + --bench docvqa \ + --run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_docvqa_20260503_160225_run \ + --max-parallel 8 \ + --execute \ + > /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_docvqa_20260503_160225_run/launcher_parallel8.log 2>&1 < /dev/null & + +# LiveMathBench clean matrix. ALFWorld should be launched separately at lower +# concurrency because Ray OOM occurred when many ALFWorld runs were mixed into a +# 24-way run. +setsid "$PY" scripts/run_ablation_matrix.py \ + --groups default split batch mbs lr sched slown mod smodel \ + --bench livemathematicianbench \ + --run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run \ + --max-parallel 8 \ + --execute \ + > /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_livemath_parallel8.log 2>&1 < /dev/null & + +# ALFWorld clean matrix. Increase to 2 only after checking memory, /tmp/ray, +# and that no other ALFWorld run is active. Do not use 8/16/24 for ALFWorld on +# the current shared machine unless resources are explicitly reserved. +setsid "$PY" scripts/run_ablation_matrix.py \ + --groups default split batch mbs lr sched slown mod smodel \ + --bench alfworld \ + --run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run \ + --max-parallel 1 \ + --execute \ + > /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_alfworld_parallel1.log 2>&1 < /dev/null & + +# Longitudinal comparison-example policy ablations. This intentionally excludes +# ALFWorld. The only varied setting is optimizer.longitudinal_pair_policy. +setsid "$PY" scripts/run_ablation_matrix.py \ + --groups longpair \ + --bench searchqa spreadsheetbench livemathematicianbench docvqa \ + --run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_longpair_20260504_run \ + --max-parallel 8 \ + --execute \ + > /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_longpair_20260504_run/launcher_longpair_parallel8.log 2>&1 < /dev/null & + +# Learning-rate-control baselines. This intentionally excludes ALFWorld. +setsid "$PY" scripts/run_ablation_matrix.py \ + --groups lrctrl \ + --bench searchqa spreadsheetbench livemathematicianbench docvqa \ + --run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_lrctrl_20260504_run \ + --max-parallel 8 \ + --execute \ + > /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_lrctrl_20260504_run/launcher_lrctrl_parallel8.log 2>&1 < /dev/null & diff --git a/configs/ablation_study/matrix.yaml b/configs/ablation_study/matrix.yaml new file mode 100644 index 0000000..767fbed --- /dev/null +++ b/configs/ablation_study/matrix.yaml @@ -0,0 +1,257 @@ +version: 2026-05-04 +purpose: "Canonical paper ablation settings matching the valid current runs." + +launcher: + script: scripts/run_ablation_matrix.py + python: /home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python + skip_completed_by: summary.json + skip_active_by: "active scripts/train.py env.out_root" + +environment: + working_directory: /home/azureuser/workspace-gzy/SkillReflection + required_env: + ALFWORLD_DATA: /home/azureuser/.cache/alfworld + docvqa_images_symlink: "data/docvqa_images -> /home/azureuser/zisu/SkillReflection/data/docvqa_images" + +common_overrides: + model.teacher_backend: openai_chat + model.student_backend: openai_chat + model.teacher: gpt-5.5 + model.student: gpt-5.5 + model.teacher_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/ + model.teacher_azure_openai_api_version: 2024-12-01-preview + model.teacher_azure_openai_auth_mode: azure_cli + model.student_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/ + model.student_azure_openai_api_version: 2024-12-01-preview + model.student_azure_openai_auth_mode: azure_cli + model.reasoning_effort: medium + train.num_epochs: 4 + train.train_size: 0 + train.batch_size: 40 + train.accumulation: 1 + train.seed: 42 + gradient.minibatch_size: 8 + gradient.merge_batch_size: 8 + gradient.analyst_workers: 16 + gradient.use_deep_reflect: false + optimizer.learning_rate: 4 + optimizer.min_learning_rate: 2 + optimizer.lr_scheduler: cosine + optimizer.skill_update_mode: patch + optimizer.use_slow_update: true + optimizer.slow_update_samples: 20 + optimizer.use_meta_skill: true + optimizer.use_meta_reflect: false + evaluation.use_gate: true + evaluation.eval_test: true + env.split_mode: split_dir + +benchmarks: + searchqa: + config: configs/searchqa/default.yaml + run_roots: + original_matrix: outputs/ablation_20260502_040604_unique48 + batch_matrix: outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run + default_split: data/ablation_splits/searchqa/2-1-7_seed42 + train: 400 + val: 200 + test: 1400 + student_rollout: + function: skillopt/envs/searchqa/rollout.py::chat_student + max_completion_tokens: + first_turn: 512 + refinement: 512 + rationale: "Short-answer QA; sampled empties are low and not LiveMath-like." + + spreadsheetbench: + config: configs/spreadsheetbench/default.yaml + run_roots: + original_matrix: outputs/ablation_20260502_040604_unique48 + batch_matrix: outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run + default_split: data/ablation_splits/spreadsheetbench/2-1-7_seed42 + train: 80 + val: 40 + test: 280 + student_rollout: + function: skillopt/envs/spreadsheetbench/codegen_agent.py::run_multi + max_output_tokens: 16384 + result_note: "results.jsonl stores execution fields, not a response field." + + livemathematicianbench: + config: configs/livemathematicianbench/default.yaml + run_roots: + clean_matrix: outputs/ablation_livemath_alfworld_clean_20260503_155155_run + default_split: data/ablation_splits/livemathematicianbench/2-1-7_seed42 + train: 35 + val: 18 + test: 124 + student_rollout: + function: skillopt/envs/livemathematicianbench/rollout.py::chat_student + max_completion_tokens: + first_turn: 16384 + refinement: 16384 + timeout_seconds: 300 + invalid_old_caps: + first_turn: 768 + refinement: 512 + invalid_archive: outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_livemath_token768_20260504_022258 + rationale: "GPT-5 reasoning consumed small budgets and produced many empty visible responses." + + alfworld: + config: configs/alfworld/default.yaml + run_roots: + clean_matrix: outputs/ablation_livemath_alfworld_clean_20260503_155155_run + default_split: data/ablation_splits/alfworld/2-1-7_seed42 + train: 39 + val: 18 + test: 134 + student_rollout: + function: skillopt/envs/alfworld/rollout.py::chat_student + max_completion_tokens: 2048 + timeout_seconds: 120 + max_steps: 50 + fallback_action: look + invalid_old_cap: 512 + invalid_archives: + - outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_token512_20260504_021417 + - outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_empty_action_20260504_025311 + - outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_prefallback_20260504_025402 + - outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_oom_partial_20260504_050517 + concurrency_note: "Do not mix many ALFWorld runs into 24-way total concurrency; Ray OOM occurred. Prefer 1-2 ALFWorld runs at a time unless resources are clearly free." + + docvqa: + config: configs/docvqa/default.yaml + run_roots: + matrix: outputs/ablation_docvqa_zisu10pct_20260505_run + default_split: /home/azureuser/zisu/SkillReflection/data/docvqa/splits + train: 1070 + val: 535 + test: 3744 + data_note: "2026-05-05: use zisu-provided 5349-item DocVQA split directly; previous local 2-1-7_seed42 used the same item pool but a different train/val/test assignment and must not be used for final DocVQA reruns." + student_rollout: + function: skillopt/envs/docvqa/rollout.py::chat_student_messages + max_completion_tokens: + first_turn: 768 + refinement: 512 + rationale: "Short-answer VQA output; preserve current setting for alignment unless explicitly rerunning all affected DocVQA." + +splits: + tags: + 1shot: + extra_overrides: + optimizer.slow_update_samples: 1 + 1-1-8: {} + 2-1-7: + default: true + 4-1-5: {} + paths: + searchqa: + 1shot: data/ablation_splits/searchqa/1shot_seed42 + 1-1-8: data/ablation_splits/searchqa/1-1-8_seed42 + 2-1-7: data/ablation_splits/searchqa/2-1-7_seed42 + 4-1-5: data/ablation_splits/searchqa/4-1-5_seed42 + spreadsheetbench: + 1shot: data/ablation_splits/spreadsheetbench/1shot_seed42 + 1-1-8: data/ablation_splits/spreadsheetbench/1-1-8_seed42 + 2-1-7: data/ablation_splits/spreadsheetbench/2-1-7_seed42 + 4-1-5: data/ablation_splits/spreadsheetbench/4-1-5_seed42 + livemathematicianbench: + 1shot: data/ablation_splits/livemathematicianbench/1shot_seed42 + 1-1-8: data/ablation_splits/livemathematicianbench/1-1-8_seed42 + 2-1-7: data/ablation_splits/livemathematicianbench/2-1-7_seed42 + 4-1-5: data/ablation_splits/livemathematicianbench/4-1-5_seed42 + alfworld: + 1shot: data/ablation_splits/alfworld/1shot_seed42 + 1-1-8: data/ablation_splits/alfworld/1-1-8_seed42 + 2-1-7: data/ablation_splits/alfworld/2-1-7_seed42 + 4-1-5: data/ablation_splits/alfworld/4-1-5_seed42 + docvqa: + 1shot: data/ablation_splits/docvqa/1shot_seed42 + 1-1-8: data/ablation_splits/docvqa/1-1-8_seed42 + 2-1-7: /home/azureuser/zisu/SkillReflection/data/docvqa/splits + 4-1-5: data/ablation_splits/docvqa/4-1-5_seed42 + +groups: + default: + run_id: "DEFAULT-{benchmark}-5.5" + overrides: {} + split: + values: [1shot, 1-1-8, 4-1-5] + skip_default_2_1_7: true + override_template: "env.split_dir={split_path}" + batch: + values: [8, 24, 56, full] + default_value_reused: 40 + full_values: + searchqa: 400 + spreadsheetbench: 80 + livemathematicianbench: 35 + alfworld: 39 + docvqa: 1070 + fixed_overrides: + gradient.minibatch_size: 8 + mbs: + values: [1, 2, 4, 16, 32] + default_value_reused: 8 + override_template: "gradient.minibatch_size={value}" + lr: + values: [1, 2, 4, 8, 16] + fixed_overrides: + optimizer.lr_scheduler: constant + optimizer.min_learning_rate: 1 + override_template: "optimizer.learning_rate={value}" + sched: + values: [constant, linear] + default_value_reused: cosine + override_template: "optimizer.lr_scheduler={value}" + slown: + values: [5, 10, 40] + default_value_reused: 20 + override_template: "optimizer.slow_update_samples={value}" + mod: + values: + slow-only: + optimizer.use_slow_update: true + optimizer.use_meta_skill: false + meta-only: + optimizer.use_slow_update: false + optimizer.use_meta_skill: true + none: + optimizer.use_slow_update: false + optimizer.use_meta_skill: false + default_value_reused: slow-meta + longpair: + values: [changed, unchanged] + default_value_reused: mixed + override_template: "optimizer.longitudinal_pair_policy={value}" + note: "Only changes slow-update/meta-skill comparison examples; prompts and other settings remain unchanged." + lrctrl: + values: + autonomous: + optimizer.lr_control_mode: autonomous + full-rewrite: + optimizer.lr_control_mode: none + optimizer.skill_update_mode: full_rewrite_minibatch + default_value_reused: "fixed patch learning_rate=4" + note: "autonomous records lr_decision.json/lr_history.jsonl; full-rewrite removes LR/select/apply-edit and uses full skill candidates." + smodel: + values: + "5.4": + model.student: gpt-5.4-pro + model.student_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/ + model.student_azure_openai_api_version: 2025-03-01-preview + model.student_azure_openai_auth_mode: azure_cli + "5.4-mini": + model.student: gpt-5.4-mini + model.student_azure_openai_endpoint: https://searchagent5.cognitiveservices.azure.com/ + model.student_azure_openai_api_version: 2024-12-01-preview + model.student_azure_openai_auth_mode: azure_cli + default_value_reused: "5.5" + +validity_rules: + use_for_tables: + - "Only runs with summary.json in valid run roots." + - "Do not use archive, archived, MISALIGNED, SUPERSEDED, dryrun, smoke, or debug directories." + - "Do not use ALFWorld runs started before empty/missing-action fallback." + - "Do not use old LiveMath runs with 768/512 token caps." + rerun_rule: "Archive or remove invalid out_root before relaunch; never write a rerun into a polluted output directory." diff --git a/configs/ablation_study/validation.md b/configs/ablation_study/validation.md new file mode 100644 index 0000000..95e2730 --- /dev/null +++ b/configs/ablation_study/validation.md @@ -0,0 +1,141 @@ +# Ablation Validation Checklist + +Use this checklist before launch, during monitoring, and before filling +`docs/ablation_paper_tables.md`. + +## Before Launch + +Run from repo root: + +```bash +cd /home/azureuser/workspace-gzy/SkillReflection +export ALFWORLD_DATA=/home/azureuser/.cache/alfworld +``` + +Verify syntax for edited files: + +```bash +/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python -m py_compile \ + scripts/run_ablation_matrix.py \ + scripts/train.py \ + skillopt/model/azure_openai.py \ + skillopt/envs/searchqa/rollout.py \ + skillopt/envs/spreadsheetbench/rollout.py \ + skillopt/envs/livemathematicianbench/rollout.py \ + skillopt/envs/alfworld/rollout.py \ + skillopt/envs/docvqa/rollout.py +``` + +Check active runs and duplicate `env.out_root` before starting more: + +```bash +/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python - <<'PY' +import subprocess, re, collections +try: + raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True) +except subprocess.CalledProcessError: + raw = "" +roots = [] +for line in raw.splitlines(): + m = re.search(r"env\.out_root=([^\s]+)", line) + if m: + roots.append(m.group(1)) +ctr = collections.Counter(roots) +print("train_count", len(roots)) +print("duplicate_roots", [r.rsplit("/", 1)[-1] for r, c in ctr.items() if c > 1]) +for root in sorted(roots): + print(root.rsplit("/", 1)[-1]) +PY +``` + +## During Monitoring + +Check launchers: + +```bash +pgrep -af 'scripts/run_ablation_matrix.py' || true +tail -80 outputs/ablation_docvqa_20260503_160225_run/launcher_parallel8.log 2>/dev/null || true +tail -80 outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_livemath_parallel8.log 2>/dev/null || true +tail -80 outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_alfworld_parallel1.log 2>/dev/null || true +``` + +Scan current logs for new hard failures: + +```bash +rg -n "Traceback|ERROR|Error code|AuthenticationError|BadRequest|RateLimit|content_filter|Killed|OutOfMemory|\\[FAIL\\]|\\[RETRY\\]" \ + outputs/ablation_docvqa_20260503_160225_run/logs \ + outputs/ablation_livemath_alfworld_clean_20260503_155155_run/logs \ + outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/logs \ + -g '*.log' | tail -160 || true +``` + +Check resource pressure: + +```bash +df -h /tmp +du -sh /tmp/ray 2>/dev/null || true +free -h | sed -n '1,3p' +``` + +## Quality Checks + +LiveMathBench current valid runs should not look like old 768/512 runs: + +```bash +/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python - <<'PY' +import json, pathlib +root = pathlib.Path("outputs/ablation_livemath_alfworld_clean_20260503_155155_run") +for run in sorted(root.glob("*livemathematicianbench*")): + if not run.is_dir() or "archive" in str(run): + continue + for rel in ["test_eval_baseline/results.jsonl", "test_eval/results.jsonl"]: + p = run / rel + if not p.exists(): + continue + rows = [json.loads(l) for l in p.open(errors="ignore") if l.strip()] + empty = sum(1 for r in rows if not str(r.get("response", "")).strip()) + answer = sum(1 for r in rows if "" in str(r.get("response", "")).lower()) + if empty: + print(run.name, rel, "empty", empty, "answer", answer, "n", len(rows)) +PY +``` + +ALFWorld valid runs must not contain empty action or missing action: + +```bash +/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python - <<'PY' +import json, pathlib +root = pathlib.Path("outputs/ablation_livemath_alfworld_clean_20260503_155155_run") +for run in sorted(root.glob("*alfworld*")): + if not run.is_dir() or "archive" in str(run): + continue + bad = [] + fallback = 0 + for c in run.glob("**/conversation.json"): + data = json.load(c.open(errors="ignore")) + for step in data: + if step.get("step") is None: + continue + if not step.get("action"): + bad.append(str(c.relative_to(run))) + break + mr = str(step.get("model_response", "")) + if "empty model response" in mr or "missing action tag" in mr: + fallback += 1 + print(run.name, "bad_action_files", len(bad), "fallback", fallback) +PY +``` + +## Filling Tables + +Use only `summary.json` fields: + +- `best_selection_hard` -> Best Sel +- `baseline_test_hard` -> Base Test +- `test_hard` -> Best Test +- `test_delta_hard` -> Delta +- `total_accepts` -> Accept +- `total_rejects` -> Reject +- `token_summary._total.total_tokens` -> Tokens + +Do not fill table rows from logs alone. diff --git a/configs/alfworld/default.yaml b/configs/alfworld/default.yaml new file mode 100644 index 0000000..69b1049 --- /dev/null +++ b/configs/alfworld/default.yaml @@ -0,0 +1,30 @@ +_base_: ../_base_/default.yaml + +train: + train_size: 0 + accumulation: 1 + +gradient: + minibatch_size: 8 + merge_batch_size: 8 + +optimizer: + learning_rate: 4 + use_meta_reflect: false + +evaluation: + sel_env_num: 0 + test_env_num: 0 + +env: + name: alfworld + skill_init: skillopt/envs/alfworld/skills/initial.md + split_mode: split_dir + split_ratio: "2:1:7" + split_dir: data/ablation_splits/alfworld/2-1-7_seed42 + data_path: "" + split_output_dir: "" + max_steps: 50 + workers: 8 + max_api_workers: 8 + limit: 0 diff --git a/configs/alfworld/meta_reflect.yaml b/configs/alfworld/meta_reflect.yaml new file mode 100644 index 0000000..bb19f64 --- /dev/null +++ b/configs/alfworld/meta_reflect.yaml @@ -0,0 +1,4 @@ +_base_: default.yaml + +optimizer: + use_meta_reflect: true diff --git a/configs/babyvision/default.yaml b/configs/babyvision/default.yaml new file mode 100644 index 0000000..cc4a0d6 --- /dev/null +++ b/configs/babyvision/default.yaml @@ -0,0 +1,21 @@ +_base_: ../_base_/default.yaml + +train: + batch_size: 64 + accumulation: 1 + +env: + name: babyvision + skill_init: skillopt/envs/babyvision/skills/initial.md + split_mode: ratio + split_ratio: "2:1:7" + split_dir: "" + data_path: "" + split_output_dir: "" + max_turns: 1 + workers: 16 + limit: 0 + image_detail: auto + judge_model: gpt-5.4 + judge_max_completion_tokens: 256 + judge_retries: 5 diff --git a/configs/docvqa/default.yaml b/configs/docvqa/default.yaml new file mode 100644 index 0000000..36aeb1a --- /dev/null +++ b/configs/docvqa/default.yaml @@ -0,0 +1,28 @@ +_base_: ../_base_/default.yaml + +model: + reasoning_effort: medium + +train: + batch_size: 40 + accumulation: 1 + +gradient: + minibatch_size: 8 + merge_batch_size: 8 + +optimizer: + learning_rate: 4 + +env: + name: docvqa + skill_init: skillopt/envs/docvqa/skills/initial.md + split_mode: split_dir + split_ratio: "2:1:7" + split_dir: /home/azureuser/zisu/SkillReflection/data/docvqa/splits + data_path: "" + split_output_dir: "" + max_turns: 1 + workers: 16 + image_detail: auto + limit: 0 diff --git a/configs/livemathematicianbench/default.yaml b/configs/livemathematicianbench/default.yaml new file mode 100644 index 0000000..c337094 --- /dev/null +++ b/configs/livemathematicianbench/default.yaml @@ -0,0 +1,22 @@ +_base_: ../_base_/default.yaml + +train: + train_size: 0 + batch_size: 40 + accumulation: 1 + +env: + name: livemathematicianbench + skill_init: skillopt/envs/livemathematicianbench/skills/initial.md + split_mode: split_dir + split_ratio: "2:1:7" + split_dir: data/ablation_splits/livemathematicianbench/2-1-7_seed42 + data_path: "" + split_output_dir: "" + max_turns: 1 + exec_timeout: 300 + workers: 64 + limit: 0 + shuffle_choices: true + use_theorem: false + use_sketch: false diff --git a/configs/mathverse/default.yaml b/configs/mathverse/default.yaml new file mode 100644 index 0000000..ac1dfd9 --- /dev/null +++ b/configs/mathverse/default.yaml @@ -0,0 +1,23 @@ +_base_: ../_base_/default.yaml + +model: + codex_exec_sandbox: danger-full-access + +train: + batch_size: 64 + accumulation: 1 + +env: + name: mathverse + skill_init: skillopt/envs/mathverse/skills/initial.md + split_dir: "" + data_root: data/MathVerse + problem_version: Text Lite + use_text_dominant_reference: false + max_turns: 1 + workers: 16 + limit: 0 + image_detail: auto + judge_model: gpt-5.4 + judge_max_completion_tokens: 256 + judge_retries: 5 diff --git a/configs/mmrb/default.yaml b/configs/mmrb/default.yaml new file mode 100644 index 0000000..aa55e33 --- /dev/null +++ b/configs/mmrb/default.yaml @@ -0,0 +1,18 @@ +_base_: ../_base_/default.yaml + +train: + batch_size: 128 + accumulation: 1 + +env: + name: mmrb + skill_init: skillopt/envs/mmrb/skills/initial.md + split_mode: ratio + split_ratio: "2:1:7" + split_dir: "" + data_path: "" + split_output_dir: "" + max_turns: 1 + workers: 16 + limit: 0 + image_detail: auto diff --git a/configs/officeqa/default.yaml b/configs/officeqa/default.yaml new file mode 100644 index 0000000..db29803 --- /dev/null +++ b/configs/officeqa/default.yaml @@ -0,0 +1,25 @@ +_base_: ../_base_/default.yaml + +model: + reasoning_effort: medium + +train: + batch_size: 40 + accumulation: 1 + +gradient: + minibatch_size: 8 + merge_batch_size: 8 + +optimizer: + learning_rate: 4 + +env: + name: officeqa + skill_init: skillopt/envs/officeqa/skills/initial.md + split_dir: data/officeqa_split + data_dirs: + - data/officeqa_docs_official + workers: 4 + max_tool_turns: 24 + limit: 0 diff --git a/configs/sealqa/default.yaml b/configs/sealqa/default.yaml new file mode 100644 index 0000000..6616dfb --- /dev/null +++ b/configs/sealqa/default.yaml @@ -0,0 +1,23 @@ +_base_: ../_base_/default.yaml + +model: + reasoning_effort: medium + +train: + batch_size: 10 + accumulation: 1 + +gradient: + minibatch_size: 8 + merge_batch_size: 8 + +optimizer: + learning_rate: 4 + +env: + name: sealqa + skill_init: skillopt/envs/sealqa/skills/initial.md + split_dir: data/sealqa_split + workers: 4 + max_tool_turns: 12 + limit: 0 diff --git a/configs/searchqa/default.yaml b/configs/searchqa/default.yaml new file mode 100644 index 0000000..bd75a7b --- /dev/null +++ b/configs/searchqa/default.yaml @@ -0,0 +1,32 @@ +_base_: ../_base_/default.yaml + +model: + reasoning_effort: medium + +train: + train_size: 400 + batch_size: 40 + accumulation: 1 + +gradient: + minibatch_size: 8 + merge_batch_size: 8 + +optimizer: + learning_rate: 4 + +evaluation: + sel_env_num: 0 + test_env_num: 0 + +env: + name: searchqa + skill_init: skillopt/envs/searchqa/skills/initial.md + split_mode: split_dir + split_ratio: "2:1:7" + split_dir: data/searchqa_split + data_path: "" + split_output_dir: "" + max_turns: 1 + workers: 24 + limit: 0 diff --git a/configs/spreadsheetbench/default.yaml b/configs/spreadsheetbench/default.yaml new file mode 100644 index 0000000..13e919f --- /dev/null +++ b/configs/spreadsheetbench/default.yaml @@ -0,0 +1,34 @@ +_base_: ../_base_/default.yaml + +model: + reasoning_effort: medium + +train: + train_size: 80 + batch_size: 40 + accumulation: 1 + +gradient: + minibatch_size: 8 + merge_batch_size: 8 + +optimizer: + learning_rate: 4 + +evaluation: + sel_env_num: 0 + test_env_num: 0 + +env: + name: spreadsheetbench + skill_init: skillopt/envs/spreadsheetbench/skills/initial.md + split_mode: split_dir + split_ratio: "2:1:7" + split_dir: data/spreadsheetbench_split + data_path: "" + split_output_dir: "" + data_root: data/spreadsheetbench_verified_400 + mode: multi + max_turns: 30 + exec_timeout: 600 + workers: 24 diff --git a/configs/swebench/default.yaml b/configs/swebench/default.yaml new file mode 100644 index 0000000..ed22386 --- /dev/null +++ b/configs/swebench/default.yaml @@ -0,0 +1,36 @@ +_base_: ../_base_/default.yaml + +model: + reasoning_effort: medium + +train: + batch_size: 20 + accumulation: 1 + +gradient: + minibatch_size: 4 + merge_batch_size: 8 + +optimizer: + learning_rate: 4 + +evaluation: + sel_env_num: 0 + test_env_num: 0 + +env: + name: swebench + skill_init: skillopt/envs/swebench/skills/initial.md + split_mode: ratio + split_ratio: "2:1:7" + split_dir: "" + data_path: "" + split_output_dir: "" + dataset_name: lite + hf_split: test + workers: 8 + eval_workers: 8 + step_limit: 50 + cost_limit: 3.0 + timeout_per_instance: 600 + limit: 0 diff --git a/scripts/codex_azure_mi.sh b/scripts/codex_azure_mi.sh new file mode 100755 index 0000000..f8c74e7 --- /dev/null +++ b/scripts/codex_azure_mi.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="/home/azureuser/workspace-gzy/SkillReflection" +PY="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python" +REAL_CODEX="/home/azureuser/.nvm/versions/node/v18.20.8/bin/codex" +CLIENT_ID="8cafa2b1-a2a7-4ad9-814a-ffe4aed7e800" +SCOPE="https://cognitiveservices.azure.com/.default" + +token="$("$PY" - < argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--out_dir", type=str, required=True) + p.add_argument("--dataset", type=str, default="UnipatAI/BabyVision") + p.add_argument("--split", type=str, default="train") + return p.parse_args() + + +def main() -> None: + args = parse_args() + + try: + from datasets import load_dataset + except ImportError as exc: # pragma: no cover + raise SystemExit("Please install `datasets` first: pip install datasets pillow") from exc + + out_dir = Path(args.out_dir).resolve() + images_dir = out_dir / "images" + meta_path = out_dir / "meta_data.jsonl" + images_dir.mkdir(parents=True, exist_ok=True) + + dataset = load_dataset(args.dataset, split=args.split) + with open(meta_path, "w", encoding="utf-8") as outf: + for idx, row in enumerate(dataset): + image = row.get("image") + if image is None: + continue + task_id = str(row.get("taskId") or row.get("id") or idx + 1) + image_name = f"{task_id}.png" + image_path = images_dir / image_name + image.save(image_path) + + record = dict(row) + record["image"] = image_name + outf.write(json.dumps(record, ensure_ascii=False) + "\n") + + print(f"Saved BabyVision to {out_dir}") + print(f"Metadata: {meta_path}") + print(f"Images: {images_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_livemathematicianbench_baseline.py b/scripts/eval_livemathematicianbench_baseline.py new file mode 100644 index 0000000..a66ace0 --- /dev/null +++ b/scripts/eval_livemathematicianbench_baseline.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +"""Evaluate LiveMathematicianBench under current or official-style prompts.""" +from __future__ import annotations + +import argparse +import json +import os +import random +import re +import sys +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +from skillopt.envs.livemathematicianbench.dataloader import load_items +from skillopt.envs.livemathematicianbench.evaluator import evaluate as current_evaluate +from skillopt.envs.livemathematicianbench.rollout import _build_system, _build_user +from skillopt.model import ( + chat_with_deployment, + configure_azure_openai, + set_backend, + set_reasoning_effort, +) + +_LABELS = ["A", "B", "C", "D", "E"] + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--data_path", type=str, required=True) + p.add_argument("--model", type=str, default="gpt-5.4") + p.add_argument("--backend", type=str, choices=["azure_openai", "codex", "claude"], default="azure_openai") + p.add_argument("--mode", type=str, choices=["current", "official"], required=True) + p.add_argument("--reasoning_effort", type=str, default=None) + p.add_argument("--azure_endpoint", type=str, default="") + p.add_argument("--azure_api_version", type=str, default="") + p.add_argument("--azure_api_key", type=str, default="") + p.add_argument("--max_completion_tokens", type=int, default=0) + p.add_argument("--workers", type=int, default=8) + p.add_argument("--seed", type=int, default=20260227) + p.add_argument("--skill_path", type=str, default="skillopt/envs/livemathematicianbench/skills/initial.md") + p.add_argument("--limit", type=int, default=0) + p.add_argument("--resume", action="store_true") + p.add_argument("--output_json", type=str, required=True) + return p.parse_args() + +def read_skill(skill_path: str) -> str: + with open(skill_path, encoding="utf-8") as f: + return f.read() + + +def official_extract_answer(response_text: str) -> str | None: + if not response_text: + return None + boxed_match = re.search(r"\\boxed\{([A-Ea-e])\}", response_text) + if boxed_match: + return boxed_match.group(1).upper() + boxed_match = re.search(r"boxed\{([A-Ea-e])\}", response_text) + if boxed_match: + return boxed_match.group(1).upper() + answer_match = re.search(r"answer is[:\s]*([A-Ea-e])", response_text, re.IGNORECASE) + if answer_match: + return answer_match.group(1).upper() + answer_match = re.search(r"Answer[:\s]*\(?([A-Ea-e])\)?", response_text) + if answer_match: + return answer_match.group(1).upper() + final_match = re.search(r"\b([A-Ea-e])\b\s*[.)]?\s*$", response_text.strip()) + if final_match: + return final_match.group(1).upper() + return None + + +def official_format_mcq_prompt(question: str, choices: list[dict]) -> str: + lines = [ + "Answer the following multiple-choice question.", + "Think carefully, then provide your final answer in the format: \\boxed{X} where X is A, B, C, D, or E.", + "", + "Question:", + question, + "", + "Choices:", + ] + for choice in choices: + lines.append(f"{choice['label']}. {choice['text']}") + lines.append("") + lines.append("Your answer:") + return "\n".join(lines) + + +def shuffle_choices(item: dict, seed: int) -> tuple[list[dict], dict]: + correct_choice = dict(item["correct_choice"]) + all_choices = [dict(choice) for choice in item["choices"]] + rng = random.Random(f"{seed}:{item['id']}") + rng.shuffle(all_choices) + + shuffled: list[dict] = [] + new_correct = dict(correct_choice) + correct_text = correct_choice["text"] + + for idx, choice in enumerate(all_choices[: len(_LABELS)]): + relabeled = {"label": _LABELS[idx], "text": choice["text"]} + shuffled.append(relabeled) + if choice["text"] == correct_text: + new_correct = dict(relabeled) + + return shuffled, new_correct + + +def load_existing(output_path: Path) -> dict[str, dict]: + if not output_path.exists(): + return {} + with open(output_path, encoding="utf-8") as f: + payload = json.load(f) + existing = {} + for row in payload.get("results", []): + existing[str(row["id"])] = row + return existing + + +def save_results(output_path: Path, meta: dict, results: list[dict]) -> None: + correct = sum(1 for row in results if row.get("is_correct")) + total = len(results) + payload = { + **meta, + "summary": { + "correct": correct, + "total": total, + "accuracy": (correct / total) if total else 0.0, + }, + "results": sorted(results, key=lambda row: str(row["id"])), + } + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + + +def call_model( + *, + model: str, + system: str, + user: str, + max_completion_tokens: int | None, + reasoning_effort: str | None, +) -> str: + last_error: Exception | None = None + for attempt in range(5): + try: + set_reasoning_effort(reasoning_effort) + raw, _ = chat_with_deployment( + deployment=model, + system=system, + user=user, + max_completion_tokens=max_completion_tokens if max_completion_tokens and max_completion_tokens > 0 else 4096, + retries=1, + stage="rollout", + ) + return str(raw or "") + except Exception as exc: # noqa: BLE001 + last_error = exc + if attempt == 4: + break + time.sleep(min(2 ** attempt, 10)) + raise RuntimeError(f"LLM call failed after retries: {last_error}") + + +def evaluate_one( + item: dict, + *, + mode: str, + model: str, + skill_content: str, + max_completion_tokens: int, + reasoning_effort: str | None, + seed: int, +) -> dict: + shuffled_choices, correct_choice = shuffle_choices(item, seed) + + if mode == "official": + system = "You are an expert mathematician. Answer accurately." + user = official_format_mcq_prompt(item["question"], shuffled_choices) + effective_max_completion_tokens = max_completion_tokens if max_completion_tokens > 0 else None + else: + materialized = dict(item) + materialized["choices"] = shuffled_choices + materialized["correct_choice"] = correct_choice + system = _build_system(skill_content) + user = _build_user(materialized, use_theorem=False, use_sketch=False) + effective_max_completion_tokens = max_completion_tokens if max_completion_tokens > 0 else 768 + + t0 = time.time() + response = call_model( + model=model, + system=system, + user=user, + max_completion_tokens=effective_max_completion_tokens, + reasoning_effort=reasoning_effort, + ) + elapsed = time.time() - t0 + + if mode == "official": + predicted = official_extract_answer(response) + predicted_text = "" + for choice in shuffled_choices: + if choice["label"] == predicted: + predicted_text = choice["text"] + break + is_correct = predicted == correct_choice["label"] + return { + "id": item["id"], + "question": item["question"], + "correct_label": correct_choice["label"], + "correct_text": correct_choice["text"], + "predicted_label": predicted, + "predicted_text": predicted_text, + "is_correct": is_correct, + "elapsed_seconds": elapsed, + "response": response, + } + + eval_result = current_evaluate(response, correct_choice, shuffled_choices) + return { + "id": item["id"], + "question": item["question"], + "correct_label": correct_choice["label"], + "correct_text": correct_choice["text"], + "predicted_label": eval_result["predicted_label"], + "predicted_text": eval_result["predicted_text"], + "is_correct": bool(eval_result["em"]), + "elapsed_seconds": elapsed, + "response": response, + } + + +def main() -> None: + args = parse_args() + set_backend(args.backend) + configure_azure_openai( + endpoint=args.azure_endpoint or None, + api_version=args.azure_api_version or None, + api_key=args.azure_api_key or None, + ) + set_reasoning_effort(args.reasoning_effort) + output_path = Path(args.output_json).resolve() + skill_content = read_skill(args.skill_path) if args.mode == "current" else "" + + items = load_items(args.data_path) + if args.limit: + items = items[:args.limit] + + existing = load_existing(output_path) if args.resume else {} + pending = [item for item in items if str(item["id"]) not in existing] + results = list(existing.values()) + + print("=" * 72, flush=True) + print("LiveMathematicianBench baseline eval", flush=True) + print("=" * 72, flush=True) + print(f"Mode: {args.mode}", flush=True) + print(f"Model: {args.model}", flush=True) + print(f"Reasoning effort: {args.reasoning_effort}", flush=True) + print(f"Items: {len(items)} total, {len(pending)} pending, {len(existing)} resumed", flush=True) + print(f"Output: {output_path}", flush=True) + print("=" * 72, flush=True) + + meta = { + "mode": args.mode, + "model": args.model, + "reasoning_effort": args.reasoning_effort, + "seed": args.seed, + "max_completion_tokens": args.max_completion_tokens, + } + + if not pending: + save_results(output_path, meta, results) + summary = json.loads(output_path.read_text(encoding="utf-8"))["summary"] + print(f"Accuracy: {summary['correct']}/{summary['total']} = {summary['accuracy']:.4f}", flush=True) + return + + with ThreadPoolExecutor(max_workers=args.workers) as ex: + futs = { + ex.submit( + evaluate_one, + item, + mode=args.mode, + model=args.model, + skill_content=skill_content, + max_completion_tokens=args.max_completion_tokens, + reasoning_effort=args.reasoning_effort, + seed=args.seed, + ): item + for item in pending + } + completed = 0 + for fut in as_completed(futs): + item = futs[fut] + try: + row = fut.result() + except Exception as exc: # noqa: BLE001 + row = { + "id": item["id"], + "question": item["question"], + "correct_label": None, + "correct_text": item["correct_choice"]["text"], + "predicted_label": None, + "predicted_text": "", + "is_correct": False, + "elapsed_seconds": 0.0, + "response": "", + "error": str(exc), + } + results.append(row) + completed += 1 + correct = sum(1 for result in results if result.get("is_correct")) + total = len(results) + print( + f"[{completed}/{len(pending)}] id={row['id']} " + f"pred={row['predicted_label']} gold={row['correct_label']} " + f"acc={correct}/{total}={correct/total:.4f}", + flush=True, + ) + save_results(output_path, meta, results) + + summary = json.loads(output_path.read_text(encoding="utf-8"))["summary"] + print("=" * 72, flush=True) + print(f"Accuracy: {summary['correct']}/{summary['total']} = {summary['accuracy']:.4f}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_only.py b/scripts/eval_only.py new file mode 100644 index 0000000..a14d3c5 --- /dev/null +++ b/scripts/eval_only.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python3 +"""ReflACT eval-only: run a single skill on a dataset without training. + +Usage +----- + python scripts/eval_only.py \ + --config configs/spreadsheetbench/default.yaml \ + --skill skillopt/envs/spreadsheetbench/skills/initial.md \ + --split_dir /path/to/split \ + --out_root outputs/eval_skill0 + +All YAML keys can be overridden from the CLI, same as train.py. +""" +from __future__ import annotations + +import argparse +import datetime +import json +import os +import sys + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +from skillopt.model import ( + configure_azure_openai, + configure_claude_code_exec, + configure_codex_exec, + set_reasoning_effort, + set_student_backend, + set_student_deployment, + set_teacher_backend, + set_teacher_deployment, +) +from skillopt.model.common import default_model_for_backend, normalize_backend_name + +_OPENAI_DEFAULT_MODEL_SENTINELS = {"gpt-5.4", "gpt-5.5"} +from skillopt.utils import compute_score + + +# ── Reuse registry from train.py ─────────────────────────────────────────── + +_ENV_REGISTRY: dict[str, type] = {} + + +def _register_builtins() -> None: + try: + from skillopt.envs.alfworld.adapter import ALFWorldAdapter + _ENV_REGISTRY["alfworld"] = ALFWorldAdapter + except ImportError: + pass + try: + from skillopt.envs.searchqa.adapter import SearchQAAdapter + _ENV_REGISTRY["searchqa"] = SearchQAAdapter + except ImportError: + pass + try: + from skillopt.envs.livemathematicianbench.adapter import LiveMathematicianBenchAdapter + _ENV_REGISTRY["livemathematicianbench"] = LiveMathematicianBenchAdapter + except ImportError: + pass + try: + from skillopt.envs.babyvision.adapter import BabyVisionAdapter + _ENV_REGISTRY["babyvision"] = BabyVisionAdapter + except ImportError: + pass + try: + from skillopt.envs.spreadsheetbench.adapter import SpreadsheetBenchAdapter + _ENV_REGISTRY["spreadsheetbench"] = SpreadsheetBenchAdapter + except ImportError: + pass + try: + from skillopt.envs.mmrb.adapter import MMRBAdapter + _ENV_REGISTRY["mmrb"] = MMRBAdapter + except ImportError: + pass + try: + from skillopt.envs.docvqa.adapter import DocVQAAdapter + _ENV_REGISTRY["docvqa"] = DocVQAAdapter + except ImportError: + pass + try: + from skillopt.envs.mathverse.adapter import MathVerseAdapter + _ENV_REGISTRY["mathverse"] = MathVerseAdapter + except ImportError: + pass + try: + from skillopt.envs.officeqa.adapter import OfficeQAAdapter + _ENV_REGISTRY["officeqa"] = OfficeQAAdapter + except ImportError: + pass + try: + from skillopt.envs.sealqa.adapter import SealQAAdapter + _ENV_REGISTRY["sealqa"] = SealQAAdapter + except ImportError: + pass + try: + from skillopt.envs.swebench.adapter import SWEBenchAdapter + _ENV_REGISTRY["swebench"] = SWEBenchAdapter + except ImportError: + pass + + +def get_adapter(cfg: dict): + _register_builtins() + env_name = cfg.get("env", "alfworld") + if env_name not in _ENV_REGISTRY: + raise ValueError( + f"Unknown environment '{env_name}'. " + f"Available: {list(_ENV_REGISTRY.keys())}" + ) + adapter_cls = _ENV_REGISTRY[env_name] + + import inspect + sig = inspect.signature(adapter_cls.__init__) + accepted = set(sig.parameters.keys()) - {"self"} + adapter_kwargs = {k: cfg[k] for k in accepted if k in cfg} + return adapter_cls(**adapter_kwargs) + + +# ── CLI ──────────────────────────────────────────────────────────────────── + +_BOOL = lambda x: str(x).lower() in ("true", "1", "yes") # noqa: E731 + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description="ReflACT eval-only") + p.add_argument("--config", type=str, required=True) + p.add_argument("--skill", type=str, required=True, + help="Path to skill .md file to evaluate") + p.add_argument("--split", type=str, default="all", + help="Which split to eval: train/valid_seen/valid_unseen/all (default: all)") + p.add_argument("--cfg-options", nargs="+", default=[], + help="Override config: section.key=value") + # Legacy flat overrides + p.add_argument("--env", type=str) + p.add_argument("--backend", type=str, + choices=["azure_openai", "codex", "codex_exec", "claude", "claude_chat", "claude_code_exec"]) + p.add_argument("--teacher_model", type=str) + p.add_argument("--student_model", type=str) + p.add_argument("--teacher_backend", type=str) + p.add_argument("--student_backend", type=str) + p.add_argument("--reasoning_effort", type=str, + choices=["", "low", "medium", "high", "xhigh", "max"]) + p.add_argument("--azure_endpoint", type=str) + p.add_argument("--azure_api_version", type=str) + p.add_argument("--azure_api_key", type=str) + p.add_argument("--azure_openai_endpoint", type=str) + p.add_argument("--azure_openai_api_version", type=str) + p.add_argument("--azure_openai_api_key", type=str) + p.add_argument("--azure_openai_auth_mode", type=str) + p.add_argument("--azure_openai_ad_scope", type=str) + p.add_argument("--azure_openai_managed_identity_client_id", type=str) + p.add_argument("--teacher_azure_openai_endpoint", type=str) + p.add_argument("--teacher_azure_openai_api_version", type=str) + p.add_argument("--teacher_azure_openai_api_key", type=str) + p.add_argument("--teacher_azure_openai_auth_mode", type=str) + p.add_argument("--teacher_azure_openai_ad_scope", type=str) + p.add_argument("--teacher_azure_openai_managed_identity_client_id", type=str) + p.add_argument("--student_azure_openai_endpoint", type=str) + p.add_argument("--student_azure_openai_api_version", type=str) + p.add_argument("--student_azure_openai_api_key", type=str) + p.add_argument("--student_azure_openai_auth_mode", type=str) + p.add_argument("--student_azure_openai_ad_scope", type=str) + p.add_argument("--student_azure_openai_managed_identity_client_id", type=str) + p.add_argument("--codex_exec_path", type=str) + p.add_argument("--codex_exec_sandbox", type=str) + p.add_argument("--codex_exec_profile", type=str) + p.add_argument("--codex_exec_full_auto", type=_BOOL) + p.add_argument("--codex_exec_reasoning_effort", type=str) + p.add_argument("--codex_exec_use_sdk", type=str) + p.add_argument("--codex_exec_network_access", type=_BOOL) + p.add_argument("--codex_exec_web_search", type=_BOOL) + p.add_argument("--codex_exec_approval_policy", type=str) + p.add_argument("--claude_code_exec_path", type=str) + p.add_argument("--claude_code_exec_profile", type=str) + p.add_argument("--claude_code_exec_use_sdk", type=str) + p.add_argument("--claude_code_exec_effort", type=str) + p.add_argument("--claude_code_exec_max_thinking_tokens", type=int) + p.add_argument("--out_root", type=str) + p.add_argument("--data_path", type=str) + p.add_argument("--split_mode", type=str, + choices=["ratio", "split_dir"]) + p.add_argument("--split_ratio", type=str) + p.add_argument("--split_seed", type=int) + p.add_argument("--split_dir", type=str) + p.add_argument("--split_output_dir", type=str) + p.add_argument("--data_root", type=str) + p.add_argument("--max_turns", type=int) + p.add_argument("--workers", type=int) + p.add_argument("--max_api_workers", type=int) + p.add_argument("--seed", type=int) + p.add_argument("--test_env_num", type=int) + p.add_argument("--mode", type=str, + help="SpreadsheetBench: single/multi/react (default comes from config)") + return p.parse_args() + + +def main() -> None: + args = parse_args() + + from skillopt.config import load_config as _load, flatten_config, is_structured + + cfg = _load(args.config, overrides=args.cfg_options) + structured = is_structured(cfg) + + # Apply legacy --key value overrides + cli = {k: v for k, v in vars(args).items() + if v is not None and k not in ("config", "skill", "split", "cfg_options")} + if cli: + if structured: + from skillopt.config import apply_overrides + _MAP = { + "backend": "model.backend", + "teacher_model": "model.teacher", + "student_model": "model.student", + "teacher_backend": "model.teacher_backend", + "student_backend": "model.student_backend", + "reasoning_effort": "model.reasoning_effort", + "azure_endpoint": "model.azure_endpoint", + "azure_api_version": "model.azure_api_version", + "azure_api_key": "model.azure_api_key", + "azure_openai_endpoint": "model.azure_openai_endpoint", + "azure_openai_api_version": "model.azure_openai_api_version", + "azure_openai_api_key": "model.azure_openai_api_key", + "azure_openai_auth_mode": "model.azure_openai_auth_mode", + "azure_openai_ad_scope": "model.azure_openai_ad_scope", + "azure_openai_managed_identity_client_id": "model.azure_openai_managed_identity_client_id", + "teacher_azure_openai_endpoint": "model.teacher_azure_openai_endpoint", + "teacher_azure_openai_api_version": "model.teacher_azure_openai_api_version", + "teacher_azure_openai_api_key": "model.teacher_azure_openai_api_key", + "teacher_azure_openai_auth_mode": "model.teacher_azure_openai_auth_mode", + "teacher_azure_openai_ad_scope": "model.teacher_azure_openai_ad_scope", + "teacher_azure_openai_managed_identity_client_id": "model.teacher_azure_openai_managed_identity_client_id", + "student_azure_openai_endpoint": "model.student_azure_openai_endpoint", + "student_azure_openai_api_version": "model.student_azure_openai_api_version", + "student_azure_openai_api_key": "model.student_azure_openai_api_key", + "student_azure_openai_auth_mode": "model.student_azure_openai_auth_mode", + "student_azure_openai_ad_scope": "model.student_azure_openai_ad_scope", + "student_azure_openai_managed_identity_client_id": "model.student_azure_openai_managed_identity_client_id", + "codex_exec_path": "model.codex_exec_path", + "codex_exec_sandbox": "model.codex_exec_sandbox", + "codex_exec_profile": "model.codex_exec_profile", + "codex_exec_full_auto": "model.codex_exec_full_auto", + "codex_exec_reasoning_effort": "model.codex_exec_reasoning_effort", + "codex_exec_use_sdk": "model.codex_exec_use_sdk", + "codex_exec_network_access": "model.codex_exec_network_access", + "codex_exec_web_search": "model.codex_exec_web_search", + "codex_exec_approval_policy": "model.codex_exec_approval_policy", + "claude_code_exec_path": "model.claude_code_exec_path", + "claude_code_exec_profile": "model.claude_code_exec_profile", + "claude_code_exec_use_sdk": "model.claude_code_exec_use_sdk", + "claude_code_exec_effort": "model.claude_code_exec_effort", + "claude_code_exec_max_thinking_tokens": "model.claude_code_exec_max_thinking_tokens", + "seed": "train.seed", + "test_env_num": "evaluation.test_env_num", + "env": "env.name", + "out_root": "env.out_root", + } + mapped = [] + for k, v in cli.items(): + dotted = _MAP.get(k) + if dotted: + mapped.append(f"{dotted}={v}") + else: + mapped.append(f"env.{k}={v}") + apply_overrides(cfg, mapped) + else: + cfg.update(cli) + + cfg = flatten_config(cfg) if structured else cfg + + for new_key, old_key in ( + ("azure_openai_endpoint", "azure_endpoint"), + ("azure_openai_api_version", "azure_api_version"), + ("azure_openai_api_key", "azure_api_key"), + ): + if cfg.get(new_key) in (None, "") and cfg.get(old_key) not in (None, ""): + cfg[new_key] = cfg[old_key] + + explicit_backend = getattr(args, "backend", None) + if explicit_backend is None: + for option in args.cfg_options or []: + key = str(option).split("=", 1)[0].strip() + if key == "model.backend": + explicit_backend = str(option).split("=", 1)[1].strip() + break + + backend = normalize_backend_name(cfg.get("model_backend") or cfg.get("student_backend") or "azure_openai") + + def _has_model_override(dotted_key: str, legacy_key: str) -> bool: + if getattr(args, legacy_key, None) is not None: + return True + for option in args.cfg_options or []: + key = str(option).split("=", 1)[0].strip() + if key == dotted_key: + return True + return False + + if explicit_backend is not None: + backend = normalize_backend_name(explicit_backend) + cfg["model_backend"] = backend + if backend in {"claude", "claude_chat"}: + cfg.setdefault("teacher_backend", "claude_chat") + cfg.setdefault("student_backend", "claude_chat") + elif backend in {"codex", "codex_exec"}: + cfg.setdefault("teacher_backend", "openai_chat") + cfg.setdefault("student_backend", "codex_exec") + elif backend == "claude_code_exec": + cfg.setdefault("teacher_backend", "openai_chat") + cfg.setdefault("student_backend", "claude_code_exec") + else: + cfg.setdefault("teacher_backend", "openai_chat") + cfg.setdefault("student_backend", "openai_chat") + else: + cfg.setdefault("teacher_backend", "openai_chat") + cfg.setdefault("student_backend", "openai_chat") + + if cfg.get("teacher_backend") == "claude_chat": + if ( + str(cfg.get("teacher_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS + and not _has_model_override("model.teacher", "teacher_model") + ): + cfg["teacher_model"] = default_model_for_backend("claude_chat") + if cfg.get("student_backend") == "claude_chat": + if ( + str(cfg.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS + and not _has_model_override("model.student", "student_model") + ): + cfg["student_model"] = default_model_for_backend("claude_chat") + if cfg.get("student_backend") == "claude_code_exec": + if ( + str(cfg.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS + and not _has_model_override("model.student", "student_model") + ): + cfg["student_model"] = default_model_for_backend("claude_chat") + + if not cfg.get("out_root"): + env = cfg.get("env", "unknown") + model = cfg.get("student_model", "unknown").replace("/", "-") + ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + cfg["out_root"] = os.path.join("outputs", f"eval_{env}_{model}_{ts}") + + cfg["out_root"] = os.path.abspath(cfg["out_root"]) + + out_root = cfg["out_root"] + os.makedirs(out_root, exist_ok=True) + + # Load skill + skill_path = os.path.abspath(args.skill) + with open(skill_path) as f: + skill_content = f.read() + print(f" [skill] {skill_path} ({len(skill_content)} chars)") + + # Configure models + configure_azure_openai( + endpoint=(cfg.get("azure_openai_endpoint") or cfg.get("azure_endpoint") or None), + api_version=(cfg.get("azure_openai_api_version") or cfg.get("azure_api_version") or None), + api_key=(cfg.get("azure_openai_api_key") or cfg.get("azure_api_key") or None), + auth_mode=cfg.get("azure_openai_auth_mode") or None, + ad_scope=cfg.get("azure_openai_ad_scope") or None, + managed_identity_client_id=cfg.get("azure_openai_managed_identity_client_id") or None, + teacher_endpoint=cfg.get("teacher_azure_openai_endpoint") or None, + teacher_api_version=cfg.get("teacher_azure_openai_api_version") or None, + teacher_api_key=cfg.get("teacher_azure_openai_api_key") or None, + teacher_auth_mode=cfg.get("teacher_azure_openai_auth_mode") or None, + teacher_ad_scope=cfg.get("teacher_azure_openai_ad_scope") or None, + teacher_managed_identity_client_id=( + cfg.get("teacher_azure_openai_managed_identity_client_id") or None + ), + student_endpoint=cfg.get("student_azure_openai_endpoint") or None, + student_api_version=cfg.get("student_azure_openai_api_version") or None, + student_api_key=cfg.get("student_azure_openai_api_key") or None, + student_auth_mode=cfg.get("student_azure_openai_auth_mode") or None, + student_ad_scope=cfg.get("student_azure_openai_ad_scope") or None, + student_managed_identity_client_id=( + cfg.get("student_azure_openai_managed_identity_client_id") or None + ), + ) + set_teacher_backend(cfg.get("teacher_backend", "openai_chat")) + set_student_backend(cfg.get("student_backend", "openai_chat")) + set_teacher_deployment(cfg.get("teacher_model", default_model_for_backend(backend))) + set_student_deployment(cfg.get("student_model", default_model_for_backend(backend))) + configure_codex_exec( + path=cfg.get("codex_exec_path", "codex"), + sandbox=cfg.get("codex_exec_sandbox", "workspace-write"), + profile=cfg.get("codex_exec_profile", ""), + full_auto=cfg.get("codex_exec_full_auto", False), + reasoning_effort=cfg.get("codex_exec_reasoning_effort", "none"), + use_sdk=cfg.get("codex_exec_use_sdk", None), + network_access=cfg.get("codex_exec_network_access", False), + web_search=cfg.get("codex_exec_web_search", False), + approval_policy=cfg.get("codex_exec_approval_policy", "never"), + ) + configure_claude_code_exec( + path=cfg.get("claude_code_exec_path", "claude"), + profile=cfg.get("claude_code_exec_profile", ""), + use_sdk=cfg.get("claude_code_exec_use_sdk", None), + effort=cfg.get("claude_code_exec_effort", cfg.get("reasoning_effort", "medium")), + max_thinking_tokens=cfg.get("claude_code_exec_max_thinking_tokens", 16384), + ) + set_reasoning_effort(cfg.get("reasoning_effort", "") or None) + + # Build adapter + adapter = get_adapter(cfg) + adapter.setup(cfg) + + seed = cfg.get("seed", 42) + split = args.split or "all" + + if split == "all": + items = ( + adapter.build_eval_env(0, "train", seed) + + adapter.build_eval_env(0, "valid_seen", seed) + + adapter.build_eval_env(0, "valid_unseen", seed) + ) + else: + env_num = cfg.get("test_env_num", 0) + items = adapter.build_eval_env(env_num, split, seed) + + print(f"\n [eval] split={split} items={len(items)}") + print(f" [eval] out_root={out_root}") + print(f"{'='*60}") + + # Run rollout + results = adapter.rollout(items, skill_content, out_root) + + # Score + hard, soft = compute_score(results) + print(f"\n{'='*60}") + print(f" Results: hard={hard:.4f} soft={soft:.4f} (n={len(results)})") + print(f"{'='*60}") + + # Save summary + summary = { + "skill": skill_path, + "split": split, + "n_items": len(results), + "hard": hard, + "soft": soft, + } + with open(os.path.join(out_root, "eval_summary.json"), "w") as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + + print(f" Saved to: {out_root}") + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_prompt_custom.py b/scripts/eval_prompt_custom.py new file mode 100644 index 0000000..b0aef9f --- /dev/null +++ b/scripts/eval_prompt_custom.py @@ -0,0 +1,361 @@ +#!/usr/bin/env python3 +"""Standalone eval: CUSTOM prompt (with Critical Rules) on verified-400. + +Usage: + python scripts/eval_prompt_custom.py --workers 8 + python scripts/eval_prompt_custom.py --workers 32 --limit 20 +""" +from __future__ import annotations + +import argparse +import glob +import json +import os +import random +import re +import subprocess +import sys +import tempfile +import textwrap +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeoutError + +import openpyxl + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +from skillopt.model import ( + chat_messages_with_deployment, + configure_azure_openai, + set_backend, + set_student_deployment, +) +from skillopt.envs.spreadsheetbench.evaluator import evaluate + + +# ── Config ────────────────────────────────────────────────────────────────── + +DATA_ROOT = "/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400" +JSONL_PATH = os.path.join(DATA_ROOT, "dataset.json") +MODEL = "gpt-5-mini" + +# ── Custom Prompt (with Critical Rules) ───────────────────────────────────── + +_SYSTEM_TEMPLATE = """\ +You are an expert Python programmer specializing in spreadsheet manipulation. +You will be given a user instruction together with a preview of an input .xlsx file. +Your job is to write a single self-contained Python script that reads the input file +at the path stored in the variable INPUT_PATH, performs the requested manipulation, +and saves the result to OUTPUT_PATH. + +## Critical Rules +1. NEVER write Excel formulas to cells. openpyxl does NOT compute formulas — + the evaluator will see None. Compute results in Python and write literal values. +2. Use only: standard library, openpyxl, pandas. +3. Do NOT hardcode cell values from the preview — iterate over actual rows. +4. The script must define INPUT_PATH and OUTPUT_PATH at the top. + +{skill_section}\ +Return ONLY the Python code inside a single ```python ... ``` fenced block. +""" + + +def build_system(skill_content: str = "") -> str: + if skill_content.strip(): + skill_section = f"## Skill\n{skill_content.strip()}\n\n" + else: + skill_section = "" + return _SYSTEM_TEMPLATE.format(skill_section=skill_section) + + +def build_user(instruction, input_xlsx, instruction_type="", answer_position=""): + try: + preview = _preview_workbook(input_xlsx) + except Exception as e: + preview = f"(failed to preview: {e})" + extra = "" + if instruction_type: + extra += f"\nInstruction type: {instruction_type}" + if answer_position: + extra += f"\nExpected answer position: {answer_position}" + return ( + f"# Instruction\n{instruction}\n{extra}\n\n" + f"# Input spreadsheet preview\n{preview}\n\n" + "# Task\n" + "Write a Python script that reads the workbook from the variable `INPUT_PATH`, " + "applies the instruction, and writes the modified workbook to `OUTPUT_PATH`. " + "Preserve all other cells unchanged. " + "The preview may be truncated — do not hardcode row counts; " + "iterate over all actual rows in the workbook instead.\n" + "Return only a ```python``` code block." + ) + + +# ── Shared utilities ──────────────────────────────────────────────────────── + +def _preview_workbook(path, max_rows=5, max_cols=20): + wb = openpyxl.load_workbook(path, data_only=False) + chunks = [] + for sn in wb.sheetnames: + ws = wb[sn] + chunks.append(f"## Sheet: {sn} (dim={ws.dimensions}, max_row={ws.max_row}, max_col={ws.max_column})") + for row in ws.iter_rows(min_row=1, max_row=min(ws.max_row, max_rows), + max_col=min(ws.max_column, max_cols), values_only=False): + cells = [] + for c in row: + v = c.value + s = "" if v is None else str(v) + if len(s) > 40: s = s[:37] + "..." + cells.append(f"{c.coordinate}={s}") + chunks.append(" | ".join(cells)) + if ws.max_row > max_rows: + chunks.append(f"... ({ws.max_row - max_rows} more rows)") + chunks.append("") + wb.close() + return "\n".join(chunks) + + +def extract_code(text): + if "```" not in text: + return text.strip() + start = text.find("```") + nl = text.find("\n", start) + end = text.find("```", nl + 1) + if nl == -1 or end == -1: + return text.strip() + return text[nl + 1:end].strip() + + +_PATH_RE = re.compile(r'^\s*(INPUT_PATH|OUTPUT_PATH)\s*=\s*.+$', re.MULTILINE) + +def strip_paths(code): + return _PATH_RE.sub("", code) + + +RUNNER_TEMPLATE = textwrap.dedent(""" + import os, sys, traceback + INPUT_PATH = {input_path!r} + OUTPUT_PATH = {output_path!r} + try: + {code_indented} + except Exception: + traceback.print_exc() + sys.exit(2) +""") + + +def run_code(code, input_path, output_path, timeout=120): + os.makedirs(os.path.dirname(output_path), exist_ok=True) + cleaned = strip_paths(code) + indented = textwrap.indent(cleaned, " ") + script = RUNNER_TEMPLATE.format(input_path=input_path, output_path=output_path, code_indented=indented) + with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f: + f.write(script) + tmp = f.name + try: + proc = subprocess.run([sys.executable, tmp], capture_output=True, text=True, timeout=timeout) + if proc.returncode != 0: + return False, (proc.stdout + "\n" + proc.stderr).strip() + if not os.path.exists(output_path): + return False, "output file was not created" + return True, "" + except subprocess.TimeoutExpired: + return False, f"timeout after {timeout}s" + finally: + try: os.unlink(tmp) + except OSError: pass + + +def find_test_cases(task_dir): + cases = [] + for ip in sorted(glob.glob(os.path.join(task_dir, "*_input.xlsx"))): + no = os.path.basename(ip).split("_", 1)[0] + ap = ip.replace("_input.xlsx", "_answer.xlsx") + if os.path.exists(ap): cases.append((no, ip, ap)) + for ip in sorted(glob.glob(os.path.join(task_dir, "*_init.xlsx"))): + no = os.path.basename(ip).split("_", 1)[0] + ap = ip.replace("_init.xlsx", "_golden.xlsx") + if os.path.exists(ap): cases.append((no, ip, ap)) + if not cases: + bare_init = os.path.join(task_dir, "initial.xlsx") + bare_gold = os.path.join(task_dir, "golden.xlsx") + if os.path.exists(bare_init) and os.path.exists(bare_gold): + cases.append(("1", bare_init, bare_gold)) + return cases + + +def load_items(path): + if path.endswith(".json"): + with open(path) as f: + data = json.load(f) + if isinstance(data, dict): + data = data.get("data") or list(data.values()) + return list(data) + items = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: items.append(json.loads(line)) + return items + + +# ── LLM call ──────────────────────────────────────────────────────────────── + +def llm_call(messages, deployment, max_tokens=16384, retries=5, llm_timeout=120): + raw, _ = chat_messages_with_deployment( + deployment=deployment, + messages=messages, + max_completion_tokens=max_tokens, + retries=retries, + stage="rollout", + timeout=llm_timeout, + ) + return str(raw or "") + + +# ── Process one task ──────────────────────────────────────────────────────── + +def process_one(item, data_root, out_root, model): + task_id = str(item["id"]) + instruction = item["instruction"] + instruction_type = item.get("instruction_type", "") + answer_position = item.get("answer_position", "") + answer_sheet = item.get("answer_sheet", "") + if answer_position and answer_sheet and "!" not in answer_position: + answer_position = f"{answer_sheet}!{answer_position}" + + sp = item.get("spreadsheet_path", f"spreadsheet/{task_id}") + task_dir = sp if os.path.isabs(sp) else os.path.join(data_root, sp) + + result = {"id": task_id, "ok": False, "hard": 0, "soft": 0.0, + "n_cases": 0, "n_pass": 0, "fail_reason": "", "error": ""} + try: + cases = find_test_cases(task_dir) + result["n_cases"] = len(cases) + if not cases: + result["fail_reason"] = "no-test-cases" + return result + + task_out = os.path.join(out_root, "predictions", task_id) + os.makedirs(task_out, exist_ok=True) + + # LLM call + system = build_system("") + user = build_user(instruction, cases[0][1], instruction_type, answer_position) + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + + raw = llm_call(messages, model) + time.sleep(3) + code = extract_code(raw) + + with open(os.path.join(task_out, "code.py"), "w") as f: f.write(code) + with open(os.path.join(task_out, "raw.txt"), "w") as f: f.write(raw) + + if not code.strip(): + result["fail_reason"] = "empty-code" + return result + + # Execute + evaluate each test case + for no, ip, ap in cases: + pred = os.path.join(task_out, f"{no}_pred.xlsx") + ok_exec, err = run_code(code, ip, pred) + if not ok_exec: + if not result["fail_reason"]: + result["fail_reason"] = f"exec: {err[:200]}" + continue + try: + ev = evaluate(pred, ap, instruction_type, answer_position) + except Exception as e: + ev = {"ok": False, "reason": str(e)} + if ev["ok"]: + result["n_pass"] += 1 + + nc, np = result["n_cases"], result["n_pass"] + result["soft"] = np / nc if nc else 0.0 + result["hard"] = 1 if nc > 0 and np == nc else 0 + result["ok"] = bool(result["hard"]) + if result["ok"]: result["fail_reason"] = "" + return result + except Exception as e: + result["fail_reason"] = f"unexpected: {e}" + result["error"] = traceback.format_exc() + return result + + +# ── Main ──────────────────────────────────────────────────────────────────── + +def main(): + ap = argparse.ArgumentParser(description="Eval CUSTOM prompt on verified-400") + ap.add_argument("--model", default=MODEL) + ap.add_argument("--backend", choices=["azure_openai", "codex", "claude"], default="azure_openai") + ap.add_argument("--azure_endpoint", default="") + ap.add_argument("--azure_api_version", default="") + ap.add_argument("--azure_api_key", default="") + ap.add_argument("--workers", type=int, default=8) + ap.add_argument("--limit", type=int, default=0) + ap.add_argument("--out_root", default="") + args = ap.parse_args() + + set_backend(args.backend) + configure_azure_openai( + endpoint=args.azure_endpoint or None, + api_version=args.azure_api_version or None, + api_key=args.azure_api_key or None, + ) + set_student_deployment(args.model) + ts = time.strftime("%Y%m%d_%H%M%S") + out_root = args.out_root or os.path.join(_PROJECT_ROOT, "outputs", f"prompt_custom_{args.model}_{ts}") + out_root = os.path.abspath(out_root) + os.makedirs(out_root, exist_ok=True) + + items = load_items(JSONL_PATH) + if args.limit: items = items[:args.limit] + + print(f"{'='*60}") + print(f" Prompt: CUSTOM (Critical Rules)") + print(f" Model: {args.model}") + print(f" Items: {len(items)}") + print(f" Output: {out_root}") + print(f"{'='*60}") + + t0 = time.time() + results = [] + with ThreadPoolExecutor(max_workers=args.workers) as ex: + futs = {ex.submit(process_one, it, DATA_ROOT, out_root, args.model): it for it in items} + for i, fut in enumerate(as_completed(futs), 1): + item = futs[fut] + try: + res = fut.result(timeout=300) + except FuturesTimeoutError: + res = {"id": str(item["id"]), "ok": False, "hard": 0, "soft": 0.0, + "n_cases": 0, "n_pass": 0, "fail_reason": "timeout"} + except Exception as e: + res = {"id": str(item["id"]), "ok": False, "hard": 0, "soft": 0.0, + "n_cases": 0, "n_pass": 0, "fail_reason": str(e)} + results.append(res) + status = "PASS" if res.get("hard") else "FAIL" + dt = time.time() - t0 + print(f" {i}/{len(items)} id={res['id']:<10} {status} cases={res.get('n_pass',0)}/{res.get('n_cases',0)} dt={dt:.0f}s") + + # Summary + hard_sum = sum(r.get("hard", 0) for r in results) + soft_sum = sum(r.get("soft", 0.0) for r in results) + n = len(results) + print(f"\n{'='*60}") + print(f" CUSTOM prompt: hard={hard_sum}/{n}={hard_sum/n:.4f} soft={soft_sum/n:.4f}") + print(f"{'='*60}") + + with open(os.path.join(out_root, "results.jsonl"), "w") as f: + for r in results: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + with open(os.path.join(out_root, "summary.json"), "w") as f: + json.dump({"prompt": "custom", "model": args.model, "n": n, + "hard": hard_sum/n, "soft": soft_sum/n}, f, indent=2) + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_prompt_official.py b/scripts/eval_prompt_official.py new file mode 100644 index 0000000..4040d3b --- /dev/null +++ b/scripts/eval_prompt_official.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +"""Standalone eval: OFFICIAL prompt (SpreadsheetBench original) on verified-400. + +Usage: + python scripts/eval_prompt_official.py --workers 8 + python scripts/eval_prompt_official.py --workers 32 --limit 20 +""" +from __future__ import annotations + +import argparse +import glob +import json +import os +import random +import re +import subprocess +import sys +import tempfile +import textwrap +import time +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeoutError + +import openpyxl + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +from skillopt.model import ( + chat_messages_with_deployment, + configure_azure_openai, + set_backend, + set_student_deployment, +) +from skillopt.envs.spreadsheetbench.evaluator import evaluate + + +# ── Config ────────────────────────────────────────────────────────────────── + +DATA_ROOT = "/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400" +JSONL_PATH = os.path.join(DATA_ROOT, "dataset.json") +MODEL = "gpt-5-mini" + +# ── Official Prompt (from SpreadsheetBench src/prompt.py) ─────────────────── + +_SYSTEM_PROMPT = ( + "You are an expert Python programmer specializing in spreadsheet manipulation. " + "You will be given a user instruction together with a preview of an input .xlsx file. " + "Your job is to write a single self-contained Python script that reads the input file " + "at the path stored in the variable INPUT_PATH, performs the requested manipulation, " + "and saves the result to OUTPUT_PATH. Use only the standard library, openpyxl, and pandas. " + "Do not print anything. Do not use input(). Do not hardcode file paths. " + "Return ONLY the Python code inside a single ```python ... ``` fenced block." +) + + +def build_system(skill_content: str = "") -> str: + base = _SYSTEM_PROMPT + if skill_content.strip(): + base += f"\n\n## Skill\n{skill_content.strip()}" + return base + + +def build_user(instruction, input_xlsx, instruction_type="", answer_position=""): + try: + preview = _preview_workbook(input_xlsx) + except Exception as e: + preview = f"(failed to preview: {e})" + extra = "" + if instruction_type: + extra += f"\nInstruction type: {instruction_type}" + if answer_position: + extra += f"\nExpected answer position: {answer_position}" + return ( + f"# Instruction\n{instruction}\n{extra}\n\n" + f"# Input spreadsheet preview\n{preview}\n\n" + "# Task\n" + "Write a Python script that reads the workbook from the variable `INPUT_PATH`, " + "applies the instruction, and writes the modified workbook to `OUTPUT_PATH`. " + "Preserve all other cells unchanged. " + "The preview may be truncated — do not hardcode row counts or assume the data ends at the last previewed row; " + "iterate over all actual rows in the workbook instead. " + "Return only a ```python``` code block." + ) + + +# ── Shared utilities (identical to custom version) ────────────────────────── + +def _preview_workbook(path, max_rows=5, max_cols=20): + wb = openpyxl.load_workbook(path, data_only=False) + chunks = [] + for sn in wb.sheetnames: + ws = wb[sn] + chunks.append(f"## Sheet: {sn} (dim={ws.dimensions}, max_row={ws.max_row}, max_col={ws.max_column})") + for row in ws.iter_rows(min_row=1, max_row=min(ws.max_row, max_rows), + max_col=min(ws.max_column, max_cols), values_only=False): + cells = [] + for c in row: + v = c.value + s = "" if v is None else str(v) + if len(s) > 40: s = s[:37] + "..." + cells.append(f"{c.coordinate}={s}") + chunks.append(" | ".join(cells)) + if ws.max_row > max_rows: + chunks.append(f"... ({ws.max_row - max_rows} more rows)") + chunks.append("") + wb.close() + return "\n".join(chunks) + + +def extract_code(text): + if "```" not in text: + return text.strip() + start = text.find("```") + nl = text.find("\n", start) + end = text.find("```", nl + 1) + if nl == -1 or end == -1: + return text.strip() + return text[nl + 1:end].strip() + + +_PATH_RE = re.compile(r'^\s*(INPUT_PATH|OUTPUT_PATH)\s*=\s*.+$', re.MULTILINE) + +def strip_paths(code): + return _PATH_RE.sub("", code) + + +RUNNER_TEMPLATE = textwrap.dedent(""" + import os, sys, traceback + INPUT_PATH = {input_path!r} + OUTPUT_PATH = {output_path!r} + try: + {code_indented} + except Exception: + traceback.print_exc() + sys.exit(2) +""") + + +def run_code(code, input_path, output_path, timeout=120): + os.makedirs(os.path.dirname(output_path), exist_ok=True) + cleaned = strip_paths(code) + indented = textwrap.indent(cleaned, " ") + script = RUNNER_TEMPLATE.format(input_path=input_path, output_path=output_path, code_indented=indented) + with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f: + f.write(script) + tmp = f.name + try: + proc = subprocess.run([sys.executable, tmp], capture_output=True, text=True, timeout=timeout) + if proc.returncode != 0: + return False, (proc.stdout + "\n" + proc.stderr).strip() + if not os.path.exists(output_path): + return False, "output file was not created" + return True, "" + except subprocess.TimeoutExpired: + return False, f"timeout after {timeout}s" + finally: + try: os.unlink(tmp) + except OSError: pass + + +def find_test_cases(task_dir): + cases = [] + for ip in sorted(glob.glob(os.path.join(task_dir, "*_input.xlsx"))): + no = os.path.basename(ip).split("_", 1)[0] + ap = ip.replace("_input.xlsx", "_answer.xlsx") + if os.path.exists(ap): cases.append((no, ip, ap)) + for ip in sorted(glob.glob(os.path.join(task_dir, "*_init.xlsx"))): + no = os.path.basename(ip).split("_", 1)[0] + ap = ip.replace("_init.xlsx", "_golden.xlsx") + if os.path.exists(ap): cases.append((no, ip, ap)) + if not cases: + bare_init = os.path.join(task_dir, "initial.xlsx") + bare_gold = os.path.join(task_dir, "golden.xlsx") + if os.path.exists(bare_init) and os.path.exists(bare_gold): + cases.append(("1", bare_init, bare_gold)) + return cases + + +def load_items(path): + if path.endswith(".json"): + with open(path) as f: + data = json.load(f) + if isinstance(data, dict): + data = data.get("data") or list(data.values()) + return list(data) + items = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: items.append(json.loads(line)) + return items + + +# ── LLM call ──────────────────────────────────────────────────────────────── + +def llm_call(messages, deployment, max_tokens=16384, retries=5, llm_timeout=120): + raw, _ = chat_messages_with_deployment( + deployment=deployment, + messages=messages, + max_completion_tokens=max_tokens, + retries=retries, + stage="rollout", + timeout=llm_timeout, + ) + return str(raw or "") + + +# ── Process one task ──────────────────────────────────────────────────────── + +def process_one(item, data_root, out_root, model): + task_id = str(item["id"]) + instruction = item["instruction"] + instruction_type = item.get("instruction_type", "") + answer_position = item.get("answer_position", "") + answer_sheet = item.get("answer_sheet", "") + if answer_position and answer_sheet and "!" not in answer_position: + answer_position = f"{answer_sheet}!{answer_position}" + + sp = item.get("spreadsheet_path", f"spreadsheet/{task_id}") + task_dir = sp if os.path.isabs(sp) else os.path.join(data_root, sp) + + result = {"id": task_id, "ok": False, "hard": 0, "soft": 0.0, + "n_cases": 0, "n_pass": 0, "fail_reason": "", "error": ""} + try: + cases = find_test_cases(task_dir) + result["n_cases"] = len(cases) + if not cases: + result["fail_reason"] = "no-test-cases" + return result + + task_out = os.path.join(out_root, "predictions", task_id) + os.makedirs(task_out, exist_ok=True) + + # LLM call + system = build_system("") + user = build_user(instruction, cases[0][1], instruction_type, answer_position) + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + + raw = llm_call(messages, model) + time.sleep(3) + code = extract_code(raw) + + with open(os.path.join(task_out, "code.py"), "w") as f: f.write(code) + with open(os.path.join(task_out, "raw.txt"), "w") as f: f.write(raw) + + if not code.strip(): + result["fail_reason"] = "empty-code" + return result + + # Execute + evaluate each test case + for no, ip, ap in cases: + pred = os.path.join(task_out, f"{no}_pred.xlsx") + ok_exec, err = run_code(code, ip, pred) + if not ok_exec: + if not result["fail_reason"]: + result["fail_reason"] = f"exec: {err[:200]}" + continue + try: + ev = evaluate(pred, ap, instruction_type, answer_position) + except Exception as e: + ev = {"ok": False, "reason": str(e)} + if ev["ok"]: + result["n_pass"] += 1 + + nc, np = result["n_cases"], result["n_pass"] + result["soft"] = np / nc if nc else 0.0 + result["hard"] = 1 if nc > 0 and np == nc else 0 + result["ok"] = bool(result["hard"]) + if result["ok"]: result["fail_reason"] = "" + return result + except Exception as e: + result["fail_reason"] = f"unexpected: {e}" + result["error"] = traceback.format_exc() + return result + + +# ── Main ──────────────────────────────────────────────────────────────────── + +def main(): + ap = argparse.ArgumentParser(description="Eval OFFICIAL prompt on verified-400") + ap.add_argument("--model", default=MODEL) + ap.add_argument("--backend", choices=["azure_openai", "codex", "claude"], default="azure_openai") + ap.add_argument("--azure_endpoint", default="") + ap.add_argument("--azure_api_version", default="") + ap.add_argument("--azure_api_key", default="") + ap.add_argument("--workers", type=int, default=8) + ap.add_argument("--limit", type=int, default=0) + ap.add_argument("--out_root", default="") + args = ap.parse_args() + + set_backend(args.backend) + configure_azure_openai( + endpoint=args.azure_endpoint or None, + api_version=args.azure_api_version or None, + api_key=args.azure_api_key or None, + ) + set_student_deployment(args.model) + ts = time.strftime("%Y%m%d_%H%M%S") + out_root = args.out_root or os.path.join(_PROJECT_ROOT, "outputs", f"prompt_official_{args.model}_{ts}") + out_root = os.path.abspath(out_root) + os.makedirs(out_root, exist_ok=True) + + items = load_items(JSONL_PATH) + if args.limit: items = items[:args.limit] + + print(f"{'='*60}") + print(f" Prompt: OFFICIAL (SpreadsheetBench original)") + print(f" Model: {args.model}") + print(f" Items: {len(items)}") + print(f" Output: {out_root}") + print(f"{'='*60}") + + t0 = time.time() + results = [] + with ThreadPoolExecutor(max_workers=args.workers) as ex: + futs = {ex.submit(process_one, it, DATA_ROOT, out_root, args.model): it for it in items} + for i, fut in enumerate(as_completed(futs), 1): + item = futs[fut] + try: + res = fut.result(timeout=300) + except FuturesTimeoutError: + res = {"id": str(item["id"]), "ok": False, "hard": 0, "soft": 0.0, + "n_cases": 0, "n_pass": 0, "fail_reason": "timeout"} + except Exception as e: + res = {"id": str(item["id"]), "ok": False, "hard": 0, "soft": 0.0, + "n_cases": 0, "n_pass": 0, "fail_reason": str(e)} + results.append(res) + status = "PASS" if res.get("hard") else "FAIL" + dt = time.time() - t0 + print(f" {i}/{len(items)} id={res['id']:<10} {status} cases={res.get('n_pass',0)}/{res.get('n_cases',0)} dt={dt:.0f}s") + + # Summary + hard_sum = sum(r.get("hard", 0) for r in results) + soft_sum = sum(r.get("soft", 0.0) for r in results) + n = len(results) + print(f"\n{'='*60}") + print(f" OFFICIAL prompt: hard={hard_sum}/{n}={hard_sum/n:.4f} soft={soft_sum/n:.4f}") + print(f"{'='*60}") + + with open(os.path.join(out_root, "results.jsonl"), "w") as f: + for r in results: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + with open(os.path.join(out_root, "summary.json"), "w") as f: + json.dump({"prompt": "official", "model": args.model, "n": n, + "hard": hard_sum/n, "soft": soft_sum/n}, f, indent=2) + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_searchqa_val500.sh b/scripts/eval_searchqa_val500.sh new file mode 100755 index 0000000..af9c3d7 --- /dev/null +++ b/scripts/eval_searchqa_val500.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash +# ────────────────────────────────────────────────────────────────────────────── +# ReflACT3 — SearchQA Eval-Only (验证集 500) +# +# Usage: +# bash scripts/eval_searchqa_val500.sh --skill_path outputs/xxx/best_skill.md +# bash scripts/eval_searchqa_val500.sh --skill_path outputs/xxx/best_skill.md --workers 32 +# ────────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")" +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5-mini}" + +VAL_PATH="/home/azureuser/workspace-yqh/refleAct/search-qa/data/searchqa_val_500.json" + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +DEFAULT_OUT_ROOT="${PROJECT_ROOT}/outputs/searchqa_eval_val500_${STUDENT_DEPLOYMENT}_${TIMESTAMP}" + +echo "============================================================" +echo " ReflACT3 — SearchQA Eval-Only (val-500)" +echo " Student: ${STUDENT_DEPLOYMENT}" +echo " Data: ${VAL_PATH}" +echo "============================================================" + +cd "${PROJECT_ROOT}" + +python scripts/eval_only.py \ + --config configs/searchqa_default.yaml \ + --data_path "${VAL_PATH}" \ + --out_root "${DEFAULT_OUT_ROOT}" \ + "$@" + +echo "" +echo "Done! Results saved to: ${DEFAULT_OUT_ROOT}" diff --git a/scripts/eval_verified400.sh b/scripts/eval_verified400.sh new file mode 100755 index 0000000..96d6261 --- /dev/null +++ b/scripts/eval_verified400.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +# ────────────────────────────────────────────────────────────────────────────── +# Eval skill0 on full SpreadsheetBench verified-400 +# +# Usage: +# bash scripts/eval_verified400.sh +# bash scripts/eval_verified400.sh --workers 64 +# ────────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")" + +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +# ── Paths ──────────────────────────────────────────────────────────────────── +DATA_ROOT="/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400" +JSONL_PATH="${DATA_ROOT}/dataset.json" +SKILL_PATH="${PROJECT_ROOT}/skillopt/envs/spreadsheetbench/skills/initial.md" + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUT_ROOT="${PROJECT_ROOT}/outputs/eval_verified400_${TIMESTAMP}" + +echo "============================================================" +echo " Eval skill0 on verified-400 (full)" +echo "============================================================" +echo " data_root: ${DATA_ROOT}" +echo " skill: ${SKILL_PATH}" +echo " out_root: ${OUT_ROOT}" +echo "============================================================" + +cd "${PROJECT_ROOT}" + +python scripts/eval_only.py \ + --config configs/spreadsheetbench_default.yaml \ + --skill "${SKILL_PATH}" \ + --split all \ + --data_root "${DATA_ROOT}" \ + --jsonl_path "${JSONL_PATH}" \ + --out_root "${OUT_ROOT}" \ + "$@" diff --git a/scripts/eval_verified400_multi.sh b/scripts/eval_verified400_multi.sh new file mode 100755 index 0000000..288c83c --- /dev/null +++ b/scripts/eval_verified400_multi.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# ────────────────────────────────────────────────────────────────────────────── +# Eval skill0 on full SpreadsheetBench verified-400 (MULTI-ROUND codegen) +# +# Usage: +# bash scripts/eval_verified400_multi.sh +# bash scripts/eval_verified400_multi.sh --workers 64 --max_turns 5 +# ────────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")" + +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +DATA_ROOT="/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400" +JSONL_PATH="${DATA_ROOT}/dataset.json" +SKILL_PATH="${PROJECT_ROOT}/skillopt/envs/spreadsheetbench/skills/initial.md" + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUT_ROOT="${PROJECT_ROOT}/outputs/eval_multi_verified400_${TIMESTAMP}" + +echo "============================================================" +echo " Eval skill0 — MULTI-ROUND codegen — verified-400" +echo "============================================================" +echo " data_root: ${DATA_ROOT}" +echo " skill: ${SKILL_PATH}" +echo " mode: multi" +echo " out_root: ${OUT_ROOT}" +echo "============================================================" + +cd "${PROJECT_ROOT}" + +python scripts/eval_only.py \ + --config configs/spreadsheetbench_default.yaml \ + --skill "${SKILL_PATH}" \ + --split all \ + --mode multi \ + --data_root "${DATA_ROOT}" \ + --jsonl_path "${JSONL_PATH}" \ + --out_root "${OUT_ROOT}" \ + "$@" diff --git a/scripts/eval_verified400_single.sh b/scripts/eval_verified400_single.sh new file mode 100755 index 0000000..5793412 --- /dev/null +++ b/scripts/eval_verified400_single.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# ────────────────────────────────────────────────────────────────────────────── +# Eval skill0 on full SpreadsheetBench verified-400 (SINGLE-ROUND codegen) +# +# Usage: +# bash scripts/eval_verified400_single.sh +# bash scripts/eval_verified400_single.sh --workers 64 +# ────────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")" + +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +DATA_ROOT="/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400" +JSONL_PATH="${DATA_ROOT}/dataset.json" +SKILL_PATH="${PROJECT_ROOT}/skillopt/envs/spreadsheetbench/skills/initial.md" + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUT_ROOT="${PROJECT_ROOT}/outputs/eval_single_verified400_${TIMESTAMP}" + +echo "============================================================" +echo " Eval skill0 — SINGLE-ROUND codegen — verified-400" +echo "============================================================" +echo " data_root: ${DATA_ROOT}" +echo " skill: ${SKILL_PATH}" +echo " mode: single" +echo " out_root: ${OUT_ROOT}" +echo "============================================================" + +cd "${PROJECT_ROOT}" + +python scripts/eval_only.py \ + --config configs/spreadsheetbench_default.yaml \ + --skill "${SKILL_PATH}" \ + --split all \ + --mode single \ + --data_root "${DATA_ROOT}" \ + --jsonl_path "${JSONL_PATH}" \ + --out_root "${OUT_ROOT}" \ + "$@" diff --git a/scripts/launch_harness_bestsetting_from_scratch.sh b/scripts/launch_harness_bestsetting_from_scratch.sh new file mode 100755 index 0000000..e5ad460 --- /dev/null +++ b/scripts/launch_harness_bestsetting_from_scratch.sh @@ -0,0 +1,120 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +PY="${PY:-python}" +RUN_ROOT="${RUN_ROOT:-$ROOT/outputs/harness_bestsetting_fromscratch_$(date -u +%Y%m%d_%H%M%S)_run}" +MAX_PARALLEL="${MAX_PARALLEL:-2}" + +mkdir -p "$RUN_ROOT/logs" +cd "$ROOT" +export PYTHONPATH="$ROOT:${PYTHONPATH:-}" + +COMMON=( + model.teacher_backend=openai_chat + model.teacher=gpt-5.5 + model.teacher_azure_openai_endpoint=https://t2vgoaigpt4o3.openai.azure.com/ + model.teacher_azure_openai_api_version=2024-12-01-preview + model.teacher_azure_openai_auth_mode=azure_cli + model.reasoning_effort=medium + train.num_epochs=4 + train.train_size=0 + train.accumulation=1 + train.seed=42 + gradient.minibatch_size=8 + gradient.merge_batch_size=8 + gradient.analyst_workers=16 + gradient.use_deep_reflect=false + optimizer.min_learning_rate=2 + optimizer.lr_control_mode=fixed + optimizer.skill_update_mode=patch + optimizer.use_slow_update=true + optimizer.slow_update_samples=20 + optimizer.use_meta_skill=true + optimizer.use_meta_reflect=false + evaluation.use_gate=true + evaluation.eval_test=true + env.split_mode=split_dir +) + +CODEX=( + model.student_backend=codex_exec + model.student=gpt-5.5 + model.codex_exec_use_sdk=auto + model.codex_exec_sandbox=workspace-write + model.codex_exec_approval_policy=never + model.codex_trace_to_teacher=true +) + +CLAUDE=( + model.student_backend=claude_code_exec + model.student=claude-sonnet-4-6 + model.claude_code_exec_use_sdk=auto + model.codex_trace_to_teacher=false +) + +active=0 +launch() { + local run_id="$1"; shift + local config="$1"; shift + local out="$RUN_ROOT/$run_id" + local log="$RUN_ROOT/logs/$run_id.log" + echo "START $run_id" + setsid "$PY" -u scripts/train.py \ + --config "$config" \ + --cfg-options "${COMMON[@]}" "$@" "env.out_root=$out" \ + > "$log" 2>&1 < /dev/null & + active=$((active + 1)) + if (( active >= MAX_PARALLEL )); then + wait -n + active=$((active - 1)) + fi +} + +# SearchQA best openai-chat setting: optimizer.lr_scheduler=constant. +launch HARNESS-BESTSETTING-searchqa-codex configs/searchqa/default.yaml \ + "${CODEX[@]}" \ + train.batch_size=40 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant \ + env.split_dir=data/searchqa/splits + +launch HARNESS-BESTSETTING-searchqa-claude configs/searchqa/default.yaml \ + "${CLAUDE[@]}" \ + train.batch_size=40 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant \ + env.split_dir=data/searchqa/splits + +# SpreadsheetBench best openai-chat setting: optimizer.lr_scheduler=constant. +# Must stay env.mode=multi; exec-backend multi support is fixed on this branch. +launch HARNESS-BESTSETTING-spreadsheetbench-codex configs/spreadsheetbench/default.yaml \ + "${CODEX[@]}" \ + train.batch_size=40 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant \ + env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi env.workers=4 + +launch HARNESS-BESTSETTING-spreadsheetbench-claude configs/spreadsheetbench/default.yaml \ + "${CLAUDE[@]}" \ + train.batch_size=40 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant \ + env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi env.workers=4 + +# LiveMathBench best openai-chat setting: optimizer.learning_rate=8. +launch HARNESS-BESTSETTING-livemathematicianbench-codex configs/livemathematicianbench/default.yaml \ + "${CODEX[@]}" \ + train.batch_size=40 optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant \ + env.split_dir=data/livemathbench/splits + +launch HARNESS-BESTSETTING-livemathematicianbench-claude configs/livemathematicianbench/default.yaml \ + "${CLAUDE[@]}" \ + train.batch_size=40 optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant \ + env.split_dir=data/livemathbench/splits + +# DocVQA best openai-chat setting was full batch. On 10% harness split, train=107. +launch HARNESS-BESTSETTING-docvqa10pct-codex configs/docvqa/default.yaml \ + "${CODEX[@]}" \ + train.batch_size=107 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine \ + env.split_dir=data/harness_splits/docvqa_zisu_first10pct + +launch HARNESS-BESTSETTING-docvqa10pct-claude configs/docvqa/default.yaml \ + "${CLAUDE[@]}" \ + train.batch_size=107 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine \ + env.split_dir=data/harness_splits/docvqa_zisu_first10pct + +wait +echo "All launched runs finished or exited. RUN_ROOT=$RUN_ROOT" diff --git a/scripts/launch_harness_canonical_claude18.sh b/scripts/launch_harness_canonical_claude18.sh new file mode 100755 index 0000000..0bd6310 --- /dev/null +++ b/scripts/launch_harness_canonical_claude18.sh @@ -0,0 +1,178 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO="/home/azureuser/workspace-gzy/SkillReflection" +PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python" + +cd "$REPO" + +# Claude Code on this machine must use the local copilot-api proxy. +export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}" +export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}" +export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}" +export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}" +export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}" +export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}" + +if [[ -f ".secrets/teacher_oaidr9.env" ]]; then + # shellcheck disable=SC1091 + source ".secrets/teacher_oaidr9.env" +else + echo "missing .secrets/teacher_oaidr9.env" >&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d_%H%M%S)" +RUN_ROOT="${1:-outputs/harness_canonical_claude18_workers2_timeout1020_${stamp}_run}" +SESSION="${2:-harness_canon_claude18_${stamp}}" + +mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands" + +COMMON_CFG=( + model.teacher_backend=openai_chat + model.teacher=gpt-5.5 + model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" + model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" + model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.reasoning_effort=medium + train.num_epochs=4 + train.train_size=0 + train.batch_size=40 + train.accumulation=1 + train.seed=42 + gradient.minibatch_size=8 + gradient.merge_batch_size=8 + gradient.analyst_workers=16 + gradient.use_deep_reflect=false + optimizer.lr_control_mode=fixed + optimizer.skill_update_mode=patch + optimizer.use_slow_update=true + optimizer.slow_update_samples=20 + optimizer.use_meta_skill=true + optimizer.use_meta_reflect=false + evaluation.use_gate=true + evaluation.eval_test=true + env.split_mode=split_dir + env.workers=2 + env.exec_timeout=1020 + model.student_backend=claude_code_exec + model.student=claude-sonnet-4-6 + model.claude_code_exec_use_sdk=sdk + model.claude_code_exec_effort=medium + model.claude_code_exec_max_thinking_tokens=16384 + model.codex_trace_to_teacher=false +) + +tmux_started=0 + +launch_run() { + local run_id="$1" + local config="$2" + local skill="$3" + shift 3 + + local cmd_file="$RUN_ROOT/commands/${run_id}.sh" + local log_file="$RUN_ROOT/logs/${run_id}.log" + local out_root="$RUN_ROOT/$run_id" + + local -a cmd=( + "$PYTHON" -u scripts/train.py + --config "$config" + --cfg-options + "${COMMON_CFG[@]}" + env.skill_init="$skill" + env.out_root="$out_root" + "$@" + ) + + { + echo "#!/usr/bin/env bash" + echo "set -euo pipefail" + echo "cd '$REPO'" + printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL" + printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN" + printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL" + printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL" + printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS" + printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC" + printf '%q ' "${cmd[@]}" + printf ' >%q 2>&1 < /dev/null\n' "$log_file" + } > "$cmd_file" + chmod +x "$cmd_file" + + if [[ "$tmux_started" -eq 0 ]]; then + tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + tmux_started=1 + else + tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + fi + echo "$run_id" +} + +SEARCHQA_SKILL="docs/harness_source_skills/searchqa_best_skill.md" +LIVEMATH_SKILL="docs/harness_source_skills/livemathematicianbench_best_skill.md" +DOCVQA_SKILL="docs/harness_source_skills/docvqa_best_skill.md" +SPREADSHEET_SKILL="docs/harness_source_skills/spreadsheetbench_best_skill.md" + +launch_run "HARNESS-Claude-SearchQA-sched-constant" "configs/searchqa/default.yaml" "$SEARCHQA_SKILL" \ + env.split_dir=data/searchqa/splits \ + optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant +launch_run "HARNESS-Claude-SearchQA-sched-linear" "configs/searchqa/default.yaml" "$SEARCHQA_SKILL" \ + env.split_dir=data/searchqa/splits \ + optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=linear +launch_run "HARNESS-Claude-SearchQA-batch-full" "configs/searchqa/default.yaml" "$SEARCHQA_SKILL" \ + env.split_dir=data/searchqa/splits \ + train.batch_size=400 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine +launch_run "HARNESS-Claude-SearchQA-lr8" "configs/searchqa/default.yaml" "$SEARCHQA_SKILL" \ + env.split_dir=data/searchqa/splits \ + optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant + +launch_run "HARNESS-Claude-LiveMath-lr8" "configs/livemathematicianbench/default.yaml" "$LIVEMATH_SKILL" \ + env.split_dir=data/livemathbench/splits \ + optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant +launch_run "HARNESS-Claude-LiveMath-lr16" "configs/livemathematicianbench/default.yaml" "$LIVEMATH_SKILL" \ + env.split_dir=data/livemathbench/splits \ + optimizer.learning_rate=16 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant +launch_run "HARNESS-Claude-LiveMath-slow10" "configs/livemathematicianbench/default.yaml" "$LIVEMATH_SKILL" \ + env.split_dir=data/livemathbench/splits \ + optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine optimizer.slow_update_samples=10 +launch_run "HARNESS-Claude-LiveMath-sched-linear" "configs/livemathematicianbench/default.yaml" "$LIVEMATH_SKILL" \ + env.split_dir=data/livemathbench/splits \ + optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=linear +launch_run "HARNESS-Claude-LiveMath-minibatch4" "configs/livemathematicianbench/default.yaml" "$LIVEMATH_SKILL" \ + env.split_dir=data/livemathbench/splits \ + gradient.minibatch_size=4 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine + +launch_run "HARNESS-Claude-DocVQA10-batch-full" "configs/docvqa/default.yaml" "$DOCVQA_SKILL" \ + env.split_dir=data/harness_splits/docvqa_zisu_first10pct \ + train.batch_size=107 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine +launch_run "HARNESS-Claude-DocVQA10-lr16" "configs/docvqa/default.yaml" "$DOCVQA_SKILL" \ + env.split_dir=data/harness_splits/docvqa_zisu_first10pct \ + optimizer.learning_rate=16 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant +launch_run "HARNESS-Claude-DocVQA10-lr8" "configs/docvqa/default.yaml" "$DOCVQA_SKILL" \ + env.split_dir=data/harness_splits/docvqa_zisu_first10pct \ + optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant +launch_run "HARNESS-Claude-DocVQA10-minibatch32" "configs/docvqa/default.yaml" "$DOCVQA_SKILL" \ + env.split_dir=data/harness_splits/docvqa_zisu_first10pct \ + gradient.minibatch_size=32 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine +launch_run "HARNESS-Claude-DocVQA10-batch24" "configs/docvqa/default.yaml" "$DOCVQA_SKILL" \ + env.split_dir=data/harness_splits/docvqa_zisu_first10pct \ + train.batch_size=24 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine + +launch_run "HARNESS-Claude-Spreadsheet-sched-constant-multi" "configs/spreadsheetbench/default.yaml" "$SPREADSHEET_SKILL" \ + env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \ + optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant +launch_run "HARNESS-Claude-Spreadsheet-lr4-multi" "configs/spreadsheetbench/default.yaml" "$SPREADSHEET_SKILL" \ + env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \ + optimizer.learning_rate=4 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant +launch_run "HARNESS-Claude-Spreadsheet-lr16-multi" "configs/spreadsheetbench/default.yaml" "$SPREADSHEET_SKILL" \ + env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \ + optimizer.learning_rate=16 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant +launch_run "HARNESS-Claude-Spreadsheet-minibatch16-multi" "configs/spreadsheetbench/default.yaml" "$SPREADSHEET_SKILL" \ + env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \ + gradient.minibatch_size=16 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine + +echo "RUN_ROOT=$RUN_ROOT" +echo "SESSION=$SESSION" diff --git a/scripts/launch_harness_canonical_claude4_smoke.sh b/scripts/launch_harness_canonical_claude4_smoke.sh new file mode 100755 index 0000000..fae6b7a --- /dev/null +++ b/scripts/launch_harness_canonical_claude4_smoke.sh @@ -0,0 +1,130 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO="/home/azureuser/workspace-gzy/SkillReflection" +PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python" + +cd "$REPO" + +export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}" +export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}" +export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}" +export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}" +export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}" +export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}" + +if [[ -f ".secrets/teacher_oaidr9.env" ]]; then + # shellcheck disable=SC1091 + source ".secrets/teacher_oaidr9.env" +else + echo "missing .secrets/teacher_oaidr9.env" >&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d_%H%M%S)" +RUN_ROOT="${1:-outputs/harness_canonical_claude4_smoke_workers2_timeout1020_${stamp}_run}" +SESSION="${2:-harness_canon_claude4_${stamp}}" + +mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands" + +COMMON_CFG=( + model.teacher_backend=openai_chat + model.teacher=gpt-5.5 + model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" + model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" + model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.reasoning_effort=medium + train.num_epochs=4 + train.train_size=0 + train.batch_size=40 + train.accumulation=1 + train.seed=42 + gradient.minibatch_size=8 + gradient.merge_batch_size=8 + gradient.analyst_workers=16 + gradient.use_deep_reflect=false + optimizer.lr_control_mode=fixed + optimizer.skill_update_mode=patch + optimizer.use_slow_update=true + optimizer.slow_update_samples=20 + optimizer.use_meta_skill=true + optimizer.use_meta_reflect=false + evaluation.use_gate=true + evaluation.eval_test=true + env.split_mode=split_dir + env.workers=2 + env.exec_timeout=1020 + model.student_backend=claude_code_exec + model.student=claude-sonnet-4-6 + model.claude_code_exec_use_sdk=sdk + model.claude_code_exec_effort=medium + model.claude_code_exec_max_thinking_tokens=16384 + model.codex_trace_to_teacher=false +) + +tmux_started=0 + +launch_run() { + local run_id="$1" + local config="$2" + local skill="$3" + shift 3 + + local cmd_file="$RUN_ROOT/commands/${run_id}.sh" + local log_file="$RUN_ROOT/logs/${run_id}.log" + local out_root="$RUN_ROOT/$run_id" + + local -a cmd=( + "$PYTHON" -u scripts/train.py + --config "$config" + --cfg-options + "${COMMON_CFG[@]}" + env.skill_init="$skill" + env.out_root="$out_root" + "$@" + ) + + { + echo "#!/usr/bin/env bash" + echo "set -euo pipefail" + echo "cd '$REPO'" + printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL" + printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN" + printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL" + printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL" + printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS" + printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC" + printf '%q ' "${cmd[@]}" + printf ' >%q 2>&1 < /dev/null\n' "$log_file" + } > "$cmd_file" + chmod +x "$cmd_file" + + if [[ "$tmux_started" -eq 0 ]]; then + tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + tmux_started=1 + else + tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + fi + echo "$run_id" +} + +launch_run "HARNESS-Claude-SearchQA-sched-constant" "configs/searchqa/default.yaml" "docs/harness_source_skills/searchqa_best_skill.md" \ + env.split_dir=data/searchqa/splits \ + optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant + +launch_run "HARNESS-Claude-LiveMath-lr8" "configs/livemathematicianbench/default.yaml" "docs/harness_source_skills/livemathematicianbench_best_skill.md" \ + env.split_dir=data/livemathbench/splits \ + optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant + +launch_run "HARNESS-Claude-DocVQA10-lr8" "configs/docvqa/default.yaml" "docs/harness_source_skills/docvqa_best_skill.md" \ + env.split_dir=data/harness_splits/docvqa_zisu_first10pct \ + optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant + +launch_run "HARNESS-Claude-Spreadsheet-lr4-multi" "configs/spreadsheetbench/default.yaml" "docs/harness_source_skills/spreadsheetbench_best_skill.md" \ + env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \ + optimizer.learning_rate=4 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant + +echo "RUN_ROOT=$RUN_ROOT" +echo "SESSION=$SESSION" diff --git a/scripts/launch_harness_canonical_wave1.sh b/scripts/launch_harness_canonical_wave1.sh new file mode 100755 index 0000000..89f2556 --- /dev/null +++ b/scripts/launch_harness_canonical_wave1.sh @@ -0,0 +1,168 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO="/home/azureuser/workspace-gzy/SkillReflection" +PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python" +CODEX_WRAPPER="$REPO/scripts/codex_azure_mi.sh" + +cd "$REPO" + +# Claude Code is routed through the local copilot-api proxy on this machine. +# Do not rely on interactive Claude login state inside tmux/train workers. +export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}" +export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}" +export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}" +export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}" +export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}" +export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}" + +if [[ -f ".secrets/teacher_oaidr9.env" ]]; then + # shellcheck disable=SC1091 + source ".secrets/teacher_oaidr9.env" +else + echo "missing .secrets/teacher_oaidr9.env" >&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d_%H%M%S)" +RUN_ROOT="${1:-outputs/harness_canonical_step12_wave1_workers2_timeout1020_${stamp}_run}" +SESSION="${2:-harness_canon_wave1_${stamp}}" + +mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands" + +COMMON_CFG=( + model.teacher_backend=openai_chat + model.teacher=gpt-5.5 + model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" + model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" + model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.reasoning_effort=medium + train.num_epochs=4 + train.train_size=0 + train.batch_size=40 + train.accumulation=1 + train.seed=42 + gradient.minibatch_size=8 + gradient.merge_batch_size=8 + gradient.analyst_workers=16 + gradient.use_deep_reflect=false + optimizer.lr_control_mode=fixed + optimizer.skill_update_mode=patch + optimizer.use_slow_update=true + optimizer.slow_update_samples=20 + optimizer.use_meta_skill=true + optimizer.use_meta_reflect=false + evaluation.use_gate=true + evaluation.eval_test=true + env.split_mode=split_dir + env.workers=2 + env.exec_timeout=1020 +) + +tmux_started=0 + +launch_run() { + local run_id="$1" + local backend="$2" + local config="$3" + local skill="$4" + local split_dir="$5" + shift 5 + + local -a backend_cfg=() + if [[ "$backend" == "codex" ]]; then + backend_cfg=( + model.student_backend=codex_exec + model.student=gpt-5.5 + model.codex_exec_path="$CODEX_WRAPPER" + model.codex_exec_use_sdk=auto + model.codex_exec_sandbox=workspace-write + model.codex_exec_approval_policy=never + model.codex_exec_reasoning_effort=medium + model.codex_trace_to_teacher=true + ) + elif [[ "$backend" == "claude" ]]; then + backend_cfg=( + model.student_backend=claude_code_exec + model.student=claude-sonnet-4-6 + model.claude_code_exec_use_sdk=sdk + model.claude_code_exec_effort=medium + model.claude_code_exec_max_thinking_tokens=16384 + model.codex_trace_to_teacher=false + ) + else + echo "unknown backend: $backend" >&2 + exit 1 + fi + + local cmd_file="$RUN_ROOT/commands/${run_id}.sh" + local log_file="$RUN_ROOT/logs/${run_id}.log" + local out_root="$RUN_ROOT/$run_id" + + local -a cmd=( + "$PYTHON" -u scripts/train.py + --config "$config" + --cfg-options + "${COMMON_CFG[@]}" + "${backend_cfg[@]}" + env.split_dir="$split_dir" + env.skill_init="$skill" + env.out_root="$out_root" + "$@" + ) + + { + echo "#!/usr/bin/env bash" + echo "set -euo pipefail" + echo "cd '$REPO'" + printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL" + printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN" + printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL" + printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL" + printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS" + printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC" + printf '%q ' "${cmd[@]}" + printf ' >%q 2>&1 < /dev/null\n' "$log_file" + } > "$cmd_file" + chmod +x "$cmd_file" + + if [[ "$tmux_started" -eq 0 ]]; then + tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + tmux_started=1 + else + tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + fi + echo "$run_id" +} + +SEARCHQA_SKILL="docs/harness_source_skills/searchqa_best_skill.md" +LIVEMATH_SKILL="docs/harness_source_skills/livemathematicianbench_best_skill.md" + +SEARCHQA_CFG="configs/searchqa/default.yaml" +LIVEMATH_CFG="configs/livemathematicianbench/default.yaml" + +SEARCHQA_SPLIT="data/searchqa/splits" +LIVEMATH_SPLIT="data/livemathbench/splits" + +for backend in codex claude; do + prefix="HARNESS-Codex" + [[ "$backend" == "claude" ]] && prefix="HARNESS-Claude" + + launch_run "${prefix}-SearchQA-sched-constant" "$backend" "$SEARCHQA_CFG" "$SEARCHQA_SKILL" "$SEARCHQA_SPLIT" \ + optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant + launch_run "${prefix}-SearchQA-sched-linear" "$backend" "$SEARCHQA_CFG" "$SEARCHQA_SKILL" "$SEARCHQA_SPLIT" \ + optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=linear + launch_run "${prefix}-SearchQA-batch-full" "$backend" "$SEARCHQA_CFG" "$SEARCHQA_SKILL" "$SEARCHQA_SPLIT" \ + train.batch_size=400 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine + launch_run "${prefix}-SearchQA-lr8" "$backend" "$SEARCHQA_CFG" "$SEARCHQA_SKILL" "$SEARCHQA_SPLIT" \ + optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant + launch_run "${prefix}-LiveMath-lr8" "$backend" "$LIVEMATH_CFG" "$LIVEMATH_SKILL" "$LIVEMATH_SPLIT" \ + optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant + launch_run "${prefix}-LiveMath-lr16" "$backend" "$LIVEMATH_CFG" "$LIVEMATH_SKILL" "$LIVEMATH_SPLIT" \ + optimizer.learning_rate=16 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant +done + +echo "RUN_ROOT=$RUN_ROOT" +echo "SESSION=$SESSION" diff --git a/scripts/launch_harness_initial_claude4.sh b/scripts/launch_harness_initial_claude4.sh new file mode 100755 index 0000000..367865d --- /dev/null +++ b/scripts/launch_harness_initial_claude4.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO="/home/azureuser/workspace-gzy/SkillReflection" +PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python" + +cd "$REPO" + +export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}" +export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}" +export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}" +export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}" +export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}" +export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}" + +if [[ -f ".secrets/teacher_oaidr9.env" ]]; then + # shellcheck disable=SC1091 + source ".secrets/teacher_oaidr9.env" +else + echo "missing .secrets/teacher_oaidr9.env" >&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d_%H%M%S)" +RUN_ROOT="${1:-outputs/harness_initial_claude4_workers2_timeout1020_${stamp}_run}" +SESSION="${2:-harness_initial_claude4_${stamp}}" + +mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands" + +COMMON_CFG=( + model.teacher_backend=openai_chat + model.teacher=gpt-5.5 + model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" + model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" + model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.reasoning_effort=medium + train.num_epochs=4 + train.train_size=0 + train.batch_size=40 + train.accumulation=1 + train.seed=42 + gradient.minibatch_size=8 + gradient.merge_batch_size=8 + gradient.analyst_workers=16 + gradient.use_deep_reflect=false + optimizer.lr_control_mode=fixed + optimizer.skill_update_mode=patch + optimizer.use_slow_update=true + optimizer.slow_update_samples=20 + optimizer.use_meta_skill=true + optimizer.use_meta_reflect=false + evaluation.use_gate=true + evaluation.eval_test=true + env.split_mode=split_dir + env.workers=2 + env.exec_timeout=1020 + model.student_backend=claude_code_exec + model.student=claude-sonnet-4-6 + model.claude_code_exec_use_sdk=sdk + model.claude_code_exec_effort=medium + model.claude_code_exec_max_thinking_tokens=16384 + model.codex_trace_to_teacher=false +) + +tmux_started=0 + +launch_run() { + local run_id="$1" + local config="$2" + shift 2 + + local cmd_file="$RUN_ROOT/commands/${run_id}.sh" + local log_file="$RUN_ROOT/logs/${run_id}.log" + local out_root="$RUN_ROOT/$run_id" + + local -a cmd=( + "$PYTHON" -u scripts/train.py + --config "$config" + --cfg-options + "${COMMON_CFG[@]}" + env.out_root="$out_root" + "$@" + ) + + { + echo "#!/usr/bin/env bash" + echo "set -euo pipefail" + echo "cd '$REPO'" + printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL" + printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN" + printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL" + printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL" + printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS" + printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC" + printf '%q ' "${cmd[@]}" + printf ' >%q 2>&1 < /dev/null\n' "$log_file" + } > "$cmd_file" + chmod +x "$cmd_file" + + if [[ "$tmux_started" -eq 0 ]]; then + tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + tmux_started=1 + else + tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + fi + echo "$run_id" +} + +launch_run "HARNESS-ClaudeInit-SearchQA-sched-constant" "configs/searchqa/default.yaml" \ + env.split_dir=data/searchqa/splits \ + optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant + +launch_run "HARNESS-ClaudeInit-LiveMath-lr8" "configs/livemathematicianbench/default.yaml" \ + env.split_dir=data/livemathbench/splits \ + optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant + +launch_run "HARNESS-ClaudeInit-DocVQA10-lr8" "configs/docvqa/default.yaml" \ + env.split_dir=data/harness_splits/docvqa_zisu_first10pct \ + optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant + +launch_run "HARNESS-ClaudeInit-Spreadsheet-lr4-multi" "configs/spreadsheetbench/default.yaml" \ + env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \ + optimizer.learning_rate=4 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant + +echo "RUN_ROOT=$RUN_ROOT" +echo "SESSION=$SESSION" diff --git a/scripts/launch_harness_initial_spreadsheet_clean.sh b/scripts/launch_harness_initial_spreadsheet_clean.sh new file mode 100755 index 0000000..eb094e6 --- /dev/null +++ b/scripts/launch_harness_initial_spreadsheet_clean.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO="/home/azureuser/workspace-gzy/SkillReflection" +PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python" + +cd "$REPO" + +export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}" +export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}" +export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}" +export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}" +export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}" +export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}" + +if [[ -f ".secrets/teacher_oaidr9.env" ]]; then + # shellcheck disable=SC1091 + source ".secrets/teacher_oaidr9.env" +else + echo "missing .secrets/teacher_oaidr9.env" >&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d_%H%M%S)" +RUN_ROOT="${1:-outputs/harness_initial_spreadsheet_clean_workers2_timeout1020_${stamp}_run}" +SESSION="${2:-harness_initial_spreadsheet_clean_${stamp}}" +RUN_ID="HARNESS-ClaudeInit-Spreadsheet-lr4-multi-clean" +SPLIT_DIR="${SPREADSHEET_SPLIT_DIR:-data/harness_splits/spreadsheetbench_full}" +DATA_ROOT="${SPREADSHEET_DATA_ROOT:-data/spreadsheetbench/files}" + +mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands" + +cmd_file="$RUN_ROOT/commands/${RUN_ID}.sh" +log_file="$RUN_ROOT/logs/${RUN_ID}.log" +out_root="$RUN_ROOT/$RUN_ID" + +cmd=( + "$PYTHON" -u scripts/train.py + --config configs/spreadsheetbench/default.yaml + --cfg-options + model.teacher_backend=openai_chat + model.teacher=gpt-5.5 + model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" + model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" + model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.reasoning_effort=medium + train.num_epochs=4 + train.train_size=0 + train.batch_size=40 + train.accumulation=1 + train.seed=42 + gradient.minibatch_size=8 + gradient.merge_batch_size=8 + gradient.analyst_workers=16 + gradient.use_deep_reflect=false + optimizer.lr_control_mode=fixed + optimizer.skill_update_mode=patch + optimizer.use_slow_update=true + optimizer.slow_update_samples=20 + optimizer.use_meta_skill=true + optimizer.use_meta_reflect=false + evaluation.use_gate=true + evaluation.eval_test=true + env.split_mode=split_dir + env.workers=2 + env.exec_timeout=1020 + model.student_backend=claude_code_exec + model.student=claude-sonnet-4-6 + model.claude_code_exec_use_sdk=sdk + model.claude_code_exec_effort=medium + model.claude_code_exec_max_thinking_tokens=16384 + model.codex_trace_to_teacher=false + env.out_root="$out_root" + env.split_dir="$SPLIT_DIR" + env.data_root="$DATA_ROOT" + env.mode=multi + optimizer.learning_rate=4 + optimizer.min_learning_rate=1 + optimizer.lr_scheduler=constant +) + +{ + echo "#!/usr/bin/env bash" + echo "set -euo pipefail" + echo "cd '$REPO'" + printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL" + printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN" + printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL" + printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL" + printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS" + printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC" + printf '%q ' "${cmd[@]}" + printf ' >%q 2>&1 < /dev/null\n' "$log_file" +} > "$cmd_file" +chmod +x "$cmd_file" + +tmux new-session -d -s "$SESSION" -n "$RUN_ID" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + +echo "RUN_ROOT=$RUN_ROOT" +echo "SESSION=$SESSION" +echo "RUN_ID=$RUN_ID" diff --git a/scripts/launch_lrctrl_fullrewrite_neutral3.sh b/scripts/launch_lrctrl_fullrewrite_neutral3.sh new file mode 100755 index 0000000..af5e455 --- /dev/null +++ b/scripts/launch_lrctrl_fullrewrite_neutral3.sh @@ -0,0 +1,116 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO="/home/azureuser/workspace-gzy/SkillReflection" +PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python" + +cd "$REPO" + +if [[ -f ".secrets/teacher_oaidr9.env" ]]; then + # shellcheck disable=SC1091 + source ".secrets/teacher_oaidr9.env" +else + echo "missing .secrets/teacher_oaidr9.env" >&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d_%H%M%S)" +RUN_ROOT="${1:-outputs/lrctrl_fullrewrite_neutral3_workers2_timeout1020_${stamp}_run}" +SESSION="${2:-lrctrl_fullrewrite_neutral3_${stamp}}" +SEED="${3:-42}" + +mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands" + +COMMON_CFG=( + model.teacher_backend=openai_chat + model.student_backend=openai_chat + model.teacher=gpt-5.5 + model.student=gpt-5.5 + model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" + model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" + model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.student_azure_openai_endpoint="${STUDENT_AZURE_OPENAI_ENDPOINT:-$TEACHER_AZURE_OPENAI_ENDPOINT}" + model.student_azure_openai_api_version="${STUDENT_AZURE_OPENAI_API_VERSION:-$TEACHER_AZURE_OPENAI_API_VERSION}" + model.student_azure_openai_auth_mode="${STUDENT_AZURE_OPENAI_AUTH_MODE:-$TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.student_azure_openai_managed_identity_client_id="${STUDENT_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID:-$TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.student_azure_openai_ad_scope="${STUDENT_AZURE_OPENAI_AD_SCOPE:-$TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.reasoning_effort=medium + train.num_epochs=4 + train.train_size=0 + train.batch_size=40 + train.accumulation=1 + train.seed="${SEED}" + gradient.minibatch_size=8 + gradient.merge_batch_size=8 + gradient.analyst_workers=16 + gradient.use_deep_reflect=false + optimizer.learning_rate=4 + optimizer.min_learning_rate=2 + optimizer.lr_scheduler=cosine + optimizer.lr_control_mode=none + optimizer.skill_update_mode=full_rewrite_minibatch + optimizer.use_slow_update=true + optimizer.slow_update_samples=20 + optimizer.use_meta_skill=true + optimizer.use_meta_reflect=false + evaluation.use_gate=true + evaluation.eval_test=true + env.split_mode=split_dir + env.workers=2 + env.exec_timeout=1020 +) + +tmux_started=0 + +launch_run() { + local run_id="$1" + local config="$2" + shift 2 + + local cmd_file="$RUN_ROOT/commands/${run_id}.sh" + local log_file="$RUN_ROOT/logs/${run_id}.log" + local out_root="$RUN_ROOT/$run_id" + + local -a cmd=( + "$PYTHON" -u scripts/train.py + --config "$config" + --cfg-options + "${COMMON_CFG[@]}" + env.out_root="$out_root" + "$@" + ) + + { + echo "#!/usr/bin/env bash" + echo "set -euo pipefail" + echo "cd '$REPO'" + printf '%q ' "${cmd[@]}" + printf ' >%q 2>&1 < /dev/null\n' "$log_file" + } > "$cmd_file" + chmod +x "$cmd_file" + + if [[ "$tmux_started" -eq 0 ]]; then + tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + tmux_started=1 + else + tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + fi + echo "$run_id" +} + +launch_run "LRCTRL-searchqa-full-rewrite-neutral3-seed${SEED}" "configs/searchqa/default.yaml" \ + env.split_dir=data/ablation_splits/searchqa/2-1-7_seed42 + +launch_run "LRCTRL-spreadsheetbench-full-rewrite-neutral3-seed${SEED}" "configs/spreadsheetbench/default.yaml" \ + env.split_dir=data/ablation_splits/spreadsheetbench/2-1-7_seed42 \ + env.data_root=data/spreadsheetbench_verified_400 \ + env.mode=multi + +launch_run "LRCTRL-livemathematicianbench-full-rewrite-neutral3-seed${SEED}" "configs/livemathematicianbench/default.yaml" \ + env.split_dir=data/ablation_splits/livemathematicianbench/2-1-7_seed42 + +echo "RUN_ROOT=$RUN_ROOT" +echo "SESSION=$SESSION" +echo "SEED=$SEED" diff --git a/scripts/launch_lrctrl_fullrewrite_neutral3_spreadsheet_repeats.sh b/scripts/launch_lrctrl_fullrewrite_neutral3_spreadsheet_repeats.sh new file mode 100755 index 0000000..bfa5ed4 --- /dev/null +++ b/scripts/launch_lrctrl_fullrewrite_neutral3_spreadsheet_repeats.sh @@ -0,0 +1,111 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO="/home/azureuser/workspace-gzy/SkillReflection" +PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python" + +cd "$REPO" + +if [[ -f ".secrets/teacher_oaidr9.env" ]]; then + # shellcheck disable=SC1091 + source ".secrets/teacher_oaidr9.env" +else + echo "missing .secrets/teacher_oaidr9.env" >&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d_%H%M%S)" +RUN_ROOT="${1:-outputs/lrctrl_fullrewrite_neutral3_spreadsheet_promptweak_workers2_timeout1020_${stamp}_run}" +SESSION="${2:-lrctrl_fr_spreadsheet_promptweak_${stamp}}" +START_INDEX="${3:-4}" +N_REPEATS="${4:-3}" + +mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands" + +COMMON_CFG=( + model.teacher_backend=openai_chat + model.student_backend=openai_chat + model.teacher=gpt-5.5 + model.student=gpt-5.5 + model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" + model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" + model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.student_azure_openai_endpoint="${STUDENT_AZURE_OPENAI_ENDPOINT:-$TEACHER_AZURE_OPENAI_ENDPOINT}" + model.student_azure_openai_api_version="${STUDENT_AZURE_OPENAI_API_VERSION:-$TEACHER_AZURE_OPENAI_API_VERSION}" + model.student_azure_openai_auth_mode="${STUDENT_AZURE_OPENAI_AUTH_MODE:-$TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.student_azure_openai_managed_identity_client_id="${STUDENT_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID:-$TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.student_azure_openai_ad_scope="${STUDENT_AZURE_OPENAI_AD_SCOPE:-$TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.reasoning_effort=medium + train.num_epochs=4 + train.train_size=0 + train.batch_size=40 + train.accumulation=1 + train.seed=42 + gradient.minibatch_size=8 + gradient.merge_batch_size=8 + gradient.analyst_workers=16 + gradient.use_deep_reflect=false + optimizer.learning_rate=4 + optimizer.min_learning_rate=2 + optimizer.lr_scheduler=cosine + optimizer.lr_control_mode=none + optimizer.skill_update_mode=full_rewrite_minibatch + optimizer.use_slow_update=true + optimizer.slow_update_samples=20 + optimizer.use_meta_skill=true + optimizer.use_meta_reflect=false + evaluation.use_gate=true + evaluation.eval_test=true + env.split_mode=split_dir + env.workers=2 + env.exec_timeout=1020 + env.split_dir=data/ablation_splits/spreadsheetbench/2-1-7_seed42 + env.data_root=data/spreadsheetbench_verified_400 + env.mode=multi +) + +tmux_started=0 + +launch_run() { + local run_id="$1" + + local cmd_file="$RUN_ROOT/commands/${run_id}.sh" + local log_file="$RUN_ROOT/logs/${run_id}.log" + local out_root="$RUN_ROOT/$run_id" + + local -a cmd=( + "$PYTHON" -u scripts/train.py + --config configs/spreadsheetbench/default.yaml + --cfg-options + "${COMMON_CFG[@]}" + env.out_root="$out_root" + ) + + { + echo "#!/usr/bin/env bash" + echo "set -euo pipefail" + echo "cd '$REPO'" + printf '%q ' "${cmd[@]}" + printf ' >%q 2>&1 < /dev/null\n' "$log_file" + } > "$cmd_file" + chmod +x "$cmd_file" + + if [[ "$tmux_started" -eq 0 ]]; then + tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + tmux_started=1 + else + tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + fi + echo "$run_id" +} + +for ((i=START_INDEX; i&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d_%H%M%S)" +RUN_ROOT="${1:-outputs/lrctrl_fullrewrite_neutral3_sq_lm_extra_workers2_timeout1020_${stamp}_run}" +SESSION="${2:-lrctrl_fr_sq_lm_extra_${stamp}}" +START_INDEX="${3:-4}" +N_REPEATS="${4:-3}" + +mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands" + +COMMON_CFG=( + model.teacher_backend=openai_chat + model.student_backend=openai_chat + model.teacher=gpt-5.5 + model.student=gpt-5.5 + model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" + model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" + model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.student_azure_openai_endpoint="${STUDENT_AZURE_OPENAI_ENDPOINT:-$TEACHER_AZURE_OPENAI_ENDPOINT}" + model.student_azure_openai_api_version="${STUDENT_AZURE_OPENAI_API_VERSION:-$TEACHER_AZURE_OPENAI_API_VERSION}" + model.student_azure_openai_auth_mode="${STUDENT_AZURE_OPENAI_AUTH_MODE:-$TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.student_azure_openai_managed_identity_client_id="${STUDENT_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID:-$TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.student_azure_openai_ad_scope="${STUDENT_AZURE_OPENAI_AD_SCOPE:-$TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.reasoning_effort=medium + train.num_epochs=4 + train.train_size=0 + train.batch_size=40 + train.accumulation=1 + train.seed=42 + gradient.minibatch_size=8 + gradient.merge_batch_size=8 + gradient.analyst_workers=16 + gradient.use_deep_reflect=false + optimizer.learning_rate=4 + optimizer.min_learning_rate=2 + optimizer.lr_scheduler=cosine + optimizer.lr_control_mode=none + optimizer.skill_update_mode=full_rewrite_minibatch + optimizer.use_slow_update=true + optimizer.slow_update_samples=20 + optimizer.use_meta_skill=true + optimizer.use_meta_reflect=false + evaluation.use_gate=true + evaluation.eval_test=true + env.split_mode=split_dir + env.workers=2 + env.exec_timeout=1020 +) + +tmux_started=0 + +launch_run() { + local run_id="$1" + local config="$2" + shift 2 + + local cmd_file="$RUN_ROOT/commands/${run_id}.sh" + local log_file="$RUN_ROOT/logs/${run_id}.log" + local out_root="$RUN_ROOT/$run_id" + + local -a cmd=( + "$PYTHON" -u scripts/train.py + --config "$config" + --cfg-options + "${COMMON_CFG[@]}" + env.out_root="$out_root" + "$@" + ) + + { + echo "#!/usr/bin/env bash" + echo "set -euo pipefail" + echo "cd '$REPO'" + printf '%q ' "${cmd[@]}" + printf ' >%q 2>&1 < /dev/null\n' "$log_file" + } > "$cmd_file" + chmod +x "$cmd_file" + + if [[ "$tmux_started" -eq 0 ]]; then + tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + tmux_started=1 + else + tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + fi + echo "$run_id" +} + +for ((i=START_INDEX; i&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d_%H%M%S)" +RUN_ROOT="${1:-outputs/spreadsheet_full_replacements_workers2_timeout1020_${stamp}_run}" +SESSION="${2:-spreadsheet_full_replacements_${stamp}}" + +mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands" + +tmux_started=0 + +launch_run() { + local run_id="$1" + shift + + local cmd_file="$RUN_ROOT/commands/${run_id}.sh" + local log_file="$RUN_ROOT/logs/${run_id}.log" + + { + echo "#!/usr/bin/env bash" + echo "set -euo pipefail" + echo "cd '$REPO'" + printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL" + printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN" + printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL" + printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL" + printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS" + printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC" + printf '%q ' "$@" + printf ' >%q 2>&1 < /dev/null\n' "$log_file" + } > "$cmd_file" + chmod +x "$cmd_file" + + if [[ "$tmux_started" -eq 0 ]]; then + tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + tmux_started=1 + else + tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600" + fi + echo "$run_id" +} + +OPENAI_COMMON=( + "$PYTHON" -u scripts/train.py + --config configs/spreadsheetbench/default.yaml + --cfg-options + model.teacher_backend=openai_chat + model.student_backend=openai_chat + model.teacher=gpt-5.5 + model.student=gpt-5.5 + model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" + model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" + model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.student_azure_openai_endpoint="${STUDENT_AZURE_OPENAI_ENDPOINT:-$TEACHER_AZURE_OPENAI_ENDPOINT}" + model.student_azure_openai_api_version="${STUDENT_AZURE_OPENAI_API_VERSION:-$TEACHER_AZURE_OPENAI_API_VERSION}" + model.student_azure_openai_auth_mode="${STUDENT_AZURE_OPENAI_AUTH_MODE:-$TEACHER_AZURE_OPENAI_AUTH_MODE}" + model.student_azure_openai_managed_identity_client_id="${STUDENT_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID:-$TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" + model.student_azure_openai_ad_scope="${STUDENT_AZURE_OPENAI_AD_SCOPE:-$TEACHER_AZURE_OPENAI_AD_SCOPE}" + model.reasoning_effort=medium + train.num_epochs=4 + train.train_size=0 + train.batch_size=40 + train.accumulation=1 + train.seed=42 + gradient.minibatch_size=8 + gradient.merge_batch_size=8 + gradient.analyst_workers=16 + gradient.use_deep_reflect=false + optimizer.learning_rate=4 + optimizer.min_learning_rate=2 + optimizer.lr_scheduler=cosine + optimizer.use_slow_update=true + optimizer.slow_update_samples=20 + optimizer.use_meta_reflect=false + evaluation.use_gate=true + evaluation.eval_test=true + env.split_mode=split_dir + env.workers=2 + env.exec_timeout=1020 + env.split_dir=data/ablation_splits/spreadsheetbench/2-1-7_seed42 + env.data_root=data/spreadsheetbench_verified_400 + env.mode=multi +) + +HARNESS_RUN="HARNESS-ClaudeInit-Spreadsheet-lr4-multi-full" +launch_run "$HARNESS_RUN" \ + "$PYTHON" -u scripts/train.py \ + --config configs/spreadsheetbench/default.yaml \ + --cfg-options \ + model.teacher_backend=openai_chat \ + model.teacher=gpt-5.5 \ + model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" \ + model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" \ + model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" \ + model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" \ + model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" \ + model.reasoning_effort=medium \ + train.num_epochs=4 \ + train.train_size=0 \ + train.batch_size=40 \ + train.accumulation=1 \ + train.seed=42 \ + gradient.minibatch_size=8 \ + gradient.merge_batch_size=8 \ + gradient.analyst_workers=16 \ + gradient.use_deep_reflect=false \ + optimizer.lr_control_mode=fixed \ + optimizer.skill_update_mode=patch \ + optimizer.use_slow_update=true \ + optimizer.slow_update_samples=20 \ + optimizer.use_meta_skill=true \ + optimizer.use_meta_reflect=false \ + evaluation.use_gate=true \ + evaluation.eval_test=true \ + env.split_mode=split_dir \ + env.workers=2 \ + env.exec_timeout=1020 \ + model.student_backend=claude_code_exec \ + model.student=claude-sonnet-4-6 \ + model.claude_code_exec_use_sdk=sdk \ + model.claude_code_exec_effort=medium \ + model.claude_code_exec_max_thinking_tokens=16384 \ + model.codex_trace_to_teacher=false \ + env.out_root="$RUN_ROOT/$HARNESS_RUN" \ + env.split_dir=data/spreadsheetbench/splits \ + env.data_root=data/spreadsheetbench/files \ + env.mode=multi \ + optimizer.learning_rate=4 \ + optimizer.min_learning_rate=1 \ + optimizer.lr_scheduler=constant + +for repeat in r1 r2 r3; do + run_id="LRCTRL-spreadsheetbench-full-rewrite-neutral3-full-${repeat}" + launch_run "$run_id" \ + "${OPENAI_COMMON[@]}" \ + optimizer.lr_control_mode=none \ + optimizer.skill_update_mode=full_rewrite_minibatch \ + optimizer.use_meta_skill=true \ + env.out_root="$RUN_ROOT/$run_id" +done + +for repeat in r1 r2 r3; do + run_id="SLOWMETA-spreadsheetbench-true-false-full-${repeat}" + launch_run "$run_id" \ + "${OPENAI_COMMON[@]}" \ + optimizer.lr_control_mode=fixed \ + optimizer.skill_update_mode=patch \ + optimizer.use_meta_skill=false \ + env.out_root="$RUN_ROOT/$run_id" +done + +echo "RUN_ROOT=$RUN_ROOT" +echo "SESSION=$SESSION" diff --git a/scripts/monitor_harness_claude18.sh b/scripts/monitor_harness_claude18.sh new file mode 100755 index 0000000..56fad35 --- /dev/null +++ b/scripts/monitor_harness_claude18.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="${1:?usage: monitor_harness_claude18.sh RUN_ROOT}" + +while true; do + ts="$(date -u +%Y-%m-%dT%H:%M:%SZ)" + echo "===== $ts CLAUDE18 =====" + uptime | sed 's/^/uptime=/' + active="$(pgrep -af "scripts/train.py.*${ROOT}" | grep -v pgrep | wc -l || true)" + claude_child="$(pgrep -af 'claude.*--output-format stream-json' | grep -v pgrep | wc -l || true)" + echo "active_train_total=$active" + echo "active_codex_train=0" + echo "active_claude_train=$active" + echo "claude_child=$claude_child" + + for d in "$ROOT"/HARNESS-Claude-*; do + [[ -d "$d" ]] || continue + rid="$(basename "$d")" + read -r base best < <(python3 - "$d" <<'PY' +import json, sys +from pathlib import Path +d = Path(sys.argv[1]) +s = d / "summary.json" +if not s.exists(): + print("pending pending") + raise SystemExit +try: + obj = json.loads(s.read_text()) +except Exception: + print("pending pending") + raise SystemExit +base = obj.get("baseline_test_hard", obj.get("base_test", "pending")) +best = obj.get("test_hard", obj.get("best_test", "pending")) +def fmt(x): + if isinstance(x, (int, float)): + return f"{x:.4f}" + return str(x) +print(fmt(base), fmt(best)) +PY +) + scan_files="$(mktemp)" + find "$d" \ + \( -path '*/codex_exec' -o -path '*/codex_multi' \) -prune -o \ + -maxdepth 6 -type f \ + \( -name 'claude_trace_summary.txt' -o -name 'codex_trace_summary.txt' -o -name '*.log' -o -name 'summary.json' \) \ + -print > "$scan_files" 2>/dev/null || true + auth="$({ xargs -r rg -l 'Not logged in|authentication_failed' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')" + e429="$({ xargs -r rg -l 'Too Many Requests|RateLimitError|Error code: 429|api_error_status.: 429|rate_limit|too_many_requests' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')" + e401="$({ xargs -r rg -l '401 Unauthorized|Error code: 401|HTTP 401|AuthenticationTypeDisabled|PermissionDeniedError' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')" + timeout="$({ xargs -r rg -l 'TimeoutError|Task timed out|timed out after|subprocess.TimeoutExpired|timeout_exceeded' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')" + teacher="$({ xargs -r rg -l 'APITimeoutError|APIConnectionError|AuthenticationError|Azure OpenAI Responses API is enabled only|teacher.*error' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')" + results="$(find "$d" -maxdepth 5 -path '*/results.jsonl' -type f -print0 2>/dev/null | xargs -0 -r wc -l | awk 'END{print $1+0}')" + empty="$({ xargs -r rg -l 'final response chars: 0|\"final_response\"\\s*:\\s*\"\"|\"result\"\\s*:\\s*\"\"' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')" + rm -f "$scan_files" + errors=$((auth + e429 + e401 + timeout + teacher)) + echo "$rid Base=$base Best=$best Errors=$errors auth=$auth 429=$e429 401=$e401 timeout=$timeout teacher=$teacher Results=$results Empty=$empty" + done | sort + echo + sleep 60 +done diff --git a/scripts/prepare_ablation_splits.py b/scripts/prepare_ablation_splits.py new file mode 100644 index 0000000..6aae772 --- /dev/null +++ b/scripts/prepare_ablation_splits.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +"""Prepare fixed data splits for ablation experiments.""" +from __future__ import annotations + +import argparse +import json +import random +from pathlib import Path + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] + +DATASETS = { + "searchqa": { + "raw": PROJECT_ROOT / "data/searchqa_train_2000.json", + "out": PROJECT_ROOT / "data/ablation_splits/searchqa", + "filenames": {"train": "train.json", "val": "selection.json", "test": "test.json"}, + }, + "spreadsheetbench": { + "raw": PROJECT_ROOT / "data/spreadsheetbench_verified_400/dataset.json", + "out": PROJECT_ROOT / "data/ablation_splits/spreadsheetbench", + "filenames": {"train": "train.json", "val": "sel.json", "test": "test.json"}, + }, +} + +SPLITS = ("1shot", "1:1:8", "2:1:7", "4:1:5") + + +def load_items(path: Path) -> list[dict]: + with path.open(encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, list): + raise TypeError(f"Expected JSON array in {path}, got {type(data).__name__}") + return data + + +def split_counts(total: int, split: str) -> tuple[int, int, int]: + if split == "1shot": + if total < 3: + raise ValueError(f"Need at least 3 items for 1shot split, got {total}") + return 1, 1, total - 2 + + ratio = split + weights = [int(part) for part in ratio.split(":")] + if len(weights) != 3 or min(weights) <= 0: + raise ValueError(f"Invalid ratio: {ratio}") + denom = sum(weights) + raw = [total * weight / denom for weight in weights] + counts = [int(value) for value in raw] + remaining = total - sum(counts) + order = sorted( + range(3), + key=lambda idx: (raw[idx] - counts[idx], weights[idx]), + reverse=True, + ) + for idx in order[:remaining]: + counts[idx] += 1 + return counts[0], counts[1], counts[2] + + +def split_tag(split: str) -> str: + return "1shot" if split == "1shot" else split.replace(":", "-") + + +def write_json(path: Path, items: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + json.dump(items, f, ensure_ascii=False, indent=2) + + +def prepare_dataset(name: str, *, seed: int, force: bool) -> None: + spec = DATASETS[name] + raw_path = spec["raw"] + out_root = spec["out"] + filenames = spec["filenames"] + + items = load_items(raw_path) + for split in SPLITS: + ratio_tag = split_tag(split) + split_dir = out_root / f"{ratio_tag}_seed{seed}" + manifest_path = split_dir / "split_manifest.json" + if manifest_path.exists() and not force: + print(f"skip {name} {split}: {split_dir} exists") + continue + + shuffled = list(items) + random.Random(seed).shuffle(shuffled) + train_n, val_n, test_n = split_counts(len(shuffled), split) + train_items = shuffled[:train_n] + val_items = shuffled[train_n: train_n + val_n] + test_items = shuffled[train_n + val_n: train_n + val_n + test_n] + + write_json(split_dir / "train" / filenames["train"], train_items) + write_json(split_dir / "val" / filenames["val"], val_items) + write_json(split_dir / "test" / filenames["test"], test_items) + write_json( + manifest_path, + { + "dataset": name, + "source": str(raw_path), + "split_mode": "precomputed_ratio", + "split_name": split, + "split_ratio": split if split != "1shot" else "1 train / 1 val / rest test", + "split_seed": seed, + "counts": { + "train": len(train_items), + "val": len(val_items), + "test": len(test_items), + }, + }, + ) + print( + f"wrote {name} {split} -> {split_dir} " + f"(train={len(train_items)}, val={len(val_items)}, test={len(test_items)})" + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--force", action="store_true") + parser.add_argument("--dataset", choices=sorted(DATASETS), action="append") + args = parser.parse_args() + + for name in args.dataset or sorted(DATASETS): + prepare_dataset(name, seed=args.seed, force=args.force) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_ablation_matrix.py b/scripts/run_ablation_matrix.py new file mode 100755 index 0000000..a82ed52 --- /dev/null +++ b/scripts/run_ablation_matrix.py @@ -0,0 +1,680 @@ +#!/usr/bin/env python3 +"""Launch the SearchQA / SpreadsheetBench ablation matrix. + +By default this script prints commands only. Pass --execute to actually start +runs. Every run writes to a unique out_root under the run root and logs stdout +/ stderr to logs/.log. +""" +from __future__ import annotations + +import argparse +import os +import re +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +PYTHON_BIN = Path("/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python") + +T2_ENDPOINT = "https://t2vgoaigpt4o3.openai.azure.com/" +SEARCHAGENT5_ENDPOINT = "https://searchagent5.cognitiveservices.azure.com/" + + +@dataclass(frozen=True) +class Experiment: + run_id: str + benchmark: str + config: str + overrides: tuple[str, ...] + + +BENCH_CONFIG = { + "searchqa": "configs/searchqa/default.yaml", + "spreadsheetbench": "configs/spreadsheetbench/default.yaml", + "livemathematicianbench": "configs/livemathematicianbench/default.yaml", + "alfworld": "configs/alfworld/default.yaml", + "docvqa": "configs/docvqa/default.yaml", +} + +DEFAULT_SPLIT = { + "searchqa": "data/ablation_splits/searchqa/2-1-7_seed42", + "spreadsheetbench": "data/ablation_splits/spreadsheetbench/2-1-7_seed42", + "livemathematicianbench": "data/ablation_splits/livemathematicianbench/2-1-7_seed42", + "alfworld": "data/ablation_splits/alfworld/2-1-7_seed42", + "docvqa": "/home/azureuser/zisu/SkillReflection/data/docvqa/splits", +} + +DEFAULT_TRAIN_SIZE = { + "searchqa": 400, + "spreadsheetbench": 80, + "livemathematicianbench": 35, + "alfworld": 39, + "docvqa": 1070, +} + +BATCH_SIZE_VALUES: tuple[int | str, ...] = (8, 24, 40, 56, "full") + +SPLITS = { + "searchqa": { + "1shot": ("data/ablation_splits/searchqa/1shot_seed42", ("optimizer.slow_update_samples=1",)), + "1-1-8": ("data/ablation_splits/searchqa/1-1-8_seed42", ()), + "2-1-7": ("data/ablation_splits/searchqa/2-1-7_seed42", ()), + "4-1-5": ("data/ablation_splits/searchqa/4-1-5_seed42", ()), + }, + "spreadsheetbench": { + "1shot": ("data/ablation_splits/spreadsheetbench/1shot_seed42", ("optimizer.slow_update_samples=1",)), + "1-1-8": ("data/ablation_splits/spreadsheetbench/1-1-8_seed42", ()), + "2-1-7": ("data/ablation_splits/spreadsheetbench/2-1-7_seed42", ()), + "4-1-5": ("data/ablation_splits/spreadsheetbench/4-1-5_seed42", ()), + }, + "livemathematicianbench": { + "1shot": ("data/ablation_splits/livemathematicianbench/1shot_seed42", ("optimizer.slow_update_samples=1",)), + "1-1-8": ("data/ablation_splits/livemathematicianbench/1-1-8_seed42", ()), + "2-1-7": ("data/ablation_splits/livemathematicianbench/2-1-7_seed42", ()), + "4-1-5": ("data/ablation_splits/livemathematicianbench/4-1-5_seed42", ()), + }, + "alfworld": { + "1shot": ("data/ablation_splits/alfworld/1shot_seed42", ("optimizer.slow_update_samples=1",)), + "1-1-8": ("data/ablation_splits/alfworld/1-1-8_seed42", ()), + "2-1-7": ("data/ablation_splits/alfworld/2-1-7_seed42", ()), + "4-1-5": ("data/ablation_splits/alfworld/4-1-5_seed42", ()), + }, + "docvqa": { + "1shot": ("data/ablation_splits/docvqa/1shot_seed42", ("optimizer.slow_update_samples=1",)), + "1-1-8": ("data/ablation_splits/docvqa/1-1-8_seed42", ()), + "2-1-7": ("/home/azureuser/zisu/SkillReflection/data/docvqa/splits", ()), + "4-1-5": ("data/ablation_splits/docvqa/4-1-5_seed42", ()), + }, +} + + +def common_overrides(benchmark: str, out_root: Path) -> list[str]: + return [ + "model.teacher_backend=openai_chat", + "model.student_backend=openai_chat", + "model.teacher=gpt-5.5", + "model.student=gpt-5.5", + f"model.teacher_azure_openai_endpoint={T2_ENDPOINT}", + "model.teacher_azure_openai_api_version=2024-12-01-preview", + "model.teacher_azure_openai_auth_mode=azure_cli", + f"model.student_azure_openai_endpoint={T2_ENDPOINT}", + "model.student_azure_openai_api_version=2024-12-01-preview", + "model.student_azure_openai_auth_mode=azure_cli", + "model.reasoning_effort=medium", + "train.num_epochs=4", + "train.train_size=0", + "train.batch_size=40", + "train.accumulation=1", + "train.seed=42", + "gradient.minibatch_size=8", + "gradient.merge_batch_size=8", + "gradient.analyst_workers=16", + "gradient.use_deep_reflect=false", + "optimizer.learning_rate=4", + "optimizer.min_learning_rate=2", + "optimizer.lr_scheduler=cosine", + "optimizer.skill_update_mode=patch", + "optimizer.use_slow_update=true", + "optimizer.slow_update_samples=20", + "optimizer.use_meta_skill=true", + "optimizer.use_meta_reflect=false", + "evaluation.use_gate=true", + "evaluation.eval_test=true", + "env.split_mode=split_dir", + f"env.split_dir={DEFAULT_SPLIT[benchmark]}", + f"env.out_root={out_root}", + ] + + +def make_experiment( + group: str, + benchmark: str, + suffix: str, + run_root: Path, + overrides: list[str], +) -> Experiment: + run_id = f"{group}-{benchmark}-{suffix}" + out_root = run_root / run_id + all_overrides = common_overrides(benchmark, out_root) + all_overrides.extend(overrides) + return Experiment( + run_id=run_id, + benchmark=benchmark, + config=BENCH_CONFIG[benchmark], + overrides=tuple(all_overrides), + ) + + +def build_matrix( + groups: set[str], + benchmarks: list[str], + run_root: Path, + *, + include_duplicate_defaults: bool = False, +) -> list[Experiment]: + exps: list[Experiment] = [] + group_order = [ + "default", + "split", + "batch", + "mbs", + "lr", + "sched", + "slown", + "mod", + "smodel", + "longpair", + "lrctrl", + ] + + for group in group_order: + if group not in groups: + continue + for benchmark in benchmarks: + if group == "default": + exps.append(make_experiment("DEFAULT", benchmark, "5.5", run_root, [])) + continue + + if group == "split": + for tag, (split_dir, extra) in SPLITS[benchmark].items(): + if not include_duplicate_defaults and tag == "2-1-7": + continue + exps.append(make_experiment( + "SPLIT", + benchmark, + tag, + run_root, + [f"env.split_dir={split_dir}", *extra], + )) + continue + + if group == "mbs": + for value in (1, 2, 4, 8, 16, 32): + if not include_duplicate_defaults and value == 8: + continue + exps.append(make_experiment( + "MBS", + benchmark, + str(value), + run_root, + [f"gradient.minibatch_size={value}"], + )) + continue + + if group == "batch": + for value in BATCH_SIZE_VALUES: + if not include_duplicate_defaults and value == 40: + continue + batch_size = DEFAULT_TRAIN_SIZE[benchmark] if value == "full" else int(value) + exps.append(make_experiment( + "BATCH", + benchmark, + str(value), + run_root, + [ + f"train.batch_size={batch_size}", + "gradient.minibatch_size=8", + ], + )) + continue + + if group == "lr": + for value in (1, 2, 4, 8, 16): + exps.append(make_experiment( + "LR", + benchmark, + str(value), + run_root, + [ + "optimizer.lr_scheduler=constant", + "optimizer.min_learning_rate=1", + f"optimizer.learning_rate={value}", + ], + )) + continue + + if group == "sched": + for value in ("constant", "cosine", "linear"): + if not include_duplicate_defaults and value == "cosine": + continue + exps.append(make_experiment( + "SCHED", + benchmark, + value, + run_root, + [f"optimizer.lr_scheduler={value}"], + )) + continue + + if group == "slown": + for value in (5, 10, 20, 40): + if not include_duplicate_defaults and value == 20: + continue + exps.append(make_experiment( + "SLOWN", + benchmark, + str(value), + run_root, + [f"optimizer.slow_update_samples={value}"], + )) + continue + + if group == "mod": + settings = { + "slow-meta": ("true", "true"), + "slow-only": ("true", "false"), + "meta-only": ("false", "true"), + "none": ("false", "false"), + } + for tag, (slow, meta) in settings.items(): + if not include_duplicate_defaults and tag == "slow-meta": + continue + exps.append(make_experiment( + "MOD", + benchmark, + tag, + run_root, + [ + f"optimizer.use_slow_update={slow}", + f"optimizer.use_meta_skill={meta}", + ], + )) + continue + + if group == "smodel": + student_settings = { + "5.4": [ + "model.student=gpt-5.4-pro", + f"model.student_azure_openai_endpoint={T2_ENDPOINT}", + "model.student_azure_openai_api_version=2025-03-01-preview", + "model.student_azure_openai_auth_mode=azure_cli", + ], + "5.4-mini": [ + "model.student=gpt-5.4-mini", + f"model.student_azure_openai_endpoint={SEARCHAGENT5_ENDPOINT}", + "model.student_azure_openai_api_version=2024-12-01-preview", + "model.student_azure_openai_auth_mode=azure_cli", + ], + "5.5": [], + } + for tag, overrides in student_settings.items(): + if not include_duplicate_defaults and tag == "5.5": + continue + exps.append(make_experiment("SMODEL", benchmark, tag, run_root, overrides)) + continue + + if group == "longpair": + for value in ("changed", "unchanged"): + exps.append(make_experiment( + "LONGPAIR", + benchmark, + value, + run_root, + [f"optimizer.longitudinal_pair_policy={value}"], + )) + continue + + if group == "lrctrl": + settings = { + "autonomous": ["optimizer.lr_control_mode=autonomous"], + "full-rewrite": [ + "optimizer.lr_control_mode=none", + "optimizer.skill_update_mode=full_rewrite_minibatch", + ], + } + for tag, overrides in settings.items(): + exps.append(make_experiment("LRCTRL", benchmark, tag, run_root, overrides)) + continue + + return exps + + +def _build_matrix_legacy( + groups: set[str], + benchmarks: list[str], + run_root: Path, + *, + include_duplicate_defaults: bool = False, +) -> list[Experiment]: + exps: list[Experiment] = [] + for benchmark in benchmarks: + if "default" in groups: + exps.append(make_experiment("DEFAULT", benchmark, "5.5", run_root, [])) + + if "split" in groups: + for tag, (split_dir, extra) in SPLITS[benchmark].items(): + if not include_duplicate_defaults and tag == "2-1-7": + continue + exps.append(make_experiment( + "SPLIT", + benchmark, + tag, + run_root, + [f"env.split_dir={split_dir}", *extra], + )) + + if "mbs" in groups: + for value in (1, 2, 4, 8, 16, 32): + if not include_duplicate_defaults and value == 8: + continue + exps.append(make_experiment( + "MBS", + benchmark, + str(value), + run_root, + [f"gradient.minibatch_size={value}"], + )) + + if "batch" in groups: + for value in BATCH_SIZE_VALUES: + if not include_duplicate_defaults and value == 40: + continue + batch_size = DEFAULT_TRAIN_SIZE[benchmark] if value == "full" else int(value) + exps.append(make_experiment( + "BATCH", + benchmark, + str(value), + run_root, + [ + f"train.batch_size={batch_size}", + "gradient.minibatch_size=8", + ], + )) + + if "lr" in groups: + for value in (1, 2, 4, 8, 16): + exps.append(make_experiment( + "LR", + benchmark, + str(value), + run_root, + [ + "optimizer.lr_scheduler=constant", + "optimizer.min_learning_rate=1", + f"optimizer.learning_rate={value}", + ], + )) + + if "sched" in groups: + for value in ("constant", "cosine", "linear"): + if not include_duplicate_defaults and value == "cosine": + continue + exps.append(make_experiment( + "SCHED", + benchmark, + value, + run_root, + [f"optimizer.lr_scheduler={value}"], + )) + + if "slown" in groups: + for value in (5, 10, 20, 40): + if not include_duplicate_defaults and value == 20: + continue + exps.append(make_experiment( + "SLOWN", + benchmark, + str(value), + run_root, + [f"optimizer.slow_update_samples={value}"], + )) + + if "mod" in groups: + settings = { + "slow-meta": ("true", "true"), + "slow-only": ("true", "false"), + "meta-only": ("false", "true"), + "none": ("false", "false"), + } + for tag, (slow, meta) in settings.items(): + if not include_duplicate_defaults and tag == "slow-meta": + continue + exps.append(make_experiment( + "MOD", + benchmark, + tag, + run_root, + [ + f"optimizer.use_slow_update={slow}", + f"optimizer.use_meta_skill={meta}", + ], + )) + + if "smodel" in groups: + student_settings = { + "5.4": [ + "model.student=gpt-5.4-pro", + f"model.student_azure_openai_endpoint={T2_ENDPOINT}", + "model.student_azure_openai_api_version=2025-03-01-preview", + "model.student_azure_openai_auth_mode=azure_cli", + ], + "5.4-mini": [ + "model.student=gpt-5.4-mini", + f"model.student_azure_openai_endpoint={SEARCHAGENT5_ENDPOINT}", + "model.student_azure_openai_api_version=2024-12-01-preview", + "model.student_azure_openai_auth_mode=azure_cli", + ], + "5.5": [], + } + for tag, overrides in student_settings.items(): + if not include_duplicate_defaults and tag == "5.5": + continue + exps.append(make_experiment("SMODEL", benchmark, tag, run_root, overrides)) + + if "longpair" in groups: + for value in ("changed", "unchanged"): + exps.append(make_experiment( + "LONGPAIR", + benchmark, + value, + run_root, + [f"optimizer.longitudinal_pair_policy={value}"], + )) + + if "lrctrl" in groups: + settings = { + "autonomous": ["optimizer.lr_control_mode=autonomous"], + "full-rewrite": [ + "optimizer.lr_control_mode=none", + "optimizer.skill_update_mode=full_rewrite_minibatch", + ], + } + for tag, overrides in settings.items(): + exps.append(make_experiment("LRCTRL", benchmark, tag, run_root, overrides)) + + return exps + + +def command_for(exp: Experiment) -> list[str]: + return [ + str(PYTHON_BIN), + "scripts/train.py", + "--config", + exp.config, + "--cfg-options", + *exp.overrides, + ] + + +def active_run_ids(run_root: Path, valid_run_ids: set[str] | None = None) -> set[str]: + try: + raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True) + except subprocess.CalledProcessError: + return set() + pattern = re.compile(re.escape(str(run_root)) + r"/([^\s]+)") + active: set[str] = set() + for line in raw.splitlines(): + for match in pattern.finditer(line): + run_id = match.group(1).strip("'\"") + if run_id.endswith(".log") or "/" in run_id: + continue + if valid_run_ids is not None and run_id not in valid_run_ids: + continue + active.add(run_id) + return active + + +def completed_run_ids(run_root: Path) -> set[str]: + return { + path.parent.name + for path in run_root.glob("*/summary.json") + if path.is_file() + } + + +def print_commands(exps: list[Experiment]) -> None: + for exp in exps: + cmd = command_for(exp) + print(f"\n# {exp.run_id}") + print(" ".join(subprocess.list2cmdline([part]) for part in cmd)) + + +def run_commands( + exps: list[Experiment], + run_root: Path, + max_parallel: int, + run_retries: int, +) -> int: + logs_dir = run_root / "logs" + logs_dir.mkdir(parents=True, exist_ok=True) + active: list[tuple[Experiment, subprocess.Popen, object]] = [] + valid_run_ids = {exp.run_id for exp in exps} + skipped_completed = completed_run_ids(run_root) + skipped_active = active_run_ids(run_root, valid_run_ids) + pending: list[tuple[Experiment, int]] = [ + (exp, 0) + for exp in exps + if exp.run_id not in skipped_completed and exp.run_id not in skipped_active + ] + for run_id in sorted(skipped_completed): + print(f"[SKIP_COMPLETED] {run_id}", flush=True) + for run_id in sorted(skipped_active): + print(f"[SKIP_ACTIVE] {run_id}", flush=True) + failures = 0 + + while pending or active: + external_active = active_run_ids(run_root, valid_run_ids) - {exp.run_id for exp, _, _ in active} + while pending and len(active) + len(external_active) < max_parallel: + exp, attempt = pending.pop(0) + log_path = logs_dir / f"{exp.run_id}.log" + if attempt: + log_path = logs_dir / f"{exp.run_id}.retry{attempt}.log" + log_f = open(log_path, "w", encoding="utf-8") + print(f"[START] {exp.run_id} attempt={attempt + 1} log={log_path}", flush=True) + proc = subprocess.Popen( + command_for(exp), + cwd=PROJECT_ROOT, + stdout=log_f, + stderr=subprocess.STDOUT, + text=True, + ) + setattr(proc, "_attempt", attempt) + active.append((exp, proc, log_f)) + + time.sleep(5) + still_active: list[tuple[Experiment, subprocess.Popen, object]] = [] + for exp, proc, log_f in active: + rc = proc.poll() + if rc is None: + still_active.append((exp, proc, log_f)) + continue + log_f.close() + if rc == 0: + print(f"[DONE] {exp.run_id}", flush=True) + else: + if getattr(proc, "_attempt", 0) < run_retries: + next_attempt = getattr(proc, "_attempt", 0) + 1 + pending.append((exp, next_attempt)) + print(f"[RETRY] {exp.run_id} rc={rc} next_attempt={next_attempt + 1}", flush=True) + else: + failures += 1 + print(f"[FAIL] {exp.run_id} rc={rc}", flush=True) + active = still_active + + return failures + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--groups", + nargs="+", + default=["default"], + choices=[ + "default", + "split", + "batch", + "mbs", + "lr", + "sched", + "slown", + "mod", + "smodel", + "longpair", + "lrctrl", + "all", + ], + help="Experiment groups to include. Default: default.", + ) + parser.add_argument( + "--bench", + nargs="+", + default=["searchqa", "spreadsheetbench"], + choices=["searchqa", "spreadsheetbench", "livemathematicianbench", "alfworld", "docvqa"], + ) + parser.add_argument("--run-root", default="", help="Output root. Default: outputs/ablation_.") + parser.add_argument("--max-parallel", type=int, default=1) + parser.add_argument("--run-retries", type=int, default=1, help="Retry failed runs this many times. Default: 1.") + parser.add_argument( + "--include-duplicate-defaults", + action="store_true", + help="Also run ablation points that are exactly the default setting.", + ) + parser.add_argument("--execute", action="store_true", help="Actually start runs. Without this, print commands only.") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + groups = set(args.groups) + if "all" in groups: + groups = {"default", "split", "batch", "mbs", "lr", "sched", "slown", "mod", "smodel"} + + ts = time.strftime("%Y%m%d_%H%M%S", time.gmtime()) + run_root = Path(args.run_root) if args.run_root else PROJECT_ROOT / "outputs" / f"ablation_{ts}" + if not run_root.is_absolute(): + run_root = PROJECT_ROOT / run_root + run_root.mkdir(parents=True, exist_ok=True) + + exps = build_matrix( + groups, + args.bench, + run_root, + include_duplicate_defaults=args.include_duplicate_defaults, + ) + print(f"run_root={run_root}") + print(f"num_experiments={len(exps)}") + print(f"groups={','.join(sorted(groups))}") + print(f"bench={','.join(args.bench)}") + + if not args.execute: + print_commands(exps) + return + + max_parallel = max(1, int(args.max_parallel)) + failures = run_commands( + exps, + run_root, + max_parallel=max_parallel, + run_retries=max(0, int(args.run_retries)), + ) + if failures: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_alfworld.sh b/scripts/run_alfworld.sh new file mode 100755 index 0000000..5b0f474 --- /dev/null +++ b/scripts/run_alfworld.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# ────────────────────────────────────────────────────────────────────────────── +# ReflACT — ALFWorld training launch script +# +# Usage: +# bash scripts/run_alfworld.sh +# bash scripts/run_alfworld.sh --num_epochs 2 --edit_budget 6 +# ────────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +# ── Paths ──────────────────────────────────────────────────────────────────── +WORKSPACE="/home/azureuser/workspace-gzy" +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")" + +# Activate conda environment +export PATH="${WORKSPACE}/miniconda3/envs/skillopt/bin:${WORKSPACE}/miniconda3/bin:${PATH}" + +# ALFWorld data — uses ~/.cache/alfworld by default (standard alfworld location) +export ALFWORLD_DATA="${ALFWORLD_DATA:-${HOME}/.cache/alfworld}" + +# Ensure ReflACT is importable +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +# ── Verify ALFWorld data exists ────────────────────────────────────────────── +if [ ! -d "${ALFWORLD_DATA}/json_2.1.1" ]; then + echo "ERROR: ALFWorld data not found at ${ALFWORLD_DATA}/json_2.1.1" + echo "" + echo "To download ALFWorld data, run:" + echo " pip install alfworld[full]" + echo " alfworld-download" + echo "" + echo "Or set ALFWORLD_DATA to the directory containing json_2.1.1/" + exit 1 +fi + +# ── Azure OpenAI credentials ──────────────────────────────────────────────── +export AZURE_OPENAI_ENDPOINT="${AZURE_OPENAI_ENDPOINT:-https://agl-dev.cognitiveservices.azure.com/}" +export AZURE_OPENAI_API_KEY="${AZURE_OPENAI_API_KEY:-}" +export AZURE_OPENAI_API_VERSION="${AZURE_OPENAI_API_VERSION:-2025-04-01-preview}" + +# ── Model configuration ───────────────────────────────────────────────────── +export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}" +export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}" + +# ── Output directory ───────────────────────────────────────────────────────── +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +DEFAULT_OUT_ROOT="${PROJECT_ROOT}/outputs/skillopt_alfworld_${STUDENT_DEPLOYMENT}_${TIMESTAMP}" + +# ── Run ────────────────────────────────────────────────────────────────────── +echo "============================================================" +echo " ReflACT — Reflective Agent Tuning (ALFWorld)" +echo "============================================================" +echo " Teacher: ${TEACHER_DEPLOYMENT}" +echo " Student: ${STUDENT_DEPLOYMENT}" +echo " ALFWORLD_DATA: ${ALFWORLD_DATA}" +echo " Output: ${DEFAULT_OUT_ROOT}" +echo "============================================================" + +cd "${PROJECT_ROOT}" + +python scripts/train.py \ + --config configs/alfworld_default.yaml \ + --out_root "${DEFAULT_OUT_ROOT}" \ + "$@" + +echo "" +echo "Done! Results saved to: ${DEFAULT_OUT_ROOT}" diff --git a/scripts/run_meta_skill_ablation.sh b/scripts/run_meta_skill_ablation.sh new file mode 100755 index 0000000..cdbaf9a --- /dev/null +++ b/scripts/run_meta_skill_ablation.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash +set -euo pipefail + +PROJECT_ROOT="/home/azureuser/workspace-gzy/SkillReflection_dev" +PYTHON_BIN="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python" +TS="$(date -u +%Y%m%d_%H%M%S)" +RUN_ROOT="${PROJECT_ROOT}/outputs/meta_skill_ablation_${TS}" + +mkdir -p "${RUN_ROOT}" + +run_train() { + local benchmark="$1" + local reasoning="$2" + local condition="$3" + local config_path="" + local reasoning_override="" + local meta_skill_flag="" + + case "${benchmark}" in + searchqa) + config_path="configs/searchqa/default.yaml" + ;; + spreadsheetbench) + config_path="configs/spreadsheetbench/default.yaml" + ;; + *) + echo "Unknown benchmark: ${benchmark}" >&2 + exit 1 + ;; + esac + + case "${reasoning}" in + medium) + reasoning_override="model.reasoning_effort=medium" + ;; + none) + reasoning_override="model.reasoning_effort=" + ;; + *) + echo "Unknown reasoning setting: ${reasoning}" >&2 + exit 1 + ;; + esac + + case "${condition}" in + slow) + meta_skill_flag="optimizer.use_meta_skill=false" + ;; + slow_meta) + meta_skill_flag="optimizer.use_meta_skill=true" + ;; + *) + echo "Unknown condition: ${condition}" >&2 + exit 1 + ;; + esac + + local out_root="${RUN_ROOT}/${benchmark}_${reasoning}_${condition}" + + echo + echo "============================================================" + echo "START ${benchmark} ${reasoning} ${condition}" + echo "out_root=${out_root}" + echo "============================================================" + + ( + cd "${PROJECT_ROOT}" + "${PYTHON_BIN}" scripts/train.py \ + --config "${config_path}" \ + --cfg-options \ + "${reasoning_override}" \ + "optimizer.use_slow_update=true" \ + "${meta_skill_flag}" \ + "optimizer.use_meta_reflect=false" \ + "gradient.use_deep_reflect=false" \ + "env.out_root=${out_root}" + ) + + echo + echo "============================================================" + echo "DONE ${benchmark} ${reasoning} ${condition}" + echo "============================================================" +} + +for benchmark in searchqa spreadsheetbench; do + for reasoning in medium none; do + run_train "${benchmark}" "${reasoning}" "slow" + run_train "${benchmark}" "${reasoning}" "slow_meta" + done +done + +echo +echo "All runs completed." +echo "Run root: ${RUN_ROOT}" diff --git a/scripts/run_missing_meta_parallel.sh b/scripts/run_missing_meta_parallel.sh new file mode 100644 index 0000000..01de6c4 --- /dev/null +++ b/scripts/run_missing_meta_parallel.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +set -euo pipefail + +PROJECT_ROOT="/home/azureuser/workspace-gzy/SkillReflection_dev" +PYTHON_BIN="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python" +RUN_ROOT="${PROJECT_ROOT}/outputs/meta_skill_parallel_20260430_072356" +LOG_DIR="${PROJECT_ROOT}/logs/meta_skill_parallel_20260430_072356" + +mkdir -p "${RUN_ROOT}" "${LOG_DIR}" + +start_run() { + local name="$1" + local config_path="$2" + local meta_skill="$3" + local out_root="${RUN_ROOT}/${name}" + local log_path="${LOG_DIR}/${name}.log" + + echo "[START] ${name}" + echo " out_root=${out_root}" + echo " log=${log_path}" + + ( + cd "${PROJECT_ROOT}" + PYTHONUNBUFFERED=1 "${PYTHON_BIN}" scripts/train.py \ + --config "${config_path}" \ + --cfg-options \ + "model.reasoning_effort=medium" \ + "optimizer.use_slow_update=true" \ + "optimizer.use_meta_skill=${meta_skill}" \ + "optimizer.use_meta_reflect=false" \ + "gradient.use_deep_reflect=false" \ + "env.out_root=${out_root}" + ) > "${log_path}" 2>&1 & + + echo "$!" > "${LOG_DIR}/${name}.pid" +} + +start_run "searchqa_medium_slow_meta" "configs/searchqa/default.yaml" "true" +start_run "spreadsheetbench_medium_slow" "configs/spreadsheetbench/default.yaml" "false" +start_run "spreadsheetbench_medium_slow_meta" "configs/spreadsheetbench/default.yaml" "true" + +echo "[WAIT] missing comparison runs are active" +wait diff --git a/scripts/run_searchqa.sh b/scripts/run_searchqa.sh new file mode 100755 index 0000000..16bb1d2 --- /dev/null +++ b/scripts/run_searchqa.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# ────────────────────────────────────────────────────────────────────────────── +# ReflACT — SearchQA training launch script +# +# Usage: +# bash scripts/run_searchqa.sh +# bash scripts/run_searchqa.sh --data_path data/searchqa_train_2000.json +# bash scripts/run_searchqa.sh --num_epochs 2 --edit_budget 6 +# ────────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +# ── Paths ──────────────────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")" + +# Ensure ReflACT is importable +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +# ── Model configuration ───────────────────────────────────────────────────── +export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}" +export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}" + +# ── Output directory ───────────────────────────────────────────────────────── +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +DEFAULT_OUT_ROOT="${PROJECT_ROOT}/outputs/skillopt_searchqa_${STUDENT_DEPLOYMENT}_${TIMESTAMP}" + +# ── Run ────────────────────────────────────────────────────────────────────── +echo "============================================================" +echo " ReflACT — Reflective Agent Tuning (SearchQA)" +echo "============================================================" +echo " Teacher: ${TEACHER_DEPLOYMENT}" +echo " Student: ${STUDENT_DEPLOYMENT}" +echo "============================================================" + +cd "${PROJECT_ROOT}" + +python scripts/train.py \ + --config configs/searchqa_default.yaml \ + --out_root "${DEFAULT_OUT_ROOT}" \ + "$@" + +echo "" +echo "Done! Results saved to: ${DEFAULT_OUT_ROOT}" diff --git a/scripts/run_spreadsheetbench.sh b/scripts/run_spreadsheetbench.sh new file mode 100755 index 0000000..74e998f --- /dev/null +++ b/scripts/run_spreadsheetbench.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +# ────────────────────────────────────────────────────────────────────────────── +# ReflACT — SpreadsheetBench training launch script +# +# Usage: +# bash scripts/run_spreadsheetbench.sh \ +# --data_root /path/to/data \ +# --jsonl_path /path/to/benchmark.jsonl +# +# bash scripts/run_spreadsheetbench.sh \ +# --data_root /path/to/data \ +# --jsonl_path /path/to/benchmark.jsonl \ +# --num_epochs 2 --edit_budget 6 +# ────────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +# ── Paths ──────────────────────────────────────────────────────────────────── +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")" + +# Ensure ReflACT is importable +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +# ── Model configuration ───────────────────────────────────────────────────── +export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}" +export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}" + +# ── Output directory ───────────────────────────────────────────────────────── +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +DEFAULT_OUT_ROOT="${PROJECT_ROOT}/outputs/skillopt_spreadsheetbench_${STUDENT_DEPLOYMENT}_${TIMESTAMP}" + +# ── Run ────────────────────────────────────────────────────────────────────── +echo "============================================================" +echo " ReflACT — Reflective Agent Tuning (SpreadsheetBench)" +echo "============================================================" +echo " Teacher: ${TEACHER_DEPLOYMENT}" +echo " Student: ${STUDENT_DEPLOYMENT}" +echo "============================================================" + +cd "${PROJECT_ROOT}" + +python scripts/train.py \ + --config configs/spreadsheetbench_default.yaml \ + --out_root "${DEFAULT_OUT_ROOT}" \ + "$@" + +echo "" +echo "Done! Results saved to: ${DEFAULT_OUT_ROOT}" diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..27060ce --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,483 @@ +#!/usr/bin/env python3 +"""ReflACT unified training entry point. + +Usage +----- + python scripts/train.py --config configs/alfworld/default.yaml + +Any YAML key can be overridden from the command line:: + + python scripts/train.py --config configs/alfworld/default.yaml \\ + --batch_size 40 --num_epochs 2 --seed 123 + +Run ``python scripts/train.py --help`` for a full list of options. +""" +from __future__ import annotations + +import argparse +import datetime +import os +import sys + +# Ensure the project root is on sys.path so ``import skillopt`` works +# regardless of where the script is invoked from. +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR) +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +from skillopt.model.common import default_model_for_backend, normalize_backend_name + +_OPENAI_DEFAULT_MODEL_SENTINELS = {"gpt-5.4", "gpt-5.5"} + + +# ── Environment registry ──────────────────────────────────────────────────── + +_ENV_REGISTRY: dict[str, type] = {} + + +def _register_builtins() -> None: + """Lazy-import built-in adapters so we don't pull heavy deps at CLI parse time.""" + try: + from skillopt.envs.alfworld.adapter import ALFWorldAdapter + _ENV_REGISTRY["alfworld"] = ALFWorldAdapter + except ImportError: + pass # ALFWorld deps not installed — skip + try: + from skillopt.envs.searchqa.adapter import SearchQAAdapter + _ENV_REGISTRY["searchqa"] = SearchQAAdapter + except ImportError: + pass + try: + from skillopt.envs.livemathematicianbench.adapter import LiveMathematicianBenchAdapter + _ENV_REGISTRY["livemathematicianbench"] = LiveMathematicianBenchAdapter + except ImportError: + pass + try: + from skillopt.envs.babyvision.adapter import BabyVisionAdapter + _ENV_REGISTRY["babyvision"] = BabyVisionAdapter + except ImportError: + pass + try: + from skillopt.envs.spreadsheetbench.adapter import SpreadsheetBenchAdapter + _ENV_REGISTRY["spreadsheetbench"] = SpreadsheetBenchAdapter + except ImportError: + pass + try: + from skillopt.envs.mmrb.adapter import MMRBAdapter + _ENV_REGISTRY["mmrb"] = MMRBAdapter + except ImportError: + pass + try: + from skillopt.envs.docvqa.adapter import DocVQAAdapter + _ENV_REGISTRY["docvqa"] = DocVQAAdapter + except ImportError: + pass + try: + from skillopt.envs.mathverse.adapter import MathVerseAdapter + _ENV_REGISTRY["mathverse"] = MathVerseAdapter + except ImportError: + pass + try: + from skillopt.envs.officeqa.adapter import OfficeQAAdapter + _ENV_REGISTRY["officeqa"] = OfficeQAAdapter + except ImportError: + pass + try: + from skillopt.envs.sealqa.adapter import SealQAAdapter + _ENV_REGISTRY["sealqa"] = SealQAAdapter + except ImportError: + pass + try: + from skillopt.envs.swebench.adapter import SWEBenchAdapter + _ENV_REGISTRY["swebench"] = SWEBenchAdapter + except ImportError: + pass + + +def get_adapter(cfg: dict): + """Instantiate the environment adapter specified in ``cfg["env"]``.""" + _register_builtins() + env_name = cfg.get("env", "alfworld") + if env_name not in _ENV_REGISTRY: + raise ValueError( + f"Unknown environment '{env_name}'. " + f"Available: {list(_ENV_REGISTRY.keys())}" + ) + adapter_cls = _ENV_REGISTRY[env_name] + + # Inspect adapter __init__ signature and only pass accepted kwargs + import inspect + sig = inspect.signature(adapter_cls.__init__) + accepted = set(sig.parameters.keys()) - {"self"} + adapter_kwargs: dict = {} + for key in accepted: + if key in cfg: + adapter_kwargs[key] = cfg[key] + + return adapter_cls(**adapter_kwargs) + + +# ── CLI ────────────────────────────────────────────────────────────────────── + +_BOOL = lambda x: x.lower() in ("true", "1", "yes") # noqa: E731 + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="ReflACT: Reflective Agent Tuning", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + p.add_argument("--config", type=str, required=True, + help="Path to YAML config file") + p.add_argument("--cfg-options", nargs="+", default=[], + help="Override config: section.key=value (e.g. train.batch_size=40)") + + # Legacy flat CLI overrides (still work, prefer --cfg-options for new usage) + p.add_argument("--env", type=str) + p.add_argument("--backend", type=str, + choices=["azure_openai", "codex", "codex_exec", "claude", "claude_chat", "claude_code_exec"]) + p.add_argument("--teacher_model", type=str) + p.add_argument("--student_model", type=str) + p.add_argument("--teacher_backend", type=str) + p.add_argument("--student_backend", type=str) + p.add_argument("--reasoning_effort", type=str, + choices=["", "low", "medium", "high", "xhigh", "max"]) + p.add_argument("--rewrite_reasoning_effort", type=str) + p.add_argument("--rewrite_max_completion_tokens", type=int) + p.add_argument("--azure_endpoint", type=str) + p.add_argument("--azure_api_version", type=str) + p.add_argument("--azure_api_key", type=str) + p.add_argument("--azure_openai_endpoint", type=str) + p.add_argument("--azure_openai_api_version", type=str) + p.add_argument("--azure_openai_api_key", type=str) + p.add_argument("--azure_openai_auth_mode", type=str) + p.add_argument("--azure_openai_ad_scope", type=str) + p.add_argument("--azure_openai_managed_identity_client_id", type=str) + p.add_argument("--teacher_azure_openai_endpoint", type=str) + p.add_argument("--teacher_azure_openai_api_version", type=str) + p.add_argument("--teacher_azure_openai_api_key", type=str) + p.add_argument("--teacher_azure_openai_auth_mode", type=str) + p.add_argument("--teacher_azure_openai_ad_scope", type=str) + p.add_argument("--teacher_azure_openai_managed_identity_client_id", type=str) + p.add_argument("--student_azure_openai_endpoint", type=str) + p.add_argument("--student_azure_openai_api_version", type=str) + p.add_argument("--student_azure_openai_api_key", type=str) + p.add_argument("--student_azure_openai_auth_mode", type=str) + p.add_argument("--student_azure_openai_ad_scope", type=str) + p.add_argument("--student_azure_openai_managed_identity_client_id", type=str) + p.add_argument("--codex_exec_path", type=str) + p.add_argument("--codex_exec_sandbox", type=str) + p.add_argument("--codex_exec_profile", type=str) + p.add_argument("--codex_exec_full_auto", type=_BOOL) + p.add_argument("--codex_exec_reasoning_effort", type=str) + p.add_argument("--codex_exec_use_sdk", type=str) + p.add_argument("--codex_exec_network_access", type=_BOOL) + p.add_argument("--codex_exec_web_search", type=_BOOL) + p.add_argument("--codex_exec_approval_policy", type=str) + p.add_argument("--claude_code_exec_path", type=str) + p.add_argument("--claude_code_exec_profile", type=str) + p.add_argument("--claude_code_exec_use_sdk", type=str) + p.add_argument("--claude_code_exec_effort", type=str) + p.add_argument("--claude_code_exec_max_thinking_tokens", type=int) + p.add_argument("--codex_trace_to_teacher", type=_BOOL) + p.add_argument("--skill_init", type=str) + p.add_argument("--num_epochs", type=int) + p.add_argument("--train_size", type=int) + p.add_argument("--steps_per_epoch", type=int) + p.add_argument("--batch_size", type=int) + p.add_argument("--accumulation", type=int) + p.add_argument("--seed", type=int) + p.add_argument("--edit_budget", type=int) + p.add_argument("--min_edit_budget", type=int) + p.add_argument("--lr_scheduler", type=str, + choices=["constant", "linear", "cosine", "autonomous"]) + p.add_argument("--lr_control_mode", type=str, + choices=["fixed", "autonomous", "none"]) + p.add_argument("--merge_batch_size", type=int) + p.add_argument("--max_analyst_rounds", type=int) + p.add_argument("--sel_env_num", type=int) + p.add_argument("--test_env_num", type=int) + p.add_argument("--eval_test", type=_BOOL) + p.add_argument("--use_gate", type=_BOOL) + p.add_argument("--max_steps", type=int) + p.add_argument("--max_api_workers", type=int) + p.add_argument("--analyst_workers", type=int) + p.add_argument("--failure_only", type=_BOOL) + p.add_argument("--minibatch_size", type=int) + p.add_argument("--use_meta_reflect", type=_BOOL) + p.add_argument("--meta_edit_budget", type=int) + p.add_argument("--skill_update_mode", type=str, + choices=[ + "patch", + "rewrite_from_suggestions", + "rewrite", + "suggestions", + "full_rewrite", + "full_rewrite_minibatch", + "minibatch_full_rewrite", + ]) + p.add_argument("--use_deep_reflect", type=_BOOL) + p.add_argument("--deep_reflect_failures", type=int) + p.add_argument("--deep_reflect_successes", type=int) + p.add_argument("--use_slow_update", type=_BOOL) + p.add_argument("--slow_update_samples", type=int) + p.add_argument("--longitudinal_pair_policy", type=str, + choices=["mixed", "changed", "unchanged"]) + p.add_argument("--use_meta_skill", type=_BOOL) + p.add_argument("--data_path", type=str) + p.add_argument("--split_mode", type=str, + choices=["ratio", "split_dir"]) + p.add_argument("--split_ratio", type=str) + p.add_argument("--split_seed", type=int) + p.add_argument("--split_dir", type=str) + p.add_argument("--split_output_dir", type=str) + p.add_argument("--data_root", type=str) + p.add_argument("--max_turns", type=int) + p.add_argument("--workers", type=int) + p.add_argument("--limit", type=int) + p.add_argument("--shuffle_choices", type=_BOOL) + p.add_argument("--use_theorem", type=_BOOL) + p.add_argument("--use_sketch", type=_BOOL) + p.add_argument("--image_detail", type=str) + p.add_argument("--judge_model", type=str) + p.add_argument("--judge_max_completion_tokens", type=int) + p.add_argument("--judge_retries", type=int) + p.add_argument("--out_root", type=str) + p.add_argument("--mode", type=str) + + return p.parse_args() + + +# ── Flat key → structured path mapping (for legacy CLI → structured config) ── + +_LEGACY_TO_STRUCTURED: dict[str, str] = { + "backend": "model.backend", + "teacher_model": "model.teacher", + "student_model": "model.student", + "teacher_backend": "model.teacher_backend", + "student_backend": "model.student_backend", + "reasoning_effort": "model.reasoning_effort", + "rewrite_reasoning_effort": "model.rewrite_reasoning_effort", + "rewrite_max_completion_tokens": "model.rewrite_max_completion_tokens", + "azure_endpoint": "model.azure_endpoint", + "azure_api_version": "model.azure_api_version", + "azure_api_key": "model.azure_api_key", + "azure_openai_endpoint": "model.azure_openai_endpoint", + "azure_openai_api_version": "model.azure_openai_api_version", + "azure_openai_api_key": "model.azure_openai_api_key", + "azure_openai_auth_mode": "model.azure_openai_auth_mode", + "azure_openai_ad_scope": "model.azure_openai_ad_scope", + "azure_openai_managed_identity_client_id": "model.azure_openai_managed_identity_client_id", + "teacher_azure_openai_endpoint": "model.teacher_azure_openai_endpoint", + "teacher_azure_openai_api_version": "model.teacher_azure_openai_api_version", + "teacher_azure_openai_api_key": "model.teacher_azure_openai_api_key", + "teacher_azure_openai_auth_mode": "model.teacher_azure_openai_auth_mode", + "teacher_azure_openai_ad_scope": "model.teacher_azure_openai_ad_scope", + "teacher_azure_openai_managed_identity_client_id": "model.teacher_azure_openai_managed_identity_client_id", + "student_azure_openai_endpoint": "model.student_azure_openai_endpoint", + "student_azure_openai_api_version": "model.student_azure_openai_api_version", + "student_azure_openai_api_key": "model.student_azure_openai_api_key", + "student_azure_openai_auth_mode": "model.student_azure_openai_auth_mode", + "student_azure_openai_ad_scope": "model.student_azure_openai_ad_scope", + "student_azure_openai_managed_identity_client_id": "model.student_azure_openai_managed_identity_client_id", + "codex_exec_path": "model.codex_exec_path", + "codex_exec_sandbox": "model.codex_exec_sandbox", + "codex_exec_profile": "model.codex_exec_profile", + "codex_exec_full_auto": "model.codex_exec_full_auto", + "codex_exec_reasoning_effort": "model.codex_exec_reasoning_effort", + "codex_exec_use_sdk": "model.codex_exec_use_sdk", + "codex_exec_network_access": "model.codex_exec_network_access", + "codex_exec_web_search": "model.codex_exec_web_search", + "codex_exec_approval_policy": "model.codex_exec_approval_policy", + "claude_code_exec_path": "model.claude_code_exec_path", + "claude_code_exec_profile": "model.claude_code_exec_profile", + "claude_code_exec_use_sdk": "model.claude_code_exec_use_sdk", + "claude_code_exec_effort": "model.claude_code_exec_effort", + "claude_code_exec_max_thinking_tokens": "model.claude_code_exec_max_thinking_tokens", + "codex_trace_to_teacher": "model.codex_trace_to_teacher", + "num_epochs": "train.num_epochs", + "train_size": "train.train_size", + "steps_per_epoch": "train.steps_per_epoch", + "batch_size": "train.batch_size", + "accumulation": "train.accumulation", + "seed": "train.seed", + "minibatch_size": "gradient.minibatch_size", + "merge_batch_size": "gradient.merge_batch_size", + "analyst_workers": "gradient.analyst_workers", + "max_analyst_rounds": "gradient.max_analyst_rounds", + "failure_only": "gradient.failure_only", + "use_deep_reflect": "gradient.use_deep_reflect", + "deep_reflect_failures": "gradient.deep_reflect_failures", + "deep_reflect_successes": "gradient.deep_reflect_successes", + "edit_budget": "optimizer.learning_rate", + "min_edit_budget": "optimizer.min_learning_rate", + "lr_scheduler": "optimizer.lr_scheduler", + "lr_control_mode": "optimizer.lr_control_mode", + "skill_update_mode": "optimizer.skill_update_mode", + "use_meta_reflect": "optimizer.use_meta_reflect", + "meta_edit_budget": "optimizer.meta_learning_rate", + "use_slow_update": "optimizer.use_slow_update", + "slow_update_samples": "optimizer.slow_update_samples", + "longitudinal_pair_policy": "optimizer.longitudinal_pair_policy", + "use_meta_skill": "optimizer.use_meta_skill", + "use_gate": "evaluation.use_gate", + "sel_env_num": "evaluation.sel_env_num", + "test_env_num": "evaluation.test_env_num", + "eval_test": "evaluation.eval_test", + "env": "env.name", + "skill_init": "env.skill_init", + "out_root": "env.out_root", +} + + +def load_config(args: argparse.Namespace) -> dict: + """Load config with _base_ inheritance, then apply CLI overrides.""" + from skillopt.config import load_config as _load, flatten_config, is_structured + + cfg = _load(args.config, overrides=args.cfg_options) + structured = is_structured(cfg) + + # Apply legacy --key value overrides + cli = {k: v for k, v in vars(args).items() + if v is not None and k not in ("config", "cfg_options")} + if cli: + if structured: + from skillopt.config import apply_overrides + mapped = [] + for k, v in cli.items(): + dotted = _LEGACY_TO_STRUCTURED.get(k) + if dotted: + mapped.append(f"{dotted}={v}") + else: + mapped.append(f"env.{k}={v}") + apply_overrides(cfg, mapped) + else: + cfg.update(cli) + + # Flatten structured config → flat dict for trainer/adapter + flat = flatten_config(cfg) if structured else cfg + + for new_key, old_key in ( + ("azure_openai_endpoint", "azure_endpoint"), + ("azure_openai_api_version", "azure_api_version"), + ("azure_openai_api_key", "azure_api_key"), + ): + if flat.get(new_key) in (None, "") and flat.get(old_key) not in (None, ""): + flat[new_key] = flat[old_key] + + explicit_backend = getattr(args, "backend", None) + if explicit_backend is None: + for option in args.cfg_options or []: + key = str(option).split("=", 1)[0].strip() + if key == "model.backend": + explicit_backend = str(option).split("=", 1)[1].strip() + break + + backend = normalize_backend_name(flat.get("model_backend") or flat.get("student_backend") or "azure_openai") + + def _has_model_override(dotted_key: str, legacy_key: str) -> bool: + if getattr(args, legacy_key, None) is not None: + return True + for option in args.cfg_options or []: + key = str(option).split("=", 1)[0].strip() + if key == dotted_key: + return True + return False + + if explicit_backend is not None: + backend = normalize_backend_name(explicit_backend) + flat["model_backend"] = backend + if backend in {"claude", "claude_chat"}: + flat.setdefault("teacher_backend", "claude_chat") + flat.setdefault("student_backend", "claude_chat") + elif backend in {"codex", "codex_exec"}: + flat.setdefault("teacher_backend", "openai_chat") + flat.setdefault("student_backend", "codex_exec") + elif backend == "claude_code_exec": + flat.setdefault("teacher_backend", "openai_chat") + flat.setdefault("student_backend", "claude_code_exec") + else: + flat.setdefault("teacher_backend", "openai_chat") + flat.setdefault("student_backend", "openai_chat") + else: + flat.setdefault("teacher_backend", "openai_chat") + flat.setdefault("student_backend", "openai_chat") + + if flat.get("teacher_backend") == "claude_chat": + if ( + str(flat.get("teacher_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS + and not _has_model_override("model.teacher", "teacher_model") + ): + flat["teacher_model"] = default_model_for_backend("claude_chat") + if flat.get("student_backend") == "claude_chat": + if ( + str(flat.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS + and not _has_model_override("model.student", "student_model") + ): + flat["student_model"] = default_model_for_backend("claude_chat") + if flat.get("student_backend") == "claude_code_exec": + if ( + str(flat.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS + and not _has_model_override("model.student", "student_model") + ): + flat["student_model"] = default_model_for_backend("claude_chat") + + # Auto-generate output root + if not flat.get("out_root"): + env = flat.get("env", "unknown") + model = flat.get("teacher_model", "unknown").replace("/", "-") + ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + flat["out_root"] = os.path.join("outputs", f"skillopt_{env}_{model}_{ts}") + + flat["out_root"] = os.path.abspath(flat["out_root"]) + return flat + + +# ── Main ───────────────────────────────────────────────────────────────────── + +def main() -> None: + args = parse_args() + cfg = load_config(args) + + print(f"\n{'='*60}") + print(f" ReflACT — Reflective Agent Tuning") + print(f"{'='*60}") + print(f" env: {cfg.get('env')}") + print(f" teacher_model: {cfg.get('teacher_model')}") + print(f" student_model: {cfg.get('student_model')}") + print(f" teacher_backend:{cfg.get('teacher_backend', 'openai_chat')}") + print(f" student_backend:{cfg.get('student_backend', 'openai_chat')}") + print(f" reasoning: {cfg.get('reasoning_effort') or 'off'}") + print(f" rewrite_effort: {cfg.get('rewrite_reasoning_effort') or 'off'}") + print(f" epochs: {cfg.get('num_epochs')}") + print(f" train_size: {cfg.get('train_size') or 'from dataset'}") + print(f" steps/epoch: auto") + print(f" batch_size: {cfg.get('batch_size')}") + print(f" edit_budget: {cfg.get('edit_budget')}") + print(f" lr_scheduler: {cfg.get('lr_scheduler', 'constant')}") + print(f" update_mode: {cfg.get('skill_update_mode', 'patch')}") + print(f" min_edit_budget:{cfg.get('min_edit_budget', 2)}") + print(f" minibatch_size: {cfg.get('minibatch_size')}") + print(f" seed: {cfg.get('seed')}") + print(f" meta_reflect: {cfg.get('use_meta_reflect', False)}") + print(f" meta_skill: {cfg.get('use_meta_skill', False)}") + print(f" out_root: {cfg.get('out_root')}") + print(f"{'='*60}\n") + + # Build adapter + adapter = get_adapter(cfg) + + # Build trainer and run + from skillopt.engine.trainer import ReflACTTrainer + trainer = ReflACTTrainer(cfg, adapter) + summary = trainer.train() + + print(f"\n Output saved to: {cfg['out_root']}") + if summary.get("test_hard") is not None: + print(f" Final test: {summary['test_hard']:.4f}") + + +if __name__ == "__main__": + main() diff --git a/scripts/train_searchqa.sh b/scripts/train_searchqa.sh new file mode 100755 index 0000000..8fd569c --- /dev/null +++ b/scripts/train_searchqa.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# ────────────────────────────────────────────────────────────────────────────── +# ReflACT3 — SearchQA Training +# +# Usage: +# bash scripts/train_searchqa.sh +# bash scripts/train_searchqa.sh --limit 50 +# bash scripts/train_searchqa.sh --num_epochs 2 --workers 32 +# ────────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")" +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" + +# ── Models ─────────────────────────────────────────────────────────────────── +export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}" +export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}" + +# ── Data ───────────────────────────────────────────────────────────────────── +DATA_PATH="/home/azureuser/workspace-yqh/refleAct/search-qa/data/searchqa_train_2000.json" + +# ── Output ─────────────────────────────────────────────────────────────────── +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +DEFAULT_OUT_ROOT="${PROJECT_ROOT}/outputs/searchqa-metaskill/searchqa_${STUDENT_DEPLOYMENT}" + +echo "============================================================" +echo " ReflACT3 — SearchQA Training" +echo " Teacher: ${TEACHER_DEPLOYMENT}" +echo " Student: ${STUDENT_DEPLOYMENT}" +echo " Data: ${DATA_PATH}" +echo "============================================================" + +cd "${PROJECT_ROOT}" + +python scripts/train.py \ + --config configs/searchqa_default.yaml \ + --data_path "${DATA_PATH}" \ + --out_root "${DEFAULT_OUT_ROOT}" \ + "$@" + +echo "" +echo "Done! Results saved to: ${DEFAULT_OUT_ROOT}" diff --git a/scripts/train_spreadsheet_multi.sh b/scripts/train_spreadsheet_multi.sh new file mode 100755 index 0000000..d6ddce1 --- /dev/null +++ b/scripts/train_spreadsheet_multi.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +# ────────────────────────────────────────────────────────────────────────────── +# ReflACT — SpreadsheetBench training (MULTI-ROUND codegen, no tool-call) +# +# Usage: +# bash scripts/train_spreadsheet_multi.sh +# bash scripts/train_spreadsheet_multi.sh --num_epochs 2 --max_turns 5 +# ────────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")" + +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" +export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}" +export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}" + +DATA_ROOT="/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400" +JSONL_PATH="${DATA_ROOT}/dataset.json" +SPLIT_DIR="/home/azureuser/workspace-yqh/refleACT3/data/spreadsheetbench_split_2_1_7" + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUT_ROOT="${PROJECT_ROOT}/outputs/spreadsheet-metaskill-new/train_multi_${STUDENT_DEPLOYMENT}" + +echo "============================================================" +echo " ReflACT — SpreadsheetBench Training (MULTI-ROUND)" +echo "============================================================" +echo " Teacher: ${TEACHER_DEPLOYMENT}" +echo " Student: ${STUDENT_DEPLOYMENT}" +echo " Mode: multi" +echo " Data: ${DATA_ROOT}" +echo " Split: ${SPLIT_DIR}" +echo " Output: ${OUT_ROOT}" +echo "============================================================" + +cd "${PROJECT_ROOT}" + +python scripts/train.py \ + --config configs/spreadsheetbench_default.yaml \ + --mode multi \ + --data_root "${DATA_ROOT}" \ + --jsonl_path "${JSONL_PATH}" \ + --split_dir "${SPLIT_DIR}" \ + --out_root "${OUT_ROOT}" \ + "$@" + +echo "" +echo "Done! Results saved to: ${OUT_ROOT}" diff --git a/scripts/train_spreadsheet_single.sh b/scripts/train_spreadsheet_single.sh new file mode 100755 index 0000000..0d67b5e --- /dev/null +++ b/scripts/train_spreadsheet_single.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +# ────────────────────────────────────────────────────────────────────────────── +# ReflACT — SpreadsheetBench training (SINGLE-ROUND codegen, no tool-call) +# +# Usage: +# bash scripts/train_spreadsheet_single.sh +# bash scripts/train_spreadsheet_single.sh --num_epochs 2 --edit_budget 6 +# ────────────────────────────────────────────────────────────────────────────── +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")" + +export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}" +export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}" +export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}" + +DATA_ROOT="/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400" +JSONL_PATH="${DATA_ROOT}/dataset.json" +SPLIT_DIR="/home/azureuser/workspace-yqh/refleACT3/data/spreadsheetbench_split_2_1_7" + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUT_ROOT="${PROJECT_ROOT}/outputs/spreadsheet-metaskill-new/train_single_${STUDENT_DEPLOYMENT}" + +echo "============================================================" +echo " ReflACT — SpreadsheetBench Training (SINGLE-ROUND)" +echo "============================================================" +echo " Teacher: ${TEACHER_DEPLOYMENT}" +echo " Student: ${STUDENT_DEPLOYMENT}" +echo " Mode: single" +echo " Data: ${DATA_ROOT}" +echo " Split: ${SPLIT_DIR}" +echo " Output: ${OUT_ROOT}" +echo "============================================================" + +cd "${PROJECT_ROOT}" + +python scripts/train.py \ + --config configs/spreadsheetbench_default.yaml \ + --mode single \ + --data_root "${DATA_ROOT}" \ + --jsonl_path "${JSONL_PATH}" \ + --split_dir "${SPLIT_DIR}" \ + --out_root "${OUT_ROOT}" \ + "$@" + +echo "" +echo "Done! Results saved to: ${OUT_ROOT}" diff --git a/scripts/watch_ablation.py b/scripts/watch_ablation.py new file mode 100644 index 0000000..a1f1bab --- /dev/null +++ b/scripts/watch_ablation.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +"""Watch an ablation run root and rerun final failures. + +This watcher is intended to run in tmux next to scripts/run_ablation_matrix.py. +It writes STATUS.md on every poll and starts a direct rerun for any run that +the launcher marks as final [FAIL]. +""" +from __future__ import annotations + +import argparse +import json +import os +import re +import subprocess +import time +from pathlib import Path + +from run_ablation_matrix import PROJECT_ROOT, build_matrix, command_for + + +RUN_RE = re.compile(r"\[(START|DONE|FAIL|RETRY)\]\s+([^\s]+)") +ERROR_RE = re.compile( + r"Traceback|RuntimeError|AuthenticationError|PermissionDenied|" + r"DeploymentNotFound|LLM call failed|LLM message call failed|" + r"BadRequestError|RateLimitError", + re.IGNORECASE, +) + + +def read_text(path: Path) -> str: + try: + return path.read_text(encoding="utf-8", errors="replace") + except FileNotFoundError: + return "" + + +def parse_launcher(path: Path) -> dict[str, list[str]]: + events = {"START": [], "DONE": [], "FAIL": [], "RETRY": []} + for line in read_text(path).splitlines(): + match = RUN_RE.search(line) + if match: + events[match.group(1)].append(match.group(2)) + return events + + +def active_run_ids(run_root: Path) -> list[str]: + try: + raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True) + except subprocess.CalledProcessError: + return [] + active: list[str] = [] + pattern = re.compile(re.escape(str(run_root)) + r"/([^\s]+)") + for line in raw.splitlines(): + for match in pattern.finditer(line): + active.append(match.group(1)) + return sorted(set(active)) + + +def scan_errors(logs_dir: Path) -> dict[str, str]: + errors: dict[str, str] = {} + for log_path in sorted(logs_dir.glob("*.log")): + text = read_text(log_path) + match = ERROR_RE.search(text) + if match: + run_id = log_path.name.split(".watchrerun", 1)[0].removesuffix(".log") + errors[run_id] = match.group(0) + return errors + + +def load_state(path: Path) -> dict: + try: + return json.loads(path.read_text(encoding="utf-8")) + except Exception: + return {"reruns": {}} + + +def save_state(path: Path, state: dict) -> None: + tmp = path.with_suffix(".tmp") + tmp.write_text(json.dumps(state, indent=2, ensure_ascii=False), encoding="utf-8") + tmp.replace(path) + + +def write_status( + run_root: Path, + total: int, + events: dict[str, list[str]], + active: list[str], + completed: list[str], + pending: list[str], + errors: dict[str, str], + reruns: dict[str, int], +) -> None: + now = time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime()) + failed = sorted(set(events["FAIL"])) + retrying = sorted(set(events["RETRY"])) + lines = [ + "# Ablation Status", + "", + f"Updated: {now}", + f"Run root: `{run_root}`", + "", + "| Metric | Count |", + "| --- | ---: |", + f"| Total planned | {total} |", + f"| Completed summaries | {len(completed)} |", + f"| Active train processes | {len(active)} |", + f"| Pending/not summarized | {len(pending)} |", + f"| Launcher final fails | {len(failed)} |", + f"| Launcher retries | {len(retrying)} |", + f"| Logs with error patterns | {len(errors)} |", + "", + "## Active", + "", + *(f"- `{run_id}`" for run_id in active), + "", + "## Final Fails", + "", + *(f"- `{run_id}` watcher_reruns={reruns.get(run_id, 0)}" for run_id in failed), + "", + "## Error Patterns", + "", + *(f"- `{run_id}`: `{err}`" for run_id, err in sorted(errors.items())), + "", + "## Recent Launcher", + "", + "```text", + "\n".join(read_text(run_root / "launcher.log").splitlines()[-30:]), + "```", + ] + (run_root / "STATUS.md").write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--run-root", required=True) + parser.add_argument("--interval", type=int, default=60) + parser.add_argument("--watcher-retries", type=int, default=1) + parser.add_argument("--groups", nargs="+", default=["all"]) + parser.add_argument("--bench", nargs="+", default=["searchqa", "spreadsheetbench"]) + args = parser.parse_args() + + run_root = Path(args.run_root).resolve() + logs_dir = run_root / "logs" + logs_dir.mkdir(parents=True, exist_ok=True) + state_path = run_root / "watcher_state.json" + + groups = set(args.groups) + if "all" in groups: + groups = {"default", "split", "mbs", "lr", "sched", "slown", "mod", "smodel"} + experiments = { + exp.run_id: exp + for exp in build_matrix(groups, args.bench, run_root, include_duplicate_defaults=False) + } + + active_reruns: dict[str, subprocess.Popen] = {} + while True: + state = load_state(state_path) + reruns = state.setdefault("reruns", {}) + events = parse_launcher(run_root / "launcher.log") + active = active_run_ids(run_root) + completed = sorted( + run_id for run_id in experiments + if (run_root / run_id / "summary.json").exists() + ) + pending = sorted(set(experiments) - set(completed)) + errors = scan_errors(logs_dir) + + # Reap watcher-started reruns. + for run_id, proc in list(active_reruns.items()): + rc = proc.poll() + if rc is None: + continue + active_reruns.pop(run_id, None) + with open(logs_dir / f"{run_id}.watcher.log", "a", encoding="utf-8") as f: + f.write(f"\n[WATCHER_DONE] rc={rc} time={time.time()}\n") + + for run_id in sorted(set(events["FAIL"])): + if run_id not in experiments: + continue + if (run_root / run_id / "summary.json").exists(): + continue + if run_id in active or run_id in active_reruns: + continue + count = int(reruns.get(run_id, 0)) + if count >= args.watcher_retries: + continue + reruns[run_id] = count + 1 + save_state(state_path, state) + log_path = logs_dir / f"{run_id}.watchrerun{count + 1}.log" + with open(log_path, "w", encoding="utf-8") as log_f: + log_f.write(f"[WATCHER_START] run_id={run_id} attempt={count + 1}\n") + log_f.flush() + proc = subprocess.Popen( + command_for(experiments[run_id]), + cwd=PROJECT_ROOT, + stdout=log_f, + stderr=subprocess.STDOUT, + text=True, + close_fds=True, + ) + active_reruns[run_id] = proc + + save_state(state_path, state) + write_status( + run_root=run_root, + total=len(experiments), + events=events, + active=active, + completed=completed, + pending=pending, + errors=errors, + reruns={k: int(v) for k, v in reruns.items()}, + ) + time.sleep(max(5, int(args.interval))) + + +if __name__ == "__main__": + main() diff --git a/skillopt/__init__.py b/skillopt/__init__.py new file mode 100644 index 0000000..1e957a9 --- /dev/null +++ b/skillopt/__init__.py @@ -0,0 +1,29 @@ +"""ReflACT: Reflective Agent Tuning. + +A general-purpose framework for iteratively optimizing LLM agent skills +through structured reflection and self-improvement. + +Pipeline stages: + 1. Rollout — execute episodes with current skill + 2. Reflect — analyze trajectories, generate patches + 3. Aggregate — hierarchical merge of patches + 4. Select — rank and select top edits + 5. Update — apply edits to skill document + 6. Evaluate — validate candidate skill, accept/reject +""" + +__version__ = "0.1.0" + +from skillopt.types import ( # noqa: F401 + BatchSpec, + Edit, + EditOp, + FailureSummaryEntry, + GateAction, + GateResult, + MetaReflectResult, + Patch, + RawPatch, + RolloutResult, + SlowUpdateResult, +) diff --git a/skillopt/config.py b/skillopt/config.py new file mode 100644 index 0000000..faf376b --- /dev/null +++ b/skillopt/config.py @@ -0,0 +1,263 @@ +"""ReflACT config loading engine — structured YAML with inheritance. + +Supports two config formats: + 1. **Structured** (new): sections like ``model``, ``train``, ``gradient``, + ``optimizer``, ``evaluation``, ``env`` — with ``_base_`` inheritance. + 2. **Flat** (legacy): all keys at top level — fully backward compatible. + +Usage:: + + from skillopt.config import load_config, flatten_config + + cfg = load_config("configs/searchqa_default.yaml") + flat = flatten_config(cfg) # always returns flat dict for trainer +""" +from __future__ import annotations + +import copy +import os +from typing import Any + +import yaml + +# ── Section names that indicate a structured config ────────────────────── + +_STRUCTURED_SECTIONS = frozenset({ + "model", "train", "gradient", "optimizer", "evaluation", "env", +}) + +# ── Structured → flat key mapping ──────────────────────────────────────── + +_FLATTEN_MAP: dict[str, str] = { + "model.backend": "model_backend", + "model.teacher": "teacher_model", + "model.student": "student_model", + "model.teacher_backend": "teacher_backend", + "model.student_backend": "student_backend", + "model.reasoning_effort": "reasoning_effort", + "model.rewrite_reasoning_effort": "rewrite_reasoning_effort", + "model.rewrite_max_completion_tokens": "rewrite_max_completion_tokens", + "model.codex_exec_path": "codex_exec_path", + "model.codex_exec_sandbox": "codex_exec_sandbox", + "model.codex_exec_profile": "codex_exec_profile", + "model.codex_exec_full_auto": "codex_exec_full_auto", + "model.codex_exec_reasoning_effort": "codex_exec_reasoning_effort", + "model.codex_exec_use_sdk": "codex_exec_use_sdk", + "model.codex_exec_network_access": "codex_exec_network_access", + "model.codex_exec_web_search": "codex_exec_web_search", + "model.codex_exec_approval_policy": "codex_exec_approval_policy", + "model.claude_code_exec_path": "claude_code_exec_path", + "model.claude_code_exec_profile": "claude_code_exec_profile", + "model.claude_code_exec_use_sdk": "claude_code_exec_use_sdk", + "model.claude_code_exec_effort": "claude_code_exec_effort", + "model.claude_code_exec_max_thinking_tokens": "claude_code_exec_max_thinking_tokens", + "model.codex_trace_to_teacher": "codex_trace_to_teacher", + "model.azure_endpoint": "azure_endpoint", + "model.azure_api_version": "azure_api_version", + "model.azure_api_key": "azure_api_key", + "model.azure_openai_endpoint": "azure_openai_endpoint", + "model.azure_openai_api_version": "azure_openai_api_version", + "model.azure_openai_api_key": "azure_openai_api_key", + "model.azure_openai_auth_mode": "azure_openai_auth_mode", + "model.azure_openai_ad_scope": "azure_openai_ad_scope", + "model.azure_openai_managed_identity_client_id": "azure_openai_managed_identity_client_id", + "model.teacher_azure_openai_endpoint": "teacher_azure_openai_endpoint", + "model.teacher_azure_openai_api_version": "teacher_azure_openai_api_version", + "model.teacher_azure_openai_api_key": "teacher_azure_openai_api_key", + "model.teacher_azure_openai_auth_mode": "teacher_azure_openai_auth_mode", + "model.teacher_azure_openai_ad_scope": "teacher_azure_openai_ad_scope", + "model.teacher_azure_openai_managed_identity_client_id": "teacher_azure_openai_managed_identity_client_id", + "model.student_azure_openai_endpoint": "student_azure_openai_endpoint", + "model.student_azure_openai_api_version": "student_azure_openai_api_version", + "model.student_azure_openai_api_key": "student_azure_openai_api_key", + "model.student_azure_openai_auth_mode": "student_azure_openai_auth_mode", + "model.student_azure_openai_ad_scope": "student_azure_openai_ad_scope", + "model.student_azure_openai_managed_identity_client_id": "student_azure_openai_managed_identity_client_id", + "train.num_epochs": "num_epochs", + "train.train_size": "train_size", + "train.steps_per_epoch": "steps_per_epoch", + "train.batch_size": "batch_size", + "train.accumulation": "accumulation", + "train.seed": "seed", + "gradient.minibatch_size": "minibatch_size", + "gradient.merge_batch_size": "merge_batch_size", + "gradient.analyst_workers": "analyst_workers", + "gradient.failure_only": "failure_only", + "gradient.use_deep_reflect": "use_deep_reflect", + "gradient.deep_reflect_failures": "deep_reflect_failures", + "gradient.deep_reflect_successes": "deep_reflect_successes", + "gradient.max_analyst_rounds": "max_analyst_rounds", + "optimizer.learning_rate": "edit_budget", + "optimizer.min_learning_rate": "min_edit_budget", + "optimizer.lr_scheduler": "lr_scheduler", + "optimizer.lr_control_mode": "lr_control_mode", + "optimizer.skill_update_mode": "skill_update_mode", + "optimizer.use_meta_reflect": "use_meta_reflect", + "optimizer.meta_learning_rate": "meta_edit_budget", + "optimizer.use_slow_update": "use_slow_update", + "optimizer.slow_update_samples": "slow_update_samples", + "optimizer.longitudinal_pair_policy": "longitudinal_pair_policy", + "optimizer.use_meta_skill": "use_meta_skill", + "evaluation.use_gate": "use_gate", + "evaluation.sel_env_num": "sel_env_num", + "evaluation.test_env_num": "test_env_num", + "evaluation.eval_test": "eval_test", + "env.name": "env", + "env.skill_init": "skill_init", + "env.out_root": "out_root", +} + + +# ── Deep merge ─────────────────────────────────────────────────────────── + +def _deep_merge(base: dict, override: dict) -> dict: + """Recursively merge *override* into *base* (returns new dict).""" + result = copy.deepcopy(base) + for key, val in override.items(): + if key in result and isinstance(result[key], dict) and isinstance(val, dict): + result[key] = _deep_merge(result[key], val) + else: + result[key] = copy.deepcopy(val) + return result + + +# ── YAML loading with _base_ inheritance ───────────────────────────────── + +def _load_yaml(path: str, _visited: set[str] | None = None) -> dict: + """Load a YAML file, resolving ``_base_`` inheritance recursively.""" + abs_path = os.path.abspath(path) + if _visited is None: + _visited = set() + if abs_path in _visited: + raise ValueError(f"Circular _base_ inheritance: {abs_path}") + _visited.add(abs_path) + + with open(abs_path) as f: + cfg = yaml.safe_load(f) or {} + + base_ref = cfg.pop("_base_", None) + if base_ref: + base_path = os.path.join(os.path.dirname(abs_path), base_ref) + base_cfg = _load_yaml(base_path, _visited) + cfg = _deep_merge(base_cfg, cfg) + + return cfg + + +# ── Format detection ───────────────────────────────────────────────────── + +def is_structured(cfg: dict) -> bool: + """Return True if *cfg* uses the new structured section format.""" + return any( + key in _STRUCTURED_SECTIONS and isinstance(cfg.get(key), dict) + for key in cfg + ) + + +# ── Flatten ────────────────────────────────────────────────────────────── + +def flatten_config(cfg: dict) -> dict: + """Convert a structured config to the flat dict expected by the trainer. + + If *cfg* is already flat, returns a shallow copy unchanged. + """ + if not is_structured(cfg): + return dict(cfg) + + flat: dict[str, Any] = {} + + evaluation_section = cfg.get("evaluation", {}) + if isinstance(evaluation_section, dict) and evaluation_section.get("use_gate") is False: + raise ValueError( + "Gate validation is mandatory in this branch. Remove " + "`evaluation.use_gate: false` from the config." + ) + + # Apply the explicit mapping + for dotted, flat_key in _FLATTEN_MAP.items(): + section, key = dotted.split(".", 1) + section_dict = cfg.get(section, {}) + if isinstance(section_dict, dict) and key in section_dict: + flat[flat_key] = section_dict[key] + + # Pass through env-specific keys not in the explicit mapping + env_section = cfg.get("env", {}) + if isinstance(env_section, dict): + mapped_env_keys = { + k.split(".", 1)[1] + for k in _FLATTEN_MAP + if k.startswith("env.") + } + for key, val in env_section.items(): + if key not in mapped_env_keys: + flat[key] = val + + return flat + + +# ── Override application ───────────────────────────────────────────────── + +def _cast_value(val_str: str) -> Any: + """Auto-cast a CLI string value to int / float / bool / str.""" + if val_str.lower() in ("true", "yes"): + return True + if val_str.lower() in ("false", "no"): + return False + try: + return int(val_str) + except ValueError: + pass + try: + return float(val_str) + except ValueError: + pass + return val_str + + +def apply_overrides(cfg: dict, overrides: list[str]) -> None: + """Apply ``key=value`` overrides to a structured config (in place). + + Supports both ``section.key=value`` (for structured configs) and + ``key=value`` (for flat configs or flat keys in env section). + """ + for item in overrides: + if "=" not in item: + raise ValueError(f"Invalid override (expected key=value): {item!r}") + key, val_str = item.split("=", 1) + val = _cast_value(val_str) + + if "." in key: + section, subkey = key.split(".", 1) + if section in cfg and isinstance(cfg[section], dict): + cfg[section][subkey] = val + else: + cfg.setdefault(section, {})[subkey] = val + else: + # Flat key — apply to top level (for legacy compat) + cfg[key] = val + + +# ── Public API ─────────────────────────────────────────────────────────── + +def load_config( + path: str, + overrides: list[str] | None = None, +) -> dict: + """Load a config file with ``_base_`` inheritance and optional overrides. + + Parameters + ---------- + path : str + Path to the YAML config file. + overrides : list[str] | None + ``key=value`` strings from ``--cfg-options``. + + Returns + ------- + dict + The merged config (structured or flat depending on the YAML). + """ + cfg = _load_yaml(path) + if overrides: + apply_overrides(cfg, overrides) + return cfg diff --git a/skillopt/datasets/__init__.py b/skillopt/datasets/__init__.py new file mode 100644 index 0000000..3aa2eb8 --- /dev/null +++ b/skillopt/datasets/__init__.py @@ -0,0 +1,7 @@ +"""ReflACT Datasets -- task batch planning and data loading. + +Analogous to the datasets and dataloaders in neural network training: +provides batch sampling, epoch planning, and data management for the +ReflACT training pipeline. +""" +from skillopt.datasets.base import BaseDataLoader, BatchSpec, SplitDataLoader # noqa: F401 diff --git a/skillopt/datasets/base.py b/skillopt/datasets/base.py new file mode 100644 index 0000000..668f201 --- /dev/null +++ b/skillopt/datasets/base.py @@ -0,0 +1,512 @@ +"""Generic task dataloader abstractions for ReflACT. + +ReflACT does not train model parameters directly. Instead, it iterates over +task batches, rolls out the current skill, reflects on failures/successes, +and updates the skill document. Because of that, the "dataloader" abstraction +here is closer to a batch sampler / episode planner than a tensor loader. + +Class hierarchy:: + + BaseDataLoader # abstract — simulator-backed envs (e.g. ALFWorld) + └── SplitDataLoader # abstract — dataset-backed envs with split_dir + +SplitDataLoader supports two dataset entry modes: + +1. ``split_mode="split_dir"``: consume an existing split directory. +2. ``split_mode="ratio"``: build a deterministic split directory from a raw + dataset path using an explicit train:val:test ratio. + +In either case, the standardised split layout is: + + split_dir/ + ├── train/ # training items + ├── val/ # validation / selection items (gate) + └── test/ # held-out test items + +Each subdirectory's contents are benchmark-specific. Subclasses only need +to implement ``load_split_items(split_path)`` to teach the loader how to +read items from one of those directories. +""" +from __future__ import annotations + +import glob +import json +import os +import random +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(slots=True) +class BatchSpec: + """A concrete batch request consumed by the training loop. + + Parameters + ---------- + phase : str + ``"train"`` or ``"eval"``. + split : str + Dataset split name, typically ``"train"`` or an eval split. + seed : int + Random seed used to construct the batch deterministically. + batch_size : int + Requested number of items / episodes in this batch. + payload : object | None + Environment-specific batch payload. For dataset-backed environments + this is often a list of sampled items; for simulator-backed + environments this may be ``None`` and the seed alone can define the + batch. + metadata : dict[str, Any] + Optional structured metadata for logging, resume, or curriculum logic. + """ + + phase: str + split: str + seed: int + batch_size: int + payload: object | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +class BaseDataLoader(ABC): + """Abstract base class for task batch planning in ReflACT. + + Subclasses are responsible for defining how a train or eval batch is + sampled. The default implementation here provides deterministic epoch seed + planning so all loaders share the same reproducibility behavior. + """ + + def setup(self, cfg: dict) -> None: + """Optional one-time initialization with the full trainer config.""" + + def set_out_root(self, out_root: str) -> None: + """Optional hook for loaders that persist split files or state.""" + + def state_dict(self) -> dict[str, Any]: + """Return serializable loader state for resume support.""" + return {} + + def load_state_dict(self, state: dict[str, Any]) -> None: + """Restore loader state from :meth:`state_dict` output.""" + + def get_train_size(self) -> int | None: + """Return the size of the training pool when known.""" + return None + + @staticmethod + def make_base_seeds(steps_per_epoch: int, accumulation: int, seed: int) -> list[int]: + """Return the deterministic seed pool used to define train batches.""" + batches_per_epoch = steps_per_epoch * accumulation + return [seed + i + 1 for i in range(batches_per_epoch)] + + @staticmethod + def shuffle_epoch_seeds(base_seeds: list[int], epoch: int, seed: int) -> list[int]: + """Return the per-epoch deterministic shuffle of *base_seeds*.""" + epoch_rng = random.Random(seed + epoch * 1000) + shuffled = list(base_seeds) + epoch_rng.shuffle(shuffled) + return shuffled + + def plan_train_epoch( + self, + *, + epoch: int, + steps_per_epoch: int, + accumulation: int, + batch_size: int, + seed: int, + **kwargs, + ) -> list[BatchSpec]: + """Build the full list of training batches for one epoch.""" + base_seeds = self.make_base_seeds( + steps_per_epoch=steps_per_epoch, + accumulation=accumulation, + seed=seed, + ) + shuffled_seeds = self.shuffle_epoch_seeds(base_seeds, epoch=epoch, seed=seed) + return [ + self.build_train_batch(batch_size=batch_size, seed=batch_seed, **kwargs) + for batch_seed in shuffled_seeds + ] + + @abstractmethod + def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec: + """Construct one training batch specification.""" + + @abstractmethod + def build_eval_batch( + self, + env_num: int, + split: str, + seed: int, + **kwargs, + ) -> BatchSpec: + """Construct one evaluation batch specification.""" + + +# ── Split-based dataloader for dataset-backed environments ────────────── + +# Canonical split names expected under split_dir/ +SPLIT_NAMES = ("train", "val", "test") + +# Maps legacy / trainer split names → canonical directory names +_SPLIT_ALIAS: dict[str, str] = { + "train": "train", + "valid_seen": "val", + "selection": "val", + "val": "val", + "valid_unseen": "test", + "test": "test", +} + + +def _load_json_or_jsonl(path: str) -> list[dict]: + """Load a list of items from a JSON or JSONL file.""" + with open(path, encoding="utf-8") as f: + content = f.read().strip() + if not content: + return [] + + try: + data = json.loads(content) + except json.JSONDecodeError: + data = None + + if isinstance(data, list): + return data + if isinstance(data, dict): + nested = data.get("data") + if isinstance(nested, list): + return nested + return list(data.values()) + + items: list[dict] = [] + for line in content.splitlines(): + line = line.strip() + if line: + items.append(json.loads(line)) + return items + + +def _parse_split_ratio(text: str) -> tuple[int, int, int]: + parts = [part.strip() for part in str(text or "").split(":") if part.strip()] + if len(parts) != 3: + raise ValueError( + f"split_ratio must be in train:val:test form, got {text!r}" + ) + try: + train, val, test = (int(part) for part in parts) + except ValueError as exc: + raise ValueError( + f"split_ratio must contain integers, got {text!r}" + ) from exc + if min(train, val, test) <= 0: + raise ValueError(f"split_ratio parts must be positive, got {text!r}") + return train, val, test + + +def _compute_split_counts(total: int, ratio: tuple[int, int, int]) -> tuple[int, int, int]: + weights = list(ratio) + denom = sum(weights) + raw = [total * weight / denom for weight in weights] + counts = [int(value) for value in raw] + remaining = total - sum(counts) + order = sorted( + range(len(raw)), + key=lambda idx: (raw[idx] - counts[idx], weights[idx]), + reverse=True, + ) + for idx in order[:remaining]: + counts[idx] += 1 + return counts[0], counts[1], counts[2] + + +class SplitDataLoader(BaseDataLoader): + """Base class for dataset-backed environments. + + Supported modes: + + - ``split_mode="split_dir"``: load an existing ``train/``, ``val/``, + ``test/`` directory tree. + - ``split_mode="ratio"``: load raw items from ``data_path`` and materialize + a deterministic split directory with the requested ratio. + """ + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + seed: int = 42, + limit: int = 0, + **kwargs, + ) -> None: + self.split_dir = split_dir + self.data_path = data_path + self.split_mode = split_mode + self.split_ratio = split_ratio + self.split_seed = int(split_seed) + self.split_output_dir = split_output_dir + self.seed = seed + self.limit = limit + self._splits: dict[str, list[dict]] = {} + + # ── Setup ──────────────────────────────────────────────────────────── + + def setup(self, cfg: dict) -> None: + if not self.split_mode: + self.split_mode = str(cfg.get("split_mode", "ratio") or "ratio") + if not self.split_dir: + self.split_dir = cfg.get("split_dir", "") + if not self.data_path: + self.data_path = cfg.get("data_path", "") + if not self.split_output_dir: + self.split_output_dir = cfg.get("split_output_dir", "") + if "split_seed" in cfg and not self.split_seed: + self.split_seed = int(cfg.get("split_seed", 0) or 0) + if not self.split_seed: + self.split_seed = self.seed + if not self.split_ratio: + self.split_ratio = str(cfg.get("split_ratio", "2:1:7") or "2:1:7") + + mode = str(self.split_mode or "ratio").strip().lower() + if mode not in {"ratio", "split_dir"}: + raise ValueError( + f"{type(self).__name__} split_mode must be 'ratio' or 'split_dir', " + f"got {self.split_mode!r}" + ) + self.split_mode = mode + + if self.split_mode == "ratio": + self.split_dir = self._materialize_ratio_split(cfg) + if not self.split_dir: + raise ValueError( + f"{type(self).__name__} requires either " + "`split_mode=ratio` with `data_path`, or `split_mode=split_dir` " + f"with `split_dir` pointing to {'/'.join(SPLIT_NAMES)}/." + ) + self._load_all_splits() + + def _resolve_split_output_dir(self, cfg: dict) -> str: + if self.split_output_dir: + return os.path.abspath(self.split_output_dir) + out_root = os.path.abspath(str(cfg.get("out_root") or os.getcwd())) + env_name = str(cfg.get("env") or type(self).__name__.replace("DataLoader", "").lower()) + ratio_tag = str(self.split_ratio or "2:1:7").replace(":", "-") + return os.path.join(out_root, "_generated_splits", f"{env_name}_{ratio_tag}_seed{self.split_seed}") + + def load_raw_items(self, data_path: str) -> list[dict]: + """Load raw items from a dataset path before ratio splitting. + + Subclasses can override when the raw dataset is not a single JSON/JSONL + file or when directory layouts require custom normalization. + """ + if os.path.isdir(data_path): + if any(os.path.isdir(os.path.join(data_path, name)) for name in SPLIT_NAMES): + raise ValueError( + f"{type(self).__name__} got a split directory as data_path. " + "Use split_mode=split_dir and pass it as split_dir instead." + ) + candidates = sorted(glob.glob(os.path.join(data_path, "*.json"))) + candidates += sorted(glob.glob(os.path.join(data_path, "*.jsonl"))) + if len(candidates) != 1: + raise ValueError( + f"{type(self).__name__} expected data_path to be one JSON/JSONL file " + f"or a directory containing exactly one such file, got: {data_path}" + ) + return _load_json_or_jsonl(candidates[0]) + return _load_json_or_jsonl(data_path) + + def write_split_items(self, split_path: str, items: list[dict]) -> None: + os.makedirs(split_path, exist_ok=True) + out_path = os.path.join(split_path, "items.json") + with open(out_path, "w", encoding="utf-8") as f: + json.dump(items, f, ensure_ascii=False, indent=2) + + def _materialize_ratio_split(self, cfg: dict) -> str: + data_path = os.path.abspath(str(self.data_path or "").strip()) + if not data_path: + raise ValueError( + f"{type(self).__name__} requires data_path when split_mode=ratio." + ) + + ratio = _parse_split_ratio(self.split_ratio) + items = self.load_raw_items(data_path) + if not isinstance(items, list) or not items: + raise ValueError(f"No raw items available for ratio split from {data_path}") + + shuffled = list(items) + rng = random.Random(self.split_seed) + rng.shuffle(shuffled) + + train_n, val_n, test_n = _compute_split_counts(len(shuffled), ratio) + train_items = shuffled[:train_n] + val_items = shuffled[train_n: train_n + val_n] + test_items = shuffled[train_n + val_n: train_n + val_n + test_n] + + split_dir = self._resolve_split_output_dir(cfg) + manifest = { + "source_data_path": data_path, + "split_mode": "ratio", + "split_ratio": self.split_ratio, + "split_seed": self.split_seed, + "counts": { + "train": len(train_items), + "val": len(val_items), + "test": len(test_items), + }, + } + os.makedirs(split_dir, exist_ok=True) + self.write_split_items(os.path.join(split_dir, "train"), train_items) + self.write_split_items(os.path.join(split_dir, "val"), val_items) + self.write_split_items(os.path.join(split_dir, "test"), test_items) + with open(os.path.join(split_dir, "split_manifest.json"), "w", encoding="utf-8") as f: + json.dump(manifest, f, ensure_ascii=False, indent=2) + print( + f" [{type(self).__name__}] generated ratio split {self.split_ratio} " + f"at {split_dir} from {data_path}" + ) + return split_dir + + def _load_all_splits(self) -> None: + for name in SPLIT_NAMES: + split_path = os.path.join(self.split_dir, name) + if not os.path.isdir(split_path): + raise ValueError( + f"Missing '{name}/' subdirectory in split_dir: {self.split_dir}" + ) + items = self.load_split_items(split_path) + if self.limit: + items = items[: self.limit] + self._splits[name] = items + + counts = " ".join(f"{k}={len(v)}" for k, v in self._splits.items()) + print(f" [{type(self).__name__}] {counts} (from {self.split_dir})") + + def load_split_items(self, split_path: str) -> list[dict]: + """Load items from one split directory (e.g. ``split_dir/train/``). + + Default: finds the first ``.json`` file in the directory and loads it + as a JSON array. Subclasses can override for custom formats. + """ + json_files = sorted(glob.glob(os.path.join(split_path, "*.json"))) + if not json_files: + raise FileNotFoundError( + f"No .json file found in {split_path}" + ) + with open(json_files[0], encoding="utf-8") as f: + items = json.load(f) + if not isinstance(items, list): + raise ValueError( + f"Expected JSON array in {json_files[0]}, got {type(items).__name__}" + ) + return items + + # ── Accessors ──────────────────────────────────────────────────────── + + @property + def train_items(self) -> list[dict]: + return self._splits.get("train", []) + + @property + def val_items(self) -> list[dict]: + return self._splits.get("val", []) + + @property + def test_items(self) -> list[dict]: + return self._splits.get("test", []) + + def get_split_items(self, split: str) -> list[dict]: + """Resolve a split name (including legacy aliases) to its item list.""" + canonical = _SPLIT_ALIAS.get(split, split) + return list(self._splits.get(canonical, self.val_items)) + + def get_train_size(self) -> int: + return len(self.train_items) + + def plan_train_epoch( + self, + *, + epoch: int, + steps_per_epoch: int, + accumulation: int, + batch_size: int, + seed: int, + **kwargs, + ) -> list[BatchSpec]: + """Build one full epoch that covers the train split in shuffled order. + + For split-backed datasets, an epoch should correspond to one pass over + the available training items rather than repeated independent sampling. + """ + epoch_rng = random.Random(seed + epoch * 1000) + items = list(self.train_items) + epoch_rng.shuffle(items) + + total_batches = steps_per_epoch * accumulation + if total_batches <= 0: + return [] + + batches: list[BatchSpec] = [] + cursor = 0 + for batch_idx in range(total_batches): + batch_items = items[cursor: cursor + batch_size] + cursor += len(batch_items) + + # Extremely small datasets can leave trailing empty microbatches + # when accumulation > 1. Reuse the shuffled prefix in that case so + # the trainer still receives the expected batch count. + if not batch_items and items: + refill_rng = random.Random(seed + epoch * 1000 + batch_idx + 1) + batch_items = list(items) + refill_rng.shuffle(batch_items) + batch_items = batch_items[:batch_size] + + batches.append( + BatchSpec( + phase="train", + split="train", + seed=seed + epoch * 1000 + batch_idx + 1, + batch_size=len(batch_items), + payload=batch_items, + ) + ) + + return batches + + # ── Batch construction ─────────────────────────────────────────────── + + def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec: + rng = random.Random(seed) + items = list(self.train_items) + rng.shuffle(items) + items = items[:batch_size] + return BatchSpec( + phase="train", + split="train", + seed=seed, + batch_size=len(items), + payload=items, + ) + + def build_eval_batch( + self, + env_num: int, + split: str, + seed: int, + **kwargs, + ) -> BatchSpec: + items = self.get_split_items(split) + if env_num and env_num < len(items): + items = items[:env_num] + return BatchSpec( + phase="eval", + split=split, + seed=seed, + batch_size=len(items), + payload=items, + ) diff --git a/skillopt/engine/__init__.py b/skillopt/engine/__init__.py new file mode 100644 index 0000000..b876e70 --- /dev/null +++ b/skillopt/engine/__init__.py @@ -0,0 +1,9 @@ +"""ReflACT Engine -- the training runner. + +Analogous to the Runner in mmengine: orchestrates the full training pipeline +including rollout, gradient computation, aggregation, optimization, and +evaluation. +""" +from skillopt.engine.trainer import ReflACTTrainer # noqa: F401 + +__all__ = ["ReflACTTrainer"] diff --git a/skillopt/engine/trainer.py b/skillopt/engine/trainer.py new file mode 100644 index 0000000..43dd39e --- /dev/null +++ b/skillopt/engine/trainer.py @@ -0,0 +1,2195 @@ +"""ReflACT Trainer — the main training loop. + +Orchestrates the 6-stage ReflACT pipeline: + 1. Rollout — execute episodes with current skill + 2. Reflect — analyze trajectories, generate patches + 3. Aggregate — hierarchical merge of patches + 4. Select — rank and select top edits + 5. Update — apply edits to skill document + 6. Evaluate — validate candidate skill, accept/reject + +The trainer is environment-agnostic; all environment-specific logic is +delegated to an :class:`~skillopt.envs.base.EnvAdapter` instance. +""" +from __future__ import annotations + +import glob +import json +import math +import os +import random +import re +import time +from collections import defaultdict + +from skillopt.datasets.base import BatchSpec +from skillopt.envs.base import EnvAdapter +from skillopt.evaluation.gate import evaluate_gate +from skillopt.gradient.aggregate import merge_patches +from skillopt.optimizer.meta_reflect import build_epoch_history, run_meta_reflect +from skillopt.optimizer.meta_skill import run_meta_skill +from skillopt.optimizer.clip import rank_and_select +from skillopt.optimizer.lr_autonomous import decide_autonomous_learning_rate +from skillopt.optimizer.rewrite import rewrite_skill_from_suggestions +from skillopt.optimizer.scheduler import build_scheduler +from skillopt.optimizer.skill import apply_patch_with_report +from skillopt.optimizer.slow_update import ( + build_comparison_pairs, + extract_slow_update_field, + inject_empty_slow_update_field, + replace_slow_update_field, + run_slow_update, + save_comparison_pairs, +) +from skillopt.optimizer.update_modes import ( + get_payload_items, + is_full_rewrite_minibatch_mode, + normalize_update_mode, + payload_label, + short_item_summary, +) +from skillopt.model import ( + configure_azure_openai, + configure_claude_code_exec, + configure_codex_exec, + get_token_summary, + reset_token_tracker, + set_reasoning_effort, + set_student_backend, + set_student_deployment, + set_teacher_backend, + set_teacher_deployment, +) +from skillopt.utils import compute_score, skill_hash + + +# ── Patch normalization ─────────────────────────────────────────────────────── + +def _normalise_patches( + raw_patches: list[dict | None], + update_mode: str = "patch", +) -> tuple[list[dict], list[dict]]: + """Extract inner 'patch' sub-dict, split into failure/success lists. + + Each element is expected to conform to :class:`~skillopt.types.RawPatch`. + """ + mode = normalize_update_mode(update_mode) + failure: list[dict] = [] + success: list[dict] = [] + for p in raw_patches: + if not isinstance(p, dict): + continue + inner = p.get("patch", p) + if not isinstance(inner, dict): + continue + items = get_payload_items(inner, mode) + if not items: + continue + support = max(int(p.get("batch_size", 0) or 0), 1) + for item in items: + if isinstance(item, dict): + item.setdefault("source_type", p.get("source_type", "failure")) + item.setdefault("support_count", support) + if p.get("source_type", "failure") == "success": + success.append(inner) + else: + failure.append(inner) + return failure, success + + +def _normalise_longitudinal_pair_policy(policy: str | None) -> str: + raw = str(policy or "mixed").strip().lower() + aliases = { + "mixed": "mixed", + "default": "mixed", + "random": "mixed", + "all": "mixed", + "changed": "changed", + "change": "changed", + "delta": "changed", + "10_01": "changed", + "01_10": "changed", + "unchanged": "unchanged", + "stable": "unchanged", + "same": "unchanged", + "00_11": "unchanged", + } + if raw not in aliases: + raise ValueError( + "optimizer.longitudinal_pair_policy must be one of " + "mixed, changed, unchanged" + ) + return aliases[raw] + + +def _normalise_lr_control_mode(mode: str | None) -> str: + raw = str(mode or "fixed").strip().lower() + aliases = { + "fixed": "fixed", + "manual": "fixed", + "scheduler": "fixed", + "scheduled": "fixed", + "autonomous": "autonomous", + "auto": "autonomous", + "teacher": "autonomous", + "none": "none", + "off": "none", + "no_lr": "none", + } + if raw not in aliases: + raise ValueError("optimizer.lr_control_mode must be one of fixed, autonomous, none") + return aliases[raw] + + +def _filter_longitudinal_pairs(pairs: list[dict], policy: str) -> list[dict]: + if policy == "mixed": + return pairs + if policy == "changed": + keep = {"improved", "regressed"} + elif policy == "unchanged": + keep = {"persistent_fail", "stable_success"} + else: + raise ValueError(f"Unknown longitudinal pair policy: {policy}") + return [p for p in pairs if p.get("category") in keep] + + +def _pair_category_counts(pairs: list[dict]) -> dict[str, int]: + counts = { + "improved": 0, + "regressed": 0, + "persistent_fail": 0, + "stable_success": 0, + } + for pair in pairs: + cat = str(pair.get("category", "")) + counts[cat] = counts.get(cat, 0) + 1 + return counts + + +def _safe_pair_id(value: str) -> str: + safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", str(value)).strip("_") + return safe[:80] or "item" + + +def _build_longitudinal_pairs( + *, + adapter: EnvAdapter, + dataloader, + prev_skill: str, + curr_skill: str, + initial_items: list[dict], + initial_prev_results: list[dict], + initial_curr_results: list[dict], + prev_rollout_dir: str, + curr_rollout_dir: str, + policy: str, + target_n: int, + seed: int, + out_root: str, +) -> tuple[list[dict], list[dict]]: + """Build longitudinal pairs, optionally filtering by change category. + + ``mixed`` preserves the legacy behavior exactly. ``changed`` keeps only + 10/01 pairs and attempts to top up to ``target_n`` by scanning the train + split once. ``unchanged`` keeps only 00/11 pairs and does not top up. + """ + all_pairs = build_comparison_pairs( + initial_prev_results, + initial_curr_results, + initial_items, + prev_rollout_dir=prev_rollout_dir, + curr_rollout_dir=curr_rollout_dir, + ) + selected_pairs = _filter_longitudinal_pairs(all_pairs, policy) + if policy != "changed" or len(selected_pairs) >= target_n or dataloader is None: + return selected_pairs, all_pairs + + train_items = list(getattr(dataloader, "train_items", []) or []) + if not train_items: + return selected_pairs, all_pairs + + seen_ids = {str(p.get("id", "")) for p in all_pairs} + rng = random.Random(seed) + candidates = list(train_items) + rng.shuffle(candidates) + candidates = [item for item in candidates if str(item.get("id", "")) not in seen_ids] + + for idx, item in enumerate(candidates): + if len(selected_pairs) >= target_n: + break + item_id = _safe_pair_id(str(item.get("id", f"item_{idx}"))) + batch = BatchSpec( + phase="train", + split="train", + seed=seed + idx + 1, + batch_size=1, + payload=[item], + ) + env = adapter.build_env_from_batch(batch, out_root=out_root) + prev_dir = os.path.join(prev_rollout_dir, "topup", item_id) + curr_dir = os.path.join(curr_rollout_dir, "topup", item_id) + prev_results = adapter.rollout(env, prev_skill, prev_dir) + curr_results = adapter.rollout(env, curr_skill, curr_dir) + pair = build_comparison_pairs( + prev_results, + curr_results, + [item], + prev_rollout_dir=prev_dir, + curr_rollout_dir=curr_dir, + ) + all_pairs.extend(pair) + selected_pairs.extend(_filter_longitudinal_pairs(pair, policy)) + + return selected_pairs[:target_n], all_pairs + + +# ── History / persistence helpers ───────────────────────────────────────────── + +_SECRET_KEYS = { + "azure_api_key", + "api_key", + "openai_api_key", +} + + +def _redact_value(val: str) -> str: + if len(val) <= 8: + return "*" * len(val) + return f"{val[:4]}...{val[-4:]}" + + +def _redact_cfg(cfg: dict) -> dict: + redacted = dict(cfg) + for key in list(redacted): + if key.lower() in _SECRET_KEYS and redacted.get(key): + redacted[key] = _redact_value(str(redacted[key])) + return redacted + +def _load_history(out_root: str) -> list[dict]: + path = os.path.join(out_root, "history.json") + if os.path.exists(path): + with open(path) as f: + return json.load(f) + return [] + + +def _save_history(out_root: str, history: list[dict]) -> None: + path = os.path.join(out_root, "history.json") + with open(path, "w") as f: + json.dump(history, f, ensure_ascii=False, indent=2) + + +def _save_skill(out_root: str, step: int, content: str) -> None: + skills_dir = os.path.join(out_root, "skills") + os.makedirs(skills_dir, exist_ok=True) + with open(os.path.join(skills_dir, f"skill_v{step:04d}.md"), "w") as f: + f.write(content) + + +def _load_skill(out_root: str, step: int) -> str: + path = os.path.join(out_root, "skills", f"skill_v{step:04d}.md") + with open(path) as f: + return f.read() + + +def _load_meta_skill_content(out_root: str, epoch: int) -> str: + if epoch <= 0: + return "" + path = os.path.join( + out_root, "meta_skill", f"epoch_{epoch:02d}", "meta_skill_result.json", + ) + if not os.path.exists(path): + return "" + try: + with open(path) as f: + result = json.load(f) + return str(result.get("meta_skill_content", "")).strip() + except Exception: + return "" + + +def _load_runtime_state(out_root: str) -> dict | None: + path = os.path.join(out_root, "runtime_state.json") + if not os.path.exists(path): + return None + try: + with open(path) as f: + state = json.load(f) + return state if isinstance(state, dict) else None + except Exception: + return None + + +def _save_runtime_state(out_root: str, state: dict) -> None: + path = os.path.join(out_root, "runtime_state.json") + with open(path, "w") as f: + json.dump(state, f, ensure_ascii=False, indent=2) + + +def _resolve_train_size(cfg: dict, dataloader) -> int: + configured = int(cfg.get("train_size", 0) or 0) + inferred: int | None = None + + if dataloader is not None: + getter = getattr(dataloader, "get_train_size", None) + if callable(getter): + try: + value = getter() + except Exception: + value = None + if value is not None: + inferred = int(value) + elif hasattr(dataloader, "train_items"): + try: + inferred = len(getattr(dataloader, "train_items")) + except Exception: + inferred = None + + if inferred is not None and inferred <= 0: + inferred = None + + if configured > 0 and inferred is not None and configured != inferred: + raise ValueError( + f"Configured train_size={configured} does not match loaded train split " + f"size={inferred}. Fix the config or the dataset split." + ) + + train_size = configured if configured > 0 else inferred + if train_size is None or train_size <= 0: + raise ValueError( + "Unable to determine train_size automatically. " + "Provide train.train_size in the config for this environment." + ) + return int(train_size) + + +def _compute_task_type_buckets(results: list[dict], task_types: list[str]) -> dict[str, dict]: + """Compute per-task-type success rates.""" + buckets: dict[str, dict] = {} + for task in task_types + ["overall"]: + buckets[task] = {"total": 0, "hard": 0, "soft": 0.0} + for r in results: + tt = r.get("task_type", "other") + for key in [tt, "overall"]: + if key not in buckets: + buckets[key] = {"total": 0, "hard": 0, "soft": 0.0} + buckets[key]["total"] += 1 + buckets[key]["hard"] += int(r.get("hard", 0)) + buckets[key]["soft"] += float(r.get("soft", 0.0)) + return buckets + + +def _format_rejection_buffer(buffer: list[dict]) -> str: + """**DEPRECATED** — kept for backward compat; use _format_step_buffer.""" + return _format_step_buffer(buffer) + + +def _extract_failure_patterns( + rollout_results: list[dict], + step_dir: str, +) -> list[dict]: + """Extract compact failure patterns from rollout results. + + Uses analyst ``failure_summary`` from minibatch patches when available, + otherwise falls back to ``fail_reason`` prefix grouping. + """ + failures = [r for r in rollout_results if not r.get("hard")] + if not failures: + return [] + + # Group by fail_reason prefix + groups: dict[str, list[dict]] = defaultdict(list) + for r in failures: + reason = r.get("fail_reason", "unknown") + prefix = reason.split(":")[0].strip() if ":" in reason else reason + groups[prefix].append(r) + + # Try richer descriptions from analyst patches + analyst_descs: list[str] = [] + patch_globs = [ + os.path.join(step_dir, "patches", "minibatch_fail_*.json"), + os.path.join(step_dir, "batch_*", "patches", "minibatch_fail_*.json"), + ] + seen_patch_files: set[str] = set() + for pattern in patch_globs: + for fname in sorted(glob.glob(pattern)): + if fname in seen_patch_files: + continue + seen_patch_files.add(fname) + try: + with open(fname) as f: + patch = json.load(f) + for fs in patch.get("failure_summary", []): + ft = fs.get("failure_type", "") + sd = fs.get("description", "") + analyst_descs.append(f"{ft}: {sd}" if sd else ft) + except Exception: + pass + + patterns = [] + desc_iter = iter(analyst_descs) + for prefix, items in groups.items(): + desc = next(desc_iter, None) or prefix + patterns.append({ + "pattern": desc, + "count": len(items), + "task_ids": [str(r.get("id", "?")) for r in items], + }) + return patterns + + +def _format_step_buffer(buffer: list[dict]) -> str: + """Format the unified step buffer into a single context block. + + Each entry captures what happened at a previous step: failure patterns + observed during rollout, and — when the step was rejected — the specific + edits that were tried and the resulting score drop. + + Returns empty string when *buffer* is empty. + """ + if not buffer: + return "" + + parts = [ + "Below is a summary of previous steps in this epoch. " + "Use it to avoid repeating ineffective edits and to prioritise " + "failure patterns that remain unsolved.\n" + ] + + for entry in buffer: + step = entry["step"] + action = entry["action"] + n_fail = entry.get("n_fail", 0) + n_total = entry.get("n_total", "?") + + parts.append(f"### Step {step} — {action.upper()} ({n_fail}/{n_total} failed)") + + # Failure patterns + for p in entry.get("failure_patterns", []): + ids = ", ".join(p["task_ids"][:3]) + parts.append(f' - "{p["pattern"]}" (×{p["count"]}, tasks: {ids})') + + # Rejected edits (only present on reject) + rejected = entry.get("rejected_edits", []) + if rejected: + score_before = entry.get("score_before", "?") + score_after = entry.get("score_after", "?") + parts.append( + f" Rejected edits (score {score_before} → {score_after}):" + ) + for i, e in enumerate(rejected, 1): + if e.get("op") is not None: + op = e.get("op", "?") + content = e.get("content", "") + target = e.get("target", "") + if target: + parts.append(f' {i}. [{op}] target="{target[:80]}" → "{content}"') + else: + parts.append(f' {i}. [{op}] "{content}"') + else: + kind = e.get("type", "?") + title = e.get("title", "") + instruction = e.get("instruction", "") + parts.append(f' {i}. [{kind}] "{title}" → "{instruction}"') + + return "\n".join(parts) + + +# ── Trainer ────────────────────────────────────────────────────────────────── + +class ReflACTTrainer: + """Main ReflACT training loop. + + Parameters + ---------- + cfg : dict + Configuration dictionary. See ``configs/alfworld_default.yaml`` + for the full list of keys. + adapter : EnvAdapter + Environment adapter instance. + """ + + def __init__(self, cfg: dict, adapter: EnvAdapter) -> None: + self.cfg = cfg + self.adapter = adapter + + def train(self) -> dict: + """Execute the full ReflACT training loop. Returns summary dict.""" + cfg = self.cfg + adapter = self.adapter + out_root = cfg["out_root"] + os.makedirs(out_root, exist_ok=True) + + # ── Adapter setup (one-time init) ──────────────────────────── + adapter.setup(cfg) + dataloader = adapter.get_dataloader() + + def _build_train_env(batch: BatchSpec): + env_manager = adapter.build_env_from_batch(batch, out_root=out_root) + return env_manager, batch.batch_size, batch.seed + + def _build_eval_env(split: str, env_num: int, seed: int): + if dataloader is None: + env_manager = adapter.build_eval_env( + env_num=env_num, + split=split, + seed=seed, + out_root=out_root, + ) + actual_n = len(env_manager) if hasattr(env_manager, "__len__") else env_num + return env_manager, actual_n + + batch = dataloader.build_eval_batch( + env_num=env_num, + split=split, + seed=seed, + out_root=out_root, + ) + env_manager = adapter.build_env_from_batch(batch, out_root=out_root) + return env_manager, batch.batch_size + + # ── Configure models ───────────────────────────────────────────── + backend = cfg.get("model_backend", "azure_openai") + configure_azure_openai( + endpoint=( + cfg.get("azure_openai_endpoint") + or cfg.get("azure_endpoint") + or None + ), + api_version=( + cfg.get("azure_openai_api_version") + or cfg.get("azure_api_version") + or None + ), + api_key=( + cfg.get("azure_openai_api_key") + or cfg.get("azure_api_key") + or None + ), + auth_mode=cfg.get("azure_openai_auth_mode") or None, + ad_scope=cfg.get("azure_openai_ad_scope") or None, + managed_identity_client_id=cfg.get("azure_openai_managed_identity_client_id") or None, + teacher_endpoint=cfg.get("teacher_azure_openai_endpoint") or None, + teacher_api_version=cfg.get("teacher_azure_openai_api_version") or None, + teacher_api_key=cfg.get("teacher_azure_openai_api_key") or None, + teacher_auth_mode=cfg.get("teacher_azure_openai_auth_mode") or None, + teacher_ad_scope=cfg.get("teacher_azure_openai_ad_scope") or None, + teacher_managed_identity_client_id=( + cfg.get("teacher_azure_openai_managed_identity_client_id") or None + ), + student_endpoint=cfg.get("student_azure_openai_endpoint") or None, + student_api_version=cfg.get("student_azure_openai_api_version") or None, + student_api_key=cfg.get("student_azure_openai_api_key") or None, + student_auth_mode=cfg.get("student_azure_openai_auth_mode") or None, + student_ad_scope=cfg.get("student_azure_openai_ad_scope") or None, + student_managed_identity_client_id=( + cfg.get("student_azure_openai_managed_identity_client_id") or None + ), + ) + teacher_backend = cfg.get("teacher_backend") + student_backend = cfg.get("student_backend") + if not teacher_backend or not student_backend: + if backend in {"claude", "claude_chat"}: + teacher_backend = teacher_backend or "claude_chat" + student_backend = student_backend or "claude_chat" + elif backend in {"codex", "codex_exec"}: + teacher_backend = teacher_backend or "openai_chat" + student_backend = student_backend or "codex_exec" + elif backend == "claude_code_exec": + teacher_backend = teacher_backend or "openai_chat" + student_backend = student_backend or "claude_code_exec" + else: + teacher_backend = teacher_backend or "openai_chat" + student_backend = student_backend or "openai_chat" + cfg["teacher_backend"] = teacher_backend + cfg["student_backend"] = student_backend + set_teacher_backend(teacher_backend) + set_student_backend(student_backend) + set_teacher_deployment(cfg["teacher_model"]) + set_student_deployment(cfg["student_model"]) + configure_codex_exec( + path=cfg.get("codex_exec_path", "codex"), + sandbox=cfg.get("codex_exec_sandbox", "workspace-write"), + profile=cfg.get("codex_exec_profile", ""), + full_auto=cfg.get("codex_exec_full_auto", False), + reasoning_effort=cfg.get("codex_exec_reasoning_effort", "none"), + use_sdk=cfg.get("codex_exec_use_sdk", None), + network_access=cfg.get("codex_exec_network_access", False), + web_search=cfg.get("codex_exec_web_search", False), + approval_policy=cfg.get("codex_exec_approval_policy", "never"), + ) + configure_claude_code_exec( + path=cfg.get("claude_code_exec_path", "claude"), + profile=cfg.get("claude_code_exec_profile", ""), + use_sdk=cfg.get("claude_code_exec_use_sdk", None), + effort=cfg.get("claude_code_exec_effort", cfg.get("reasoning_effort", "medium")), + max_thinking_tokens=cfg.get("claude_code_exec_max_thinking_tokens", 16384), + ) + os.environ["REFLACT_CODEX_TRACE_TO_TEACHER"] = ( + "1" + if student_backend == "codex_exec" and cfg.get("codex_trace_to_teacher", False) + else "0" + ) + reasoning = cfg.get("reasoning_effort", "") or None + set_reasoning_effort(reasoning) + if student_backend == "claude_code_exec" and cfg.get("use_deep_reflect", False): + raise NotImplementedError("claude_code_exec does not support use_deep_reflect yet.") + print( + f" [model config] backend={backend} " + f"teacher={cfg['teacher_model']} ({teacher_backend}) " + f"student={cfg['student_model']} ({student_backend}) " + f"reasoning={reasoning or 'off'}" + ) + + # ── Initialize Ray ─────────────────────────────────────────────── + if adapter.requires_ray(): + try: + import ray + except ImportError as e: + raise ImportError( + "This environment requires ray, but ray is not installed." + ) from e + + if not ray.is_initialized(): + ray.init(num_gpus=0) + + # ── Load initial skill ─────────────────────────────────────────── + skill_init_path = os.path.abspath(cfg["skill_init"]) + if os.path.exists(skill_init_path): + with open(skill_init_path) as f: + skill_init = f.read() + print(f" [initial skill] {skill_init_path} ({len(skill_init)} chars)") + else: + skill_init = "" + print(" [initial skill] no initial skill file — starting from blank") + + # ── Training parameters ────────────────────────────────────────── + batch_size = cfg["batch_size"] + num_epochs = cfg["num_epochs"] + accumulation = cfg["accumulation"] + seed = cfg["seed"] + merge_bs = cfg["merge_batch_size"] + max_analyst_rounds = int(cfg.get("max_analyst_rounds", 3) or 3) + update_mode = normalize_update_mode(cfg.get("skill_update_mode", "patch")) + lr_control_mode = _normalise_lr_control_mode(cfg.get("lr_control_mode", "fixed")) + if is_full_rewrite_minibatch_mode(update_mode): + lr_control_mode = "none" + longitudinal_pair_policy = _normalise_longitudinal_pair_policy( + cfg.get("longitudinal_pair_policy", "mixed") + ) + rewrite_reasoning_effort = cfg.get("rewrite_reasoning_effort", "high") + if rewrite_reasoning_effort == "": + rewrite_reasoning_effort = None + rewrite_max_completion_tokens = int(cfg.get("rewrite_max_completion_tokens", 64000)) + if batch_size <= 0: + raise ValueError(f"batch_size must be positive, got {batch_size}") + if accumulation <= 0: + raise ValueError(f"accumulation must be positive, got {accumulation}") + + train_size = _resolve_train_size(cfg, dataloader) + steps_per_epoch = math.ceil(train_size / (batch_size * accumulation)) + batches_per_epoch = steps_per_epoch * accumulation + total_steps = num_epochs * steps_per_epoch + + # Persist resolved derived fields so config.json / summary.json match + # the actual runtime recipe. + cfg["train_size"] = train_size + cfg["steps_per_epoch"] = steps_per_epoch + cfg["batches_per_epoch"] = batches_per_epoch + cfg["samples_per_epoch"] = train_size + cfg["skill_update_mode"] = update_mode + cfg["lr_control_mode"] = lr_control_mode + + # Save config after deriving runtime values. + with open(os.path.join(out_root, "config.json"), "w") as f: + json.dump(_redact_cfg(cfg), f, indent=2, ensure_ascii=False) + + train_pool_size = train_size + + scheduler = build_scheduler( + mode=cfg.get("lr_scheduler", "constant"), + max_lr=cfg["edit_budget"], + min_lr=cfg.get("min_edit_budget", 2), + total_steps=total_steps, + ) + + # Fixed training pool: base seeds (each seed = one deterministic batch) + if dataloader is not None: + base_seeds = dataloader.make_base_seeds( + steps_per_epoch=steps_per_epoch, + accumulation=accumulation, + seed=seed, + ) + else: + base_seeds = [seed + i + 1 for i in range(batches_per_epoch)] + + print(f"\n [config] epochs={num_epochs} steps/epoch={steps_per_epoch} " + f"(auto) accum={accumulation} batch_size={batch_size}") + print(f" [config] train_size={train_size}") + print(f" [config] batches/epoch={batches_per_epoch} " + f"total_steps={total_steps} " + f"games/epoch={train_pool_size}") + print(f" [config] lr_scheduler={cfg.get('lr_scheduler', 'constant')} " + f"edit_budget={cfg['edit_budget']} " + f"min_edit_budget={cfg.get('min_edit_budget', 2)}") + print(f" [config] skill_update_mode={update_mode} " + f"lr_control_mode={lr_control_mode} " + f"rewrite_reasoning_effort={rewrite_reasoning_effort or 'off'} " + f"rewrite_max_completion_tokens={rewrite_max_completion_tokens} " + f"max_analyst_rounds={max_analyst_rounds}") + print(f" [config] longitudinal_pair_policy={longitudinal_pair_policy}") + print(f" [config] base_seeds={base_seeds}") + + # ── Resume check ───────────────────────────────────────────────── + history = _load_history(out_root) + runtime_state = _load_runtime_state(out_root) + if runtime_state: + last_step = int(runtime_state.get("last_completed_step", 0) or 0) + current_skill_path = runtime_state.get("current_skill_path") or os.path.join( + out_root, "skills", f"skill_v{last_step:04d}.md", + ) + with open(current_skill_path) as f: + current_skill = f.read() + best_skill_path = runtime_state.get("best_skill_path") or os.path.join( + out_root, "best_skill.md", + ) + if os.path.exists(best_skill_path): + with open(best_skill_path) as f: + best_skill = f.read() + else: + best_skill = current_skill + current_score = float(runtime_state.get("current_score", -1.0) or -1.0) + best_score = float(runtime_state.get("best_score", current_score) or current_score) + best_step = runtime_state.get("best_step", last_step) + current_origin = str( + runtime_state.get("current_origin") + or (f"step_{last_step:04d}" if last_step > 0 else "initial_skill") + ) + best_origin = str(runtime_state.get("best_origin") or current_origin) + resume_from = last_step + 1 + scheduler.load_state_dict({"current_step": last_step}) + print( + f" [resume] from step {resume_from} " + f"current={current_score:.4f} best={best_score:.4f} " + f"(origin={current_origin})" + ) + elif history: + last_step = history[-1]["step"] + current_skill = _load_skill(out_root, last_step) + best_rec = max(history, key=lambda h: h.get("best_score", 0.0)) + best_score = best_rec["best_score"] + best_step = best_rec["best_step"] + best_skill_path = os.path.join(out_root, "best_skill.md") + if os.path.exists(best_skill_path): + with open(best_skill_path) as f: + best_skill = f.read() + else: + best_skill = _load_skill(out_root, best_step) + current_score = history[-1].get("current_score", best_score) + current_origin = f"step_{last_step:04d}" + best_origin = f"step_{int(best_step):04d}" if isinstance(best_step, int) else str(best_step) + resume_from = last_step + 1 + scheduler.load_state_dict({"current_step": last_step}) + print( + f" [resume] from step {resume_from} " + f"current={current_score:.4f} best={best_score:.4f}" + ) + else: + current_skill = skill_init + best_skill = skill_init + best_score = -1.0 + current_score = -1.0 + best_step = 0 + current_origin = "initial_skill" + best_origin = "initial_skill" + resume_from = 1 + + _save_skill(out_root, 0, skill_init) + + def _persist_runtime_state(last_completed_step: int) -> None: + _save_runtime_state( + out_root, + { + "last_completed_step": last_completed_step, + "current_skill_path": os.path.join( + out_root, "skills", f"skill_v{last_completed_step:04d}.md", + ), + "current_score": current_score, + "current_origin": current_origin, + "best_skill_path": os.path.join(out_root, "best_skill.md"), + "best_score": best_score, + "best_step": best_step, + "best_origin": best_origin, + }, + ) + + # ── Selection cache ────────────────────────────────────────────── + sel_cache: dict[str, tuple[float, float]] = {} + for rec in history: + sh = rec.get("candidate_hash", "") + if sh and rec.get("selection_hard") is not None: + sel_cache[sh] = (rec["selection_hard"], rec["selection_soft"]) + + # ── Baseline evaluation on selection set ───────────────────────── + if cfg.get("use_gate") is False: + raise ValueError( + "Gate validation is mandatory in this branch. Remove " + "`evaluation.use_gate=false` from the config." + ) + if current_score < 0: + print(f"\n{'='*60}") + print(" BASELINE — evaluate initial skill on Selection set (valid_seen)") + print(f"{'='*60}") + sel_env, sel_n = _build_eval_env( + split="valid_seen", + env_num=cfg["sel_env_num"], + seed=seed, + ) + print(f" Selection items: {sel_n}") + baseline_dir = os.path.join(out_root, "selection_eval_baseline") + baseline_results = adapter.rollout(sel_env, skill_init, baseline_dir) + current_score, baseline_soft = compute_score(baseline_results) + best_score = current_score + sh = skill_hash(skill_init) + sel_cache[sh] = (current_score, baseline_soft) + current_origin = "initial_skill" + best_origin = "initial_skill" + _persist_runtime_state(0) + print( + f" [baseline result] selection hard={current_score:.4f} " + f"soft={baseline_soft:.4f}" + ) + + # ── Training loop ──────────────────────────────────────────────── + t_loop_start = time.time() + + if resume_from > total_steps: + print(f"\n [skip] all {total_steps} steps complete — jumping to evaluation") + + global_step = 0 + for epoch in range(1, num_epochs + 1): + if dataloader is not None: + epoch_batches = dataloader.plan_train_epoch( + epoch=epoch, + steps_per_epoch=steps_per_epoch, + accumulation=accumulation, + batch_size=batch_size, + seed=seed, + out_root=out_root, + ) + shuffled_seeds = [batch.seed for batch in epoch_batches] + else: + epoch_batches = [] + epoch_rng = random.Random(seed + epoch * 1000) + shuffled_seeds = base_seeds.copy() + epoch_rng.shuffle(shuffled_seeds) + + # Step buffer: accumulates per-step context (failure patterns + + # rejected edits) within this epoch so teachers see full history. + step_buffer: list[dict] = [] + active_meta_skill = ( + _load_meta_skill_content(out_root, epoch - 1) + if cfg.get("use_meta_skill", False) + else "" + ) + + print( + f"\n [EPOCH {epoch}/{num_epochs}] " + f"shuffled_seeds={shuffled_seeds}" + ) + if active_meta_skill: + print( + f" [meta skill] loaded from epoch {epoch - 1} " + f"({len(active_meta_skill)} chars)" + ) + + for step_in_epoch in range(steps_per_epoch): + global_step += 1 + if global_step < resume_from: + continue + + step_t0 = time.time() + step_dir = os.path.join(out_root, "steps", f"step_{global_step:04d}") + os.makedirs(step_dir, exist_ok=True) + + tokens_before = get_token_summary() + + print( + f"\n [STEP {global_step}/{total_steps}] " + f"epoch={epoch} step_in_epoch={step_in_epoch} " + f"{'='*30}" + ) + + step_rec: dict = { + "step": global_step, + "epoch": epoch, + "step_in_epoch": step_in_epoch, + "timing": {}, + "tokens": {}, + } + + # ── Accumulation: Rollout + Reflect ────────────────────── + all_failure_patches: list[dict] = [] + all_success_patches: list[dict] = [] + all_raw_patches: list[dict | None] = [] + all_rollout_results: list[dict] = [] + accum_rollout_stats: list[dict] = [] + total_rollout_time = 0.0 + total_reflect_time = 0.0 + total_deep_reflect_time = 0.0 + + for a in range(accumulation): + batch_idx = step_in_epoch * accumulation + a + if dataloader is not None: + batch_spec = epoch_batches[batch_idx] + train_env, train_n, batch_seed = _build_train_env(batch_spec) + else: + batch_seed = shuffled_seeds[batch_idx] + train_env = adapter.build_train_env( + batch_size=batch_size, + seed=batch_seed, + out_root=out_root, + ) + train_n = len(train_env) if hasattr(train_env, "__len__") else batch_size + + # Directory routing + if accumulation > 1: + batch_dir = os.path.join(step_dir, f"batch_{a}") + else: + batch_dir = step_dir + + rollout_dir = os.path.join(batch_dir, "rollout") + patches_dir = os.path.join(batch_dir, "patches") + + # ① ROLLOUT ──────────────────────────────────────────── + t_phase = time.time() + print(f" [1/6 ROLLOUT] train items={train_n} (from pool, batch_seed={batch_seed})") + rollout_results = adapter.rollout( + train_env, current_skill, rollout_dir, + use_eval_feedback=True, + ) + r_hard, r_soft = compute_score(rollout_results) + total_rollout_time += time.time() - t_phase + all_rollout_results.extend(rollout_results) + print(f" [1/6 done] hard={r_hard:.4f} soft={r_soft:.4f}") + + # ② REFLECT ──────────────────────────────────────────── + t_phase = time.time() + pred_dir = os.path.join(rollout_dir, "predictions") + + # Build step context from buffer + step_buffer_context = _format_step_buffer(step_buffer) + + raw_patches = adapter.reflect( + rollout_results, current_skill, batch_dir, + prediction_dir=pred_dir, patches_dir=patches_dir, + random_seed=batch_seed, + step_buffer_context=step_buffer_context, + meta_skill_context=active_meta_skill, + ) + failure_patches, success_patches = _normalise_patches( + raw_patches, + update_mode=update_mode, + ) + all_failure_patches.extend(failure_patches) + all_success_patches.extend(success_patches) + all_raw_patches.extend(raw_patches) + total_reflect_time += time.time() - t_phase + + print( + f" [2/6 done] failure_patches={len(failure_patches)} " + f"success_patches={len(success_patches)}" + ) + + deep_failure_patches: list[dict] = [] + deep_success_patches: list[dict] = [] + if cfg.get("use_deep_reflect", False): + t_phase = time.time() + deep_raw_patches = adapter.deep_reflect( + rollout_results, + current_skill, + batch_dir, + env_manager=train_env, + prediction_dir=pred_dir, + random_seed=batch_seed, + step_buffer_context=step_buffer_context, + meta_skill_context=active_meta_skill, + ) + deep_failure_patches, deep_success_patches = _normalise_patches( + deep_raw_patches, + update_mode=update_mode, + ) + all_failure_patches.extend(deep_failure_patches) + all_success_patches.extend(deep_success_patches) + all_raw_patches.extend(deep_raw_patches) + total_deep_reflect_time += time.time() - t_phase + print( + f" [2b/6 DEEP REFLECT] failure_patches={len(deep_failure_patches)} " + f"success_patches={len(deep_success_patches)}" + ) + + # Track per-batch stats + accum_rollout_stats.append({ + "batch_idx": a, + "batch_seed": batch_seed, + "n_envs": len(rollout_results), + "hard": r_hard, + "soft": r_soft, + "n_failure_patches": len(failure_patches), + "n_success_patches": len(success_patches), + "n_deep_failure_patches": len(deep_failure_patches), + "n_deep_success_patches": len(deep_success_patches), + }) + + # ── End of accumulation loop ───────────────────────────── + + # Aggregate rollout stats across batches + total_n = sum(b["n_envs"] for b in accum_rollout_stats) + agg_hard = sum(b["hard"] * b["n_envs"] for b in accum_rollout_stats) / max(total_n, 1) + agg_soft = sum(b["soft"] * b["n_envs"] for b in accum_rollout_stats) / max(total_n, 1) + + step_rec["rollout_hard"] = round(agg_hard, 6) + step_rec["rollout_soft"] = round(agg_soft, 6) + step_rec["rollout_n"] = total_n + step_rec["accumulation_batches"] = accum_rollout_stats + step_rec["timing"]["rollout_s"] = round(total_rollout_time, 1) + step_rec["timing"]["reflect_s"] = round(total_reflect_time, 1) + if cfg.get("use_deep_reflect", False): + step_rec["timing"]["deep_reflect_s"] = round(total_deep_reflect_time, 1) + + n_total_patches = len(all_failure_patches) + len(all_success_patches) + step_rec["n_patches"] = n_total_patches + step_rec["n_failure_patches"] = len(all_failure_patches) + step_rec["n_success_patches"] = len(all_success_patches) + + if accumulation > 1: + print( + f" [accum done] total: failure={len(all_failure_patches)} " + f"success={len(all_success_patches)} " + f"from {accumulation} batches" + ) + + # ── No patches? Skip ───────────────────────────────────── + if not all_failure_patches and not all_success_patches: + step_rec["action"] = "skip_no_patches" + step_rec["current_score"] = current_score + step_rec["best_score"] = best_score + step_rec["best_step"] = best_step + step_rec["skill_len"] = len(current_skill) + step_rec["wall_time_s"] = round(time.time() - step_t0, 1) + history.append(step_rec) + _save_history(out_root, history) + _save_skill(out_root, global_step, current_skill) + _persist_runtime_state(global_step) + with open(os.path.join(step_dir, "step_record.json"), "w") as f: + json.dump(step_rec, f, indent=2, ensure_ascii=False) + print(" [skip] no usable patches — skill unchanged") + continue + + # ③ AGGREGATE ────────────────────────────────────────────── + t_phase = time.time() + merged_patch = merge_patches( + current_skill, all_failure_patches, all_success_patches, + batch_size=merge_bs, verbose=True, + workers=cfg["analyst_workers"], + update_mode=update_mode, + meta_skill_context=active_meta_skill, + ) + with open(os.path.join(step_dir, "merged_patch.json"), "w") as f: + json.dump(merged_patch, f, ensure_ascii=False, indent=2) + + merged_items = get_payload_items(merged_patch, update_mode) + n_edits_merged = len(merged_items) + step_rec["n_edits_merged"] = n_edits_merged + step_rec["timing"]["aggregate_s"] = round(time.time() - t_phase, 1) + print(f" [3/6 done] merged {n_edits_merged} {payload_label(update_mode)}") + + # ④ SELECT ───────────────────────────────────────────────── + t_phase = time.time() + lr_decision = None + if is_full_rewrite_minibatch_mode(update_mode): + edit_budget = None + ranked_patch = merged_patch + ranked_items = merged_items + n_edits_ranked = len(ranked_items) + step_rec["n_edits_ranked"] = n_edits_ranked + step_rec["edit_budget"] = None + step_rec["lr_control_mode"] = "none" + with open(os.path.join(step_dir, "ranked_edits.json"), "w") as f: + json.dump(ranked_patch, f, ensure_ascii=False, indent=2) + else: + if lr_control_mode == "autonomous": + lr_decision = decide_autonomous_learning_rate( + skill_content=current_skill, + merged_patch=merged_patch, + update_mode=update_mode, + rollout_hard=agg_hard, + rollout_soft=agg_soft, + rollout_n=total_n, + step_buffer_context=step_buffer_context, + meta_skill_context=active_meta_skill, + ) + edit_budget = int(lr_decision["learning_rate"]) + with open(os.path.join(step_dir, "lr_decision.json"), "w") as f: + json.dump(lr_decision, f, ensure_ascii=False, indent=2) + with open(os.path.join(out_root, "lr_history.jsonl"), "a") as f: + f.write(json.dumps({ + "step": global_step, + "epoch": epoch, + **lr_decision, + }, ensure_ascii=False) + "\n") + else: + edit_budget = scheduler.step() + ranked_patch = rank_and_select( + current_skill, merged_patch, + max_edits=edit_budget, + update_mode=update_mode, + meta_skill_context=active_meta_skill, + ) + with open(os.path.join(step_dir, "ranked_edits.json"), "w") as f: + json.dump(ranked_patch, f, ensure_ascii=False, indent=2) + + ranked_items = get_payload_items(ranked_patch, update_mode) + n_edits_ranked = len(ranked_items) + step_rec["n_edits_ranked"] = n_edits_ranked + step_rec["edit_budget"] = edit_budget + step_rec["lr_control_mode"] = lr_control_mode + if lr_decision is not None: + step_rec["lr_decision"] = lr_decision + step_rec["timing"]["select_s"] = round(time.time() - t_phase, 1) + + support_counts = [ + item.get("support_count", 0) for item in ranked_items if isinstance(item, dict) + ] + step_rec["support_counts"] = support_counts + if is_full_rewrite_minibatch_mode(update_mode): + print( + f" [4/6 SELECT] skipped LR/select; " + f"using {n_edits_ranked} merged {payload_label(update_mode)}" + ) + else: + print( + f" [4/6 SELECT] " + f"{n_edits_merged} -> {n_edits_ranked} {payload_label(update_mode)} " + f"(budget={edit_budget}, lr_control={lr_control_mode})" + ) + + # ⑤ UPDATE ───────────────────────────────────────────────── + t_phase = time.time() + rewrite_result = None + if update_mode == "rewrite_from_suggestions": + rewrite_result = rewrite_skill_from_suggestions( + current_skill, + ranked_patch, + step_buffer_context=step_buffer_context, + env=cfg.get("env"), + reasoning_effort=rewrite_reasoning_effort, + max_completion_tokens=rewrite_max_completion_tokens, + ) + if rewrite_result and rewrite_result.get("new_skill"): + candidate_skill = rewrite_result["new_skill"] + apply_report = [] + with open(os.path.join(step_dir, "rewrite_result.json"), "w") as f: + json.dump(rewrite_result, f, ensure_ascii=False, indent=2) + else: + candidate_skill = current_skill + apply_report = [] + elif is_full_rewrite_minibatch_mode(update_mode): + skill_candidates = get_payload_items(ranked_patch, update_mode) + selected_candidate = next( + ( + item for item in skill_candidates + if isinstance(item, dict) and str(item.get("new_skill", "")).strip() + ), + None, + ) + if selected_candidate: + candidate_skill = str(selected_candidate["new_skill"]).rstrip() + "\n" + apply_report = [] + rewrite_result = { + "reasoning": ranked_patch.get("reasoning", ""), + "change_summary": selected_candidate.get("change_summary", []), + "title": selected_candidate.get("title", ""), + "source_type": selected_candidate.get("source_type", ""), + } + with open(os.path.join(step_dir, "full_rewrite_result.json"), "w") as f: + json.dump( + { + "selected_candidate": selected_candidate, + "merged_patch": ranked_patch, + }, + f, + ensure_ascii=False, + indent=2, + ) + else: + candidate_skill = current_skill + apply_report = [] + else: + candidate_skill, apply_report = apply_patch_with_report(current_skill, ranked_patch) + with open(os.path.join(step_dir, "candidate_skill.md"), "w") as f: + f.write(candidate_skill) + if apply_report: + with open(os.path.join(step_dir, "edit_apply_report.json"), "w") as f: + json.dump(apply_report, f, indent=2, ensure_ascii=False) + + cand_hash = skill_hash(candidate_skill) + step_rec["candidate_hash"] = cand_hash + step_rec["candidate_skill_len"] = len(candidate_skill) + if rewrite_result: + step_rec["rewrite_change_summary"] = rewrite_result.get("change_summary", []) + if apply_report: + step_rec["edit_apply_summary"] = { + "total": len(apply_report), + "applied": sum( + 1 for row in apply_report if str(row.get("status", "")).startswith("applied") + ), + "skipped": sum( + 1 for row in apply_report if str(row.get("status", "")).startswith("skipped") + ), + "errors": sum( + 1 for row in apply_report if row.get("status") == "error" + ), + } + step_rec["timing"]["update_s"] = round(time.time() - t_phase, 1) + if ( + update_mode == "rewrite_from_suggestions" + and rewrite_result is None + ) or ( + is_full_rewrite_minibatch_mode(update_mode) + and rewrite_result is None + ): + step_rec["action"] = "skip_no_rewrite" + step_rec["current_score"] = current_score + step_rec["best_score"] = best_score + step_rec["best_step"] = best_step + step_rec["skill_len"] = len(current_skill) + step_rec["wall_time_s"] = round(time.time() - step_t0, 1) + history.append(step_rec) + _save_history(out_root, history) + _save_skill(out_root, global_step, current_skill) + _persist_runtime_state(global_step) + with open(os.path.join(step_dir, "step_record.json"), "w") as f: + json.dump(step_rec, f, indent=2, ensure_ascii=False) + print(" [skip] no usable rewrite generated — skill unchanged") + continue + print( + f" [5/6 UPDATE] " + f"skill_len {len(current_skill)} -> {len(candidate_skill)}" + ) + + # ⑥ EVALUATE ─────────────────────────────────────────────── + t_phase = time.time() + if cand_hash in sel_cache: + cand_hard, cand_soft = sel_cache[cand_hash] + print( + f" [6/6 EVALUATE] " + f"cache hit {cand_hash}: hard={cand_hard:.4f}" + ) + else: + sel_env, sel_n = _build_eval_env( + split="valid_seen", + env_num=cfg["sel_env_num"], + seed=seed, + ) + print(f" [6/6 EVALUATE] selection items={sel_n}") + sel_eval_dir = os.path.join(step_dir, "selection_eval") + sel_results = adapter.rollout(sel_env, candidate_skill, sel_eval_dir) + cand_hard, cand_soft = compute_score(sel_results) + sel_cache[cand_hash] = (cand_hard, cand_soft) + + step_rec["selection_hard"] = cand_hard + step_rec["selection_soft"] = cand_soft + + gate = evaluate_gate( + candidate_skill=candidate_skill, + cand_hard=cand_hard, + current_skill=current_skill, + current_score=current_score, + best_skill=best_skill, + best_score=best_score, + best_step=best_step, + global_step=global_step, + ) + step_rec["action"] = gate.action + prev_current = current_score + prev_best = best_score + current_skill = gate.current_skill + current_score = gate.current_score + best_skill = gate.best_skill + best_score = gate.best_score + best_step = gate.best_step + if gate.action in {"accept", "accept_new_best"}: + current_origin = f"step_{global_step:04d}" + if gate.action == "accept_new_best": + best_origin = current_origin + + if gate.action == "accept_new_best": + print( + f" [6/6 EVALUATE] ACCEPT (new best) " + f"hard={cand_hard:.4f} > prev best {prev_best:.4f}" + ) + elif gate.action == "accept": + print( + f" [6/6 EVALUATE] ACCEPT " + f"hard={cand_hard:.4f} > current={prev_current:.4f}" + ) + else: + print( + f" [6/6 EVALUATE] REJECT " + f"hard={cand_hard:.4f} <= current={current_score:.4f}" + ) + + step_rec["timing"]["evaluate_s"] = round(time.time() - t_phase, 1) + + # ── Step buffer: unified failure patterns + rejected edits ─ + action = step_rec.get("action", "unknown") + n_total = len(all_rollout_results) or 1 + n_fail = sum(1 for r in all_rollout_results if not r.get("hard")) + failure_patterns = _extract_failure_patterns( + all_rollout_results, step_dir, + ) + + buf_entry: dict = { + "step": global_step, + "action": action, + "n_total": n_total, + "n_fail": n_fail, + "failure_patterns": failure_patterns, + } + + # Attach rejected edits when the step was rejected + if "reject" in action and ranked_patch: + rejected_edits = [ + short_item_summary(item, update_mode) + for item in ranked_items + if isinstance(item, dict) + ] + buf_entry["score_before"] = current_score + buf_entry["score_after"] = cand_hard + buf_entry["rejected_edits"] = rejected_edits + + step_buffer.append(buf_entry) + + # Persist for meta-reflect + digest_path = os.path.join(step_dir, "trajectory_digest.json") + with open(digest_path, "w") as f: + json.dump(buf_entry, f, indent=2, ensure_ascii=False) + + # ── Token snapshot ─────────────────────────────────────── + tokens_after = get_token_summary() + step_tokens: dict = {} + for stage in tokens_after: + if stage == "_total": + continue + after = tokens_after[stage] + before = tokens_before.get(stage, {}) + step_tokens[stage] = { + "calls": after.get("calls", 0) - before.get("calls", 0), + "prompt_tokens": after.get("prompt_tokens", 0) + - before.get("prompt_tokens", 0), + "completion_tokens": after.get("completion_tokens", 0) + - before.get("completion_tokens", 0), + } + step_rec["tokens"] = step_tokens + + # ── Save state ─────────────────────────────────────────── + step_rec["current_score"] = current_score + step_rec["best_score"] = best_score + step_rec["best_step"] = best_step + step_rec["current_origin"] = current_origin + step_rec["best_origin"] = best_origin + step_rec["skill_len"] = len(current_skill) + step_rec["wall_time_s"] = round(time.time() - step_t0, 1) + + _save_skill(out_root, global_step, current_skill) + with open(os.path.join(out_root, "best_skill.md"), "w") as f: + f.write(best_skill) + history.append(step_rec) + _save_history(out_root, history) + _persist_runtime_state(global_step) + with open(os.path.join(step_dir, "step_record.json"), "w") as f: + json.dump(step_rec, f, indent=2, ensure_ascii=False) + + timing = step_rec["timing"] + print( + f"\n [STEP {global_step} done] " + f"epoch={epoch} action={step_rec['action']} " + f"current={current_score:.4f} best={best_score:.4f} " + f"dt={step_rec['wall_time_s']}s\n" + f" timing: rollout={timing.get('rollout_s',0)}s " + f"reflect={timing.get('reflect_s',0)}s " + f"deep_reflect={timing.get('deep_reflect_s',0)}s " + f"aggregate={timing.get('aggregate_s',0)}s " + f"select={timing.get('select_s',0)}s " + f"evaluate={timing.get('evaluate_s',0)}s" + ) + + epoch_last_step_skill = current_skill + epoch_comparison_pairs: list[dict] | None = None + + # ── SLOW UPDATE (end of epoch) ────────────────────────────── + use_slow = cfg.get("use_slow_update", False) + if use_slow: + slow_dir = os.path.join(out_root, "slow_update", f"epoch_{epoch:02d}") + slow_done_path = os.path.join(slow_dir, "slow_result.json") + + if os.path.exists(slow_done_path): + # Resume support + print( + f"\n [SLOW UPDATE epoch {epoch}] " + f"resumed — already done" + ) + with open(slow_done_path) as f: + slow_saved = json.load(f) + comparison_path = os.path.join(slow_dir, "comparison_pairs.json") + if os.path.exists(comparison_path): + try: + with open(comparison_path) as f: + epoch_comparison_pairs = json.load(f) + except Exception: + epoch_comparison_pairs = None + if ( + slow_saved.get("slow_update_content") + and slow_saved.get("action") in {"accept", "accept_new_best"} + and epoch >= 2 + ): + current_skill = replace_slow_update_field( + current_skill, slow_saved["slow_update_content"], + ) + elif epoch == 1: + # Epoch 1: inject empty placeholder + os.makedirs(slow_dir, exist_ok=True) + current_skill = inject_empty_slow_update_field(current_skill) + current_origin = f"slow_update_placeholder_epoch_{epoch:02d}" + _save_skill(out_root, global_step, current_skill) + with open(os.path.join(out_root, "best_skill.md"), "w") as f: + f.write(best_skill if best_score > current_score else current_skill) + with open(slow_done_path, "w") as f: + json.dump({"action": "inject_placeholder", "epoch": epoch}, f, indent=2) + _persist_runtime_state(global_step) + print( + f"\n [SLOW UPDATE epoch {epoch}] " + f"injected empty placeholder" + ) + else: + # Epoch 2+: longitudinal comparison + os.makedirs(slow_dir, exist_ok=True) + print( + f"\n {'='*60}\n" + f" SLOW UPDATE — Epoch {epoch} " + f"(comparing epoch {epoch-1} vs {epoch})\n" + f" {'='*60}" + ) + + # 1. Get skill from last step of previous epoch + prev_epoch_records = [ + h for h in history if h.get("epoch") == epoch - 1 + ] + prev_epoch_last_step = prev_epoch_records[-1]["step"] + prev_skill = _load_skill(out_root, prev_epoch_last_step) + + # 2. Sample items from train set + slow_n = cfg.get("slow_update_samples", 20) + slow_seed = seed + epoch * 2000 + if dataloader is not None: + slow_batch = dataloader.build_train_batch( + batch_size=slow_n, + seed=slow_seed, + out_root=out_root, + ) + slow_env = adapter.build_env_from_batch( + slow_batch, out_root=out_root, + ) + else: + slow_env = adapter.build_train_env( + batch_size=slow_n, + seed=slow_seed, + out_root=out_root, + ) + slow_items = list(slow_env) if hasattr(slow_env, "__iter__") else slow_env + print(f" [slow update] sampled {len(slow_items)} train items (seed={slow_seed})") + + # 3. Rollout with both skills + t_slow = time.time() + prev_rollout_dir = os.path.join(slow_dir, "rollout_prev") + curr_rollout_dir = os.path.join(slow_dir, "rollout_curr") + results_prev = adapter.rollout(slow_env, prev_skill, prev_rollout_dir) + results_curr = adapter.rollout(slow_env, current_skill, curr_rollout_dir) + + prev_hard, _ = compute_score(results_prev) + curr_hard, _ = compute_score(results_curr) + print( + f" [slow update] prev epoch hard={prev_hard:.4f} " + f"curr epoch hard={curr_hard:.4f}" + ) + + # 4. Build and save structured comparison pairs + comparison_pairs, all_comparison_pairs = _build_longitudinal_pairs( + adapter=adapter, + dataloader=dataloader, + prev_skill=prev_skill, + curr_skill=current_skill, + initial_items=slow_items, + initial_prev_results=results_prev, + initial_curr_results=results_curr, + prev_rollout_dir=prev_rollout_dir, + curr_rollout_dir=curr_rollout_dir, + policy=longitudinal_pair_policy, + target_n=slow_n, + seed=slow_seed, + out_root=out_root, + ) + epoch_comparison_pairs = comparison_pairs + if all_comparison_pairs is not comparison_pairs: + save_comparison_pairs( + all_comparison_pairs, + os.path.join(slow_dir, "comparison_pairs_all.json"), + ) + save_comparison_pairs( + comparison_pairs, + os.path.join(slow_dir, "comparison_pairs.json"), + ) + n_regressed = sum(1 for p in comparison_pairs if p["category"] == "regressed") + n_improved = sum(1 for p in comparison_pairs if p["category"] == "improved") + n_persist = sum(1 for p in comparison_pairs if p["category"] == "persistent_fail") + n_stable = sum(1 for p in comparison_pairs if p["category"] == "stable_success") + print( + f" [slow update] comparison: " + f"regressed={n_regressed} improved={n_improved} " + f"persistent_fail={n_persist} stable_success={n_stable} " + f"policy={longitudinal_pair_policy} " + f"kept={len(comparison_pairs)}/{len(all_comparison_pairs)}" + ) + + # 5. Extract previous slow update guidance for reflection + existing_guidance = extract_slow_update_field(current_skill) + + # 6. Teacher analysis (with reflection on previous guidance) + slow_result = run_slow_update( + current_skill, + results_prev, + results_curr, + slow_items, + prev_skill=prev_skill, + prev_slow_update_content=existing_guidance, + prev_rollout_dir=prev_rollout_dir, + curr_rollout_dir=curr_rollout_dir, + comparison_pairs=comparison_pairs, + ) + slow_time = round(time.time() - t_slow, 1) + + if slow_result and slow_result.get("slow_update_content"): + slow_candidate = replace_slow_update_field( + current_skill, slow_result["slow_update_content"], + ) + slow_candidate_hash = skill_hash(slow_candidate) + with open(os.path.join(slow_dir, "candidate_skill.md"), "w") as f: + f.write(slow_candidate) + slow_result["time_s"] = slow_time + slow_result["prev_hard"] = prev_hard + slow_result["curr_hard"] = curr_hard + slow_result["candidate_hash"] = slow_candidate_hash + slow_result["update_origin"] = "slow_update_momentum" + slow_result["update_target"] = ( + "Address longitudinal regressions and persistent failures " + "observed across adjacent epochs." + ) + + if slow_candidate_hash in sel_cache: + slow_sel_hard, slow_sel_soft = sel_cache[slow_candidate_hash] + print( + f" [slow gate] cache hit: hard={slow_sel_hard:.4f}" + ) + else: + sel_env, sel_n = _build_eval_env( + split="valid_seen", + env_num=cfg["sel_env_num"], + seed=seed, + ) + print(f" [slow gate] selection items={sel_n}") + slow_eval_dir = os.path.join(slow_dir, "selection_eval") + slow_eval_results = adapter.rollout( + sel_env, slow_candidate, slow_eval_dir, + ) + slow_sel_hard, slow_sel_soft = compute_score(slow_eval_results) + sel_cache[slow_candidate_hash] = (slow_sel_hard, slow_sel_soft) + + slow_gate = evaluate_gate( + candidate_skill=slow_candidate, + cand_hard=slow_sel_hard, + current_skill=current_skill, + current_score=current_score, + best_skill=best_skill, + best_score=best_score, + best_step=best_step, + global_step=global_step, + ) + slow_result["selection_hard"] = slow_sel_hard + slow_result["selection_soft"] = slow_sel_soft + slow_result["action"] = slow_gate.action + prev_current = current_score + prev_best = best_score + current_skill = slow_gate.current_skill + current_score = slow_gate.current_score + best_skill = slow_gate.best_skill + best_score = slow_gate.best_score + best_step = slow_gate.best_step + if slow_gate.action in {"accept", "accept_new_best"}: + current_origin = f"slow_update_epoch_{epoch:02d}" + if slow_gate.action == "accept_new_best": + best_origin = current_origin + print( + f" [slow gate] ACCEPT (new best) " + f"hard={slow_sel_hard:.4f} > prev best {prev_best:.4f}" + ) + elif slow_gate.action == "accept": + print( + f" [slow gate] ACCEPT " + f"hard={slow_sel_hard:.4f} > current={prev_current:.4f}" + ) + else: + print( + f" [slow gate] REJECT " + f"hard={slow_sel_hard:.4f} <= current={current_score:.4f}" + ) + + print( + f" [slow update] guidance written " + f"({len(slow_result['slow_update_content'])} chars), " + f"{slow_time}s" + ) + else: + slow_result = slow_result or {} + slow_result["action"] = "no_content" + slow_result["time_s"] = slow_time + print( + f" [slow update] no guidance produced, " + f"{slow_time}s" + ) + + # 5. Save + with open(slow_done_path, "w") as f: + json.dump(slow_result, f, indent=2, ensure_ascii=False) + _save_skill(out_root, global_step, current_skill) + with open(os.path.join(out_root, "best_skill.md"), "w") as f: + f.write(best_skill) + _persist_runtime_state(global_step) + + print( + f"\n [SLOW UPDATE epoch {epoch} done] " + f"current={current_score:.4f} best={best_score:.4f}" + ) + + # ── META SKILL (end of epoch, teacher-side memory) ───────── + use_meta_skill = cfg.get("use_meta_skill", False) + if use_meta_skill: + meta_skill_dir = os.path.join(out_root, "meta_skill", f"epoch_{epoch:02d}") + meta_skill_done_path = os.path.join(meta_skill_dir, "meta_skill_result.json") + os.makedirs(meta_skill_dir, exist_ok=True) + + if os.path.exists(meta_skill_done_path): + print(f"\n [META SKILL epoch {epoch}] resumed — already done") + elif epoch == 1: + with open(meta_skill_done_path, "w") as f: + json.dump( + {"action": "skip_first_epoch", "epoch": epoch}, + f, indent=2, ensure_ascii=False, + ) + print(f"\n [META SKILL epoch {epoch}] skipped — first epoch") + else: + print( + f"\n {'='*60}\n" + f" META SKILL — Epoch {epoch} " + f"(teacher memory from epoch {epoch-1} vs {epoch})\n" + f" {'='*60}" + ) + + prev_epoch_records = [h for h in history if h.get("epoch") == epoch - 1] + prev_epoch_last_step = prev_epoch_records[-1]["step"] + prev_skill = _load_skill(out_root, prev_epoch_last_step) + prev_meta_skill = _load_meta_skill_content(out_root, epoch - 1) + + if epoch_comparison_pairs is None: + meta_n = cfg.get("slow_update_samples", 20) + meta_seed = seed + epoch * 2000 + if dataloader is not None: + meta_batch = dataloader.build_train_batch( + batch_size=meta_n, + seed=meta_seed, + out_root=out_root, + ) + meta_env = adapter.build_env_from_batch( + meta_batch, out_root=out_root, + ) + else: + meta_env = adapter.build_train_env( + batch_size=meta_n, + seed=meta_seed, + out_root=out_root, + ) + meta_items = list(meta_env) if hasattr(meta_env, "__iter__") else meta_env + prev_rollout_dir = os.path.join(meta_skill_dir, "rollout_prev") + curr_rollout_dir = os.path.join(meta_skill_dir, "rollout_curr") + results_prev = adapter.rollout(meta_env, prev_skill, prev_rollout_dir) + results_curr = adapter.rollout(meta_env, epoch_last_step_skill, curr_rollout_dir) + epoch_comparison_pairs, all_meta_comparison_pairs = _build_longitudinal_pairs( + adapter=adapter, + dataloader=dataloader, + prev_skill=prev_skill, + curr_skill=epoch_last_step_skill, + initial_items=meta_items, + initial_prev_results=results_prev, + initial_curr_results=results_curr, + prev_rollout_dir=prev_rollout_dir, + curr_rollout_dir=curr_rollout_dir, + policy=longitudinal_pair_policy, + target_n=meta_n, + seed=meta_seed, + out_root=out_root, + ) + if all_meta_comparison_pairs is not epoch_comparison_pairs: + save_comparison_pairs( + all_meta_comparison_pairs, + os.path.join(meta_skill_dir, "comparison_pairs_all.json"), + ) + save_comparison_pairs( + epoch_comparison_pairs, + os.path.join(meta_skill_dir, "comparison_pairs.json"), + ) + meta_counts = _pair_category_counts(epoch_comparison_pairs) + print( + f" [meta skill] comparison: " + f"regressed={meta_counts.get('regressed', 0)} " + f"improved={meta_counts.get('improved', 0)} " + f"persistent_fail={meta_counts.get('persistent_fail', 0)} " + f"stable_success={meta_counts.get('stable_success', 0)} " + f"policy={longitudinal_pair_policy} " + f"kept={len(epoch_comparison_pairs)}/{len(all_meta_comparison_pairs)}" + ) + + t_meta_skill = time.time() + meta_skill_result = run_meta_skill( + prev_skill=prev_skill, + curr_skill=epoch_last_step_skill, + comparison_pairs=epoch_comparison_pairs or [], + prev_meta_skill_content=prev_meta_skill, + ) + meta_skill_time = round(time.time() - t_meta_skill, 1) + + if meta_skill_result and meta_skill_result.get("meta_skill_content"): + meta_skill_result["time_s"] = meta_skill_time + meta_skill_result["action"] = "write_meta_skill" + print( + f" [meta skill] memory written " + f"({len(meta_skill_result['meta_skill_content'])} chars), " + f"{meta_skill_time}s" + ) + else: + meta_skill_result = meta_skill_result or {} + meta_skill_result["time_s"] = meta_skill_time + meta_skill_result["action"] = "no_content" + print(f" [meta skill] no memory produced, {meta_skill_time}s") + + with open(meta_skill_done_path, "w") as f: + json.dump(meta_skill_result, f, indent=2, ensure_ascii=False) + + # ── META-REFLECT (end of epoch) ───────────────────────────── + use_meta = cfg.get("use_meta_reflect", False) + if use_meta: + # Collect this epoch's step records from history + epoch_records = [ + h for h in history if h.get("epoch") == epoch + ] + if epoch_records: + meta_step_tag = f"meta_epoch_{epoch}" + meta_dir = os.path.join(out_root, "meta_reflect", f"epoch_{epoch:02d}") + meta_done_path = os.path.join(meta_dir, "meta_result.json") + + # Resume support: skip if already done + if os.path.exists(meta_done_path): + with open(meta_done_path) as f: + meta_result = json.load(f) + meta_summary = meta_result.get("meta_summary", "") + meta_action = meta_result.get("action", "unknown") + print( + f"\n [META-REFLECT epoch {epoch}] " + f"resumed — {meta_action}" + ) + else: + os.makedirs(meta_dir, exist_ok=True) + print( + f"\n {'='*60}\n" + f" META-REFLECT — Epoch {epoch} " + f"({len(epoch_records)} steps)\n" + f" {'='*60}" + ) + + meta_edit_budget = cfg.get("meta_edit_budget", 4) + + # Build epoch history text + epoch_history_text = build_epoch_history( + epoch_records, out_root, + update_mode=update_mode, + ) + + # Load previous meta summary + prev_meta_path = os.path.join( + out_root, "meta_reflect", + f"epoch_{epoch - 1:02d}", "meta_result.json", + ) + prev_meta_summary = "" + if os.path.exists(prev_meta_path): + try: + with open(prev_meta_path) as f: + prev = json.load(f) + prev_meta_summary = prev.get("meta_summary", "") + except Exception: + pass + + # Get env-specific meta prompt if available + meta_system = adapter.get_meta_reflect_prompt() \ + if hasattr(adapter, "get_meta_reflect_prompt") else None + + # Run meta-reflect + t_meta = time.time() + meta_result = run_meta_reflect( + skill_content=current_skill, + epoch_history_text=epoch_history_text, + prev_meta_summary=prev_meta_summary, + meta_edit_budget=meta_edit_budget, + system_prompt=meta_system, + update_mode=update_mode, + ) + meta_time = round(time.time() - t_meta, 1) + + meta_items = get_payload_items(meta_result.get("patch", {}) if meta_result else {}, update_mode) + if meta_result and meta_items: + for item in meta_items: + item.setdefault("update_origin", "meta_reflect_momentum") + item.setdefault( + "update_target", + "Consolidate epoch-level accepted/rejected edit patterns.", + ) + meta_summary = meta_result.get("meta_summary", "") + print( + f" [meta-reflect] " + f"{len(meta_items)} {payload_label(update_mode)} proposed, " + f"{meta_time}s" + ) + + meta_rewrite_result = None + if update_mode == "rewrite_from_suggestions": + meta_rewrite_result = rewrite_skill_from_suggestions( + current_skill, + meta_result["patch"], + env=cfg.get("env"), + reasoning_effort=rewrite_reasoning_effort, + max_completion_tokens=rewrite_max_completion_tokens, + ) + if meta_rewrite_result and meta_rewrite_result.get("new_skill"): + meta_candidate = meta_rewrite_result["new_skill"] + meta_apply_report = [] + else: + meta_candidate = current_skill + meta_apply_report = [] + else: + meta_candidate, meta_apply_report = apply_patch_with_report( + current_skill, meta_result["patch"], + ) + meta_cand_hash = skill_hash(meta_candidate) + + # Save meta candidate + with open(os.path.join(meta_dir, "meta_candidate.md"), "w") as f: + f.write(meta_candidate) + with open(os.path.join(meta_dir, "meta_patch.json"), "w") as f: + json.dump(meta_result, f, indent=2, ensure_ascii=False) + if meta_apply_report: + with open(os.path.join(meta_dir, "meta_edit_apply_report.json"), "w") as f: + json.dump(meta_apply_report, f, indent=2, ensure_ascii=False) + if meta_rewrite_result: + with open(os.path.join(meta_dir, "meta_rewrite_result.json"), "w") as f: + json.dump(meta_rewrite_result, f, indent=2, ensure_ascii=False) + meta_result["rewrite_change_summary"] = meta_rewrite_result.get("change_summary", []) + + if update_mode == "rewrite_from_suggestions" and meta_rewrite_result is None: + meta_action = "skip_no_rewrite" + meta_result["action"] = meta_action + meta_result["meta_summary"] = meta_summary + meta_result["time_s"] = meta_time + print( + " [meta-reflect] no usable rewrite generated — " + f"skill unchanged, {meta_time}s" + ) + else: + # Gate: evaluate meta candidate + if meta_cand_hash in sel_cache: + meta_hard, meta_soft = sel_cache[meta_cand_hash] + print( + f" [meta-gate] " + f"cache hit: hard={meta_hard:.4f}" + ) + else: + sel_env, _ = _build_eval_env( + split="valid_seen", + env_num=cfg["sel_env_num"], + seed=seed, + ) + meta_eval_dir = os.path.join(meta_dir, "selection_eval") + meta_eval_results = adapter.rollout( + sel_env, meta_candidate, meta_eval_dir, + ) + meta_hard, meta_soft = compute_score(meta_eval_results) + sel_cache[meta_cand_hash] = (meta_hard, meta_soft) + + meta_gate = evaluate_gate( + candidate_skill=meta_candidate, + cand_hard=meta_hard, + current_skill=current_skill, + current_score=current_score, + best_skill=best_skill, + best_score=best_score, + best_step=best_step, + global_step=global_step, + ) + meta_action = meta_gate.action + prev_score = current_score + current_skill = meta_gate.current_skill + current_score = meta_gate.current_score + best_skill = meta_gate.best_skill + best_score = meta_gate.best_score + best_step = meta_gate.best_step + if meta_gate.action in {"accept", "accept_new_best"}: + current_origin = f"meta_reflect_epoch_{epoch:02d}" + if meta_gate.action == "accept_new_best": + best_origin = current_origin + if meta_gate.action == "accept_new_best": + print( + f" [meta-gate] ACCEPT (new best) " + f"hard={meta_hard:.4f} > " + f"prev best {prev_score:.4f}" + ) + elif meta_gate.action == "accept": + print( + f" [meta-gate] ACCEPT " + f"hard={meta_hard:.4f} > " + f"current={prev_score:.4f}" + ) + else: + print( + f" [meta-gate] REJECT " + f"hard={meta_hard:.4f} <= " + f"current={current_score:.4f}" + ) + + # Save meta result with gate outcome + meta_result["action"] = meta_action + meta_result["gate_score"] = meta_hard + meta_result["time_s"] = meta_time + meta_result["update_origin"] = "meta_reflect_momentum" + meta_result["update_target"] = ( + "Consolidate epoch-level editing directions that helped or hurt." + ) + else: + meta_summary = meta_result.get("meta_summary", "") if meta_result else "" + meta_action = f"skip_no_{payload_label(update_mode)}" + if meta_result is None: + meta_result = {} + meta_result["action"] = meta_action + meta_result["meta_summary"] = meta_summary + meta_result["time_s"] = meta_time + print( + f" [meta-reflect] no {payload_label(update_mode)} proposed — " + f"skill unchanged, {meta_time}s" + ) + + # Persist + with open(meta_done_path, "w") as f: + json.dump(meta_result, f, indent=2, ensure_ascii=False) + + # Save updated skill after meta-reflect + _save_skill(out_root, global_step, current_skill) + with open(os.path.join(out_root, "best_skill.md"), "w") as f: + f.write(best_skill) + _persist_runtime_state(global_step) + + print( + f"\n [META-REFLECT epoch {epoch} done] " + f"action={meta_action} " + f"current={current_score:.4f} " + f"best={best_score:.4f}" + ) + + # ── Save best skill ────────────────────────────────────────────── + with open(os.path.join(out_root, "best_skill.md"), "w") as f: + f.write(best_skill) + _persist_runtime_state(global_step) + print( + f"\n [done] best skill from step {best_step}, " + f"score={best_score:.4f}" + ) + + # ── Final test evaluation (valid_unseen) ───────────────────────── + baseline_test_hard = None + baseline_test_soft = None + test_hard = None + test_soft = None + + if cfg["eval_test"]: + task_types = adapter.get_task_types() + + # Baseline: S_0 on test set (valid_unseen) + print(f"\n{'='*60}") + print(" BASELINE TEST — evaluate initial skill on Test set (valid_unseen)") + print(f"{'='*60}") + test_env, test_n = _build_eval_env( + split="valid_unseen", + env_num=cfg["test_env_num"], + seed=seed, + ) + print(f" Test items: {test_n}") + baseline_test_dir = os.path.join(out_root, "test_eval_baseline") + baseline_test_results = adapter.rollout(test_env, skill_init, baseline_test_dir) + baseline_test_hard, baseline_test_soft = compute_score(baseline_test_results) + baseline_buckets = _compute_task_type_buckets(baseline_test_results, task_types) + print("\n === Baseline Test Results (S_0) ===") + for task_type in task_types + ["overall"]: + b = baseline_buckets.get(task_type, {"total": 0, "hard": 0}) + t = max(b["total"], 1) + print( + f" {task_type:<40s}: " + f"hard={b['hard']}/{b['total']}={b['hard']/t:.4f}" + ) + with open(os.path.join(baseline_test_dir, "summary.json"), "w") as f: + json.dump( + { + k: { + "total": b["total"], + "hard_acc": b["hard"] / max(b["total"], 1), + } + for k, b in baseline_buckets.items() + }, + f, indent=2, ensure_ascii=False, + ) + + # Best skill on test set + print(f"\n{'='*60}") + print(" BEST SKILL TEST — evaluate best skill on Test set (valid_unseen)") + print(f"{'='*60}") + test_env2, test_n2 = _build_eval_env( + split="valid_unseen", + env_num=cfg["test_env_num"], + seed=seed, + ) + print(f" Test items: {test_n2}") + test_dir = os.path.join(out_root, "test_eval") + test_results = adapter.rollout(test_env2, best_skill, test_dir) + test_hard, test_soft = compute_score(test_results) + best_buckets = _compute_task_type_buckets(test_results, task_types) + print("\n === Best Skill Test Results ===") + for task_type in task_types + ["overall"]: + b = best_buckets.get(task_type, {"total": 0, "hard": 0}) + t = max(b["total"], 1) + print( + f" {task_type:<40s}: " + f"hard={b['hard']}/{b['total']}={b['hard']/t:.4f}" + ) + with open(os.path.join(test_dir, "summary.json"), "w") as f: + json.dump( + { + k: { + "total": b["total"], + "hard_acc": b["hard"] / max(b["total"], 1), + } + for k, b in best_buckets.items() + }, + f, indent=2, ensure_ascii=False, + ) + + # Comparison + delta_hard = (test_hard or 0) - (baseline_test_hard or 0) + print(f"\n === Improvement (best vs baseline) ===") + print( + f" hard: {baseline_test_hard:.4f} -> {test_hard:.4f} " + f"(delta={delta_hard:+.4f})" + ) + + # ── Global summary ─────────────────────────────────────────────── + total_wall = time.time() - t_loop_start + n_accept = sum(1 for h in history if "accept" in h.get("action", "")) + n_reject = sum(1 for h in history if h.get("action") == "reject") + n_skip = sum(1 for h in history if h.get("action") == "skip_no_patches") + + token_summary = get_token_summary() + + # Epoch-level statistics + epoch_stats = [] + for e in range(1, num_epochs + 1): + epoch_records = [h for h in history if h.get("epoch") == e] + if epoch_records: + epoch_stats.append({ + "epoch": e, + "steps": [h["step"] for h in epoch_records], + "accepts": sum(1 for h in epoch_records if "accept" in h.get("action", "")), + "rejects": sum(1 for h in epoch_records if h.get("action") == "reject"), + "skips": sum(1 for h in epoch_records if h.get("action") == "skip_no_patches"), + "best_score_at_epoch_end": epoch_records[-1].get("best_score", 0.0), + "current_score_at_epoch_end": epoch_records[-1].get("current_score", 0.0), + }) + + summary = { + "version": "skillopt-0.1.0", + "config": _redact_cfg(cfg), + "baseline_selection_hard": sel_cache.get( + skill_hash(skill_init), (None, None), + )[0], + "best_selection_hard": best_score, + "best_step": best_step, + "current_origin": current_origin, + "best_origin": best_origin, + "total_steps": len(history), + "total_accepts": n_accept, + "total_rejects": n_reject, + "total_skips": n_skip, + "epoch_stats": epoch_stats, + "baseline_test_hard": baseline_test_hard, + "baseline_test_soft": baseline_test_soft, + "test_hard": test_hard, + "test_soft": test_soft, + "test_delta_hard": ( + (test_hard or 0) - (baseline_test_hard or 0) + if test_hard is not None + else None + ), + "total_wall_time_s": round(total_wall, 1), + "token_summary": token_summary, + } + with open(os.path.join(out_root, "summary.json"), "w") as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + + print(f"\n{'='*60}") + print(" Final Summary") + print(f"{'='*60}") + print( + f" steps={len(history)} accept={n_accept} " + f"reject={n_reject} skip={n_skip}" + ) + print(f" best_score={best_score:.4f} (step {best_step}) wall={total_wall:.0f}s") + if epoch_stats: + for es in epoch_stats: + print( + f" epoch {es['epoch']}: accept={es['accepts']} reject={es['rejects']} " + f"best={es['best_score_at_epoch_end']:.4f}" + ) + if test_hard is not None: + print(f" test_hard={test_hard:.4f} test_soft={test_soft:.4f}") + if token_summary.get("_total"): + t = token_summary["_total"] + print( + f" total tokens: {t['total_tokens']:,} " + f"(prompt={t['prompt_tokens']:,} " + f"completion={t['completion_tokens']:,} " + f"calls={t['calls']})" + ) + + return summary diff --git a/skillopt/envs/__init__.py b/skillopt/envs/__init__.py new file mode 100644 index 0000000..ecd0aaa --- /dev/null +++ b/skillopt/envs/__init__.py @@ -0,0 +1 @@ +"""ReflACT environment adapters.""" diff --git a/skillopt/envs/alfworld/__init__.py b/skillopt/envs/alfworld/__init__.py new file mode 100644 index 0000000..e9a28ff --- /dev/null +++ b/skillopt/envs/alfworld/__init__.py @@ -0,0 +1,5 @@ +"""ALFWorld environment adapter for ReflACT.""" + +from skillopt.envs.alfworld.adapter import ALFWorldAdapter + +__all__ = ["ALFWorldAdapter"] diff --git a/skillopt/envs/alfworld/adapter.py b/skillopt/envs/alfworld/adapter.py new file mode 100644 index 0000000..9fd0909 --- /dev/null +++ b/skillopt/envs/alfworld/adapter.py @@ -0,0 +1,585 @@ +"""ALFWorld environment adapter for ReflACT. + +Connects the ReflACT training loop to ALFWorld by implementing +:class:`~skillopt.envs.base.EnvAdapter`. +""" +from __future__ import annotations + +from dataclasses import dataclass +import json +import os + +from skillopt.gradient.deep_probe import generate_deep_probe_instruction +from skillopt.datasets.base import BatchSpec +from skillopt.envs.base import EnvAdapter +from skillopt.envs.alfworld.dataloader import ALFWorldDataLoader +from skillopt.envs.alfworld.rollout import ( + build_alfworld_env, + run_alfworld_batch, + TASKS, +) +from skillopt.gradient.reflect import run_minibatch_reflect +from skillopt.utils import compute_score + + +@dataclass(frozen=True) +class ALFWorldBatchRun: + """Lazy ALFWorld batch description. + + The adapter materializes this in rollout chunks so a large evaluation set + does not keep every ALFWorld simulator open at once. + """ + + env_num: int + eval_dataset: str + seed: int + is_train: bool + workers: int + specific_gamefiles: list[str] | None = None + result_ids: list[str] | None = None + items: list[dict] | None = None + + def __iter__(self): + return iter(self.items or []) + + def __len__(self) -> int: + return int(self.env_num or 0) + + +class ALFWorldAdapter(EnvAdapter): + """ALFWorld environment adapter. + + Parameters + ---------- + max_steps : int + Maximum steps per ALFWorld episode (default 50). + max_api_workers : int + Maximum concurrent API calls during rollout (default 8). + analyst_workers : int + Parallel workers for analyst stage (default 16). + failure_only : bool + If True, only run error analyst (skip success analyst). + minibatch_size : int + Trajectories per analyst group, M (default 8). + edit_budget : int + Maximum edits per minibatch, L (default 4). + """ + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "split_dir", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + seed: int = 42, + limit: int = 0, + train_size: int = 0, + max_steps: int = 50, + workers: int = 8, + max_api_workers: int = 8, + analyst_workers: int = 16, + failure_only: bool = False, + minibatch_size: int = 8, + edit_budget: int = 4, + use_deep_reflect: bool = False, + deep_reflect_failures: int = 4, + deep_reflect_successes: int = 2, + ) -> None: + self.max_steps = max_steps + self.workers = max(int(workers or 1), 1) + self.max_api_workers = max_api_workers + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.use_deep_reflect = use_deep_reflect + self.deep_reflect_failures = deep_reflect_failures + self.deep_reflect_successes = deep_reflect_successes + self.dataloader = ALFWorldDataLoader( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + train_size=train_size, + ) + self._traj_cache: dict[str, dict | None] = {} + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + self.dataloader.setup(cfg) + + def _load_traj_data(self, item: dict) -> dict | None: + gamefile = str(item.get("gamefile") or "").strip() + if not gamefile: + return None + if gamefile in self._traj_cache: + return self._traj_cache[gamefile] + + traj_path = os.path.join(os.path.dirname(gamefile), "traj_data.json") + try: + with open(traj_path, encoding="utf-8") as f: + data = json.load(f) + except Exception: + data = None + self._traj_cache[gamefile] = data + return data + + @staticmethod + def _unique_lines(values: list[str], *, limit: int = 0) -> list[str]: + lines: list[str] = [] + seen: set[str] = set() + for raw in values: + line = str(raw or "").strip() + if not line or line in seen: + continue + seen.add(line) + lines.append(line) + if limit > 0 and len(lines) >= limit: + break + return lines + + @staticmethod + def _format_high_pddl(high_pddl: list[dict]) -> list[str]: + steps: list[str] = [] + for idx, step in enumerate(high_pddl or [], start=1): + discrete = step.get("discrete_action") or {} + action = str(discrete.get("action") or "").strip() + args = [str(arg).strip() for arg in (discrete.get("args") or []) if str(arg).strip()] + if action and args: + text = f"{action}({', '.join(args)})" + elif action: + text = action + else: + planner_action = step.get("planner_action") or {} + text = str(planner_action.get("action") or "").strip() + if text: + steps.append(f"{idx}. {text}") + return steps + + def _build_reference_bundle(self, item: dict) -> dict: + data = self._load_traj_data(item) + if not data: + return {} + + anns = ((data.get("turk_annotations") or {}).get("anns") or []) + task_descs = self._unique_lines( + [ann.get("task_desc", "") for ann in anns], + limit=3, + ) + high_descs = self._unique_lines( + [step for ann in anns for step in (ann.get("high_descs") or [])], + limit=12, + ) + pddl_params = { + key: value + for key, value in (data.get("pddl_params") or {}).items() + if value not in ("", None, [], {}) + } + scene = data.get("scene") or {} + scene_summary = { + key: scene.get(key) + for key in ("floor_plan", "scene_num", "dirty_and_empty") + if scene.get(key) not in ("", None, [], {}) + } + high_pddl = self._format_high_pddl((data.get("plan") or {}).get("high_pddl") or []) + task_type = str(data.get("task_type") or item.get("task_type") or "").strip() + return { + "task_type": task_type, + "task_descs": task_descs, + "high_descs": high_descs, + "pddl_params": pddl_params, + "high_pddl": high_pddl, + "scene_summary": scene_summary, + } + + def build_reference_text(self, item: dict) -> str: + bundle = self._build_reference_bundle(item) + if not bundle: + return "" + + parts: list[str] = [] + if bundle["task_type"]: + parts.append(f"## Reference Task Type\n{bundle['task_type']}") + if bundle["task_descs"]: + parts.append( + "## Reference Human Task Descriptions\n" + + "\n".join(f"- {line}" for line in bundle["task_descs"]) + ) + if bundle["high_descs"]: + parts.append( + "## Reference Human High-Level Steps\n" + + "\n".join(f"{idx}. {line}" for idx, line in enumerate(bundle["high_descs"], start=1)) + ) + if bundle["pddl_params"]: + parts.append( + "## Reference PDDL Params\n" + + "\n".join(f"- {key}: {value}" for key, value in bundle["pddl_params"].items()) + ) + if bundle["high_pddl"]: + parts.append( + "## Reference Planner High-Level Plan\n" + "\n".join(bundle["high_pddl"]) + ) + if bundle["scene_summary"]: + parts.append( + "## Reference Scene Summary\n" + + "\n".join(f"- {key}: {value}" for key, value in bundle["scene_summary"].items()) + ) + return "\n\n".join(parts) + + def get_reference_metadata(self, item: dict) -> dict: + bundle = self._build_reference_bundle(item) + if not bundle: + return {"fields": [], "preview": ""} + + fields: list[str] = [] + previews: list[str] = [] + if bundle["task_type"]: + fields.append("task_type") + previews.append(f"[task_type] {bundle['task_type']}") + if bundle["task_descs"]: + fields.append("task_desc") + previews.append("[task_desc]\n" + "\n".join(bundle["task_descs"][:2])) + if bundle["high_descs"]: + fields.append("high_descs") + previews.append("[high_descs]\n" + "\n".join(bundle["high_descs"][:3])) + if bundle["pddl_params"]: + fields.append("pddl_params") + previews.append( + "[pddl_params]\n" + + "\n".join( + f"{key}: {value}" for key, value in list(bundle["pddl_params"].items())[:4] + ) + ) + if bundle["high_pddl"]: + fields.append("plan.high_pddl") + previews.append("[plan.high_pddl]\n" + "\n".join(bundle["high_pddl"][:3])) + if bundle["scene_summary"]: + fields.append("scene") + previews.append( + "[scene]\n" + + "\n".join( + f"{key}: {value}" for key, value in bundle["scene_summary"].items() + ) + ) + return { + "fields": fields, + "preview": "\n\n".join(previews)[:600], + } + + @staticmethod + def _infer_dataset_from_gamefile(gamefile: str) -> tuple[str, bool]: + path = str(gamefile or "") + if "/valid_seen/" in path: + return "eval_in_distribution", False + if "/valid_unseen/" in path: + return "eval_out_of_distribution", False + return "train", True + + def get_dataloader(self): + return self.dataloader + + def _comparison_items(self, items: list[dict]) -> list[dict]: + enriched: list[dict] = [] + for item in items: + row = dict(item) + bundle = self._build_reference_bundle(row) + if bundle.get("task_descs"): + row["task_description"] = bundle["task_descs"][0] + elif bundle.get("task_type"): + row["task_description"] = bundle["task_type"] + enriched.append(row) + return enriched + + def requires_ray(self) -> bool: + return False + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + gamefiles = list(batch.metadata.get("gamefiles") or []) + result_ids = list(batch.metadata.get("result_ids") or []) + items = self._comparison_items(list(batch.payload or [])) + return ALFWorldBatchRun( + env_num=batch.batch_size, + eval_dataset=batch.metadata.get("eval_dataset", batch.split), + seed=batch.seed, + is_train=batch.metadata.get("is_train", batch.phase == "train"), + specific_gamefiles=gamefiles or None, + result_ids=result_ids or None, + items=items, + workers=self.workers, + ) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout( + self, + env_manager, + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict]: + results_path = os.path.join(out_dir, "results.jsonl") + os.makedirs(out_dir, exist_ok=True) + + # Resume support + if os.path.exists(results_path): + existing: list[dict] = [] + with open(results_path) as f: + for line in f: + try: + existing.append(json.loads(line)) + except Exception: + pass + if existing: + return existing + + if isinstance(env_manager, ALFWorldBatchRun): + results = self._run_batch( + env_manager, + skill_content=skill_content, + out_dir=out_dir, + ) + else: + results = run_alfworld_batch( + env_manager=env_manager, + skill_content=skill_content, + max_steps=self.max_steps, + out_root=out_dir, + max_api_workers=self.max_api_workers, + result_ids=getattr(env_manager, "_skillopt_result_ids", None), + ) + + with open(results_path, "w") as f: + for r in results: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + + return results + + @staticmethod + def _close_env(env_manager) -> None: + close = getattr(env_manager, "close", None) + if callable(close): + close() + + def _run_batch( + self, + batch: ALFWorldBatchRun, + skill_content: str, + out_dir: str, + *, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + ) -> list[dict]: + total = int(batch.env_num or 0) + if total <= 0: + return [] + + workers = max(1, min(int(batch.workers or self.workers), total)) + if total > workers: + print( + f" [alfworld rollout] episodes={total} " + f"env_workers={workers} chunks={(total + workers - 1) // workers}" + ) + + all_results: list[dict] = [] + for start in range(0, total, workers): + chunk_size = min(workers, total - start) + chunk_gamefiles = ( + batch.specific_gamefiles[start:start + chunk_size] + if batch.specific_gamefiles + else None + ) + chunk_ids = ( + batch.result_ids[start:start + chunk_size] + if batch.result_ids + else [f"env_{idx:03d}" for idx in range(start, start + chunk_size)] + ) + chunk_env = build_alfworld_env( + env_num=chunk_size, + eval_dataset=batch.eval_dataset, + seed=batch.seed + start, + is_train=batch.is_train, + specific_gamefiles=chunk_gamefiles, + ) + try: + chunk_results = run_alfworld_batch( + env_manager=chunk_env, + skill_content=skill_content, + max_steps=self.max_steps, + out_root=out_dir, + max_api_workers=min(self.max_api_workers, chunk_size), + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + result_ids=chunk_ids, + ) + finally: + self._close_env(chunk_env) + all_results.extend(chunk_results) + return all_results + + def reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + ) + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + if not self.use_deep_reflect: + return [] + + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + selected_items = self.select_representative_items( + results, + results, + n_failures=self.deep_reflect_failures, + n_successes=self.deep_reflect_successes, + seed=random_seed, + ) + if not selected_items: + return [] + + selected_ids = {str(item["id"]) for item in selected_items} + selected_results = [row for row in results if str(row.get("id")) in selected_ids] + selected_examples = self.attach_reference_context(selected_results, selected_items) + + field_counts: dict[str, int] = {} + selected_metadata: list[dict] = [] + for item in selected_items: + meta = self.get_reference_metadata(item) + for field in meta["fields"]: + field_counts[field] = field_counts.get(field, 0) + 1 + selected_metadata.append({ + "id": str(item["id"]), + "task_type": str(item.get("task_type") or "alfworld"), + "gamefile": str(item.get("gamefile") or ""), + "reference_fields": meta["fields"], + "reference_preview": meta["preview"], + }) + + deep_dir = os.path.join(out_dir, "deep_reflect") + rollout_dir = os.path.join(deep_dir, "rollout") + patches_dir = os.path.join(deep_dir, "patches") + os.makedirs(deep_dir, exist_ok=True) + field_summary = ", ".join( + f"{field}({count}/{len(selected_items)})" + for field, count in sorted(field_counts.items()) + ) or "none" + print( + f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} " + f"reference_fields={field_summary}" + ) + probe = generate_deep_probe_instruction( + skill_content=skill_content, + items=selected_examples, + prediction_dir=prediction_dir, + system_prompt=self.get_deep_probe_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + output_requirements=[ + "- Some trajectories may include a hidden Reference block. Use it to target the student's latent subgoal, missing precondition, or next-step intent, but do not reveal or paraphrase that reference to the student.", + "- The instruction must request a brief diagnostic readout inside the existing ... block.", + "- The student must still output exactly one admissible action inside ....", + "- Do not ask for exhaustive inventories, full plans, or long chain-of-thought.", + "- The instruction text should be ready to append directly to the student's prompt.", + ], + ) + if not probe: + return [] + + with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f: + json.dump( + { + **probe, + "reference_summary": { + "selected_count": len(selected_items), + "field_counts": field_counts, + }, + "selected_examples": selected_metadata, + }, + f, + ensure_ascii=False, + indent=2, + ) + + gamefiles = [str(item.get("gamefile") or "") for item in selected_items] + if any(not gamefile for gamefile in gamefiles): + return [] + eval_dataset, is_train = self._infer_dataset_from_gamefile(gamefiles[0]) + deep_env = ALFWorldBatchRun( + env_num=len(selected_items), + eval_dataset=eval_dataset, + seed=random_seed or 42, + is_train=is_train, + specific_gamefiles=gamefiles, + workers=min(self.workers, max(len(selected_items), 1)), + result_ids=[str(item["id"]) for item in selected_items], + ) + deep_results = self._run_batch( + deep_env, + skill_content=skill_content, + out_dir=rollout_dir, + diagnostic_mode=True, + diagnostic_instruction=probe["probe_instruction"], + ) + deep_results = self.attach_reference_context(deep_results, selected_items) + return run_minibatch_reflect( + results=deep_results, + skill_content=skill_content, + prediction_dir=os.path.join(rollout_dir, "predictions"), + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + ) + + def get_task_types(self) -> list[str]: + return list(TASKS) diff --git a/skillopt/envs/alfworld/dataloader.py b/skillopt/envs/alfworld/dataloader.py new file mode 100644 index 0000000..80fcd70 --- /dev/null +++ b/skillopt/envs/alfworld/dataloader.py @@ -0,0 +1,123 @@ +"""ALFWorld task dataloader.""" +from __future__ import annotations + +from skillopt.datasets.base import BatchSpec, SplitDataLoader + + +class ALFWorldDataLoader(SplitDataLoader): + """ALFWorld batch planner. + + In split_dir mode, batches are fixed gamefile items so ablations differ + only in how the same training set is batched. + """ + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "split_dir", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + seed: int = 42, + limit: int = 0, + train_size: int = 0, + **kwargs, + ) -> None: + super().__init__( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + ) + self.train_size_override = int(train_size or 0) + + @staticmethod + def _metadata_for_items(items: list[dict], split: str, phase: str) -> dict: + gamefiles = [str(item.get("gamefile") or "") for item in items] + if any(not gamefile for gamefile in gamefiles): + raise ValueError("ALFWorld split items must contain non-empty gamefile paths.") + eval_dataset = "train" + is_train = phase == "train" + first = gamefiles[0] if gamefiles else "" + if "/valid_seen/" in first: + eval_dataset = "eval_in_distribution" + is_train = False + elif "/valid_unseen/" in first: + eval_dataset = "eval_out_of_distribution" + is_train = False + return { + "eval_dataset": eval_dataset, + "is_train": is_train, + "gamefiles": gamefiles, + "result_ids": [str(item.get("id") or idx) for idx, item in enumerate(items)], + } + + def get_train_size(self) -> int: + if self.train_size_override > 0: + return self.train_size_override + return super().get_train_size() + + def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec: + batch = super().build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + items = list(batch.payload or []) + batch.metadata.update(self._metadata_for_items(items, "train", "train")) + return BatchSpec( + phase="train", + split="train", + seed=seed, + batch_size=len(items), + payload=items, + metadata=batch.metadata, + ) + + def plan_train_epoch( + self, + *, + epoch: int, + steps_per_epoch: int, + accumulation: int, + batch_size: int, + seed: int, + **kwargs, + ) -> list[BatchSpec]: + batches = super().plan_train_epoch( + epoch=epoch, + steps_per_epoch=steps_per_epoch, + accumulation=accumulation, + batch_size=batch_size, + seed=seed, + **kwargs, + ) + for batch in batches: + items = list(batch.payload or []) + batch.metadata.update(self._metadata_for_items(items, "train", "train")) + return batches + + def build_eval_batch( + self, + env_num: int, + split: str, + seed: int, + **kwargs, + ) -> BatchSpec: + batch = super().build_eval_batch( + env_num=env_num, + split=split, + seed=seed, + **kwargs, + ) + items = list(batch.payload or []) + batch.metadata.update(self._metadata_for_items(items, split, "eval")) + return BatchSpec( + phase="eval", + split=split, + seed=seed, + batch_size=len(items), + payload=items, + metadata=batch.metadata, + ) diff --git a/skillopt/envs/alfworld/prompts/analyst_error.md b/skillopt/envs/alfworld/prompts/analyst_error.md new file mode 100644 index 0000000..f464716 --- /dev/null +++ b/skillopt/envs/alfworld/prompts/analyst_error.md @@ -0,0 +1,55 @@ +You are an expert failure-analysis agent for ALFWorld embodied household tasks. + +You will be given MULTIPLE failed agent trajectories from a single minibatch +and the current skill document. +Your job is to identify the most important COMMON failure patterns across +the batch and propose a concise set of skill edits. + +## ALFWorld Task Types +- pick_and_place: Put object in/on a receptacle +- pick_two_obj_and_place: Put two instances of an object in/on a receptacle +- look_at_obj_in_light: Examine an object under a desklamp +- pick_heat_then_place_in_recep: Heat an object and put it in/on a receptacle +- pick_cool_then_place_in_recep: Cool an object and put it in/on a receptacle +- pick_clean_then_place_in_recep: Clean an object and put it in/on a receptacle + +## Failure Type Categories +- **navigation_loop**: the agent revisits the same locations repeatedly without progress +- **missed_object**: the agent fails to pick up a visible/reachable goal object +- **wrong_sequence**: the agent performs actions in the wrong order (e.g., placing before transforming) +- **premature_stop**: the agent stops or gets stuck before completing all goal conditions +- **action_loop**: the agent repeats the same action without advancing +- **appliance_error**: the agent misuses or skips an appliance (microwave, fridge, sink) +- **rule_missing**: the skill lacks a relevant rule for this situation +- **rule_wrong**: an existing skill rule is misleading or incorrect +- **rule_ignored**: the skill has the right rule but the agent did not follow it +- **other**: none of the above + +## Analysis Process +1. Read ALL trajectories in the minibatch. +2. Identify the most prevalent, systematic failure patterns across them. +3. For each pattern, classify its failure type. +4. Propose skill edits that address the COMMON patterns — not individual edge cases. +5. Edits must be generalizable; do not hardcode task-specific values. +6. Only patch gaps in the skill — do not duplicate existing content. + +You will be told the maximum number of edits (the budget L). Produce AT MOST L edits, +focusing on the highest-impact patterns. You may produce fewer if warranted. + +Respond ONLY with a valid JSON object (no markdown fences, no extra text): +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +Only include edits that are needed. "edits" can be an empty list if no patch is warranted. diff --git a/skillopt/envs/alfworld/prompts/analyst_success.md b/skillopt/envs/alfworld/prompts/analyst_success.md new file mode 100644 index 0000000..957d3a8 --- /dev/null +++ b/skillopt/envs/alfworld/prompts/analyst_success.md @@ -0,0 +1,33 @@ +You are an expert success-pattern analyst for AI agents operating in ALFWorld, +a text-based embodied household environment. + +You will be given MULTIPLE successful agent trajectories from a single minibatch +and the current skill document. Your job is to identify generalizable behavior +patterns that are COMMON across the batch and worth encoding in the skill. + +## Rules +- Only propose patches for patterns NOT already covered in the skill. +- Focus on patterns that appear across MULTIPLE trajectories in the batch. +- Be concise. Patterns must generalize beyond specific tasks. +- Prefer reinforcing existing sections over adding new top-level sections. +- If the agents' success involved efficient exploration or smart appliance usage, + consider reinforcing that in the patch. + +You will be told the maximum number of edits (the budget L). Produce AT MOST L edits, +focusing on the most broadly applicable patterns. You may produce fewer if warranted. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +"edits" may be empty if the skill already covers all observed patterns. diff --git a/skillopt/envs/alfworld/prompts/deep_probe.md b/skillopt/envs/alfworld/prompts/deep_probe.md new file mode 100644 index 0000000..c38e94c --- /dev/null +++ b/skillopt/envs/alfworld/prompts/deep_probe.md @@ -0,0 +1,35 @@ +You are an expert diagnostic-probe designer for ALFWorld embodied tasks. + +You will design one short diagnostic instruction to append to the student's prompt +for a handful of representative ALFWorld trajectories. + +The goal is to expose whether the student has the right intermediate subgoal, +object/receptacle state, and next-step intention without substantially changing +the current scaffold. + +## Hard Constraints +1. Do NOT substantially change the student's existing action-selection scaffold. +2. Do NOT prescribe a brand-new planner or long multi-step policy. +3. Do NOT ask for exhaustive search over all objects or all admissible actions. +4. Keep the diagnostic readout brief and place it inside the existing ... block. +5. The student must still output exactly one admissible action inside .... +6. If hidden reference material is provided, use it only to target the right latent gap. +7. Never copy hidden reference content into the student-facing probe. + +## Good Probe Targets +- current subgoal +- target object / target receptacle / target state +- decisive missing precondition +- why one candidate action is better than a tempting alternative +- whether the current step should explore, transform an object, or place it + +## Bad Probe Targets +- a full optimal plan from start to finish +- exhaustive object inventories +- a new theorem-like or planner-like protocol + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "probe_instruction": "" +} diff --git a/skillopt/envs/alfworld/prompts/rollout_no_history.md b/skillopt/envs/alfworld/prompts/rollout_no_history.md new file mode 100644 index 0000000..d1d605b --- /dev/null +++ b/skillopt/envs/alfworld/prompts/rollout_no_history.md @@ -0,0 +1,8 @@ + +You are an expert agent operating in the ALFRED Embodied Environment. +Your current observation is: {current_observation} +Your admissible actions of the current situation are: [{admissible_actions}]. + +Now it's your turn to take an action. +You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags. +Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags. diff --git a/skillopt/envs/alfworld/prompts/rollout_with_history.md b/skillopt/envs/alfworld/prompts/rollout_with_history.md new file mode 100644 index 0000000..f0a635d --- /dev/null +++ b/skillopt/envs/alfworld/prompts/rollout_with_history.md @@ -0,0 +1,9 @@ + +You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description} +Prior to this step, you have already taken {step_count} step(s). Below are the most recent {history_length} observations and the corresponding actions you took: {action_history} +You are now at step {current_step} and your current observation is: {current_observation} +Your admissible actions of the current situation are: [{admissible_actions}]. + +Now it's your turn to take an action. +You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags. +Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags. diff --git a/skillopt/envs/alfworld/prompts/rollout_with_memory.md b/skillopt/envs/alfworld/prompts/rollout_with_memory.md new file mode 100644 index 0000000..c90dc7f --- /dev/null +++ b/skillopt/envs/alfworld/prompts/rollout_with_memory.md @@ -0,0 +1,16 @@ + +You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description} + +## Retrieved Relevant Experience + +{retrieved_memories} + +## Current Progress + +Prior to this step, you have already taken {step_count} step(s). Below are the most recent {history_length} observations and the corresponding actions you took: {action_history} +You are now at step {current_step} and your current observation is: {current_observation} +Your admissible actions of the current situation are: [{admissible_actions}]. + +Now it's your turn to take an action. +You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within tags. +Once you've finished your reasoning, you should choose an admissible action for current step and present it within tags. diff --git a/skillopt/envs/alfworld/reflect.py b/skillopt/envs/alfworld/reflect.py new file mode 100644 index 0000000..a32d989 --- /dev/null +++ b/skillopt/envs/alfworld/reflect.py @@ -0,0 +1,4 @@ +"""ALFWorld Reflect stage. + +Prompts are now loaded from .md files by the base adapter. +""" diff --git a/skillopt/envs/alfworld/rollout.py b/skillopt/envs/alfworld/rollout.py new file mode 100644 index 0000000..50e8fa3 --- /dev/null +++ b/skillopt/envs/alfworld/rollout.py @@ -0,0 +1,359 @@ +"""ALFWorld rollout module for ReflACT. + +Provides: + - build_alfworld_env(): build ALFWorld environment (wraps vendored SkillRL env) + - run_alfworld_batch(): run a batch of ALFWorld episodes in parallel + - TASKS: list of ALFWorld task types +""" +from __future__ import annotations + +import json +import os +import re +import sys +import time +import concurrent.futures +import numpy as np + +from skillopt.model import chat_student + +# ── Constants ───────────────────────────────────────────────────────────────── + +TASKS = [ + "pick_and_place", + "pick_two_obj_and_place", + "look_at_obj_in_light", + "pick_heat_then_place_in_recep", + "pick_cool_then_place_in_recep", + "pick_clean_then_place_in_recep", +] + +# ── Helpers ─────────────────────────────────────────────────────────────────── + + +def _get_task_type(gamefile: str) -> str: + for task in TASKS: + if task in gamefile: + return task + return "other" + + +def _extract_action(model_response: str) -> str | None: + match = re.search(r"(.*?)", model_response, re.DOTALL) + return match.group(1).strip() if match else None + + +def _extract_think(model_response: str) -> str | None: + match = re.search(r"(.*?)", model_response, re.DOTALL) + return match.group(1).strip() if match else None + + +def _build_skill_prompt(skill_content: str) -> str: + """Build the skill section to inject into the agent's system prompt.""" + if not skill_content or not skill_content.strip(): + return "" + return ( + "\n\n## Skill Knowledge\n" + "Below is a skill document with learned strategies. " + "Use these guidelines to inform your decisions:\n\n" + f"{skill_content}\n" + ) + + +def _append_diagnostic_instruction(prompt: str, diagnostic_instruction: str) -> str: + if not diagnostic_instruction or not diagnostic_instruction.strip(): + return prompt + return f"{prompt}\n\n## Training Readout\n{diagnostic_instruction.strip()}\n" + + +# ── Environment builder ────────────────────────────────────────────────────── + + +def build_alfworld_env( + env_num: int, + eval_dataset: str = "eval_out_of_distribution", + seed: int = 42, + is_train: bool = False, + specific_gamefiles: list[str] | None = None, +): + """Build ALFWorld environment manager. + + Args: + env_num: number of parallel environments + eval_dataset: 'eval_in_distribution' or 'eval_out_of_distribution' or train + seed: random seed + is_train: whether to use training set + + Returns: + env_manager: AlfWorldEnvironmentManager instance + """ + from omegaconf import OmegaConf + from functools import partial + + from skillopt.envs.alfworld.vendor.alfworld_envs import build_alfworld_envs + from skillopt.envs.alfworld.vendor.alfworld_projection import alfworld_projection + from skillopt.envs.alfworld.vendor.env_manager import AlfWorldEnvironmentManager + + HERE = os.path.dirname(os.path.abspath(__file__)) + + alf_config_path = os.path.join(HERE, "vendor", "config_tw.yaml") + env_kwargs = {"eval_dataset": eval_dataset} + + envs = build_alfworld_envs( + alf_config_path, + seed=seed, + env_num=env_num, + group_n=1, + is_train=is_train, + env_kwargs=env_kwargs, + resources_per_worker=None, + gamefiles=specific_gamefiles, + ) + + config = OmegaConf.create( + { + "env": { + "history_length": 2, + "env_name": "alfworld/AlfredTWEnv", + } + } + ) + + projection_f = partial(alfworld_projection) + env_manager = AlfWorldEnvironmentManager(envs, projection_f, config) + return env_manager + + +# ── Batch rollout ───────────────────────────────────────────────────────────── + + +def run_alfworld_batch( + env_manager, + skill_content: str, + max_steps: int = 50, + out_root: str = "", + max_api_workers: int = 8, + temperature: float = 0.4, + max_completion_tokens: int = 2048, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + result_ids: list[str] | None = None, +) -> list[dict]: + """Run a batch of ALFWorld episodes. + + Returns a list of result dicts compatible with SkillReflection v2 pipeline: + [ + { + "id": "_", + "hard": 0 or 1, + "soft": 0.0 or 1.0, + "n_turns": , + "fail_reason": "", + "agent_ok": True, + "task_type": "", + "gamefile": "", + "task_description": "", + }, + ... + ] + + Also saves conversation.json per environment in out_root/predictions// + """ + skill_prompt = _build_skill_prompt(skill_content) + + obs, infos = env_manager.reset({}) + env_num = len(obs["text"]) + env_dones = [False] * env_num + overall_success = [False] * env_num + + # Build per-env metadata + env_meta: list[dict] = [] + for i in range(env_num): + gamefile = infos[i].get("extra.gamefile", "") if isinstance(infos[i], dict) else "" + task_type = _get_task_type(gamefile) + # Extract task description from initial observation + task_desc = "" + anchor_text = obs["anchor"][i] if "anchor" in obs else "" + task_start = anchor_text.find("Your task is to: ") + if task_start != -1: + task_desc = anchor_text[task_start + len("Your task is to: "):].strip() + + env_meta.append({ + "gamefile": gamefile, + "task_type": task_type, + "task_description": task_desc, + }) + + # Per-env conversation records + conversations: list[list[dict]] = [[] for _ in range(env_num)] + + for step_idx in range(max_steps): + if all(env_dones): + break + + active_indices = [i for i in range(env_num) if not env_dones[i]] + + # Build prompts with skill injection + prompts: dict[int, str] = {} + for i in active_indices: + prompt = obs["text"][i] + if skill_prompt: + # Inject skill before the action instruction + prompt = skill_prompt + "\n" + prompt + if diagnostic_mode and diagnostic_instruction.strip(): + prompt = _append_diagnostic_instruction(prompt, diagnostic_instruction) + prompts[i] = prompt + + # Call API in parallel + actions = ["None"] * env_num + action_timeout = 180 + + def call_api(idx): + try: + response, _ = chat_student( + system="You are an expert agent operating in the ALFRED Embodied Environment.", + user=prompts[idx], + max_completion_tokens=max_completion_tokens, + retries=5, + stage="rollout", + timeout=120, + ) + response = (response or "").strip() + if not response: + return idx, "empty model responselook" + if _extract_action(response) is None: + return idx, "missing action taglook" + return idx, response + except Exception as e: + return idx, "errorlook" + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_api_workers) + try: + futures = {executor.submit(call_api, i): i for i in active_indices} + started_at = {future: time.time() for future in futures} + pending_futs = set(futures) + while pending_futs: + done, _ = concurrent.futures.wait( + pending_futs, + timeout=5, + return_when=concurrent.futures.FIRST_COMPLETED, + ) + now = time.time() + timed_out = [ + future for future in pending_futs - done + if now - started_at[future] >= action_timeout + ] + for future in done: + pending_futs.remove(future) + try: + idx, response = future.result() + except Exception: # noqa: BLE001 + idx = futures[future] + response = "errorlook" + actions[idx] = response + for future in timed_out: + pending_futs.remove(future) + idx = futures[future] + actions[idx] = "api timeoutlook" + finally: + executor.shutdown(wait=False, cancel_futures=True) + + # Save model responses before stepping + model_responses = {i: actions[i] for i in active_indices} + + # Step environment + obs, rewards, dones, infos = env_manager.step(actions) + + # Record trajectory + for i in active_indices: + step_record = { + "step": step_idx, + "action": _extract_action(model_responses[i]), + "reasoning": _extract_think(model_responses[i]), + "model_response": model_responses[i], + "env_feedback": obs["anchor"][i] if "anchor" in obs else "", + "reward": float(rewards[i]), + "done": bool(dones[i]), + } + conversations[i].append(step_record) + + # Update done status + for i in range(env_num): + if env_dones[i]: + continue + if dones[i]: + env_dones[i] = True + won = bool(infos[i].get("won", False)) + overall_success[i] = won + + # Build results and save conversations + results: list[dict] = [] + pred_dir = os.path.join(out_root, "predictions") if out_root else "" + + for i in range(env_num): + gamefile = env_meta[i]["gamefile"] + task_type = env_meta[i]["task_type"] + task_desc = env_meta[i]["task_description"] + n_turns = len(conversations[i]) + won = overall_success[i] + + # Generate stable task ID from env index and gamefile + task_id = str(result_ids[i]) if result_ids and i < len(result_ids) else f"env_{i:03d}" + + fail_reason = "" + if not won: + if not env_dones[i]: + fail_reason = f"Timeout after {max_steps} steps" + else: + fail_reason = "Episode ended without completing the task" + + result = { + "id": task_id, + "hard": 1 if won else 0, + "soft": 1.0 if won else 0.0, + "n_turns": n_turns, + "fail_reason": fail_reason, + "agent_ok": True, # ALFWorld agent always runs OK (no crash) + "task_type": task_type, + "gamefile": gamefile, + "task_description": task_desc, + "instruction_type": task_type, # for compatibility with v2 pipeline + } + results.append(result) + + # Save conversation + if pred_dir: + conv_dir = os.path.join(pred_dir, task_id) + os.makedirs(conv_dir, exist_ok=True) + with open(os.path.join(conv_dir, "conversation.json"), "w") as f: + json.dump(conversations[i], f, ensure_ascii=False, indent=2) + + return results + + +# ── Item loading (for compatibility with split_three_way) ──────────────────── + + +def load_alfworld_items( + eval_dataset: str, + env_num: int, + seed: int = 42, + is_train: bool = False, +) -> list[dict]: + """Create pseudo-item dicts for ALFWorld environments. + + Since ALFWorld doesn't have a static JSON dataset like SpreadsheetBench, + we create lightweight item dicts that carry enough metadata for the pipeline. + The actual environment is built dynamically. + + Returns: + List of dicts with "id" keys, one per environment slot. + """ + items = [] + for i in range(env_num): + items.append({ + "id": f"env_{i:03d}", + "eval_dataset": eval_dataset, + "env_index": i, + }) + return items diff --git a/skillopt/envs/alfworld/skills/initial.md b/skillopt/envs/alfworld/skills/initial.md new file mode 100644 index 0000000..d19ad02 --- /dev/null +++ b/skillopt/envs/alfworld/skills/initial.md @@ -0,0 +1,45 @@ +# ALFWorld Embodied Agent Skill + +## Overview +This skill guides agents operating in the ALFWorld text-based embodied environment. +The agent must complete household tasks by navigating rooms, interacting with objects, +and using appliances. Actions must be chosen from the admissible action list provided +at each step. + +**Output format**: Always output `...` for reasoning, then `...` for the chosen action. + +--- + +## Task Types + +| Type | Goal | Key Steps | +|------|------|-----------| +| Pick & Place | Put object X in/on receptacle Y | Find X -> take X -> go to Y -> put X in/on Y | +| Pick Two & Place | Put two instances of X in/on Y | Find X1 -> take -> place -> find X2 -> take -> place | +| Examine in Light | Examine object X under desklamp | Find X -> take X -> find desklamp -> use desklamp | +| Clean & Place | Clean object X and put in/on Y | Find X -> take X -> go to sink -> clean X -> go to Y -> put X | +| Heat & Place | Heat object X and put in/on Y | Find X -> take X -> go to microwave -> heat X -> go to Y -> put X | +| Cool & Place | Cool object X and put in/on Y | Find X -> take X -> go to fridge -> cool X -> go to Y -> put X | + +--- + +## General Principles + +1. **Decompose the task**: Parse the goal into ordered sub-goals (locate, acquire, transform, deliver). Complete each before moving to the next. +2. **Systematic exploration**: Search each surface and container exactly once before revisiting. Open closed containers (drawers, cabinets, fridge) before judging them empty. +3. **Grab immediately**: When a required object is visible and reachable, take it right away before moving elsewhere. +4. **Transform before placing**: If the task requires cleaning, heating, or cooling, perform the state change at the appropriate appliance before heading to the final destination. +5. **Direct delivery**: Once holding the transformed (or untransformed) goal object, navigate straight to the target receptacle and place it. +6. **Track progress**: Maintain an internal count of how many objects still need to be found and placed. Only stop searching when the count reaches zero. +7. **Avoid loops**: Never repeat the same action more than twice in a row. If stuck, move to a different unexplored location. +8. **Only choose admissible actions**: Always pick an action from the admissible action list. Do not invent actions. + +--- + +## Common Mistakes to Avoid + +- **Revisiting searched locations**: Keep track of which surfaces/containers have been checked; do not re-examine them. +- **Ignoring visible objects**: If the target object appears in the observation, pick it up immediately. +- **Skipping state changes**: Do not place an object at the destination without first cleaning/heating/cooling it when required. +- **Premature termination**: Do not stop the episode until all goal conditions are verified as met. +- **Action loops**: Repeatedly toggling or examining the same object wastes steps. Move on to new locations instead. diff --git a/skillopt/envs/alfworld/vendor/__init__.py b/skillopt/envs/alfworld/vendor/__init__.py new file mode 100644 index 0000000..93dd8cb --- /dev/null +++ b/skillopt/envs/alfworld/vendor/__init__.py @@ -0,0 +1,9 @@ +"""Vendored ALFWorld environment runtime. + +Minimal subset of SkillRL's agent_system package needed to run +ALFWorld environments with ReflACT. Original source: +https://github.com/NTU-LANTERN/SkillRL (Apache-2.0 License) +""" +from .alfworld_envs import AlfworldEnvs, build_alfworld_envs +from .alfworld_projection import alfworld_projection +from .env_manager import AlfWorldEnvironmentManager diff --git a/skillopt/envs/alfworld/vendor/alfworld_envs.py b/skillopt/envs/alfworld/vendor/alfworld_envs.py new file mode 100644 index 0000000..06b9716 --- /dev/null +++ b/skillopt/envs/alfworld/vendor/alfworld_envs.py @@ -0,0 +1,221 @@ +# Vendored from SkillRL (Apache-2.0 License) +# Original: agent_system/environments/env_package/alfworld/envs.py +# Modified: imports use pip-installed alfworld package instead of vendored copy. + +import os +import multiprocessing as mp +import traceback +import yaml +import gymnasium as gym +import numpy as np + +from alfworld.agents.environment import get_environment + + +def load_config_file(path): + assert os.path.exists(path), f"Invalid config file: {path}" + with open(path) as reader: + config = yaml.safe_load(reader) + return config + + +def compute_reward(info, multi_modal=False): + if multi_modal: + reward = 10.0 * float(info['won']) + float(info['goal_condition_success_rate']) + else: + reward = 10.0 * float(info['won']) + return reward + + +class AlfworldWorker: + """Stateful worker that holds one ALFWorld sub-environment.""" + + def __init__(self, config, seed, base_env, gamefile=None): + if gamefile: + base_env.game_files = [gamefile] + if hasattr(base_env, "num_games"): + base_env.num_games = 1 + self.env = base_env.init_env(batch_size=1) + self.env.seed(seed) + + def step(self, action): + actions = [action] + obs, scores, dones, infos = self.env.step(actions) + infos['observation_text'] = obs + return obs, scores, dones, infos + + def reset(self): + obs, infos = self.env.reset() + infos['observation_text'] = obs + return obs, infos + + +def _worker_loop(cmd_q, result_q, config, seed, is_train, eval_dataset, gamefile): + """Run one ALFWorld environment in a child process.""" + try: + env_type = config['env']['type'] + base_env = get_environment(env_type)( + config, + train_eval='train' if is_train else eval_dataset, + ) + worker = AlfworldWorker(config, seed, base_env, gamefile) + result_q.put((True, "ready")) + except BaseException: + result_q.put((False, traceback.format_exc())) + return + + while True: + cmd, payload = cmd_q.get() + if cmd == "close": + result_q.put((True, None)) + return + try: + if cmd == "reset": + result = worker.reset() + elif cmd == "step": + result = worker.step(payload) + else: + raise ValueError(f"Unknown ALFWorld worker command: {cmd}") + result_q.put((True, result)) + except BaseException: + result_q.put((False, traceback.format_exc())) + + +class _ProcessWorker: + """Small stdlib actor wrapper for one environment process.""" + + def __init__(self, ctx, config, seed, is_train, eval_dataset, gamefile=None): + self.cmd_q = ctx.Queue(maxsize=1) + self.result_q = ctx.Queue(maxsize=1) + self.process = ctx.Process( + target=_worker_loop, + args=(self.cmd_q, self.result_q, config, seed, is_train, eval_dataset, gamefile), + ) + self.process.start() + ok, payload = self.result_q.get() + if not ok: + self.close(kill=True) + raise RuntimeError(f"Failed to start ALFWorld worker:\n{payload}") + + def send(self, cmd, payload=None): + self.cmd_q.put((cmd, payload)) + + def recv(self): + ok, payload = self.result_q.get() + if not ok: + raise RuntimeError(f"ALFWorld worker failed:\n{payload}") + return payload + + def close(self, kill=False): + if self.process.is_alive() and not kill: + try: + self.send("close") + self.recv() + except Exception: + kill = True + if kill and self.process.is_alive(): + self.process.terminate() + self.process.join(timeout=5) + if self.process.is_alive(): + self.process.kill() + self.process.join(timeout=1) + self.cmd_q.close() + self.result_q.close() + + +class AlfworldEnvs(gym.Env): + """Vectorized ALFWorld environment using local process workers.""" + + def __init__(self, alf_config_path, seed, env_num, group_n, + resources_per_worker, is_train=True, env_kwargs=None, gamefiles=None): + super().__init__() + if env_kwargs is None: + env_kwargs = {} + + eval_dataset = env_kwargs.get('eval_dataset', 'eval_in_distribution') + config = load_config_file(alf_config_path) + env_type = config['env']['type'] + self.multi_modal = (env_type == 'AlfredThorEnv') + self.num_processes = env_num * group_n + self.group_n = group_n + self.gamefiles = list(gamefiles or []) + if self.gamefiles and len(self.gamefiles) != self.num_processes: + raise ValueError( + f"Expected {self.num_processes} gamefiles, got {len(self.gamefiles)}" + ) + + start_method = os.environ.get("ALFWORLD_WORKER_START_METHOD") or None + ctx = mp.get_context(start_method) if start_method else mp.get_context() + self.workers = [] + for i in range(self.num_processes): + worker_gamefile = self.gamefiles[i] if self.gamefiles else None + worker = _ProcessWorker( + ctx, + config, + seed + (i // self.group_n), + is_train, + eval_dataset, + worker_gamefile, + ) + self.workers.append(worker) + + self.prev_admissible_commands = [None for _ in range(self.num_processes)] + + def step(self, actions): + assert len(actions) == self.num_processes + + for i, worker in enumerate(self.workers): + worker.send("step", actions[i]) + results = [worker.recv() for worker in self.workers] + + text_obs_list = [] + rewards_list = [] + dones_list = [] + info_list = [] + + for i, (obs, scores, dones, info) in enumerate(results): + for k in info.keys(): + info[k] = info[k][0] + text_obs_list.append(obs[0]) + dones_list.append(dones[0]) + info_list.append(info) + self.prev_admissible_commands[i] = info['admissible_commands'] + rewards_list.append(compute_reward(info, self.multi_modal)) + + image_obs_list = None + return text_obs_list, image_obs_list, rewards_list, dones_list, info_list + + def reset(self): + for worker in self.workers: + worker.send("reset") + results = [worker.recv() for worker in self.workers] + + text_obs_list = [] + info_list = [] + + for i, (obs, info) in enumerate(results): + for k in info.keys(): + info[k] = info[k][0] + text_obs_list.append(obs[0]) + self.prev_admissible_commands[i] = info['admissible_commands'] + info_list.append(info) + + image_obs_list = None + return text_obs_list, image_obs_list, info_list + + @property + def get_admissible_commands(self): + return self.prev_admissible_commands + + def close(self): + for worker in self.workers: + worker.close() + + +def build_alfworld_envs(alf_config_path, seed, env_num, group_n, + resources_per_worker, is_train=True, env_kwargs=None, gamefiles=None): + """Build vectorized ALFWorld environments.""" + return AlfworldEnvs( + alf_config_path, seed, env_num, group_n, + resources_per_worker, is_train, env_kwargs, gamefiles, + ) diff --git a/skillopt/envs/alfworld/vendor/alfworld_projection.py b/skillopt/envs/alfworld/vendor/alfworld_projection.py new file mode 100644 index 0000000..8c499ff --- /dev/null +++ b/skillopt/envs/alfworld/vendor/alfworld_projection.py @@ -0,0 +1,60 @@ +# Vendored from SkillRL (Apache-2.0 License) +# Original: agent_system/environments/env_package/alfworld/projection.py + +from typing import List +import re + + +def alfworld_projection(actions: List[str], action_pools: List[List[str]]): + """Process raw model outputs into valid ALFWorld actions. + + Extracts text from ``...`` tags and validates that + the response also contains ``...`` tags. + + Parameters + ---------- + actions : list[str] + Raw model outputs, one per environment. + action_pools : list[list[str]] + Admissible action lists per environment (unused but kept for API compat). + + Returns + ------- + actions : list[str] + Cleaned action strings. + valids : list[int] + 1 if the action was successfully parsed, 0 otherwise. + """ + valids = [0] * len(actions) + + for i in range(len(actions)): + original_str = actions[i] + actions[i] = actions[i].lower() + + start_tag = "" + end_tag = "" + start_idx = actions[i].find(start_tag) + end_idx = actions[i].find(end_tag) + try: + if start_idx == -1 or end_idx == -1: + actions[i] = actions[i][-30:] + continue + + extracted_action = actions[i][start_idx + len(start_tag):end_idx].strip().lower() + actions[i] = extracted_action + valids[i] = 1 + + except Exception: + actions[i] = actions[i][-30:] + + # Require ... + think_start_idx = original_str.find("") + think_end_idx = original_str.find("") + if think_start_idx == -1 or think_end_idx == -1: + valids[i] = 0 + + # Reject responses containing Chinese characters + if re.search(r'[\u4e00-\u9fff]', original_str): + valids[i] = 0 + + return actions, valids diff --git a/skillopt/envs/alfworld/vendor/alfworld_prompts.py b/skillopt/envs/alfworld/vendor/alfworld_prompts.py new file mode 100644 index 0000000..bb7ec49 --- /dev/null +++ b/skillopt/envs/alfworld/vendor/alfworld_prompts.py @@ -0,0 +1,8 @@ +# Vendored from SkillRL (Apache-2.0 License) +# Original: agent_system/environments/prompts/alfworld.py + +from skillopt.prompts import load_prompt + +ALFWORLD_TEMPLATE_NO_HIS = load_prompt("rollout_no_history", env="alfworld") +ALFWORLD_TEMPLATE = load_prompt("rollout_with_history", env="alfworld") +ALFWORLD_TEMPLATE_WITH_MEMORY = load_prompt("rollout_with_memory", env="alfworld") diff --git a/skillopt/envs/alfworld/vendor/config_tw.yaml b/skillopt/envs/alfworld/vendor/config_tw.yaml new file mode 100644 index 0000000..e9bf169 --- /dev/null +++ b/skillopt/envs/alfworld/vendor/config_tw.yaml @@ -0,0 +1,145 @@ +dataset: + data_path: '$ALFWORLD_DATA/json_2.1.1/train' + eval_id_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_seen' # null/None to disable + eval_ood_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_unseen' # null/None to disable + num_train_games: -1 # max training games (<=0 indicates full dataset) + num_eval_games: -1 # max evaluation games (<=0 indicates full dataset) + +logic: + domain: '$ALFWORLD_DATA/logic/alfred.pddl' # PDDL domain file that defines the world dynamics + grammar: '$ALFWORLD_DATA/logic/alfred.twl2' # Grammar file that defines the text feedbacks + +env: + type: 'AlfredTWEnv' # 'AlfredTWEnv' or 'AlfredThorEnv' or 'AlfredHybrid' + # regen_game_files: False # check if game is solvable by expert and save to game.tw-pddl file + domain_randomization: False # shuffle Textworld print order and object id nums + task_types: [1, 2, 3, 4, 5, 6] # task-type ids: 1 - Pick & Place, 2 - Examine in Light, 3 - Clean & Place, 4 - Heat & Place, 5 - Cool & Place, 6 - Pick Two & Place + expert_timeout_steps: 150 # max steps before timeout for expert to solve the task + expert_type: "handcoded" # 'handcoded' or 'planner'. Note: the planner is very slow for real-time use + goal_desc_human_anns_prob: 0.0 # prob of using human-annotated goal language instead of templated goals (1.0 indicates all human annotations from ALFRED) + + hybrid: + start_eps: 100000 # starting episode of hybrid training, tw-only training upto this point + thor_prob: 0.5 # prob of AlfredThorEnv during hybrid training + eval_mode: "tw" # 'tw' or 'thor' - env used for evaluation during hybrid training + + thor: + screen_width: 300 # width of THOR window + screen_height: 300 # height of THOR window + smooth_nav: False # smooth rotations, looks, and translations during navigation (very slow) + save_frames_to_disk: False # save frame PNGs to disk (useful for making videos) + save_frames_path: './videos/' # path to save frame PNGs + +controller: + type: 'oracle' # 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar' (aka BUTLER) + debug: False + load_receps: True # load receptacle locations from precomputed dict (if available) + +mask_rcnn: + pretrained_model_path: '$ALFWORLD_DATA/detectors/mrcnn.pth' + +general: + random_seed: 42 + use_cuda: True # disable this when running on machine without cuda + visdom: False # plot training/eval curves, run with visdom server + task: 'alfred' + training_method: 'dagger' # 'dqn' or 'dagger' + save_path: './training/' # path to save pytorch models + observation_pool_capacity: 3 # k-size queue, 0 indicates no observation + hide_init_receptacles: False # remove initial observation containing navigable receptacles + + training: + batch_size: 10 + max_episode: 50000 + smoothing_eps: 0.1 + optimizer: + learning_rate: 0.001 + clip_grad_norm: 5 + + evaluate: + run_eval: True + batch_size: 10 + env: + type: "AlfredTWEnv" + + checkpoint: + report_frequency: 1000 # report every N episode + experiment_tag: 'test' # name of experiment + load_pretrained: False # during test, enable this so that the agent load your pretrained model + load_from_tag: 'not loading anything' # name of pre-trained model to load in save_path + + model: + encoder_layers: 1 + decoder_layers: 1 + encoder_conv_num: 5 + block_hidden_dim: 64 + n_heads: 1 + dropout: 0.1 + block_dropout: 0.1 + recurrent: True + +rl: + action_space: "admissible" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'beam_search_choice' or 'exhaustive' (not working) + max_target_length: 20 # max token length for seq2seq generation + beam_width: 10 # 1 means greedy + generate_top_k: 3 + + training: + max_nb_steps_per_episode: 50 # terminate after this many steps + learn_start_from_this_episode: 0 # delay updates until this epsiode + target_net_update_frequency: 500 # sync target net with online net per this many epochs + + replay: + accumulate_reward_from_final: True + count_reward_lambda: 0.0 # 0 to disable + novel_object_reward_lambda: 0.0 # 0 to disable + discount_gamma_game_reward: 0.9 + discount_gamma_count_reward: 0.5 + discount_gamma_novel_object_reward: 0.5 + replay_memory_capacity: 500000 # adjust this depending on your RAM size + replay_memory_priority_fraction: 0.5 + update_per_k_game_steps: 5 + replay_batch_size: 64 + multi_step: 3 + replay_sample_history_length: 4 + replay_sample_update_from: 2 + + epsilon_greedy: + noisy_net: False # if this is true, then epsilon greedy is disabled + epsilon_anneal_episodes: 1000 # -1 if not annealing + epsilon_anneal_from: 0.3 + epsilon_anneal_to: 0.1 + +dagger: + action_space: "generation" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'exhaustive' (not working) + max_target_length: 20 # max token length for seq2seq generation + beam_width: 10 # 1 means greedy + generate_top_k: 5 + unstick_by_beam_search: False # use beam-search for failed actions, set True during evaluation + + training: + max_nb_steps_per_episode: 50 # terminate after this many steps + + fraction_assist: + fraction_assist_anneal_episodes: 50000 + fraction_assist_anneal_from: 1.0 + fraction_assist_anneal_to: 0.01 + + fraction_random: + fraction_random_anneal_episodes: 0 + fraction_random_anneal_from: 0.0 + fraction_random_anneal_to: 0.0 + + replay: + replay_memory_capacity: 500000 + update_per_k_game_steps: 5 + replay_batch_size: 64 + replay_sample_history_length: 4 + replay_sample_update_from: 2 + +vision_dagger: + model_type: "resnet" # 'resnet' (whole image features) or 'maskrcnn_whole' (whole image MaskRCNN feats) or 'maskrcnn' (top k MaskRCNN detection feats) or 'no_vision' (zero vision input) + resnet_fc_dim: 64 + maskrcnn_top_k_boxes: 10 # top k box features + use_exploration_frame_feats: False # append feats from initial exploration (memory intensive!) + sequence_aggregation_method: "average" # 'sum' or 'average' or 'rnn' diff --git a/skillopt/envs/alfworld/vendor/env_base.py b/skillopt/envs/alfworld/vendor/env_base.py new file mode 100644 index 0000000..00affa7 --- /dev/null +++ b/skillopt/envs/alfworld/vendor/env_base.py @@ -0,0 +1,84 @@ +# Vendored from SkillRL (Apache-2.0 License) +# Original: agent_system/environments/base.py +# Trimmed to only include what ALFWorld needs. + +from typing import List, Tuple, Dict, Any +import numpy as np +from collections import defaultdict + + +def to_numpy(data): + """Convert data to numpy array.""" + # Lazy-check for torch.Tensor to avoid hard dependency on torch + _torch_tensor = None + try: + import torch + _torch_tensor = torch.Tensor + except ImportError: + pass + + if _torch_tensor is not None and isinstance(data, _torch_tensor): + data = data.detach().cpu().numpy() + elif isinstance(data, np.ndarray): + pass + elif isinstance(data, (int, float, bool, Tuple, List)): + data = np.array(data) + else: + raise ValueError(f"Unsupported type: {type(data)})") + return data + + +class EnvironmentManagerBase: + """Base class for vectorized environment managers. + + Manages a set of parallel environments, handles action projection, + observation post-processing, and history tracking. + """ + + def __init__(self, envs, projection_f, config): + self.envs = envs + self.projection_f = projection_f + self.config = config + + def reset(self, kwargs) -> Dict[str, Any]: + obs, infos = self.envs.reset() + return {'text': None, 'image': obs, 'anchor': None}, infos + + def step(self, text_actions: List[str]): + actions, valids = self.projection_f(text_actions) + next_obs, rewards, dones, infos = self.envs.step(actions) + + next_observations = { + 'text': None, + 'image': next_obs, + 'anchor': None, + } + for i, info in enumerate(infos): + info['is_action_valid'] = to_numpy(valids[i]) + + rewards = to_numpy(rewards) + dones = to_numpy(dones) + return next_observations, rewards, dones, infos + + def close(self) -> None: + self.envs.close() + + def success_evaluator(self, *args, **kwargs) -> Dict[str, np.ndarray]: + total_infos = kwargs['total_infos'] + total_batch_list = kwargs['total_batch_list'] + batch_size = len(total_batch_list) + + success = defaultdict(list) + for bs in range(batch_size): + self._process_batch(bs, total_batch_list, total_infos, success) + assert len(success['success_rate']) == batch_size + return {key: np.array(value) for key, value in success.items()} + + def _process_batch(self, batch_idx, total_batch_list, total_infos, success): + for i in reversed(range(len(total_batch_list[batch_idx]))): + batch_item = total_batch_list[batch_idx][i] + if batch_item['active_masks']: + info = total_infos[batch_idx][i] + won_value = float(info['won']) + success['success_rate'].append(won_value) + return diff --git a/skillopt/envs/alfworld/vendor/env_manager.py b/skillopt/envs/alfworld/vendor/env_manager.py new file mode 100644 index 0000000..d937e4d --- /dev/null +++ b/skillopt/envs/alfworld/vendor/env_manager.py @@ -0,0 +1,139 @@ +# Vendored from SkillRL (Apache-2.0 License) +# Original: agent_system/environments/env_manager.py +# Trimmed to only include AlfWorldEnvironmentManager and its helpers. + +from typing import List, Dict, Any +from collections import defaultdict +import numpy as np + +from skillopt.envs.alfworld.vendor.env_base import EnvironmentManagerBase, to_numpy +from skillopt.envs.alfworld.vendor.alfworld_prompts import ( + ALFWORLD_TEMPLATE, + ALFWORLD_TEMPLATE_NO_HIS, + ALFWORLD_TEMPLATE_WITH_MEMORY, +) +from skillopt.envs.alfworld.vendor.memory import SimpleMemory + + +def parse_gamefile(infos): + gamefile = [] + for info in infos: + if 'extra.gamefile' in info: + gamefile.append(info['extra.gamefile']) + else: + gamefile.append(None) + return gamefile + + +def set_gamefile(infos, gamefile): + for i in range(len(infos)): + if 'extra.gamefile' in infos[i]: + infos[i]['extra.gamefile'] = gamefile[i] + else: + infos[i]['extra.gamefile'] = None + return infos + + +class AlfWorldEnvironmentManager(EnvironmentManagerBase): + """Manages parallel ALFWorld environments with observation templating.""" + + def __init__(self, envs, projection_f, config): + self.memory = SimpleMemory() + self.retrieval_memory = None + super().__init__(envs, projection_f, config) + + def reset(self, kwargs): + text_obs, image_obs, infos = self.envs.reset() + self.gamefile = parse_gamefile(infos) + self.memory.reset(batch_size=len(text_obs)) + self.tasks = [] + self.pre_text_obs = text_obs + self.extract_task(text_obs) + + full_text_obs = self.build_text_obs(text_obs, self.envs.get_admissible_commands, init=True) + return {'text': full_text_obs, 'image': image_obs, 'anchor': text_obs}, infos + + def step(self, text_actions: List[str]): + actions, valids = self.projection_f(text_actions, self.envs.get_admissible_commands) + text_obs, image_obs, rewards, dones, infos = self.envs.step(actions) + self.memory.store({'text_obs': self.pre_text_obs, 'action': actions}) + self.pre_text_obs = text_obs + + full_text_obs = self.build_text_obs(text_obs, self.envs.get_admissible_commands) + if infos[0].get("extra.gamefile") is None: + infos = set_gamefile(infos, self.gamefile) + + for i, info in enumerate(infos): + info['is_action_valid'] = to_numpy(valids[i]) + + next_observations = {'text': full_text_obs, 'image': image_obs, 'anchor': text_obs} + rewards = to_numpy(rewards) + dones = to_numpy(dones) + return next_observations, rewards, dones, infos + + def extract_task(self, text_obs: List[str]): + for obs in text_obs: + task_start = obs.find('Your task is to: ') + if task_start != -1: + self.tasks.append(obs[task_start + len('Your task is to: '):].strip()) + else: + raise ValueError("Task description not found in text observation.") + + def build_text_obs(self, text_obs: List[str], admissible_actions: List[List[str]], init: bool = False) -> List[str]: + postprocess_text_obs = [] + if not init and self.config.env.history_length > 0: + memory_contexts, valid_lens = self.memory.fetch( + self.config.env.history_length, + obs_key="text_obs", + action_key="action", + ) + + for i in range(len(text_obs)): + reformatted_admissible_actions = "\n ".join( + f"'{s}'" for s in admissible_actions[i] if s != 'help' + ) + + if init or self.config.env.history_length <= 0: + obs = ALFWORLD_TEMPLATE_NO_HIS.format( + current_observation=text_obs[i], + admissible_actions=reformatted_admissible_actions, + ) + else: + obs = ALFWORLD_TEMPLATE.format( + task_description=self.tasks[i], + step_count=len(self.memory[i]), + history_length=valid_lens[i], + action_history=memory_contexts[i], + current_step=len(self.memory[i]) + 1, + current_observation=text_obs[i], + admissible_actions=reformatted_admissible_actions, + ) + postprocess_text_obs.append(obs) + return postprocess_text_obs + + def _process_batch(self, batch_idx, total_batch_list, total_infos, success): + for i in reversed(range(len(total_batch_list[batch_idx]))): + batch_item = total_batch_list[batch_idx][i] + if batch_item['active_masks']: + info = total_infos[batch_idx][i] + won_value = float(info['won']) + success['success_rate'].append(won_value) + + gamefile = info.get("extra.gamefile") + if gamefile: + self._process_gamefile(gamefile, won_value, success) + return + + def _process_gamefile(self, gamefile, won_value, success): + tasks = [ + "pick_and_place", + "pick_two_obj_and_place", + "look_at_obj_in_light", + "pick_heat_then_place_in_recep", + "pick_cool_then_place_in_recep", + "pick_clean_then_place_in_recep", + ] + for task in tasks: + if task in gamefile: + success[f"{task}_success_rate"].append(won_value) + break diff --git a/skillopt/envs/alfworld/vendor/memory.py b/skillopt/envs/alfworld/vendor/memory.py new file mode 100644 index 0000000..045f306 --- /dev/null +++ b/skillopt/envs/alfworld/vendor/memory.py @@ -0,0 +1,87 @@ +# Vendored from SkillRL (Apache-2.0 License) +# Original: agent_system/memory/base.py + agent_system/memory/memory.py +# Merged into a single file for simplicity. + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Tuple + + +class BaseMemory(ABC): + """Base class for memory management.""" + + @abstractmethod + def __len__(self): + pass + + @abstractmethod + def __getitem__(self, idx: int): + pass + + @abstractmethod + def reset(self, batch_size: int): + pass + + @abstractmethod + def store(self, record: Dict[str, List[Any]]): + pass + + @abstractmethod + def fetch(self, step: int): + pass + + +class SimpleMemory(BaseMemory): + """Per-environment history buffer for storing observations and actions.""" + + def __init__(self): + self._data = None + self.keys = None + self.batch_size = 0 + + def __len__(self): + return len(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + def reset(self, batch_size: int): + if self._data is not None: + self._data.clear() + self._data = [[] for _ in range(batch_size)] + self.batch_size = batch_size + self.keys = None + + def store(self, record: Dict[str, List[Any]]): + if self.keys is None: + self.keys = list(record.keys()) + assert self.keys == list(record.keys()) + + for env_idx in range(self.batch_size): + self._data[env_idx].append({k: record[k][env_idx] for k in self.keys}) + + def fetch( + self, + history_length: int, + obs_key: str = "text_obs", + action_key: str = "action", + ) -> Tuple[List[str], List[int]]: + memory_contexts, valid_lengths = [], [] + + for env_idx in range(self.batch_size): + recent = self._data[env_idx][-history_length:] + valid_len = len(recent) + start_idx = len(self._data[env_idx]) - valid_len + + lines = [] + for j, rec in enumerate(recent): + step_num = start_idx + j + 1 + act = rec[action_key] + obs = rec[obs_key] + lines.append( + f"[Observation {step_num}: '{obs}', Action {step_num}: '{act}']" + ) + + memory_contexts.append("\n".join(lines)) + valid_lengths.append(valid_len) + + return memory_contexts, valid_lengths diff --git a/skillopt/envs/babyvision/__init__.py b/skillopt/envs/babyvision/__init__.py new file mode 100644 index 0000000..a8ea43d --- /dev/null +++ b/skillopt/envs/babyvision/__init__.py @@ -0,0 +1 @@ +"""BabyVision environment package for ReflACT.""" diff --git a/skillopt/envs/babyvision/adapter.py b/skillopt/envs/babyvision/adapter.py new file mode 100644 index 0000000..601b736 --- /dev/null +++ b/skillopt/envs/babyvision/adapter.py @@ -0,0 +1,267 @@ +"""BabyVision environment adapter for ReflACT.""" +from __future__ import annotations + +import json +import os + +from skillopt.gradient.deep_probe import generate_deep_probe_instruction +from skillopt.datasets.base import BatchSpec +from skillopt.gradient.reflect import run_minibatch_reflect +from skillopt.envs.base import EnvAdapter +from skillopt.envs.babyvision.dataloader import BabyVisionDataLoader +from skillopt.envs.babyvision.rollout import run_batch +from skillopt.model import get_student_backend + + +class BabyVisionAdapter(EnvAdapter): + """BabyVision adapter.""" + + def build_reference_text(self, item: dict) -> str: + cot = str(item.get("cot") or "").strip() + if not cot: + return "" + return f"## Reference CoT\n{cot}" + + def get_reference_metadata(self, item: dict) -> dict: + cot = str(item.get("cot") or "").strip() + if not cot: + return {"fields": [], "preview": ""} + return { + "fields": ["cot"], + "preview": cot[:400], + } + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + max_turns: int = 1, + workers: int = 32, + analyst_workers: int = 16, + failure_only: bool = False, + minibatch_size: int = 8, + edit_budget: int = 4, + seed: int = 42, + limit: int = 0, + image_detail: str = "auto", + judge_model: str = "gpt-5.4", + judge_max_completion_tokens: int = 256, + judge_retries: int = 5, + use_deep_reflect: bool = False, + deep_reflect_failures: int = 4, + deep_reflect_successes: int = 2, + ) -> None: + self.max_turns = max_turns + self.workers = workers + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.image_detail = image_detail + self.judge_model = judge_model + self.judge_max_completion_tokens = judge_max_completion_tokens + self.judge_retries = judge_retries + self.use_deep_reflect = use_deep_reflect + self.deep_reflect_failures = deep_reflect_failures + self.deep_reflect_successes = deep_reflect_successes + self.dataloader = BabyVisionDataLoader( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + ) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + self.dataloader.setup(cfg) + + def get_dataloader(self): + return self.dataloader + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + return list(batch.payload or []) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout( + self, + env_manager, + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict]: + items: list[dict] = env_manager + return run_batch( + items=items, + out_root=out_dir, + skill_content=skill_content, + max_turns=self.max_turns, + workers=self.workers, + image_detail=self.image_detail, + judge_model=self.judge_model, + judge_max_completion_tokens=self.judge_max_completion_tokens, + judge_retries=self.judge_retries, + diagnostic_mode=kwargs.get("diagnostic_mode", False), + diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), + diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"), + ) + + def reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + if not self.use_deep_reflect: + return [] + + env_manager = kwargs.get("env_manager") + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + codex_backend = get_student_backend() == "codex_exec" + selected_items = self.select_representative_items( + results, + env_manager if isinstance(env_manager, list) else None, + n_failures=self.deep_reflect_failures, + n_successes=self.deep_reflect_successes, + seed=random_seed, + ) + if not selected_items: + return [] + selected_ids = {str(item["id"]) for item in selected_items} + selected_results = [row for row in results if str(row.get("id")) in selected_ids] + selected_examples = self.attach_reference_context(selected_results, selected_items) + if codex_backend: + selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir) + selected_metadata = [] + cot_count = 0 + for item in selected_items: + meta = self.get_reference_metadata(item) + if meta["fields"]: + cot_count += 1 + selected_metadata.append({ + "id": str(item["id"]), + "task_type": str(item.get("subtype") or item.get("task_type") or "babyvision"), + "reference_fields": meta["fields"], + "reference_preview": meta["preview"], + }) + + deep_dir = os.path.join(out_dir, "deep_reflect") + rollout_dir = os.path.join(deep_dir, "rollout") + patches_dir = os.path.join(deep_dir, "patches") + os.makedirs(deep_dir, exist_ok=True) + print( + f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} " + f"reference_fields=cot({cot_count}/{len(selected_items)})" + ) + probe = generate_deep_probe_instruction( + skill_content=skill_content, + items=selected_examples, + prediction_dir=prediction_dir, + system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + ) + if not probe: + return [] + diagnostic_trace_context_by_id = None + if codex_backend: + selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target( + selected_items=selected_items, + selected_examples=selected_examples, + prediction_dir=prediction_dir, + probe=probe, + ) + probe_record = { + **probe, + "reference_summary": { + "selected_count": len(selected_items), + "field_counts": { + "cot": cot_count, + }, + }, + "selected_examples": selected_metadata, + } + with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f: + json.dump(probe_record, f, ensure_ascii=False, indent=2) + deep_results = run_batch( + items=selected_items, + out_root=rollout_dir, + skill_content=skill_content, + max_turns=self.max_turns, + workers=min(self.workers, max(len(selected_items), 1)), + image_detail=self.image_detail, + judge_model=self.judge_model, + judge_max_completion_tokens=self.judge_max_completion_tokens, + judge_retries=self.judge_retries, + diagnostic_mode=True, + diagnostic_instruction=probe["probe_instruction"], + diagnostic_trace_context_by_id=diagnostic_trace_context_by_id, + ) + deep_results = self.attach_reference_context(deep_results, selected_items) + return run_minibatch_reflect( + results=deep_results, + skill_content=skill_content, + prediction_dir=os.path.join(rollout_dir, "predictions"), + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def get_task_types(self) -> list[str]: + return self.dataloader.get_task_types() diff --git a/skillopt/envs/babyvision/dataloader.py b/skillopt/envs/babyvision/dataloader.py new file mode 100644 index 0000000..3d37860 --- /dev/null +++ b/skillopt/envs/babyvision/dataloader.py @@ -0,0 +1,214 @@ +"""BabyVision task dataloader.""" +from __future__ import annotations + +import json +import os +from typing import Any + +from skillopt.datasets.base import SplitDataLoader + + +# ── Raw data loading utilities (for preprocessing / standalone eval) ───── + +_CHOICE_LABELS = ["A", "B", "C", "D", "E", "F", "G"] + + +def _iter_jsonl(path: str) -> list[dict]: + items: list[dict] = [] + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + items.append(json.loads(line)) + return items + + +def _normalize_ans_type(raw: Any, options: list[dict], choice_answer: Any) -> str: + text = str(raw or "").strip().lower() + if text in {"choice", "multiple_choice", "mcq", "option"}: + return "choice" + if text in {"blank", "open", "open_ended", "fill_blank", "short_answer"}: + return "blank" + if options or choice_answer not in (None, "", []): + return "choice" + return "blank" + + +def _coerce_options(raw: Any) -> list[dict]: + options: list[dict] = [] + if isinstance(raw, list): + for idx, item in enumerate(raw): + if isinstance(item, dict): + text = str(item.get("text") or item.get("content") or item.get("option") or "").strip() + label = str(item.get("label") or _CHOICE_LABELS[idx]).strip() + else: + text = str(item).strip() + label = _CHOICE_LABELS[idx] + if text: + options.append({"label": label, "text": text}) + elif isinstance(raw, dict): + for idx, (key, value) in enumerate(raw.items()): + text = str(value).strip() + if text: + options.append({"label": str(key).strip() or _CHOICE_LABELS[idx], "text": text}) + return options + + +def _normalize_choice_answer(choice_answer: Any, options: list[dict]) -> dict[str, str]: + if not options: + return {"label": "", "text": ""} + + if isinstance(choice_answer, dict): + label = str(choice_answer.get("label") or "").strip().upper() + text = str(choice_answer.get("text") or "").strip() + for option in options: + if label and option["label"].strip().upper() == label: + return {"label": option["label"], "text": option["text"]} + if text and option["text"] == text: + return {"label": option["label"], "text": option["text"]} + + if isinstance(choice_answer, int): + idx = choice_answer + if 0 <= idx < len(options): + return dict(options[idx]) + if 1 <= idx <= len(options): + return dict(options[idx - 1]) + + text = str(choice_answer or "").strip() + label = text.upper().rstrip(".):") + for option in options: + if option["label"].strip().upper() == label: + return dict(option) + if option["text"] == text: + return dict(option) + + return {"label": "", "text": ""} + + +def _coerce_blank_answers(raw: Any) -> list[str]: + if isinstance(raw, list): + return [str(item).strip() for item in raw if str(item).strip()] + if raw is None: + return [] + text = str(raw).strip() + return [text] if text else [] + + +def load_items(data_path: str) -> list[dict]: + """Load and normalise BabyVision items from a directory or JSONL file.""" + if not data_path: + raise ValueError("BabyVision requires data_path pointing to a local dataset directory or meta_data.jsonl.") + + if os.path.isdir(data_path): + meta_path = os.path.join(data_path, "meta_data.jsonl") + image_root = os.path.join(data_path, "images") + else: + meta_path = data_path + image_root = os.path.join(os.path.dirname(data_path), "images") + + if not os.path.exists(meta_path): + raise ValueError( + "BabyVision expected a meta_data.jsonl file. " + f"Could not find: {meta_path}" + ) + + raw_items = _iter_jsonl(meta_path) + items: list[dict] = [] + for idx, raw in enumerate(raw_items): + options = _coerce_options(raw.get("options") or raw.get("choices") or raw.get("choiceOptions")) + ans_type = _normalize_ans_type(raw.get("ansType"), options, raw.get("choiceAns")) + correct_choice = _normalize_choice_answer(raw.get("choiceAns"), options) + blank_answers = _coerce_blank_answers(raw.get("blankAns")) + + image_name = str( + raw.get("image") + or raw.get("image_path") + or raw.get("image_file") + or raw.get("img") + or "" + ).strip() + if not image_name: + continue + image_path = image_name if os.path.isabs(image_name) else os.path.join(image_root, image_name) + if not os.path.exists(image_path): + alt = os.path.join(os.path.dirname(meta_path), image_name) + if os.path.exists(alt): + image_path = alt + else: + continue + + task_id = str(raw.get("taskId") or raw.get("id") or idx + 1) + task_type = str(raw.get("type") or raw.get("taskType") or "unknown").strip() or "unknown" + subtype = str(raw.get("subtype") or raw.get("subType") or task_type).strip() or task_type + question = str(raw.get("question") or raw.get("query") or "").strip() + if not question: + continue + + if ans_type == "choice" and not correct_choice["label"]: + continue + if ans_type != "choice" and not blank_answers: + continue + + items.append({ + "id": task_id, + "task_type": task_type, + "subtype": subtype, + "question": question, + "image_path": os.path.abspath(image_path), + "ans_type": ans_type, + "choices": options, + "correct_choice": correct_choice, + "blank_answers": blank_answers, + "cot": str(raw.get("coT") or raw.get("cot") or "").strip(), + "source_path": os.path.abspath(meta_path), + }) + + if not items: + raise ValueError(f"No valid BabyVision items loaded from {data_path}") + return items + + +# ── Dataloader ─────────────────────────────────────────────────────────── + +class BabyVisionDataLoader(SplitDataLoader): + """BabyVision dataloader.""" + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + seed: int = 42, + limit: int = 0, + **kwargs, + ) -> None: + super().__init__( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + ) + self._task_types: list[str] = [] + + def load_raw_items(self, data_path: str) -> list[dict]: + return load_items(data_path) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + all_items = self.train_items + self.val_items + self.test_items + task_types = { + item.get("subtype") or item.get("task_type") or "unknown" + for item in all_items + } + self._task_types = sorted(task_types) + + def get_task_types(self) -> list[str]: + return list(self._task_types) diff --git a/skillopt/envs/babyvision/evaluator.py b/skillopt/envs/babyvision/evaluator.py new file mode 100644 index 0000000..e19e342 --- /dev/null +++ b/skillopt/envs/babyvision/evaluator.py @@ -0,0 +1,160 @@ +"""BabyVision evaluation helpers using the official-style LLM judge.""" +from __future__ import annotations + +import re +import string + +import regex + +from skillopt.model import chat_with_deployment +from skillopt.prompts import load_prompt + +_EVAL_MODE = "babyvision_judge_v2_official_style" + +def normalize_text(text: str) -> str: + text = str(text).strip().lower() + text = "".join(ch for ch in text if ch not in string.punctuation) + return " ".join(text.split()) + + +def extract_boxed_answer(text: str | None) -> str | None: + """Extract the final answer using the official BabyVision rule.""" + if text is None: + return None + + pattern = r'\\boxed\{((?:[^{}]|{(?:[^{}]|{.*})*})*)\}' + matches = regex.findall(pattern, text) + if matches: + return matches[-1] + + pattern_alt = r'<\|begin_of_box\|>(.*?)<\|end_of_box\|>' + matches_alt = regex.findall(pattern_alt, text) + if matches_alt: + return matches_alt[-1].strip() + + return None + + +def _token_f1(prediction: str, gold: str) -> float: + pred_tokens = normalize_text(prediction).split() + gold_tokens = normalize_text(gold).split() + if not pred_tokens and not gold_tokens: + return 1.0 + if not pred_tokens or not gold_tokens: + return 0.0 + pred_set = {} + gold_set = {} + for tok in pred_tokens: + pred_set[tok] = pred_set.get(tok, 0) + 1 + for tok in gold_tokens: + gold_set[tok] = gold_set.get(tok, 0) + 1 + common = 0 + for tok, count in pred_set.items(): + common += min(count, gold_set.get(tok, 0)) + if common == 0: + return 0.0 + precision = common / len(pred_tokens) + recall = common / len(gold_tokens) + return 2 * precision * recall / (precision + recall) + + +def _format_choices(choices: list[dict]) -> str: + return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices) + + +def _judge_answer( + *, + item: dict, + prediction_text: str, + extracted_answer: str, + judge_model: str, + max_completion_tokens: int, + retries: int, +) -> dict: + if item["ans_type"] == "choice": + ground_truth = str(item["correct_choice"]["label"]) + else: + if len(item["blank_answers"]) == 1: + ground_truth = item["blank_answers"][0] + else: + ground_truth = " | ".join(item["blank_answers"]) + + question = str(item["question"]) + if item["ans_type"] == "choice" and item.get("choices"): + question = f"{question}\nChoices:\n{_format_choices(item['choices'])}" + + raw, _ = chat_with_deployment( + deployment=judge_model, + system="You are a careful and strict evaluator.", + user=load_prompt("judge", env="babyvision").format( + question=question, + groundtruth=ground_truth, + modeloutput=extracted_answer, + ), + max_completion_tokens=max_completion_tokens, + retries=retries, + stage="babyvision_judge", + ) + judge_response_clean = str(raw).strip().lower() + if "true" in judge_response_clean: + correct = True + elif "false" in judge_response_clean: + correct = False + else: + correct = False + return { + "raw": raw, + "correct": correct, + "reason": judge_response_clean, + "matched_gold": ground_truth if correct else "", + } + + +def evaluate_item( + *, + item: dict, + prediction_text: str, + judge_model: str, + max_completion_tokens: int = 256, + retries: int = 5, +) -> dict: + answer = extract_boxed_answer(prediction_text) + judge = _judge_answer( + item=item, + prediction_text=prediction_text, + extracted_answer=answer, + judge_model=judge_model, + max_completion_tokens=max_completion_tokens, + retries=retries, + ) + hard = 1.0 if judge["correct"] else 0.0 + + result = { + "evaluation_mode": _EVAL_MODE, + "predicted_answer": answer, + "em": hard, + "f1": hard, + "sub_em": hard, + "judge_model": judge_model, + "judge_raw": judge["raw"], + "judge_reason": judge["reason"], + "matched_gold": judge["matched_gold"], + } + + if item["ans_type"] == "choice": + result["predicted_label"] = str(answer or "").strip().upper().rstrip(".):") + result["predicted_text"] = "" + result["correct_label"] = str(item["correct_choice"].get("label") or "") + result["correct_text"] = str(item["correct_choice"].get("text") or "") + else: + result["gold_answers"] = list(item["blank_answers"]) + best_f1 = 0.0 + for gold in item["blank_answers"]: + best_f1 = max(best_f1, _token_f1(str(answer or ""), gold)) + result["string_f1"] = best_f1 + + return result + + +def evaluation_mode() -> str: + return _EVAL_MODE diff --git a/skillopt/envs/babyvision/prompts/analyst_error.md b/skillopt/envs/babyvision/prompts/analyst_error.md new file mode 100644 index 0000000..79c0c0d --- /dev/null +++ b/skillopt/envs/babyvision/prompts/analyst_error.md @@ -0,0 +1,36 @@ +You are an expert failure-analysis agent for child-level visual reasoning tasks. + +You will be given MULTIPLE failed BabyVision trajectories from a minibatch and the current skill document. +Each trajectory includes the text prompt, the model answer, and the evaluation result. +You do not have direct access to raw pixel content during reflection, so focus on general reasoning, +option-selection, and visual-question-answering behaviors that can be improved through prompting. + +## Failure Type Categories +- **visual_detail_miss**: the agent likely overlooked a salient visual attribute, relation, count, or object state +- **option_mismatch**: the agent selected the wrong option despite relevant evidence likely being present +- **instruction_slip**: the agent ignored output format or answered too vaguely +- **answer_granularity**: the agent gave an answer that was too broad, too narrow, or mismatched the expected specificity +- **other**: none of the above + +## Rules +1. Focus on patterns recurring across the minibatch. +2. Prefer reusable behaviors for inspecting images and grounding answers in visible evidence. +3. Do not memorize dataset-specific answers. +4. Only patch gaps not already covered by the current skill. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} diff --git a/skillopt/envs/babyvision/prompts/analyst_success.md b/skillopt/envs/babyvision/prompts/analyst_success.md new file mode 100644 index 0000000..212a345 --- /dev/null +++ b/skillopt/envs/babyvision/prompts/analyst_success.md @@ -0,0 +1,25 @@ +You are an expert success-pattern analyst for child-level visual reasoning tasks. + +You will be given MULTIPLE successful BabyVision trajectories from a minibatch and the current skill document. +Identify generalizable behavior patterns that help the agent inspect the image carefully and answer at the right level of specificity. + +## Rules +- Focus on broadly useful visual QA behaviors. +- Prefer patterns about systematic image inspection, comparing options, and concise grounded answers. +- Do not add dataset-specific facts. +- "edits" may be empty if the skill already captures the useful patterns. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} diff --git a/skillopt/envs/babyvision/prompts/deep_probe.md b/skillopt/envs/babyvision/prompts/deep_probe.md new file mode 100644 index 0000000..ff53c53 --- /dev/null +++ b/skillopt/envs/babyvision/prompts/deep_probe.md @@ -0,0 +1,25 @@ +You are an expert diagnostic-probe designer for BabyVision-style visual reasoning tasks. + +You will be shown representative trajectories, the current student skill, and the student's original prompt context. +Design one SMALL diagnostic instruction that exposes the student's intermediate visual judgment without materially changing the original scaffold. + +## Hard Constraints +1. Do NOT substantially change the original scaffold. +2. Do NOT prescribe a new step-by-step solving method. +3. You MAY ask for a short structured list of a few intermediate conclusions, candidate cues, or counted units, as long as it stays close to the original scaffold. +4. Do NOT ask for exhaustive listing of all cells, all objects, or a full chain-of-thought. +5. Ask only for a short readout that reveals the student's current latent state. +6. Keep it brief and structured, and require the final answer to remain in .... + +## Good Probe Targets +- top answer and runner-up +- decisive visual cue +- suspicious region or compared objects +- counting unit or formatting interpretation +- 2-4 short intermediate conclusions that directly support the final answer + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "probe_instruction": "" +} diff --git a/skillopt/envs/babyvision/prompts/judge.md b/skillopt/envs/babyvision/prompts/judge.md new file mode 100644 index 0000000..7f0872e --- /dev/null +++ b/skillopt/envs/babyvision/prompts/judge.md @@ -0,0 +1,35 @@ +You are a careful and strict evaluator. You will be given: + +1. **Question** +2. **Ground Truth Answer** (correct answer) +3. **Model Output** (answer from another model) + +**Your goal:** Determine if the Model Output **accurately matches** the Ground Truth Answer in meaning. + +* Matching means: the facts, entities, and key details are equivalent, even if phrasing differs. +* Not matching means: the Model Output is wrong, incomplete, contains extra incorrect facts, or changes the meaning. + +**Process (internal reasoning):** + +1. Read and understand the Question, Ground Truth Answer, and Model Output. +2. Ignore small wording differences, formatting, or synonyms. +3. If all factual content matches, conclude `1`. Otherwise, conclude `0`. + +**Important:** + +* Think through your decision step-by-step **internally** before responding. +* In your final output, return **only** True or False, with no extra text or explanation. + +**Output format:** + +True + +or + +False + +**Input:** + +Question: {question}, +Ground Truth Answer: {groundtruth}, +Model Output: {modeloutput} diff --git a/skillopt/envs/babyvision/prompts/rollout_system.md b/skillopt/envs/babyvision/prompts/rollout_system.md new file mode 100644 index 0000000..42921a8 --- /dev/null +++ b/skillopt/envs/babyvision/prompts/rollout_system.md @@ -0,0 +1,13 @@ +You are an expert visual reasoning agent solving child-level image understanding tasks. + +{skill_section}## Task Format +You will receive one image and one question about it. +Inspect the image carefully before answering. Ground the answer in visible evidence. + +## Answer Format +Think step by step, then provide your final answer in \boxed{{Answer}} format. +- For multiple-choice questions, output only the single choice label, such as \boxed{{A}}. +- For open questions, output only a short final answer inside \boxed{{...}}. + +Example: +\boxed{{B}} diff --git a/skillopt/envs/babyvision/reflect.py b/skillopt/envs/babyvision/reflect.py new file mode 100644 index 0000000..bef0999 --- /dev/null +++ b/skillopt/envs/babyvision/reflect.py @@ -0,0 +1,4 @@ +"""BabyVision Reflect stage. + +Prompts are now loaded from .md files by the base adapter. +""" diff --git a/skillopt/envs/babyvision/rollout.py b/skillopt/envs/babyvision/rollout.py new file mode 100644 index 0000000..ff1e64e --- /dev/null +++ b/skillopt/envs/babyvision/rollout.py @@ -0,0 +1,467 @@ +"""BabyVision rollout — multimodal visual QA with image input.""" +from __future__ import annotations + +import base64 +import json +import mimetypes +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +from skillopt.envs.babyvision.evaluator import evaluate_item, evaluation_mode, extract_boxed_answer +from skillopt.model import chat_student_messages, get_student_backend, is_student_exec_backend +from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec +from skillopt.prompts import load_prompt + +def _build_system(skill_content: str) -> str: + if skill_content.strip(): + skill_section = f"## Skill\n{skill_content.strip()}\n\n" + else: + skill_section = "" + return load_prompt("rollout_system", env="babyvision").format(skill_section=skill_section) + + +def _format_choices(choices: list[dict]) -> str: + return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices) + + +def _build_user_text( + item: dict, + *, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> str: + parts = [] + if diagnostic_trace_context.strip(): + parts.append( + "## Previous Codex Trace Snapshot\n" + "This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n" + f"{diagnostic_trace_context.strip()}" + ) + parts.append(f"## Question\n{item['question']}") + if item["ans_type"] == "choice": + parts.append(f"## Choices\n{_format_choices(item['choices'])}") + parts.append("Answer using the single correct option label in \\boxed{...}.") + else: + parts.append("Answer with a short phrase in \\boxed{...}.") + if diagnostic_mode and diagnostic_instruction.strip(): + parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}") + return "\n\n".join(parts) + + +def _image_to_data_uri(path: str) -> str: + mime = mimetypes.guess_type(path)[0] or "image/png" + with open(path, "rb") as f: + encoded = base64.b64encode(f.read()).decode("ascii") + return f"data:{mime};base64,{encoded}" + + +def _build_messages( + item: dict, + skill_content: str, + image_detail: str, + *, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> tuple[list[dict], str, str]: + system = _build_system(skill_content) + user_text = _build_user_text( + item, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + image_url = { + "url": _image_to_data_uri(item["image_path"]), + } + if image_detail and image_detail != "auto": + image_url["detail"] = image_detail + messages = [ + {"role": "system", "content": system}, + { + "role": "user", + "content": [ + {"type": "text", "text": user_text}, + {"type": "image_url", "image_url": image_url}, + ], + }, + ] + return messages, system, user_text + + +def _build_codex_skill(skill_content: str) -> str: + return render_skill_md( + skill_content, + description="Dynamic ReflACT skill for solving the current BabyVision visual reasoning question.", + preamble=( + "Use this skill when answering the current visual reasoning question.\n" + "Inspect the attached image carefully and return the final answer in \\boxed{...}." + ), + ) + + +def _run_codex_once( + *, + pred_dir: str, + item: dict, + skill_content: str, + model: str, + timeout: int, + image_detail: str, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", + previous_response: str = "", +) -> tuple[str, str, str, str]: + user_text = _build_user_text( + item, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + task_parts = [user_text] + if previous_response: + task_parts.append( + "## Previous Attempt\n" + f"{previous_response}\n\n" + "Review the same image and question carefully. If needed, correct the answer." + ) + task_text = "\n\n".join(task_parts) + skill_md = _build_codex_skill(skill_content) + work_dir = os.path.join(pred_dir, "codex_exec") + prepare_workspace( + work_dir=work_dir, + skill_md=skill_md, + task_text=task_text, + images=[item["image_path"]], + ) + prompt = ( + "Use the `skillopt-student` skill available in this workspace.\n" + "Read `task.md`, inspect the attached image, and answer the question.\n" + "Return the final answer in \\boxed{...}." + ) + final_message, raw = run_student_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + images=[item["image_path"]], + ) + return final_message or raw, raw, skill_md, task_text + + +def process_one( + item: dict, + out_root: str, + skill_content: str, + *, + max_turns: int = 1, + image_detail: str = "auto", + judge_model: str = "gpt-5.4", + judge_max_completion_tokens: int = 256, + judge_retries: int = 5, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> dict: + item_id = str(item["id"]) + result = { + "id": item_id, + "question": item["question"], + "task_type": item.get("subtype") or item.get("task_type") or "babyvision", + "task_description": item["question"], + "hard": 0, + "soft": 0.0, + "predicted_answer": "", + "predicted_label": "", + "predicted_text": "", + "response": "", + "fail_reason": "", + "agent_ok": False, + "n_turns": 0, + "image_path": item["image_path"], + "ans_type": item["ans_type"], + "evaluation_mode": evaluation_mode(), + "judge_model": judge_model, + } + if item["ans_type"] == "choice": + result["correct_label"] = item["correct_choice"]["label"] + result["correct_text"] = item["correct_choice"]["text"] + else: + result["gold_answers"] = item["blank_answers"] + + try: + pred_dir = os.path.join(out_root, "predictions", item_id) + os.makedirs(pred_dir, exist_ok=True) + + if is_student_exec_backend(): + from skillopt.model import azure_openai as _llm + + response = "" + conversation: list[dict] = [ + {"role": "user", "content": f"{item['question']}\n\n[image] {os.path.basename(item['image_path'])}"} + ] + system_prompt = "" + user_text = "" + for turn in range(max_turns): + response, raw, system_prompt, user_text = _run_codex_once( + pred_dir=pred_dir, + item=item, + skill_content=skill_content, + model=_llm.STUDENT_DEPLOYMENT, + timeout=120, + image_detail=image_detail, + diagnostic_mode=diagnostic_mode if turn == 0 else False, + diagnostic_instruction=diagnostic_instruction if turn == 0 else "", + diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "", + previous_response=response if turn > 0 else "", + ) + conversation.append({"type": "message", "turn": turn + 1, "content": response}) + if extract_boxed_answer(response) is not None: + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) - 1 + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f: + f.write(system_prompt) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f: + f.write(user_text) + + eval_result = evaluate_item( + item=item, + prediction_text=response, + judge_model=judge_model, + max_completion_tokens=judge_max_completion_tokens, + retries=judge_retries, + ) + result["evaluation_mode"] = eval_result["evaluation_mode"] + result["judge_raw"] = eval_result["judge_raw"] + result["judge_reason"] = eval_result["judge_reason"] + result["matched_gold"] = eval_result["matched_gold"] + if item["ans_type"] == "choice": + result["predicted_label"] = eval_result["predicted_label"] + result["predicted_text"] = eval_result["predicted_text"] + result["predicted_answer"] = eval_result["predicted_answer"] + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + if not result["hard"]: + result["fail_reason"] = ( + f"judge=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' " + f"but expected '{eval_result['correct_label']}' ({eval_result['judge_reason']})" + ) + eval_detail = ( + f"[EVALUATION RESULT]\n" + f"Question: {item['question']}\n" + f"Predicted label: {eval_result['predicted_label']!r}\n" + f"Predicted text: {eval_result['predicted_text']!r}\n" + f"Correct label: {eval_result['correct_label']!r}\n" + f"Correct text: {eval_result['correct_text']!r}\n" + f"Judge correct: {eval_result['em']}\n" + f"Judge reason: {eval_result['judge_reason']}" + ) + else: + result["predicted_answer"] = eval_result["predicted_answer"] + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + if not result["hard"]: + result["fail_reason"] = ( + f"judge=0: predicted '{eval_result['predicted_answer']}' " + f"but expected {item['blank_answers']} ({eval_result['judge_reason']})" + ) + eval_detail = ( + f"[EVALUATION RESULT]\n" + f"Question: {item['question']}\n" + f"Predicted answer: {eval_result['predicted_answer']!r}\n" + f"Gold answers: {item['blank_answers']!r}\n" + f"Judge correct: {eval_result['em']}\n" + f"Judge reason: {eval_result['judge_reason']}\n" + f"String F1: {eval_result.get('string_f1', 0.0):.4f}" + ) + conversation.append({"role": "system", "content": eval_detail}) + with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + return result + + messages, system_prompt, user_text = _build_messages( + item, + skill_content, + image_detail, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + response = "" + conversation: list[dict] = [ + {"role": "user", "content": f"{user_text}\n\n[image] {os.path.basename(item['image_path'])}"} + ] + + for turn in range(max_turns): + if turn == 0: + resp_text, _ = chat_student_messages( + messages=messages, + max_completion_tokens=768, + retries=5, + stage="rollout", + ) + else: + refinement_text = ( + f"Your previous answer was:\n{response}\n\n" + "Review the same image and question carefully. " + "If needed, correct your answer. Output the final answer in \\boxed{...}." + ) + refinement_messages = [ + messages[0], + messages[1], + {"role": "assistant", "content": response}, + {"role": "user", "content": refinement_text}, + ] + resp_text, _ = chat_student_messages( + messages=refinement_messages, + max_completion_tokens=512, + retries=5, + stage="rollout", + ) + response = resp_text + conversation.append({"type": "message", "turn": turn + 1, "content": resp_text}) + if extract_boxed_answer(resp_text) is not None: + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) - 1 + + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f: + f.write(system_prompt) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f: + f.write(user_text) + + eval_result = evaluate_item( + item=item, + prediction_text=response, + judge_model=judge_model, + max_completion_tokens=judge_max_completion_tokens, + retries=judge_retries, + ) + result["evaluation_mode"] = eval_result["evaluation_mode"] + result["judge_raw"] = eval_result["judge_raw"] + result["judge_reason"] = eval_result["judge_reason"] + result["matched_gold"] = eval_result["matched_gold"] + + if item["ans_type"] == "choice": + result["predicted_label"] = eval_result["predicted_label"] + result["predicted_text"] = eval_result["predicted_text"] + result["predicted_answer"] = eval_result["predicted_answer"] + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + if not result["hard"]: + result["fail_reason"] = ( + f"judge=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' " + f"but expected '{eval_result['correct_label']}' ({eval_result['judge_reason']})" + ) + eval_detail = ( + f"[EVALUATION RESULT]\n" + f"Question: {item['question']}\n" + f"Predicted label: {eval_result['predicted_label']!r}\n" + f"Predicted text: {eval_result['predicted_text']!r}\n" + f"Correct label: {eval_result['correct_label']!r}\n" + f"Correct text: {eval_result['correct_text']!r}\n" + f"Judge correct: {eval_result['em']}\n" + f"Judge reason: {eval_result['judge_reason']}" + ) + else: + result["predicted_answer"] = eval_result["predicted_answer"] + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + if not result["hard"]: + result["fail_reason"] = ( + f"judge=0: predicted '{eval_result['predicted_answer']}' " + f"but expected {item['blank_answers']} ({eval_result['judge_reason']})" + ) + eval_detail = ( + f"[EVALUATION RESULT]\n" + f"Question: {item['question']}\n" + f"Predicted answer: {eval_result['predicted_answer']!r}\n" + f"Gold answers: {item['blank_answers']!r}\n" + f"Judge correct: {eval_result['em']}\n" + f"Judge reason: {eval_result['judge_reason']}\n" + f"String F1: {eval_result.get('string_f1', 0.0):.4f}" + ) + + conversation.append({"role": "system", "content": eval_detail}) + with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + except Exception as e: # noqa: BLE001 + result["fail_reason"] = f"error: {e}" + return result + + +def run_batch( + items: list[dict], + out_root: str, + skill_content: str, + *, + max_turns: int = 1, + workers: int = 32, + image_detail: str = "auto", + judge_model: str = "gpt-5.4", + judge_max_completion_tokens: int = 256, + judge_retries: int = 5, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context_by_id: dict[str, str] | None = None, +) -> list[dict]: + results_path = os.path.join(out_root, "results.jsonl") + os.makedirs(out_root, exist_ok=True) + + expected_eval_mode = evaluation_mode() + done_ids: set[str] = set() + existing: list[dict] = [] + rewrite_results = False + if os.path.exists(results_path): + with open(results_path, encoding="utf-8") as f: + for line in f: + try: + row = json.loads(line) + if row.get("evaluation_mode") != expected_eval_mode: + rewrite_results = True + continue + done_ids.add(str(row["id"])) + existing.append(row) + except Exception: + rewrite_results = True + + pending = [item for item in items if str(item["id"]) not in done_ids] + if not pending and not rewrite_results: + return existing + + results = list(existing) + file_mode = "w" if rewrite_results else "a" + with open(results_path, file_mode, encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex: + if rewrite_results: + for row in existing: + outf.write(json.dumps(row, ensure_ascii=False) + "\n") + futs = { + ex.submit( + process_one, + item, + out_root, + skill_content, + max_turns=max_turns, + image_detail=image_detail, + judge_model=judge_model, + judge_max_completion_tokens=judge_max_completion_tokens, + judge_retries=judge_retries, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""), + ): item + for item in pending + } + for fut in as_completed(futs): + row = fut.result() + results.append(row) + outf.write(json.dumps(row, ensure_ascii=False) + "\n") + outf.flush() + return results diff --git a/skillopt/envs/babyvision/skills/initial.md b/skillopt/envs/babyvision/skills/initial.md new file mode 100644 index 0000000..56564f8 --- /dev/null +++ b/skillopt/envs/babyvision/skills/initial.md @@ -0,0 +1,18 @@ +# BabyVision Visual QA Heuristics + +## Image Inspection +- First identify the main objects, their attributes, and their spatial relations before answering. +- If the question involves counting, compare all relevant instances carefully instead of stopping after the first match. +- If the question asks about color, size, position, or action, verify the specific visible evidence for that attribute. + +## Multiple Choice +- Compare every option against the visible image evidence before deciding. +- Prefer the option that matches the image exactly; reject options that are only partially true or too vague. +- When two options are close, check the smallest discriminating visual detail. + +## Open Answers +- Answer with the shortest phrase that is fully supported by the image. +- Match the expected level of specificity: not broader than the image evidence, not narrower than the question asks. + +## Final Answer +- Output only the final answer inside .... diff --git a/skillopt/envs/base.py b/skillopt/envs/base.py new file mode 100644 index 0000000..4267944 --- /dev/null +++ b/skillopt/envs/base.py @@ -0,0 +1,396 @@ +"""ReflACT environment adapter — abstract interface. + +To connect ReflACT to a new environment (benchmark, simulator, etc.), +implement a subclass of :class:`EnvAdapter` with environment-specific +rollout and reflection logic. + +Example:: + + class MyBenchAdapter(EnvAdapter): + def build_train_env(self, batch_size, seed, **kw): + return MyEnvManager(split="train", n=batch_size, seed=seed) + + def build_eval_env(self, env_num, split, seed, **kw): + return MyEnvManager(split=split, n=env_num, seed=seed) + + def rollout(self, env_manager, skill_content, out_dir, **kw): + # Run episodes, return [{"id": ..., "hard": 0/1, "soft": 0.0-1.0, ...}] + ... + + def reflect(self, results, skill_content, out_dir, **kw): + # Analyze trajectories, return list of patch dicts + ... + + def get_task_types(self): + return ["task_a", "task_b"] +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +import os +import random + +from skillopt.datasets.base import BaseDataLoader, BatchSpec +from skillopt.model.codex_harness import extract_codex_trace_prefix, format_codex_trace_steps, parse_codex_raw +from skillopt.prompts import load_prompt + + +class EnvAdapter(ABC): + """Abstract adapter for connecting ReflACT to any environment. + + Subclasses must implement all abstract methods. The ReflACT trainer + calls these methods at the appropriate pipeline stages. + """ + + # ── Lifecycle hooks ──────────────────────────────────────────────────── + + def setup(self, cfg: dict) -> None: + """Called once by the trainer before the training loop begins. + + Override to perform one-time initialization that requires the full + config (e.g., data loading, split creation). Default is a no-op. + """ + self._cfg = dict(cfg) + + def get_dataloader(self) -> BaseDataLoader | None: + """Return the task dataloader used by this adapter, if any.""" + return None + + def requires_ray(self) -> bool: + """Return whether this adapter requires Ray runtime initialization.""" + return False + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + """Optional deeper diagnostic reflection pass. + + Default behavior is a no-op. Dataset-backed adapters may override this + to re-query the student on a small representative subset of the current + batch using minimally-perturbed diagnostic prompts that expose + intermediate reasoning state. + """ + return [] + + def build_reference_text(self, item: dict) -> str: + """Return hidden reference material for deep reflection, if any.""" + return str(item.get("reference_text") or "").strip() + + def get_reference_metadata(self, item: dict) -> dict: + """Return structured metadata about hidden reference material.""" + reference_text = self.build_reference_text(item) + if not reference_text: + return {"fields": [], "preview": ""} + return { + "fields": ["reference_text"], + "preview": reference_text[:400], + } + + def get_codex_deep_probe_prompt(self) -> str | None: + env_name = getattr(self, "_cfg", {}).get("env_name") + return load_prompt("deep_probe_codex", env=env_name) + + def attach_codex_probe_context( + self, + results: list[dict], + prediction_dir: str, + ) -> list[dict]: + """Attach compact Codex step metadata for codex-aware deep reflection.""" + enriched: list[dict] = [] + for row in results: + merged = dict(row) + tid = str(row.get("id")) + raw_path = os.path.join(prediction_dir, tid, "codex_raw.txt") + if os.path.exists(raw_path): + with open(raw_path, encoding="utf-8") as f: + raw = f.read() + parsed = parse_codex_raw(raw) + merged["codex_probe_trace_steps"] = format_codex_trace_steps(raw) + merged["codex_probe_step_count"] = len(parsed["steps"]) + enriched.append(merged) + return enriched + + def resolve_codex_probe_target( + self, + *, + selected_items: list[dict], + selected_examples: list[dict], + prediction_dir: str, + probe: dict, + ) -> tuple[list[dict], dict[str, str] | None, dict]: + """Resolve the teacher-selected codex probe target and raw trace prefix.""" + target_id = str(probe.get("probe_target_id", "")).strip() + selected_id_set = {str(item["id"]) for item in selected_items} + if target_id not in selected_id_set: + target_id = str(selected_items[0]["id"]) + target_item = next(item for item in selected_items if str(item["id"]) == target_id) + target_result = next( + (row for row in selected_examples if str(row.get("id")) == target_id), + None, + ) + max_probe_step = int((target_result or {}).get("codex_probe_step_count", 0)) + default_probe_step = max_probe_step - 1 if max_probe_step > 1 else max_probe_step + probe_after_step = int(probe.get("probe_after_step", default_probe_step)) + if max_probe_step > 0: + probe_after_step = max(0, min(probe_after_step, max_probe_step)) + else: + probe_after_step = 0 + raw_path = os.path.join(prediction_dir, target_id, "codex_raw.txt") + trace_prefix = "" + if os.path.exists(raw_path): + with open(raw_path, encoding="utf-8") as f: + trace_prefix = extract_codex_trace_prefix(f.read(), after_step=probe_after_step) + updated_probe = dict(probe) + updated_probe["probe_target_id"] = target_id + updated_probe["probe_after_step"] = probe_after_step + return [target_item], {target_id: trace_prefix}, updated_probe + + def attach_reference_context( + self, + results: list[dict], + items: list[dict] | None, + ) -> list[dict]: + """Attach environment-specific hidden reference text to result dicts.""" + if not results or not items: + return list(results) + + item_by_id = { + str(item.get("id")): item + for item in items + if isinstance(item, dict) and item.get("id") is not None + } + enriched: list[dict] = [] + for row in results: + merged = dict(row) + item = item_by_id.get(str(row.get("id"))) + if item: + reference_text = self.build_reference_text(item) + if reference_text: + merged["reference_text"] = reference_text + enriched.append(merged) + return enriched + + def select_representative_items( + self, + results: list[dict], + items: list[dict] | None, + *, + n_failures: int, + n_successes: int, + seed: int | None = None, + ) -> list[dict]: + """Select a small diverse subset of current-batch items by outcome.""" + if not items: + return [] + + item_by_id = { + str(item.get("id")): item + for item in items + if isinstance(item, dict) and item.get("id") is not None + } + failures = [ + (result, item_by_id[str(result.get("id"))]) + for result in results + if not result.get("hard") and str(result.get("id")) in item_by_id + ] + successes = [ + (result, item_by_id[str(result.get("id"))]) + for result in results + if result.get("hard") and str(result.get("id")) in item_by_id + ] + + rng = random.Random(seed) + + def _pick(pool: list[tuple[dict, dict]], quota: int) -> list[dict]: + if quota <= 0 or not pool: + return [] + shuffled = list(pool) + rng.shuffle(shuffled) + + picked_ids: set[str] = set() + picked: list[dict] = [] + seen_types: set[str] = set() + + for result, item in shuffled: + task_type = str(result.get("task_type") or item.get("task_type") or item.get("subtype") or "unknown") + item_id = str(item["id"]) + if task_type in seen_types or item_id in picked_ids: + continue + picked.append(item) + picked_ids.add(item_id) + seen_types.add(task_type) + if len(picked) >= quota: + return picked + + for _, item in shuffled: + item_id = str(item["id"]) + if item_id in picked_ids: + continue + picked.append(item) + picked_ids.add(item_id) + if len(picked) >= quota: + break + return picked + + selected = _pick(failures, n_failures) + selected_ids = {str(item["id"]) for item in selected} + selected.extend( + item for item in _pick(successes, n_successes) + if str(item["id"]) not in selected_ids + ) + return selected + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + """Build an environment manager or item list from a :class:`BatchSpec`. + + Default behavior preserves the legacy adapter API by routing training + batches through :meth:`build_train_env` and evaluation batches through + :meth:`build_eval_env`. + """ + if batch.phase == "train": + return self.build_train_env(batch_size=batch.batch_size, seed=batch.seed, **kwargs) + return self.build_eval_env( + env_num=batch.batch_size, + split=batch.split, + seed=batch.seed, + **kwargs, + ) + + @abstractmethod + def build_train_env(self, batch_size: int, seed: int, **kwargs): + """Build a training environment manager. + + Returns + ------- + object + An environment manager that can be passed to :meth:`rollout`. + """ + + @abstractmethod + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + """Build an evaluation environment manager. + + Parameters + ---------- + env_num : int + Number of evaluation environments. + split : str + Dataset split (e.g. ``"valid_seen"``, ``"valid_unseen"``). + seed : int + Random seed for reproducibility. + + Returns + ------- + object + An environment manager that can be passed to :meth:`rollout`. + """ + + @abstractmethod + def rollout( + self, + env_manager, + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict]: + """Run a batch of episodes using the current skill. + + Returns + ------- + list[dict] + Each dict conforms to :class:`~skillopt.types.RolloutResult`: + must have ``"id"`` (str), ``"hard"`` (0/1), ``"soft"`` + (float 0-1). May include env-specific fields. + """ + + @abstractmethod + def reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + """Analyze rollout results and produce patches. + + Each returned dict conforms to :class:`~skillopt.types.RawPatch`: + ``"patch"`` (with ``"edits"`` list) + ``"source_type"`` + (``"failure"`` or ``"success"``). + + Returns + ------- + list[dict | None] + Raw analyst outputs; ``None`` entries are filtered out. + """ + + @abstractmethod + def get_task_types(self) -> list[str]: + """Return the list of task type names for this environment.""" + + # ── Prompt configuration (two-level priority) ──────────────────────── + # + # Priority: env-specific prompt file > generic default prompt file. + # + # Prompts are loaded from ``.md`` files via ``load_prompt(name, env)``: + # 1. ``skillopt/envs//prompts/.md`` (env-specific) + # 2. ``skillopt/prompts/.md`` (generic fallback) + # + # Subclasses can still override ``get_*_prompt()`` for full control. + + @property + def _env_name(self) -> str: + """Derive the env directory name from this adapter's module path.""" + # e.g. "skillopt.envs.searchqa.adapter" → "searchqa" + module = type(self).__module__ + parts = module.split(".") + if len(parts) >= 3 and parts[-3] == "envs": + return parts[-2] + return "" + + def _load_env_prompt(self, name: str) -> str | None: + """Load a prompt with env-specific override. Returns None if not found.""" + try: + return load_prompt(name, env=self._env_name) + except FileNotFoundError: + return None + + def get_error_minibatch_prompt(self) -> str | None: + update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch") + raw_mode = str(update_mode).strip().lower() + if raw_mode in {"full_rewrite", "full_rewrite_minibatch", "minibatch_full_rewrite", "skill_rewrite_minibatch"}: + prompt = self._load_env_prompt("analyst_error_full_rewrite") + if prompt is not None: + return prompt + if raw_mode in {"rewrite", "rewrite_from_suggestions", "suggestions", "rewrite_suggestions"}: + prompt = self._load_env_prompt("analyst_error_rewrite") + if prompt is not None: + return prompt + return self._load_env_prompt("analyst_error") + + def get_success_minibatch_prompt(self) -> str | None: + update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch") + raw_mode = str(update_mode).strip().lower() + if raw_mode in {"full_rewrite", "full_rewrite_minibatch", "minibatch_full_rewrite", "skill_rewrite_minibatch"}: + prompt = self._load_env_prompt("analyst_success_full_rewrite") + if prompt is not None: + return prompt + if raw_mode in {"rewrite", "rewrite_from_suggestions", "suggestions", "rewrite_suggestions"}: + prompt = self._load_env_prompt("analyst_success_rewrite") + if prompt is not None: + return prompt + return self._load_env_prompt("analyst_success") + + def get_deep_probe_prompt(self) -> str | None: + return self._load_env_prompt("deep_probe") + + def get_meta_reflect_prompt(self) -> str | None: + update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch") + if str(update_mode).strip().lower() == "rewrite_from_suggestions": + prompt = self._load_env_prompt("meta_reflect_rewrite") + if prompt is not None: + return prompt + return self._load_env_prompt("meta_reflect") diff --git a/skillopt/envs/deep_reflect.py b/skillopt/envs/deep_reflect.py new file mode 100644 index 0000000..d0fd37b --- /dev/null +++ b/skillopt/envs/deep_reflect.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import json +import os +from typing import Any, Callable + +from skillopt.gradient.deep_probe import generate_deep_probe_instruction +from skillopt.gradient.reflect import run_minibatch_reflect + + +def run_no_reference_deep_reflect( + adapter: Any, + results: list[dict], + skill_content: str, + out_dir: str, + *, + env_manager: Any = None, + prediction_dir: str | None = None, + random_seed: int | None = None, + step_buffer_context: str = "", + output_requirements: list[str] | None = None, + metadata_builder: Callable[[dict], dict] | None = None, +) -> list[dict | None]: + """Run teacher-designed diagnostic probing without hidden references.""" + if not getattr(adapter, "use_deep_reflect", False): + return [] + if not isinstance(env_manager, list): + return [] + + prediction_dir = prediction_dir or os.path.join(out_dir, "predictions") + selected_items = adapter.select_representative_items( + results, + env_manager, + n_failures=getattr(adapter, "deep_reflect_failures", 4), + n_successes=getattr(adapter, "deep_reflect_successes", 2), + seed=random_seed, + ) + if not selected_items: + return [] + + selected_ids = {str(item["id"]) for item in selected_items} + selected_results = [row for row in results if str(row.get("id")) in selected_ids] + if metadata_builder is None: + selected_metadata = [ + { + "id": str(item.get("id")), + "task_type": str(item.get("task_type") or item.get("topic") or "unknown"), + "question_preview": str(item.get("question") or "")[:200], + } + for item in selected_items + ] + else: + selected_metadata = [metadata_builder(item) for item in selected_items] + + deep_dir = os.path.join(out_dir, "deep_reflect") + rollout_dir = os.path.join(deep_dir, "rollout") + patches_dir = os.path.join(deep_dir, "patches") + os.makedirs(deep_dir, exist_ok=True) + print( + f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} " + "mode=no_reference_probe" + ) + + probe = generate_deep_probe_instruction( + skill_content=skill_content, + items=selected_results, + prediction_dir=prediction_dir, + system_prompt=adapter.get_deep_probe_prompt(), + step_buffer_context=step_buffer_context, + output_requirements=output_requirements, + ) + if not probe: + return [] + + with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f: + json.dump( + { + **probe, + "reference_summary": { + "mode": "no_reference_probe", + "selected_count": len(selected_items), + }, + "selected_examples": selected_metadata, + }, + f, + ensure_ascii=False, + indent=2, + ) + + deep_results = adapter.rollout( + selected_items, + skill_content, + rollout_dir, + diagnostic_mode=True, + diagnostic_instruction=probe["probe_instruction"], + ) + return run_minibatch_reflect( + results=deep_results, + skill_content=skill_content, + prediction_dir=os.path.join(rollout_dir, "predictions"), + patches_dir=patches_dir, + workers=getattr(adapter, "analyst_workers", 8), + failure_only=getattr(adapter, "failure_only", False), + minibatch_size=getattr(adapter, "minibatch_size", 8), + edit_budget=getattr(adapter, "edit_budget", 4), + random_seed=random_seed, + error_system=adapter.get_error_minibatch_prompt(), + success_system=adapter.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + update_mode=getattr(getattr(adapter, "_cfg", {}), "get", lambda *_: "patch")( + "skill_update_mode", + "patch", + ), + ) diff --git a/skillopt/envs/docvqa/__init__.py b/skillopt/envs/docvqa/__init__.py new file mode 100644 index 0000000..38c999d --- /dev/null +++ b/skillopt/envs/docvqa/__init__.py @@ -0,0 +1 @@ +"""DocVQA environment package for ReflACT.""" diff --git a/skillopt/envs/docvqa/adapter.py b/skillopt/envs/docvqa/adapter.py new file mode 100644 index 0000000..f693176 --- /dev/null +++ b/skillopt/envs/docvqa/adapter.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import os + +from skillopt.datasets.base import BatchSpec +from skillopt.envs.base import EnvAdapter +from skillopt.envs.deep_reflect import run_no_reference_deep_reflect +from skillopt.envs.docvqa.dataloader import DocVQADataLoader +from skillopt.envs.docvqa.rollout import run_batch +from skillopt.gradient.reflect import run_minibatch_reflect + + +class DocVQAAdapter(EnvAdapter): + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "split_dir", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + max_turns: int = 1, + exec_timeout: int = 120, + workers: int = 16, + analyst_workers: int = 16, + failure_only: bool = False, + minibatch_size: int = 8, + edit_budget: int = 4, + seed: int = 42, + limit: int = 0, + exec_timeout: int = 600, + image_detail: str = "auto", + use_deep_reflect: bool = False, + deep_reflect_failures: int = 4, + deep_reflect_successes: int = 2, + ) -> None: + self.max_turns = max_turns + self.exec_timeout = exec_timeout + self.workers = workers + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.exec_timeout = exec_timeout + self.image_detail = image_detail + self.use_deep_reflect = use_deep_reflect + self.deep_reflect_failures = deep_reflect_failures + self.deep_reflect_successes = deep_reflect_successes + self.dataloader = DocVQADataLoader( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + ) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + self.dataloader.setup(cfg) + + def get_dataloader(self): + return self.dataloader + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + return list(batch.payload or []) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout(self, env_manager, skill_content: str, out_dir: str, **kwargs) -> list[dict]: + items: list[dict] = env_manager + return run_batch( + items=items, + out_root=out_dir, + skill_content=skill_content, + max_turns=self.max_turns, + exec_timeout=self.exec_timeout, + workers=self.workers, + image_detail=self.image_detail, + diagnostic_mode=kwargs.get("diagnostic_mode", False), + diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), + task_timeout=self.exec_timeout, + ) + + def reflect(self, results: list[dict], skill_content: str, out_dir: str, **kwargs) -> list[dict | None]: + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + return run_no_reference_deep_reflect( + self, + results, + skill_content, + out_dir, + env_manager=kwargs.get("env_manager"), + prediction_dir=kwargs.get("prediction_dir"), + random_seed=kwargs.get("random_seed"), + step_buffer_context=kwargs.get("step_buffer_context", ""), + output_requirements=[ + "- There is no hidden reference block. Use only the document image prompt, student output, and evaluation result to infer what intermediate state is worth probing.", + "- The instruction must explicitly request a short ... block before the final ....", + "- The readout should focus on visual region, field/table/figure label, OCR text read, candidate answer, and answer-format normalization.", + "- Do not ask for exhaustive transcription or a full chain-of-thought.", + "- The instruction text should be ready to append directly to the student's prompt.", + ], + metadata_builder=lambda item: { + "id": str(item.get("id")), + "task_type": str(item.get("task_type") or "docvqa"), + "question_preview": str(item.get("question") or "")[:200], + "image_path": item.get("image_path", ""), + "docId": item.get("docId", ""), + "page": item.get("ucsf_document_page_no", ""), + }, + ) + + def get_task_types(self) -> list[str]: + seen: list[str] = [] + for item in self.dataloader.train_items + self.dataloader.val_items + self.dataloader.test_items: + task_type = str(item.get("task_type") or "docvqa") + if task_type not in seen: + seen.append(task_type) + return seen or ["docvqa"] diff --git a/skillopt/envs/docvqa/dataloader.py b/skillopt/envs/docvqa/dataloader.py new file mode 100644 index 0000000..212f0ef --- /dev/null +++ b/skillopt/envs/docvqa/dataloader.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import ast +import csv +from pathlib import Path + +from skillopt.datasets.base import SplitDataLoader + + +def _parse_answers(raw: str) -> list[str]: + text = str(raw or "").strip() + if not text: + return [] + try: + parsed = ast.literal_eval(text) + except Exception: + return [text] + if isinstance(parsed, list): + return [str(item).strip() for item in parsed if str(item).strip()] + return [str(parsed).strip()] + + +def _extract_document_path(question: str) -> tuple[str, str]: + marker = "document_path:" + if marker not in question: + return question.strip(), "" + main, tail = question.split(marker, 1) + return main.strip(), tail.strip() + + +def _normalize_row(row: dict[str, str]) -> dict: + question_text, document_path = _extract_document_path(str(row.get("question") or "")) + answers = _parse_answers(row.get("answer") or row.get("ground_truth") or "") + image_path = str(row.get("image_path") or document_path or "").strip() + task_type = str(row.get("topic") or row.get("category") or "docvqa").strip() or "docvqa" + return { + "id": str(row.get("questionId") or row.get("id") or "").strip(), + "question": question_text, + "answer": answers[0] if answers else "", + "answers": answers, + "task_type": task_type, + "subtask": task_type, + "image_paths": [image_path] if image_path else [], + "image_path": image_path, + "questionId": str(row.get("questionId") or "").strip(), + "docId": str(row.get("docId") or "").strip(), + "ucsf_document_id": str(row.get("ucsf_document_id") or "").strip(), + "ucsf_document_page_no": str(row.get("ucsf_document_page_no") or "").strip(), + "source_split": str(row.get("source_split") or "").strip(), + } + + +class DocVQADataLoader(SplitDataLoader): + def load_split_items(self, split_path: str) -> list[dict]: + path = Path(split_path) + csv_files = sorted(path.glob("*.csv")) + if not csv_files: + raise FileNotFoundError(f"No .csv file found in {split_path}") + with csv_files[0].open(encoding="utf-8", newline="") as f: + reader = csv.DictReader(f) + return [_normalize_row(row) for row in reader] diff --git a/skillopt/envs/docvqa/evaluator.py b/skillopt/envs/docvqa/evaluator.py new file mode 100644 index 0000000..85c09ef --- /dev/null +++ b/skillopt/envs/docvqa/evaluator.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import ast +import json +from collections.abc import Iterable +from typing import Any + +DEFAULT_ANLS_THRESHOLD = 0.5 + + +def _normalize_text(value: Any) -> str: + if value is None: + return "" + text = str(value).strip().lower() + return " ".join(text.split()) + + +def _levenshtein_distance(a: str, b: str) -> int: + if a == b: + return 0 + if not a: + return len(b) + if not b: + return len(a) + if len(a) > len(b): + a, b = b, a + previous = list(range(len(b) + 1)) + for i, char_a in enumerate(a, start=1): + current = [i] + for j, char_b in enumerate(b, start=1): + insert_cost = current[j - 1] + 1 + delete_cost = previous[j] + 1 + replace_cost = previous[j - 1] + (char_a != char_b) + current.append(min(insert_cost, delete_cost, replace_cost)) + previous = current + return previous[-1] + + +def _score_single_answer(predicted: Any, target: Any, threshold: float) -> float: + predicted_norm = _normalize_text(predicted) + target_norm = _normalize_text(target) + if not predicted_norm and not target_norm: + return 1.0 + if not predicted_norm or not target_norm: + return 0.0 + distance = _levenshtein_distance(predicted_norm, target_norm) + normalized_distance = distance / max(len(predicted_norm), len(target_norm)) + if normalized_distance >= threshold: + return 0.0 + return 1.0 - normalized_distance + + +def _extract_answer_strings(raw: Any) -> list[str]: + if raw is None: + return [""] + if isinstance(raw, str): + text = raw.strip() + if not text: + return [""] + parsed = None + if text[0] in "[{": + try: + parsed = json.loads(text) + except json.JSONDecodeError: + try: + parsed = ast.literal_eval(text) + except (ValueError, SyntaxError): + parsed = None + if parsed is None: + return [text] + return _extract_answer_strings(parsed) + if isinstance(raw, dict): + for key in ("answers", "ground_truth", "answer"): + if key in raw: + return _extract_answer_strings(raw[key]) + return [str(raw)] + if isinstance(raw, Iterable) and not isinstance(raw, (bytes, bytearray)): + answers: list[str] = [] + for item in raw: + if isinstance(item, dict): + for key in ("text", "answer", "value"): + if key in item: + answers.extend(_extract_answer_strings(item[key])) + break + else: + answers.append(str(item)) + continue + answers.append(str(item)) + return answers or [""] + return [str(raw)] + + +def extract_answer(text: str) -> str: + lower = text.lower() + start = lower.rfind("") + end = lower.rfind("") + if start != -1 and end != -1 and end > start: + return text[start + len(""):end].strip() + lines = [line.strip() for line in text.splitlines() if line.strip()] + return lines[-1] if lines else text.strip() + + +def evaluate(prediction_text: str, gold_answers: Any) -> dict: + answer = extract_answer(prediction_text) + answers = _extract_answer_strings(gold_answers) + score = 0.0 + for target in answers: + score = max(score, _score_single_answer(answer, target, DEFAULT_ANLS_THRESHOLD)) + return { + "anls": score, + "predicted_answer": answer, + "gold_answers": answers, + } diff --git a/skillopt/envs/docvqa/prompts/analyst_error.md b/skillopt/envs/docvqa/prompts/analyst_error.md new file mode 100644 index 0000000..9f6c367 --- /dev/null +++ b/skillopt/envs/docvqa/prompts/analyst_error.md @@ -0,0 +1,35 @@ +You are an expert failure-analysis agent for visual document question answering tasks. + +You will be given MULTIPLE failed DocVQA trajectories from a single minibatch and the current skill document. Each trajectory includes the model response and an evaluation result scored with ANLS against one or more acceptable answers. + +Your job is to identify the most important COMMON failure patterns across the batch and propose concise skill edits. + +## Failure Type Categories +- evidence_miss: the model overlooked the relevant visible region or line +- near_match_confusion: the model selected a nearby but incorrect text span +- normalization_error: the answer differed mainly in formatting, spacing, punctuation, or minor text normalization +- reading_error: the model misread the document content +- other: none of the above + +## Rules +- Focus on common, reusable reading and extraction behaviors. +- Do not hardcode image-specific answers. +- Prefer concise edits that improve evidence selection and exact span extraction. + +Respond ONLY with a valid JSON object (no markdown fences, no extra text): +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +Only include edits that are needed. "edits" can be an empty list if no patch is warranted. diff --git a/skillopt/envs/docvqa/prompts/analyst_success.md b/skillopt/envs/docvqa/prompts/analyst_success.md new file mode 100644 index 0000000..2ce71d8 --- /dev/null +++ b/skillopt/envs/docvqa/prompts/analyst_success.md @@ -0,0 +1,24 @@ +You are an expert success-pattern analyst for visual document question answering tasks. + +You will be given MULTIPLE successful DocVQA trajectories from a single minibatch and the current skill document. Your job is to identify common visual reading and exact-answer extraction behaviors worth encoding in the skill. + +## Rules +- Focus on patterns shared across multiple successful trajectories. +- Reinforce reusable behaviors like locating the right region, copying exact spans, and preferring the shortest exact answer over paraphrase. +- Only propose patches for patterns not already captured by the current skill. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +"edits" may be empty if the skill already covers all observed patterns. diff --git a/skillopt/envs/docvqa/prompts/rollout_system.md b/skillopt/envs/docvqa/prompts/rollout_system.md new file mode 100644 index 0000000..e859c02 --- /dev/null +++ b/skillopt/envs/docvqa/prompts/rollout_system.md @@ -0,0 +1,12 @@ +You are an expert visual document question answering agent. + +{skill_section}You will receive a document image and a question about the document. +Read the visual evidence carefully and answer concisely. + +Rules: +- Ground the answer in the visible document content. +- Prefer exact spans, numbers, dates, and names from the document. +- Do not invent content that is not visible. +- If multiple near-matches exist, choose the one best supported by the document. + +Return the final answer inside .... diff --git a/skillopt/envs/docvqa/rollout.py b/skillopt/envs/docvqa/rollout.py new file mode 100644 index 0000000..e53fbeb --- /dev/null +++ b/skillopt/envs/docvqa/rollout.py @@ -0,0 +1,365 @@ +from __future__ import annotations + +import json +import os +import time +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait + +from skillopt.envs.docvqa.evaluator import evaluate +from skillopt.model import chat_student_messages, get_student_backend, is_student_exec_backend +from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec +from skillopt.prompts import load_prompt + + +def _build_system(skill_content: str) -> str: + if skill_content.strip(): + skill_section = f"## Skill\n{skill_content.strip()}\n\n" + else: + skill_section = "" + return load_prompt("rollout_system", env="docvqa").format(skill_section=skill_section) + + +def _image_to_data_uri(path: str) -> str: + import base64 + import mimetypes + + mime = mimetypes.guess_type(path)[0] or "image/png" + with open(path, "rb") as f: + encoded = base64.b64encode(f.read()).decode("ascii") + return f"data:{mime};base64,{encoded}" + + +def _build_messages( + item: dict, + skill_content: str, + image_detail: str, + *, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", +) -> tuple[list[dict], str, str]: + system = _build_system(skill_content) + user_text = item["question"] + "\n\nReturn the final answer inside ...." + if diagnostic_mode and diagnostic_instruction.strip(): + user_text += f"\n\n## Training Readout\n{diagnostic_instruction.strip()}" + image_url = {"url": _image_to_data_uri(item["image_path"])} + if image_detail and image_detail != "auto": + image_url["detail"] = image_detail + messages = [ + {"role": "system", "content": system}, + { + "role": "user", + "content": [ + {"type": "text", "text": user_text}, + {"type": "image_url", "image_url": image_url}, + ], + }, + ] + return messages, system, user_text + + +def _build_codex_skill(skill_content: str) -> str: + return render_skill_md( + skill_content, + description="Dynamic ReflACT skill for solving the current DocVQA document-image question.", + preamble=( + "Use this skill when answering the current DocVQA question.\n" + "Inspect the attached document image carefully and return the final answer inside ...." + ), + ) + + +def _run_codex_once( + *, + pred_dir: str, + item: dict, + skill_content: str, + model: str, + timeout: int, + image_detail: str, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + previous_response: str = "", +) -> tuple[str, str, str, str]: + _ = image_detail + _messages, _system, user_text = _build_messages( + item, + skill_content, + image_detail, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + ) + task_parts = [user_text] + image_abs = os.path.abspath(item["image_path"]) + task_parts.append( + "## Document Image\n" + "The document image is available in this workspace via `ATTACHMENTS.md`.\n" + f"Original image path: `{image_abs}`\n" + "Open or inspect that image before answering; do not answer from memory." + ) + if previous_response: + task_parts.append( + "## Previous Attempt\n" + f"{previous_response}\n\n" + "Review the same document image carefully and correct the answer if needed." + ) + task_text = "\n\n".join(task_parts) + skill_md = _build_codex_skill(skill_content) + work_dir = os.path.join(pred_dir, "codex_exec") + prepare_workspace( + work_dir=work_dir, + skill_md=skill_md, + task_text=task_text, + images=[item["image_path"]], + ) + prompt = ( + "Use the `skillopt-student` skill available in this workspace.\n" + "Read `task.md`, inspect the attached document image, and answer the DocVQA question.\n" + "Return the final answer inside ...." + ) + final_message, raw = run_student_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + images=[item["image_path"]], + ) + return final_message or raw, raw, skill_md, task_text + + +def process_one( + item: dict, + out_root: str, + skill_content: str, + *, + max_turns: int = 1, + exec_timeout: int = 120, + image_detail: str = "auto", + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", +) -> dict: + item_id = str(item["id"]) + result = { + "id": item_id, + "question": item["question"], + "task_type": item.get("subtask") or item.get("task_type") or "docvqa", + "task_description": item["question"], + "hard": 0, + "soft": 0.0, + "predicted_answer": "", + "response": "", + "fail_reason": "", + "agent_ok": False, + "n_turns": 0, + "image_paths": item.get("image_paths", []), + "gold_answer": item.get("answers", []), + } + try: + response = "" + system_prompt = "" + user_text = "" + conversation: list[dict] = [] + if is_student_exec_backend(): + from skillopt.model import azure_openai as _llm + + conversation = [ + { + "role": "user", + "content": item["question"] + "\n\n" + f"[image] {os.path.basename(item['image_path'])}", + } + ] + for turn in range(max_turns): + response, _raw, system_prompt, user_text = _run_codex_once( + pred_dir=os.path.join(out_root, "predictions", item_id), + item=item, + skill_content=skill_content, + model=_llm.STUDENT_DEPLOYMENT, + timeout=exec_timeout, + image_detail=image_detail, + diagnostic_mode=diagnostic_mode if turn == 0 else False, + diagnostic_instruction=diagnostic_instruction if turn == 0 else "", + previous_response=response if turn > 0 else "", + ) + conversation.append({"type": "message", "turn": turn + 1, "content": response}) + if "" in response.lower(): + break + else: + messages, system_prompt, user_text = _build_messages( + item, + skill_content, + image_detail, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + ) + conversation = [ + { + "role": "user", + "content": user_text + "\n\n" + f"[image] {os.path.basename(item['image_path'])}", + } + ] + for turn in range(max_turns): + if turn == 0: + resp_text, _ = chat_student_messages( + messages=messages, + max_completion_tokens=768, + retries=5, + stage="rollout", + timeout=exec_timeout, + ) + else: + refinement_messages = [ + messages[0], + messages[1], + {"role": "assistant", "content": response}, + {"role": "user", "content": "Review the same image carefully and answer again. Keep the final answer inside ...."}, + ] + resp_text, _ = chat_student_messages( + messages=refinement_messages, + max_completion_tokens=512, + retries=5, + stage="rollout", + timeout=exec_timeout, + ) + response = resp_text + conversation.append({"type": "message", "turn": turn + 1, "content": resp_text}) + if "" in resp_text.lower(): + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) - 1 + + pred_dir = os.path.join(out_root, "predictions", item_id) + os.makedirs(pred_dir, exist_ok=True) + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f: + f.write(system_prompt) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f: + f.write(user_text) + + eval_result = evaluate(response, item.get("answers", [])) + result["predicted_answer"] = eval_result["predicted_answer"] + result["hard"] = int(eval_result["anls"] >= 0.999) + result["soft"] = eval_result["anls"] + if result["soft"] <= 0.0: + result["fail_reason"] = f"predicted '{eval_result['predicted_answer']}' but expected one of {item.get('answers', [])}" + + eval_detail = ( + "[EVALUATION RESULT]\n" + f"Question: {item['question']}\n" + f"Predicted answer: {eval_result['predicted_answer']!r}\n" + f"Gold answers: {item.get('answers', [])!r}\n" + f"ANLS: {eval_result['anls']:.4f}" + ) + conversation.append({"role": "system", "content": eval_detail}) + with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + except Exception as e: # noqa: BLE001 + result["fail_reason"] = f"error: {e}" + return result + + +def run_batch( + items: list[dict], + out_root: str, + skill_content: str, + *, + max_turns: int = 1, + exec_timeout: int = 120, + workers: int = 16, + image_detail: str = "auto", + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + task_timeout: int = 600, +) -> list[dict]: + task_timeout = max(int(task_timeout), int(exec_timeout) + 60) + results_path = os.path.join(out_root, "results.jsonl") + os.makedirs(out_root, exist_ok=True) + + done_ids: set[str] = set() + existing: list[dict] = [] + if os.path.exists(results_path): + with open(results_path, encoding="utf-8") as f: + for line in f: + try: + row = json.loads(line) + except Exception: + continue + done_ids.add(str(row["id"])) + existing.append(row) + + pending = [item for item in items if str(item["id"]) not in done_ids] + if not pending: + return existing + + def _timeout_result(item: dict) -> dict: + return { + "id": str(item["id"]), + "question": item.get("question", ""), + "task_type": item.get("subtask") or item.get("task_type") or "docvqa", + "task_description": item.get("question", ""), + "hard": 0, + "soft": 0.0, + "predicted_answer": "", + "response": "", + "fail_reason": f"task-timeout-{task_timeout}s", + "agent_ok": False, + "n_turns": 0, + "image_paths": item.get("image_paths", []), + "gold_answer": item.get("answers", []), + "phase": "timeout", + } + + def _error_result(item: dict, exc: Exception) -> dict: + row = _timeout_result(item) + row["phase"] = "error" + row["fail_reason"] = f"unexpected: {type(exc).__name__}: {exc}" + return row + + started_at: dict[str, float] = {} + + def _run_one(item: dict) -> dict: + started_at[str(item["id"])] = time.time() + return process_one( + item, + out_root, + skill_content, + max_turns=max_turns, + exec_timeout=exec_timeout, + image_detail=image_detail, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + ) + + results = list(existing) + with open(results_path, "a", encoding="utf-8") as outf: + ex = ThreadPoolExecutor(max_workers=workers) + try: + futs = {ex.submit(_run_one, item): item for item in pending} + pending_futs = set(futs) + while pending_futs: + done, _ = wait(pending_futs, timeout=5, return_when=FIRST_COMPLETED) + now = time.time() + timed_out = [ + fut for fut in pending_futs - done + if str(futs[fut]["id"]) in started_at + and now - started_at[str(futs[fut]["id"])] >= task_timeout + ] + for fut in done: + pending_futs.remove(fut) + item = futs[fut] + try: + res = fut.result() + except Exception as exc: # noqa: BLE001 + res = _error_result(item, exc) + results.append(res) + outf.write(json.dumps(res, ensure_ascii=False) + "\n") + outf.flush() + for fut in timed_out: + pending_futs.remove(fut) + fut.cancel() + res = _timeout_result(futs[fut]) + results.append(res) + outf.write(json.dumps(res, ensure_ascii=False) + "\n") + outf.flush() + finally: + ex.shutdown(wait=False, cancel_futures=True) + return results diff --git a/skillopt/envs/docvqa/skills/initial.md b/skillopt/envs/docvqa/skills/initial.md new file mode 100644 index 0000000..806fbe6 --- /dev/null +++ b/skillopt/envs/docvqa/skills/initial.md @@ -0,0 +1,11 @@ +# DocVQA Skill + +## Visual Evidence Discipline +- Read the document carefully before answering. +- Prefer the smallest exact text span that answers the question. +- When several nearby strings look similar, choose the one whose surrounding labels or layout best match the question. + +## Exact Answer Discipline +- Copy names, numbers, and dates exactly from the document whenever possible. +- Prefer direct extraction over paraphrase. +- Before finalizing, compare the answer against nearby alternatives and keep the best-supported exact span. diff --git a/skillopt/envs/livemathematicianbench/__init__.py b/skillopt/envs/livemathematicianbench/__init__.py new file mode 100644 index 0000000..bcc2138 --- /dev/null +++ b/skillopt/envs/livemathematicianbench/__init__.py @@ -0,0 +1 @@ +"""LiveMathematicianBench environment package for ReflACT.""" diff --git a/skillopt/envs/livemathematicianbench/adapter.py b/skillopt/envs/livemathematicianbench/adapter.py new file mode 100644 index 0000000..aca03b8 --- /dev/null +++ b/skillopt/envs/livemathematicianbench/adapter.py @@ -0,0 +1,284 @@ +"""LiveMathematicianBench environment adapter for ReflACT.""" +from __future__ import annotations + +import json +import os + +from skillopt.gradient.deep_probe import generate_deep_probe_instruction +from skillopt.datasets.base import BatchSpec +from skillopt.gradient.reflect import run_minibatch_reflect +from skillopt.envs.base import EnvAdapter +from skillopt.envs.livemathematicianbench.dataloader import LiveMathematicianBenchDataLoader +from skillopt.envs.livemathematicianbench.rollout import run_batch +from skillopt.model import get_student_backend + + +class LiveMathematicianBenchAdapter(EnvAdapter): + """LiveMathematicianBench adapter.""" + + def build_reference_text(self, item: dict) -> str: + parts: list[str] = [] + theorem = str(item.get("theorem") or "").strip() + sketch = str(item.get("sketch") or "").strip() + if theorem: + parts.append(f"## Reference Theorem\n{theorem}") + if sketch: + parts.append(f"## Reference Sketch\n{sketch}") + return "\n\n".join(parts) + + def get_reference_metadata(self, item: dict) -> dict: + fields: list[str] = [] + previews: list[str] = [] + theorem = str(item.get("theorem") or "").strip() + sketch = str(item.get("sketch") or "").strip() + if theorem: + fields.append("theorem") + previews.append(f"[theorem]\n{theorem[:220]}") + if sketch: + fields.append("sketch") + previews.append(f"[sketch]\n{sketch[:220]}") + return { + "fields": fields, + "preview": "\n\n".join(previews)[:500], + } + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + max_turns: int = 1, + exec_timeout: int = 300, + workers: int = 64, + analyst_workers: int = 16, + failure_only: bool = False, + minibatch_size: int = 8, + edit_budget: int = 4, + seed: int = 42, + limit: int = 0, + shuffle_choices: bool = True, + use_theorem: bool = False, + use_sketch: bool = False, + exec_timeout: int = 600, + use_deep_reflect: bool = False, + deep_reflect_failures: int = 4, + deep_reflect_successes: int = 2, + ) -> None: + self.max_turns = max_turns + self.exec_timeout = exec_timeout + self.workers = workers + self.exec_timeout = exec_timeout + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.use_theorem = use_theorem + self.use_sketch = use_sketch + self.use_deep_reflect = use_deep_reflect + self.deep_reflect_failures = deep_reflect_failures + self.deep_reflect_successes = deep_reflect_successes + self.dataloader = LiveMathematicianBenchDataLoader( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + shuffle_choices=shuffle_choices, + ) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + self.dataloader.setup(cfg) + + def get_dataloader(self): + return self.dataloader + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + return list(batch.payload or []) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout( + self, + env_manager, + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict]: + items: list[dict] = env_manager + return run_batch( + items=items, + out_root=out_dir, + skill_content=skill_content, + max_turns=self.max_turns, + exec_timeout=self.exec_timeout, + workers=self.workers, + use_theorem=self.use_theorem, + use_sketch=self.use_sketch, + diagnostic_mode=kwargs.get("diagnostic_mode", False), + diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), + diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"), + task_timeout=self.exec_timeout, + ) + + def reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + if not self.use_deep_reflect: + return [] + + env_manager = kwargs.get("env_manager") + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + codex_backend = get_student_backend() == "codex_exec" + selected_items = self.select_representative_items( + results, + env_manager if isinstance(env_manager, list) else None, + n_failures=self.deep_reflect_failures, + n_successes=self.deep_reflect_successes, + seed=random_seed, + ) + if not selected_items: + return [] + selected_ids = {str(item["id"]) for item in selected_items} + selected_results = [row for row in results if str(row.get("id")) in selected_ids] + selected_examples = self.attach_reference_context(selected_results, selected_items) + if codex_backend: + selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir) + selected_metadata = [] + theorem_count = 0 + sketch_count = 0 + for item in selected_items: + meta = self.get_reference_metadata(item) + if "theorem" in meta["fields"]: + theorem_count += 1 + if "sketch" in meta["fields"]: + sketch_count += 1 + selected_metadata.append({ + "id": str(item["id"]), + "task_type": str(item.get("theorem_type", ["math_mcq"])[0] if item.get("theorem_type") else "math_mcq"), + "reference_fields": meta["fields"], + "reference_preview": meta["preview"], + }) + + deep_dir = os.path.join(out_dir, "deep_reflect") + rollout_dir = os.path.join(deep_dir, "rollout") + patches_dir = os.path.join(deep_dir, "patches") + os.makedirs(deep_dir, exist_ok=True) + print( + f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} " + f"reference_fields=theorem({theorem_count}/{len(selected_items)})," + f"sketch({sketch_count}/{len(selected_items)})" + ) + probe = generate_deep_probe_instruction( + skill_content=skill_content, + items=selected_examples, + prediction_dir=prediction_dir, + system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + ) + if not probe: + return [] + diagnostic_trace_context_by_id = None + if codex_backend: + selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target( + selected_items=selected_items, + selected_examples=selected_examples, + prediction_dir=prediction_dir, + probe=probe, + ) + probe_record = { + **probe, + "reference_summary": { + "selected_count": len(selected_items), + "field_counts": { + "theorem": theorem_count, + "sketch": sketch_count, + }, + }, + "selected_examples": selected_metadata, + } + with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f: + json.dump(probe_record, f, ensure_ascii=False, indent=2) + deep_results = run_batch( + items=selected_items, + out_root=rollout_dir, + skill_content=skill_content, + max_turns=self.max_turns, + workers=min(self.workers, max(len(selected_items), 1)), + use_theorem=self.use_theorem, + use_sketch=self.use_sketch, + diagnostic_mode=True, + diagnostic_instruction=probe["probe_instruction"], + diagnostic_trace_context_by_id=diagnostic_trace_context_by_id, + task_timeout=self.exec_timeout, + ) + deep_results = self.attach_reference_context(deep_results, selected_items) + return run_minibatch_reflect( + results=deep_results, + skill_content=skill_content, + prediction_dir=os.path.join(rollout_dir, "predictions"), + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def get_task_types(self) -> list[str]: + return self.dataloader.get_task_types() diff --git a/skillopt/envs/livemathematicianbench/dataloader.py b/skillopt/envs/livemathematicianbench/dataloader.py new file mode 100644 index 0000000..3ab53f5 --- /dev/null +++ b/skillopt/envs/livemathematicianbench/dataloader.py @@ -0,0 +1,308 @@ +"""LiveMathematicianBench task dataloader.""" +from __future__ import annotations + +import glob +import hashlib +import json +import os +import random +from typing import Any + +from skillopt.datasets.base import BatchSpec, SplitDataLoader + + +# ── Raw data loading utilities (for preprocessing / standalone eval) ───── + +_CHOICE_LABELS = ["A", "B", "C", "D", "E", "F", "G"] + + +def _load_json(path: str) -> Any: + with open(path) as f: + return json.load(f) + + +def _iter_monthly_files(data_path: str) -> list[str]: + if not data_path: + return [] + if os.path.isfile(data_path): + return [data_path] + if os.path.isdir(data_path): + nested = glob.glob( + os.path.join(data_path, "**", "qa_*_final.json"), + recursive=True, + ) + flat = glob.glob(os.path.join(data_path, "qa_*_final.json")) + return sorted(set(nested + flat)) + return [] + + +def _coerce_choices(raw_choices: Any) -> list[dict]: + if isinstance(raw_choices, list): + choices: list[dict] = [] + for idx, item in enumerate(raw_choices): + if isinstance(item, dict): + label = str(item.get("label") or _CHOICE_LABELS[idx]).strip() + text = str(item.get("text") or item.get("content") or "").strip() + else: + label = _CHOICE_LABELS[idx] + text = str(item).strip() + if text: + choices.append({"label": label, "text": text}) + return choices + + if isinstance(raw_choices, dict): + labels = sorted(raw_choices.keys()) + return [ + {"label": str(label).strip(), "text": str(raw_choices[label]).strip()} + for label in labels + if str(raw_choices[label]).strip() + ] + + return [] + + +def _coerce_theorem_types(raw: Any) -> list[str]: + if isinstance(raw, list): + return [str(x).strip() for x in raw if str(x).strip()] + if raw is None: + return [] + text = str(raw).strip() + return [text] if text else [] + + +def _normalize_label(text: str) -> str: + return str(text).strip().upper().rstrip(".):") + + +def _normalize_item(item: dict, row_idx: int, source_path: str) -> dict: + mcq = item.get("mcq", {}) if isinstance(item.get("mcq"), dict) else {} + question = str(mcq.get("question") or item.get("question") or "").strip() + choices = _coerce_choices(mcq.get("choices") or item.get("choices") or []) + correct = mcq.get("correct_choice") or item.get("correct_choice") or {} + + if isinstance(correct, dict): + correct_label = _normalize_label(correct.get("label", "")) + correct_text = str(correct.get("text") or "").strip() + else: + correct_label = _normalize_label(correct) + correct_text = "" + + choice_by_label = { + _normalize_label(choice["label"]): choice["text"] + for choice in choices + } + if correct_label and not correct_text: + correct_text = choice_by_label.get(correct_label, "") + if correct_label and correct_text and correct_label not in choice_by_label: + choices.append({"label": correct_label, "text": correct_text}) + choices.sort(key=lambda choice: _CHOICE_LABELS.index(choice["label"]) if choice["label"] in _CHOICE_LABELS else len(_CHOICE_LABELS)) + choice_by_label[correct_label] = correct_text + + month = str(item.get("month") or "").strip() + item_no = item.get("no", row_idx + 1) + item_id = f"{month}:{item_no}" if month else str(item_no) + + return { + "id": item_id, + "month": month, + "no": item_no, + "paper_link": str(item.get("paper_link") or "").strip(), + "theorem": str(item.get("theorem") or "").strip(), + "sketch": str(item.get("sketch") or "").strip(), + "theorem_type": _coerce_theorem_types(item.get("theorem_type")), + "question": question, + "choices": choices, + "correct_choice": { + "label": correct_label, + "text": correct_text, + }, + "source_path": source_path, + } + + +def load_items(data_path: str) -> list[dict]: + """Load and normalise LiveMathematicianBench items from JSON files.""" + files = _iter_monthly_files(data_path) + if not files: + raise ValueError( + "LiveMathematicianBench requires data_path to be a qa_*_final.json file " + "or a directory containing monthly qa_*_final.json files." + ) + + items: list[dict] = [] + for path in files: + raw = _load_json(path) + if not isinstance(raw, list): + raise ValueError(f"Expected JSON array in {path}, got {type(raw).__name__}") + for row_idx, item in enumerate(raw): + norm = _normalize_item(item, row_idx=row_idx, source_path=path) + if norm["question"] and norm["choices"] and norm["correct_choice"]["label"]: + items.append(norm) + if not items: + raise ValueError(f"No valid LiveMathematicianBench items loaded from {data_path}") + return items + + +# ── Dataloader ─────────────────────────────────────────────────────────── + +class LiveMathematicianBenchDataLoader(SplitDataLoader): + """LiveMathematicianBench dataloader with per-seed choice shuffling.""" + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + seed: int = 42, + limit: int = 0, + shuffle_choices: bool = True, + **kwargs, + ) -> None: + super().__init__( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + ) + self.shuffle_choices = shuffle_choices + self._task_types: list[str] = [] + + def load_raw_items(self, data_path: str) -> list[dict]: + return load_items(data_path) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + all_items = self.train_items + self.val_items + self.test_items + task_types: set[str] = set() + for item in all_items: + for name in item.get("theorem_type", []): + if name: + task_types.add(name) + self._task_types = sorted(task_types) + + def get_task_types(self) -> list[str]: + return list(self._task_types) + + # ── Choice shuffling ───────────────────────────────────────────────── + + @staticmethod + def _item_shuffle_seed(item_id: str, seed: int) -> int: + digest = hashlib.sha256(f"{seed}:{item_id}".encode("utf-8")).hexdigest() + return int(digest[:16], 16) + + def _shuffle_item_choices(self, item: dict, seed: int) -> dict: + if not self.shuffle_choices: + return { + **item, + "choices": [dict(c) for c in item["choices"]], + "correct_choice": dict(item["correct_choice"]), + } + + shuffled_choices = [dict(c) for c in item["choices"]] + rng = random.Random(self._item_shuffle_seed(str(item["id"]), seed)) + rng.shuffle(shuffled_choices) + + original_correct = _normalize_label(item["correct_choice"]["label"]) + remapped_choices: list[dict] = [] + new_correct_choice = dict(item["correct_choice"]) + + for idx, choice in enumerate(shuffled_choices): + new_label = _CHOICE_LABELS[idx] + old_label = _normalize_label(choice["label"]) + remapped_choices.append({"label": new_label, "text": choice["text"]}) + if old_label == original_correct: + new_correct_choice = {"label": new_label, "text": choice["text"]} + + transformed = dict(item) + transformed["choices"] = remapped_choices + transformed["correct_choice"] = new_correct_choice + return transformed + + def _materialize_batch(self, items: list[dict], seed: int) -> list[dict]: + return [self._shuffle_item_choices(item, seed) for item in items] + + # ── Batch construction (override for choice shuffling) ─────────────── + + def plan_train_epoch( + self, + *, + epoch: int, + steps_per_epoch: int, + accumulation: int, + batch_size: int, + seed: int, + **kwargs, + ) -> list[BatchSpec]: + """Build a shuffled epoch while preserving per-batch choice shuffling.""" + epoch_rng = random.Random(seed + epoch * 1000) + items = list(self.train_items) + epoch_rng.shuffle(items) + + total_batches = steps_per_epoch * accumulation + if total_batches <= 0: + return [] + + batches: list[BatchSpec] = [] + cursor = 0 + for batch_idx in range(total_batches): + batch_seed = seed + epoch * 1000 + batch_idx + 1 + batch_items = items[cursor: cursor + batch_size] + cursor += len(batch_items) + + if not batch_items and items: + refill_rng = random.Random(batch_seed) + batch_items = list(items) + refill_rng.shuffle(batch_items) + batch_items = batch_items[:batch_size] + + batch_items = self._materialize_batch(batch_items, batch_seed) + batches.append( + BatchSpec( + phase="train", + split="train", + seed=batch_seed, + batch_size=len(batch_items), + payload=batch_items, + ) + ) + + return batches + + def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec: + rng = random.Random(seed) + items = list(self.train_items) + rng.shuffle(items) + items = self._materialize_batch(items[:batch_size], seed) + return BatchSpec( + phase="train", + split="train", + seed=seed, + batch_size=len(items), + payload=items, + ) + + def build_eval_batch( + self, + env_num: int, + split: str, + seed: int, + **kwargs, + ) -> BatchSpec: + items = self.get_split_items(split) + if env_num and env_num < len(items): + items = items[:env_num] + items = self._materialize_batch(items, seed) + return BatchSpec( + phase="eval", + split=split, + seed=seed, + batch_size=len(items), + payload=items, + ) diff --git a/skillopt/envs/livemathematicianbench/evaluator.py b/skillopt/envs/livemathematicianbench/evaluator.py new file mode 100644 index 0000000..d15db3e --- /dev/null +++ b/skillopt/envs/livemathematicianbench/evaluator.py @@ -0,0 +1,62 @@ +"""LiveMathematicianBench evaluation helpers.""" +from __future__ import annotations + +import re + + +def extract_answer(text: str) -> str: + matches = re.findall(r"(.*?)", text, re.DOTALL | re.IGNORECASE) + if matches: + return matches[-1].strip() + lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()] + if lines: + return lines[-1] + return text.strip() + + +def normalize_label(text: str) -> str: + return str(text).strip().upper().rstrip(".):") + + +def parse_choice_label(prediction_text: str, choices: list[dict]) -> str: + answer = extract_answer(prediction_text) + label = normalize_label(answer) + valid_labels = {normalize_label(choice.get("label", "")) for choice in choices} + if label in valid_labels: + return label + + answer_lower = answer.lower() + for choice in choices: + choice_label = normalize_label(choice.get("label", "")) + choice_text = str(choice.get("text", "")).strip() + if choice_text and choice_text.lower() == answer_lower: + return choice_label + + first_token = normalize_label(answer.split()[0]) if answer.split() else "" + if first_token in valid_labels: + return first_token + return label + + +def evaluate(prediction_text: str, correct_choice: dict, choices: list[dict]) -> dict: + predicted_label = parse_choice_label(prediction_text, choices) + correct_label = normalize_label(correct_choice.get("label", "")) + predicted_text = "" + correct_text = str(correct_choice.get("text", "")).strip() + + for choice in choices: + if normalize_label(choice.get("label", "")) == predicted_label: + predicted_text = str(choice.get("text", "")).strip() + break + + is_correct = float(predicted_label == correct_label) + return { + "em": is_correct, + "f1": is_correct, + "sub_em": is_correct, + "predicted_answer": predicted_label or extract_answer(prediction_text), + "predicted_label": predicted_label, + "predicted_text": predicted_text, + "correct_label": correct_label, + "correct_text": correct_text, + } diff --git a/skillopt/envs/livemathematicianbench/prompts/analyst_error.md b/skillopt/envs/livemathematicianbench/prompts/analyst_error.md new file mode 100644 index 0000000..7a78d10 --- /dev/null +++ b/skillopt/envs/livemathematicianbench/prompts/analyst_error.md @@ -0,0 +1,37 @@ +You are an expert failure-analysis agent for theorem-grounded mathematical multiple-choice questions. + +You will be given MULTIPLE failed trajectories from a single minibatch and the current skill document. +Each trajectory includes the student's response and an evaluation result showing the predicted option +versus the correct option. + +Your job is to identify COMMON reasoning failures across the batch and propose concise skill edits. + +## Failure Type Categories +- **quantifier_miss**: the agent missed exact quantifiers, scope, or existence/uniqueness conditions +- **strength_mismatch**: the agent preferred a weaker or stronger statement than what was proved +- **condition_miss**: the agent ignored hypotheses, equality cases, or domain restrictions +- **option_confusion**: the agent confused similar answer choices or failed to compare them exactly +- **other**: none of the above + +## Rules +1. Focus on patterns that recur across the minibatch. +2. Prefer edits that improve exact choice discrimination, not theorem-specific memorization. +3. Do not hardcode paper-specific content. +4. Only patch gaps not already covered by the skill. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} diff --git a/skillopt/envs/livemathematicianbench/prompts/analyst_success.md b/skillopt/envs/livemathematicianbench/prompts/analyst_success.md new file mode 100644 index 0000000..7ff47d1 --- /dev/null +++ b/skillopt/envs/livemathematicianbench/prompts/analyst_success.md @@ -0,0 +1,25 @@ +You are an expert success-pattern analyst for theorem-grounded mathematical multiple-choice questions. + +You will be given MULTIPLE successful trajectories from a minibatch and the current skill document. +Identify generalizable behavior patterns that are genuinely helping the agent choose the exact correct option. + +## Rules +- Focus on broadly useful reasoning behaviors. +- Prefer patterns about exact comparison of options, quantifiers, and equality conditions. +- Do not add theorem-specific facts. +- "edits" may be empty if the skill already captures the useful patterns. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} diff --git a/skillopt/envs/livemathematicianbench/prompts/deep_probe.md b/skillopt/envs/livemathematicianbench/prompts/deep_probe.md new file mode 100644 index 0000000..9987da6 --- /dev/null +++ b/skillopt/envs/livemathematicianbench/prompts/deep_probe.md @@ -0,0 +1,23 @@ +You are an expert diagnostic-probe designer for theorem-grounded mathematical multiple-choice tasks. + +You will be shown representative trajectories, the current student skill, and the student's original prompt context. +Design one SMALL diagnostic instruction that exposes the student's intermediate judgment without materially changing the original scaffold. + +## Hard Constraints +1. Do NOT substantially change the original scaffold. +2. Do NOT prescribe a new multi-step theorem-solving procedure. +3. Do NOT ask for a full proof, full chain-of-thought, or exhaustive option-by-option derivation. +4. Ask only for a short readout of the signals already behind the student's current answer. +5. Keep it brief and structured, and require the final answer to remain in .... + +## Good Probe Targets +- top choice and runner-up +- decisive constraint +- why the runner-up was rejected +- strongest-vs-weaker discrimination signal + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "probe_instruction": "" +} diff --git a/skillopt/envs/livemathematicianbench/prompts/deep_probe_codex.md b/skillopt/envs/livemathematicianbench/prompts/deep_probe_codex.md new file mode 100644 index 0000000..7ed4f3e --- /dev/null +++ b/skillopt/envs/livemathematicianbench/prompts/deep_probe_codex.md @@ -0,0 +1,26 @@ +You are an expert diagnostic-probe designer for theorem-grounded mathematical multiple-choice tasks executed through a Codex trace. + +You will be shown representative trajectories, the current student skill, the student's original prompt context, hidden reference fields, and numbered Codex trace steps. +Choose exactly one trajectory and one probe point. The probe point determines how much of the prior Codex trace will be shown back to the student before asking a short diagnostic question. + +## Hard Constraints +1. Do NOT reveal or paraphrase the hidden reference directly to the student. +2. Do NOT prescribe a new full solving procedure. +3. Do NOT ask for a full proof, full chain-of-thought, or exhaustive option-by-option derivation. +4. Ask only for a short readout of the signal that should already exist at that point in the student's process. +5. The probe instruction must explicitly request a short ... block before the final .... +6. Select a probe point that is informative about theorem choice, decisive constraint, option elimination, or why a stronger/weaker option should be rejected. + +## Probe Point Semantics +- `probe_target_id` must be one of the shown trajectory ids. +- `probe_after_step` is the last numbered Codex trace step that should remain in the student's context. +- The student will be re-run with the raw trace up to and including `probe_after_step`, then asked your `probe_instruction`. +- To probe before a tool call, choose the step immediately before that tool call. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "probe_target_id": "", + "probe_after_step": , + "probe_instruction": "" +} diff --git a/skillopt/envs/livemathematicianbench/prompts/rollout_system.md b/skillopt/envs/livemathematicianbench/prompts/rollout_system.md new file mode 100644 index 0000000..607153d --- /dev/null +++ b/skillopt/envs/livemathematicianbench/prompts/rollout_system.md @@ -0,0 +1,12 @@ +You are an expert mathematical reasoning agent solving multiple-choice questions. + +{skill_section}## Task Format +You will receive one mathematics multiple-choice question and its answer choices. +Reason carefully about quantifiers, hypotheses, extremal wording, and exact equality conditions. + +## Answer Format +Think step by step, then provide your final answer inside ... tags. +Inside the tags, output only the single choice label, such as A or C. + +Example: +B diff --git a/skillopt/envs/livemathematicianbench/reflect.py b/skillopt/envs/livemathematicianbench/reflect.py new file mode 100644 index 0000000..b738481 --- /dev/null +++ b/skillopt/envs/livemathematicianbench/reflect.py @@ -0,0 +1,4 @@ +"""LiveMathematicianBench Reflect stage. + +Prompts are now loaded from .md files by the base adapter. +""" diff --git a/skillopt/envs/livemathematicianbench/rollout.py b/skillopt/envs/livemathematicianbench/rollout.py new file mode 100644 index 0000000..1243f43 --- /dev/null +++ b/skillopt/envs/livemathematicianbench/rollout.py @@ -0,0 +1,401 @@ +"""LiveMathematicianBench rollout — theorem-grounded math MCQ agent.""" +from __future__ import annotations + +import json +import os +import time +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait + +from skillopt.envs.livemathematicianbench.evaluator import evaluate +from skillopt.model import chat_student, get_student_backend, is_student_exec_backend +from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec +from skillopt.prompts import load_prompt + +def _build_system(skill_content: str) -> str: + if skill_content.strip(): + skill_section = f"## Skill\n{skill_content.strip()}\n\n" + else: + skill_section = "" + return load_prompt("rollout_system", env="livemathematicianbench").format(skill_section=skill_section) + + +def _format_choices(choices: list[dict]) -> str: + return "\n".join( + f"{choice['label']}. {choice['text']}" + for choice in choices + ) + + +def _build_user( + item: dict, + *, + use_theorem: bool = False, + use_sketch: bool = False, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> str: + parts = [f"## Question\n{item['question']}", f"## Choices\n{_format_choices(item['choices'])}"] + if use_theorem and item.get("theorem"): + parts.append(f"## Theorem\n{item['theorem']}") + if use_sketch and item.get("sketch"): + parts.append(f"## Proof Sketch\n{item['sketch']}") + if diagnostic_trace_context.strip(): + parts.append( + "## Previous Codex Trace Snapshot\n" + "This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n" + f"{diagnostic_trace_context.strip()}" + ) + if diagnostic_mode and diagnostic_instruction.strip(): + parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}") + return "\n\n".join(parts) + + +def _build_codex_skill(skill_content: str) -> str: + return render_skill_md( + skill_content, + description="Dynamic ReflACT skill for solving the current LiveMathematicianBench multiple-choice question.", + preamble=( + "Use this skill when solving the current math multiple-choice question.\n" + "Inspect the option wording carefully and output only the final choice label inside ...." + ), + ) + +def _run_codex_once( + *, + pred_dir: str, + skill_content: str, + item: dict, + model: str, + timeout: int, + use_theorem: bool = False, + use_sketch: bool = False, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", + previous_response: str = "", +) -> tuple[str, str, str, str]: + user = _build_user( + item, + use_theorem=use_theorem, + use_sketch=use_sketch, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + task_parts = [user] + if previous_response: + task_parts.append( + "## Previous Attempt\n" + f"{previous_response}\n\n" + "Re-evaluate the exact option wording. If needed, correct it." + ) + task_text = "\n\n".join(task_parts) + skill_md = _build_codex_skill(skill_content) + work_dir = os.path.join(pred_dir, "codex_exec") + prepare_workspace(work_dir=work_dir, skill_md=skill_md, task_text=task_text) + prompt = ( + "Use the `skillopt-student` skill available in this workspace.\n" + "Read `task.md` and solve the multiple-choice problem.\n" + "Output only the final choice label inside ...." + ) + final_message, raw = run_student_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + ) + return final_message or raw, raw, skill_md, task_text + + +def process_one( + item: dict, + out_root: str, + skill_content: str, + *, + max_turns: int = 1, + use_theorem: bool = False, + use_sketch: bool = False, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", + exec_timeout: int = 300, +) -> dict: + item_id = str(item["id"]) + result = { + "id": item_id, + "question": item["question"], + "task_type": item.get("theorem_type", ["math_mcq"])[0] if item.get("theorem_type") else "math_mcq", + "hard": 0, + "soft": 0.0, + "predicted_answer": "", + "predicted_label": "", + "predicted_text": "", + "correct_label": item["correct_choice"]["label"], + "correct_text": item["correct_choice"]["text"], + "response": "", + "fail_reason": "", + "agent_ok": False, + "n_turns": 0, + } + + try: + pred_dir = os.path.join(out_root, "predictions", item_id) + os.makedirs(pred_dir, exist_ok=True) + + if is_student_exec_backend(): + from skillopt.model import azure_openai as _llm + + conversation: list[dict] = [] + response = "" + system = "" + user = "" + for turn in range(max_turns): + response, raw, system, user = _run_codex_once( + pred_dir=pred_dir, + skill_content=skill_content, + item=item, + model=_llm.STUDENT_DEPLOYMENT, + timeout=exec_timeout, + use_theorem=use_theorem, + use_sketch=use_sketch, + diagnostic_mode=diagnostic_mode if turn == 0 else False, + diagnostic_instruction=diagnostic_instruction if turn == 0 else "", + diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "", + previous_response=response if turn > 0 else "", + ) + conversation.append({"type": "message", "turn": turn + 1, "content": response}) + if "" in response.lower(): + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) + + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f: + f.write(system) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f: + f.write(user) + + eval_result = evaluate(response, item["correct_choice"], item["choices"]) + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + result["predicted_answer"] = eval_result["predicted_answer"] + result["predicted_label"] = eval_result["predicted_label"] + result["predicted_text"] = eval_result["predicted_text"] + if not result["hard"]: + result["fail_reason"] = ( + f"MCQ=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' " + f"but expected '{eval_result['correct_label']}'" + ) + eval_detail = ( + f"[EVALUATION RESULT]\n" + f"Question: {item['question']}\n" + f"Predicted label: {eval_result['predicted_label']!r}\n" + f"Predicted text: {eval_result['predicted_text']!r}\n" + f"Correct label: {eval_result['correct_label']!r}\n" + f"Correct text: {eval_result['correct_text']!r}\n" + f"Exact Match: {eval_result['em']}" + ) + conversation.append({"role": "system", "content": eval_detail}) + with open(os.path.join(pred_dir, "conversation.json"), "w") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + return result + + system = _build_system(skill_content) + user = _build_user( + item, + use_theorem=use_theorem, + use_sketch=use_sketch, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + conversation: list[dict] = [] + response = "" + + for turn in range(max_turns): + if turn == 0: + resp_text, _ = chat_student( + system=system, + user=user, + max_completion_tokens=16384, + retries=5, + stage="rollout", + timeout=exec_timeout, + ) + else: + refinement = ( + f"Your previous answer was:\n{response}\n\n" + "Re-evaluate the exact option wording. If needed, correct it. " + "Output only the final choice label inside ...." + ) + resp_text, _ = chat_student( + system=system, + user=refinement, + max_completion_tokens=16384, + retries=5, + stage="rollout", + timeout=exec_timeout, + ) + response = resp_text + conversation.append({"type": "message", "turn": turn + 1, "content": resp_text}) + if "" in resp_text.lower(): + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) + + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f: + f.write(system) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f: + f.write(user) + + eval_result = evaluate(response, item["correct_choice"], item["choices"]) + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + result["predicted_answer"] = eval_result["predicted_answer"] + result["predicted_label"] = eval_result["predicted_label"] + result["predicted_text"] = eval_result["predicted_text"] + + if not result["hard"]: + result["fail_reason"] = ( + f"MCQ=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' " + f"but expected '{eval_result['correct_label']}'" + ) + + eval_detail = ( + f"[EVALUATION RESULT]\n" + f"Question: {item['question']}\n" + f"Predicted label: {eval_result['predicted_label']!r}\n" + f"Predicted text: {eval_result['predicted_text']!r}\n" + f"Correct label: {eval_result['correct_label']!r}\n" + f"Correct text: {eval_result['correct_text']!r}\n" + f"Exact Match: {eval_result['em']}" + ) + conversation.append({"role": "system", "content": eval_detail}) + + with open(os.path.join(pred_dir, "conversation.json"), "w") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + + except Exception as e: # noqa: BLE001 + result["fail_reason"] = f"error: {e}" + + return result + + +def run_batch( + items: list[dict], + out_root: str, + skill_content: str, + *, + max_turns: int = 1, + exec_timeout: int = 300, + workers: int = 64, + use_theorem: bool = False, + use_sketch: bool = False, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context_by_id: dict[str, str] | None = None, + task_timeout: int = 600, +) -> list[dict]: + task_timeout = max(int(task_timeout), int(exec_timeout) + 60) + results_path = os.path.join(out_root, "results.jsonl") + os.makedirs(out_root, exist_ok=True) + + done_ids: set[str] = set() + existing: list[dict] = [] + if os.path.exists(results_path): + with open(results_path) as f: + for line in f: + try: + r = json.loads(line) + done_ids.add(str(r["id"])) + existing.append(r) + except Exception: + pass + + pending = [it for it in items if str(it["id"]) not in done_ids] + if not pending: + return existing + + results = list(existing) + + started_at: dict[str, float] = {} + + def _run_one(it: dict) -> dict: + started_at[str(it["id"])] = time.time() + return process_one( + it, + out_root, + skill_content, + max_turns=max_turns, + exec_timeout=exec_timeout, + use_theorem=use_theorem, + use_sketch=use_sketch, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(it["id"]), ""), + ) + + def _timeout_result(it: dict) -> dict: + correct = it.get("correct_choice") or {} + return { + "id": str(it["id"]), + "question": it.get("question", ""), + "task_type": it.get("theorem_type", ["math_mcq"])[0] if it.get("theorem_type") else "math_mcq", + "hard": 0, + "soft": 0.0, + "predicted_answer": "", + "predicted_label": "", + "predicted_text": "", + "correct_label": correct.get("label", ""), + "correct_text": correct.get("text", ""), + "response": "", + "fail_reason": f"task-timeout-{task_timeout}s", + "agent_ok": False, + "n_turns": 0, + } + + def _error_result(it: dict, exc: Exception) -> dict: + res = _timeout_result(it) + res["fail_reason"] = f"error: {type(exc).__name__}: {exc}" + return res + + with open(results_path, "a") as outf: + ex = ThreadPoolExecutor(max_workers=workers) + try: + futs = { + ex.submit(_run_one, it): it + for it in pending + } + pending_futs = set(futs) + while pending_futs: + done, _ = wait(pending_futs, timeout=5, return_when=FIRST_COMPLETED) + now = time.time() + timed_out = [ + fut for fut in pending_futs - done + if str(futs[fut]["id"]) in started_at + and now - started_at[str(futs[fut]["id"])] >= task_timeout + ] + for fut in done: + pending_futs.remove(fut) + item = futs[fut] + try: + res = fut.result() + except Exception as e: # noqa: BLE001 + res = _error_result(item, e) + results.append(res) + outf.write(json.dumps(res, ensure_ascii=False) + "\n") + outf.flush() + for fut in timed_out: + pending_futs.remove(fut) + res = _timeout_result(futs[fut]) + results.append(res) + outf.write(json.dumps(res, ensure_ascii=False) + "\n") + outf.flush() + finally: + ex.shutdown(wait=False, cancel_futures=True) + + return results diff --git a/skillopt/envs/livemathematicianbench/skills/initial.md b/skillopt/envs/livemathematicianbench/skills/initial.md new file mode 100644 index 0000000..d34f603 --- /dev/null +++ b/skillopt/envs/livemathematicianbench/skills/initial.md @@ -0,0 +1,16 @@ +# Live Mathematical MCQ Heuristics + +## Option Comparison +- Compare all options before committing. The correct choice is often the strongest statement justified by the question, while nearby distractors are weaker, overstrong, or miss an equality case. +- Track exact quantifiers such as "there exists", "for every", "if and only if", and "exactly when". + +## Theorem-Level Precision +- Check whether an option weakens the conclusion by dropping a characterization, equality clause, or full equivalence. +- Check whether an option overstates the theorem by upgrading regularity, removing scale restrictions, or changing an existential statement into a universal one. + +## Hypotheses +- Verify the hypotheses and domain carefully. Distractors often keep the theorem shape but alter the required assumptions. +- Pay close attention to equality cases, extremal conditions, and whether a result applies to the full family or only a restricted subfamily. + +## Final Answer +- Output the final answer as the single option label only. diff --git a/skillopt/envs/mathverse/__init__.py b/skillopt/envs/mathverse/__init__.py new file mode 100644 index 0000000..cd96751 --- /dev/null +++ b/skillopt/envs/mathverse/__init__.py @@ -0,0 +1,5 @@ +"""MathVerse environment package.""" + +from skillopt.envs.mathverse.adapter import MathVerseAdapter + +__all__ = ["MathVerseAdapter"] diff --git a/skillopt/envs/mathverse/adapter.py b/skillopt/envs/mathverse/adapter.py new file mode 100644 index 0000000..1c6af51 --- /dev/null +++ b/skillopt/envs/mathverse/adapter.py @@ -0,0 +1,280 @@ +"""MathVerse environment adapter for ReflACT.""" +from __future__ import annotations + +import json +import os + +from skillopt.datasets.base import BatchSpec +from skillopt.envs.base import EnvAdapter +from skillopt.envs.mathverse.dataloader import MathVerseDataLoader +from skillopt.envs.mathverse.rollout import run_batch +from skillopt.gradient.deep_probe import generate_deep_probe_instruction +from skillopt.gradient.reflect import run_minibatch_reflect +from skillopt.model import get_student_backend + + +class MathVerseAdapter(EnvAdapter): + """MathVerse adapter.""" + + def build_reference_text(self, item: dict) -> str: + if not self.use_text_dominant_reference: + return "" + question = str(item.get("text_dominant_question") or "").strip() + if not question: + return "" + return f"## Reference Full Question\n{question}" + + def get_reference_metadata(self, item: dict) -> dict: + if not self.use_text_dominant_reference: + return {"fields": [], "preview": ""} + question = str(item.get("text_dominant_question") or "").strip() + if not question: + return {"fields": [], "preview": ""} + return { + "fields": ["text_dominant_question"], + "preview": question[:400], + } + + def __init__( + self, + split_dir: str = "", + data_root: str = "", + problem_version: str = "Text Lite", + use_text_dominant_reference: bool = False, + max_turns: int = 1, + workers: int = 16, + analyst_workers: int = 16, + failure_only: bool = False, + minibatch_size: int = 8, + edit_budget: int = 4, + seed: int = 42, + limit: int = 0, + image_detail: str = "auto", + judge_model: str = "gpt-5.4", + judge_max_completion_tokens: int = 256, + judge_retries: int = 5, + use_deep_reflect: bool = False, + deep_reflect_failures: int = 4, + deep_reflect_successes: int = 2, + ) -> None: + self.max_turns = max_turns + self.workers = workers + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.image_detail = image_detail + self.judge_model = judge_model + self.judge_max_completion_tokens = judge_max_completion_tokens + self.judge_retries = judge_retries + self.problem_version = problem_version + self.use_text_dominant_reference = use_text_dominant_reference + self.use_deep_reflect = use_deep_reflect + self.deep_reflect_failures = deep_reflect_failures + self.deep_reflect_successes = deep_reflect_successes + self.dataloader = MathVerseDataLoader( + split_dir=split_dir, + seed=seed, + limit=limit, + data_root=data_root, + problem_version=problem_version, + ) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + self.dataloader.setup(cfg) + + def get_dataloader(self): + return self.dataloader + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + return list(batch.payload or []) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout( + self, + env_manager, + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict]: + items: list[dict] = env_manager + return run_batch( + items=items, + out_root=out_dir, + skill_content=skill_content, + max_turns=self.max_turns, + workers=self.workers, + image_detail=self.image_detail, + judge_model=self.judge_model, + judge_max_completion_tokens=self.judge_max_completion_tokens, + judge_retries=self.judge_retries, + diagnostic_mode=kwargs.get("diagnostic_mode", False), + diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), + diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"), + ) + + def reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + if not self.use_deep_reflect: + return [] + + env_manager = kwargs.get("env_manager") + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + selected_items = self.select_representative_items( + results, + env_manager if isinstance(env_manager, list) else None, + n_failures=self.deep_reflect_failures, + n_successes=self.deep_reflect_successes, + seed=random_seed, + ) + if not selected_items: + return [] + + selected_ids = {str(item["id"]) for item in selected_items} + selected_results = [row for row in results if str(row.get("id")) in selected_ids] + selected_examples = self.attach_reference_context(selected_results, selected_items) + codex_backend = get_student_backend() == "codex_exec" + if codex_backend: + selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir) + selected_metadata = [] + ref_count = 0 + for item in selected_items: + meta = self.get_reference_metadata(item) + if meta["fields"]: + ref_count += 1 + record = { + "id": str(item["id"]), + "task_type": str(item.get("task_type") or item.get("question_type") or "mathverse"), + "reference_fields": meta["fields"], + "reference_preview": meta["preview"], + } + if codex_backend: + record["codex_probe_step_count"] = int( + next( + (row.get("codex_probe_step_count", 0) for row in selected_examples if str(row.get("id")) == str(item["id"])), + 0, + ) + ) + selected_metadata.append(record) + + deep_dir = os.path.join(out_dir, "deep_reflect") + rollout_dir = os.path.join(deep_dir, "rollout") + patches_dir = os.path.join(deep_dir, "patches") + os.makedirs(deep_dir, exist_ok=True) + print( + f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} " + f"reference_fields=text_dominant_question({ref_count}/{len(selected_items)})" + ) + probe = generate_deep_probe_instruction( + skill_content=skill_content, + items=selected_examples, + prediction_dir=prediction_dir, + system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(), + step_buffer_context=step_buffer_context, + ) + if not probe: + return [] + + targeted_items = selected_items + diagnostic_trace_context_by_id: dict[str, str] | None = None + if codex_backend: + targeted_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target( + selected_items=selected_items, + selected_examples=selected_examples, + prediction_dir=prediction_dir, + probe=probe, + ) + + with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f: + json.dump( + { + **probe, + "reference_summary": { + "selected_count": len(selected_items), + "field_counts": { + "text_dominant_question": ref_count, + }, + }, + "selected_examples": selected_metadata, + }, + f, + ensure_ascii=False, + indent=2, + ) + + deep_results = run_batch( + items=targeted_items, + out_root=rollout_dir, + skill_content=skill_content, + max_turns=self.max_turns, + workers=min(self.workers, max(len(targeted_items), 1)), + image_detail=self.image_detail, + judge_model=self.judge_model, + judge_max_completion_tokens=self.judge_max_completion_tokens, + judge_retries=self.judge_retries, + diagnostic_mode=True, + diagnostic_instruction=probe["probe_instruction"], + diagnostic_trace_context_by_id=diagnostic_trace_context_by_id, + ) + deep_results = self.attach_reference_context(deep_results, targeted_items) + return run_minibatch_reflect( + results=deep_results, + skill_content=skill_content, + prediction_dir=os.path.join(rollout_dir, "predictions"), + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def get_task_types(self) -> list[str]: + return self.dataloader.get_task_types() diff --git a/skillopt/envs/mathverse/dataloader.py b/skillopt/envs/mathverse/dataloader.py new file mode 100644 index 0000000..128e5c2 --- /dev/null +++ b/skillopt/envs/mathverse/dataloader.py @@ -0,0 +1,228 @@ +"""MathVerse task dataloader.""" +from __future__ import annotations + +import json +import os +import re +from typing import Any + +from skillopt.datasets.base import SplitDataLoader + + +_CHOICE_LABELS = ["A", "B", "C", "D", "E", "F", "G"] +_CHOICE_BLOCK_RE = re.compile(r"\bChoices?\s*:\s*", re.IGNORECASE) +_CHOICE_ITEM_RE = re.compile(r"([A-G])\s*[:.)]\s*(.*?)(?=(?:\s+[A-G]\s*[:.)])|$)", re.DOTALL) + + +def _load_json(path: str) -> Any: + with open(path, encoding="utf-8") as f: + return json.load(f) + + +def _normalize_space(text: Any) -> str: + return re.sub(r"\s+", " ", str(text or "").strip()) + + +def _resolve_image_path(raw_path: str, *, data_root: str, source_path: str) -> str: + candidates = [] + if raw_path: + if os.path.isabs(raw_path): + candidates.append(raw_path) + else: + if data_root: + candidates.append(os.path.join(data_root, raw_path)) + candidates.append(os.path.join(data_root, "images", raw_path)) + candidates.append(os.path.join(os.path.dirname(source_path), raw_path)) + for candidate in candidates: + if candidate and os.path.exists(candidate): + return os.path.abspath(candidate) + return "" + + +def _split_question_and_choices(question: str) -> tuple[str, list[dict]]: + text = str(question or "").strip() + match = _CHOICE_BLOCK_RE.search(text) + if not match: + return text, [] + + stem = text[:match.start()].strip() + choice_block = text[match.end():].strip() + choices: list[dict] = [] + for idx, m in enumerate(_CHOICE_ITEM_RE.finditer(choice_block)): + label = (m.group(1) or _CHOICE_LABELS[idx]).strip().upper() + choice_text = _normalize_space(m.group(2)) + if choice_text: + choices.append({"label": label, "text": choice_text}) + return stem or text, choices + + +def _build_text_dominant_map(data_root: str) -> dict[str, str]: + if not data_root: + return {} + candidates = [ + os.path.join(data_root, "testmini.json"), + os.path.join(data_root, "data", "testmini.json"), + ] + source_path = next((path for path in candidates if os.path.exists(path)), "") + if not source_path: + return {} + + raw = _load_json(source_path) + if not isinstance(raw, list): + return {} + + mapping: dict[str, str] = {} + for item in raw: + if not isinstance(item, dict): + continue + if str(item.get("problem_version") or "").strip() != "Text Dominant": + continue + problem_index = str(item.get("problem_index") or "").strip() + question = str(item.get("question") or "").strip() + if problem_index and question: + mapping[problem_index] = question + return mapping + + +def _normalize_item( + item: dict, + *, + row_idx: int, + source_path: str, + data_root: str, + problem_version: str, + text_dominant_map: dict[str, str], +) -> dict | None: + raw_problem_version = str(item.get("problem_version") or "").strip() + if problem_version and raw_problem_version and raw_problem_version != problem_version: + return None + + question = str(item.get("question") or "").strip() + question_type = str(item.get("question_type") or "").strip() + answer = str(item.get("answer") or "").strip() + image_rel = str(item.get("image") or "").strip() + image_path = _resolve_image_path(image_rel, data_root=data_root, source_path=source_path) + if not answer or not image_path: + return None + + metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} + subject = str(metadata.get("subject") or "").strip() + subfield = str(metadata.get("subfield") or "").strip() + source = str(metadata.get("source") or "").strip() + + question_stem, choices = _split_question_and_choices(question) + is_choice = question_type == "multi-choice" or bool(choices) + + correct_choice = {"label": "", "text": ""} + if is_choice: + label = str(answer).strip().upper().rstrip(".):") + choice_text = "" + for choice in choices: + if choice["label"].upper() == label: + choice_text = choice["text"] + break + correct_choice = {"label": label, "text": choice_text} + + problem_index = str(item.get("problem_index") or "").strip() + sample_index = str(item.get("sample_index") or row_idx + 1).strip() + item_id = problem_index or sample_index + task_type = subfield or subject or question_type or "mathverse" + + return { + "id": item_id, + "sample_index": sample_index, + "problem_index": problem_index, + "problem_version": raw_problem_version or problem_version, + "question": question, + "question_stem": question_stem, + "question_for_eval": str(item.get("question_for_eval") or question).strip(), + "question_type": question_type or ("multi-choice" if is_choice else "free-form"), + "is_choice": is_choice, + "choices": choices, + "correct_choice": correct_choice, + "answer": answer, + "gold_answers": [answer] if answer else [], + "image_rel": image_rel, + "image_path": image_path, + "query_wo": str(item.get("query_wo") or "").strip(), + "query_cot": str(item.get("query_cot") or "").strip(), + "metadata": { + "split": str(metadata.get("split") or "").strip(), + "source": source, + "subject": subject, + "subfield": subfield, + }, + "task_type": task_type, + "source_path": os.path.abspath(source_path), + "text_dominant_question": str( + item.get("text_dominant_question") + or text_dominant_map.get(problem_index, "") + ).strip(), + } + + +class MathVerseDataLoader(SplitDataLoader): + """MathVerse dataloader.""" + + def __init__( + self, + split_dir: str = "", + seed: int = 42, + limit: int = 0, + data_root: str = "", + problem_version: str = "Text Lite", + **kwargs, + ) -> None: + super().__init__(split_dir=split_dir, seed=seed, limit=limit) + self.data_root = data_root + self.problem_version = problem_version + self._task_types: list[str] = [] + self._text_dominant_map = _build_text_dominant_map(data_root) + + def setup(self, cfg: dict) -> None: + if not self.data_root: + self.data_root = str(cfg.get("data_root") or "") + if not self.problem_version: + self.problem_version = str(cfg.get("problem_version") or "Text Lite") + self._text_dominant_map = _build_text_dominant_map(self.data_root) + super().setup(cfg) + all_items = self.train_items + self.val_items + self.test_items + task_types = { + item.get("task_type") or item.get("question_type") or "mathverse" + for item in all_items + } + self._task_types = sorted(str(x) for x in task_types if str(x).strip()) + + def get_task_types(self) -> list[str]: + return list(self._task_types) + + def load_split_items(self, split_path: str) -> list[dict]: + raw_items = super().load_split_items(split_path) + source_path = next( + ( + os.path.join(split_path, name) + for name in sorted(os.listdir(split_path)) + if name.endswith(".json") + ), + split_path, + ) + items: list[dict] = [] + for row_idx, item in enumerate(raw_items): + if not isinstance(item, dict): + continue + norm = _normalize_item( + item, + row_idx=row_idx, + source_path=source_path, + data_root=self.data_root, + problem_version=self.problem_version, + text_dominant_map=self._text_dominant_map, + ) + if norm is not None: + items.append(norm) + if not items: + raise ValueError( + f"No valid MathVerse items loaded from {split_path} " + f"for problem_version={self.problem_version!r}" + ) + return items diff --git a/skillopt/envs/mathverse/evaluator.py b/skillopt/envs/mathverse/evaluator.py new file mode 100644 index 0000000..b54d0ad --- /dev/null +++ b/skillopt/envs/mathverse/evaluator.py @@ -0,0 +1,180 @@ +"""MathVerse evaluation helpers.""" +from __future__ import annotations + +import re +import string + +from skillopt.model import chat_with_deployment +from skillopt.prompts import load_prompt + + +_EVAL_MODE = "mathverse_choice_or_judge_v1" + + +def normalize_text(text: str) -> str: + text = str(text or "").strip().lower() + text = text.replace("\\,", " ") + text = text.replace("\\ ", " ") + text = "".join(ch for ch in text if ch not in string.punctuation) + return " ".join(text.split()) + + +def normalize_math_text(text: str) -> str: + text = str(text or "").strip() + text = text.replace("$", "") + text = text.replace("\\mathrm", "") + text = text.replace("{", "") + text = text.replace("}", "") + text = text.replace("~", " ") + text = text.replace("\\,", " ") + text = text.replace("\\ ", " ") + return " ".join(text.split()).lower() + + +def extract_answer(text: str | None) -> str: + raw = str(text or "").strip() + if not raw: + return "" + + tags = re.findall(r"\s*(.*?)\s*", raw, re.IGNORECASE | re.DOTALL) + if tags: + return tags[-1].strip() + + boxed = re.findall(r"\\boxed\{(.*?)\}", raw, re.IGNORECASE | re.DOTALL) + if boxed: + return boxed[-1].strip() + + lines = [ln.strip() for ln in raw.splitlines() if ln.strip()] + if lines: + return lines[-1] + return raw + + +def _judge_answer( + *, + item: dict, + extracted_answer: str, + judge_model: str, + max_completion_tokens: int, + retries: int, +) -> dict: + question = str(item.get("question_for_eval") or item.get("question") or "").strip() + ground_truth = str(item.get("answer") or "").strip() + raw, _ = chat_with_deployment( + deployment=judge_model, + system="You are a careful and strict mathematical answer evaluator.", + user=load_prompt("judge", env="mathverse").format( + question=question, + groundtruth=ground_truth, + modeloutput=extracted_answer, + ), + max_completion_tokens=max_completion_tokens, + retries=retries, + stage="mathverse_judge", + ) + response = str(raw).strip().lower() + if "true" in response: + correct = True + elif "false" in response: + correct = False + else: + correct = False + return { + "raw": raw, + "correct": correct, + "reason": response, + "matched_gold": ground_truth if correct else "", + } + + +def evaluate_item( + *, + item: dict, + prediction_text: str, + judge_model: str, + max_completion_tokens: int = 256, + retries: int = 5, +) -> dict: + extracted = extract_answer(prediction_text) + + if item.get("is_choice"): + predicted_label = str(extracted).strip().upper().rstrip(".):") + correct_label = str(item["correct_choice"].get("label") or "").strip().upper() + predicted_text = "" + for choice in item.get("choices") or []: + if str(choice.get("label") or "").strip().upper() == predicted_label: + predicted_text = str(choice.get("text") or "").strip() + break + hard = 1.0 if predicted_label == correct_label else 0.0 + return { + "evaluation_mode": _EVAL_MODE, + "predicted_answer": extracted, + "predicted_label": predicted_label, + "predicted_text": predicted_text, + "correct_label": correct_label, + "correct_text": str(item["correct_choice"].get("text") or "").strip(), + "em": hard, + "f1": hard, + "sub_em": hard, + "judge_raw": "", + "judge_reason": "exact_label_match" if hard else "label_mismatch", + "matched_gold": correct_label if hard else "", + } + + gold_answer = str(item.get("answer") or "").strip() + pred_norm = normalize_math_text(extracted) + gold_norm = normalize_math_text(gold_answer) + if pred_norm and gold_norm and pred_norm == gold_norm: + return { + "evaluation_mode": _EVAL_MODE, + "predicted_answer": extracted, + "em": 1.0, + "f1": 1.0, + "sub_em": 1.0, + "judge_raw": "", + "judge_reason": "normalized_exact_match", + "matched_gold": gold_answer, + "string_f1": 1.0, + } + + judge = _judge_answer( + item=item, + extracted_answer=extracted, + judge_model=judge_model, + max_completion_tokens=max_completion_tokens, + retries=retries, + ) + hard = 1.0 if judge["correct"] else 0.0 + pred_tokens = normalize_text(extracted).split() + gold_tokens = normalize_text(gold_answer).split() + overlap = 0 + gold_counts: dict[str, int] = {} + for tok in gold_tokens: + gold_counts[tok] = gold_counts.get(tok, 0) + 1 + for tok in pred_tokens: + count = gold_counts.get(tok, 0) + if count > 0: + overlap += 1 + gold_counts[tok] = count - 1 + if pred_tokens and gold_tokens and overlap: + precision = overlap / len(pred_tokens) + recall = overlap / len(gold_tokens) + string_f1 = 2 * precision * recall / (precision + recall) + else: + string_f1 = 0.0 + + return { + "evaluation_mode": _EVAL_MODE, + "predicted_answer": extracted, + "em": hard, + "f1": hard, + "sub_em": hard, + "judge_raw": judge["raw"], + "judge_reason": judge["reason"], + "matched_gold": judge["matched_gold"], + "string_f1": string_f1, + } + + +def evaluation_mode() -> str: + return _EVAL_MODE diff --git a/skillopt/envs/mathverse/prompts/analyst_error.md b/skillopt/envs/mathverse/prompts/analyst_error.md new file mode 100644 index 0000000..78ec605 --- /dev/null +++ b/skillopt/envs/mathverse/prompts/analyst_error.md @@ -0,0 +1,37 @@ +You are an expert failure-analysis agent for visual mathematical reasoning problems. + +You will be given MULTIPLE failed trajectories from a single minibatch and the current skill document. +Each trajectory includes the student's response, the evaluation result, and sometimes a hidden reference +containing the fuller Text Dominant version of the same problem. + +Your job is to identify COMMON reasoning failures across the batch and propose concise skill edits. + +## Failure Type Categories +- **diagram_underuse**: the agent did not recover key constraints from the image +- **constraint_drop**: the agent ignored a condition or relation that should guide the solution +- **option_confusion**: the agent failed to discriminate between close answer choices +- **format_miss**: the agent solved roughly correctly but returned the wrong final form, unit, or expression +- **other**: none of the above + +## Rules +1. Focus on patterns that recur across the minibatch. +2. Prefer edits that improve visual grounding and exact answer selection. +3. Do not hardcode problem-specific formulas or answers. +4. If hidden reference text is present, use it only to infer what information the student failed to recover from the Text Lite version. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} diff --git a/skillopt/envs/mathverse/prompts/analyst_success.md b/skillopt/envs/mathverse/prompts/analyst_success.md new file mode 100644 index 0000000..80c7f6e --- /dev/null +++ b/skillopt/envs/mathverse/prompts/analyst_success.md @@ -0,0 +1,26 @@ +You are an expert success-pattern analyst for visual mathematical reasoning problems. + +You will be given MULTIPLE successful trajectories from a minibatch and the current skill document. +Identify generalizable behavior patterns that genuinely help the agent recover the right constraints +from the image and convert them into the exact final answer. + +## Rules +- Focus on broadly useful visual-math reasoning behaviors. +- Prefer patterns about reading decisive diagram cues, checking hidden assumptions, and matching the final answer format exactly. +- Do not add benchmark-specific facts or formulas. +- "edits" may be empty if the skill already captures the useful patterns. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} diff --git a/skillopt/envs/mathverse/prompts/deep_probe.md b/skillopt/envs/mathverse/prompts/deep_probe.md new file mode 100644 index 0000000..04db36b --- /dev/null +++ b/skillopt/envs/mathverse/prompts/deep_probe.md @@ -0,0 +1,25 @@ +You are an expert diagnostic-probe designer for visual mathematical reasoning tasks. + +You will be shown representative trajectories, the current student skill, and the student's original prompt context. +Some trajectories may also include a hidden reference containing the fuller Text Dominant wording of the same problem. +Design one SMALL diagnostic instruction that exposes the student's intermediate judgment without materially changing the original scaffold. + +## Hard Constraints +1. Do NOT substantially change the original scaffold. +2. Do NOT prescribe a new long multi-step solving procedure. +3. Do NOT ask for a full proof or full chain-of-thought. +4. Ask only for a short readout of the signals already behind the student's current answer. +5. Keep it brief and structured, and require the final answer to remain in .... +6. If hidden reference text is present, use it only to target what visual or textual constraint the student likely missed. + +## Good Probe Targets +- decisive diagram cue +- top candidate and runner-up +- missing relation or quantity +- why a near-miss option was rejected + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "probe_instruction": "" +} diff --git a/skillopt/envs/mathverse/prompts/judge.md b/skillopt/envs/mathverse/prompts/judge.md new file mode 100644 index 0000000..dc75cff --- /dev/null +++ b/skillopt/envs/mathverse/prompts/judge.md @@ -0,0 +1,25 @@ +You are a careful and strict evaluator for visual math problems. + +You will be given: +1. The original question +2. The ground-truth answer +3. A model output + +Decide whether the model output is mathematically equivalent to the ground-truth answer. + +Rules: +- Ignore harmless formatting differences. +- Accept mathematically equivalent expressions, equations, and values. +- Reject answers that are numerically wrong, symbolically different in meaning, missing required units when the unit changes meaning, or correspond to a different choice. +- Do not reward partially correct reasoning if the final answer is wrong. + +Return only: +True + +or + +False + +Question: {question} +Ground Truth Answer: {groundtruth} +Model Output: {modeloutput} diff --git a/skillopt/envs/mathverse/prompts/rollout_system.md b/skillopt/envs/mathverse/prompts/rollout_system.md new file mode 100644 index 0000000..8660520 --- /dev/null +++ b/skillopt/envs/mathverse/prompts/rollout_system.md @@ -0,0 +1,11 @@ +You are an expert visual mathematical reasoning agent. + +{skill_section}## Task Format +You will receive one math problem with an image or diagram. +Use the visible diagram as evidence, not just the text. +If some information is abbreviated in the text, recover it from the image before answering. + +## Answer Format +Think step by step, then provide your final answer inside .... +- For multiple-choice questions, output only the single option label, such as B. +- For free-form questions, output only the final mathematical answer, such as 14. diff --git a/skillopt/envs/mathverse/reflect.py b/skillopt/envs/mathverse/reflect.py new file mode 100644 index 0000000..9f8f7f2 --- /dev/null +++ b/skillopt/envs/mathverse/reflect.py @@ -0,0 +1,4 @@ +"""MathVerse Reflect stage. + +Prompts are loaded from .md files by the base adapter. +""" diff --git a/skillopt/envs/mathverse/rollout.py b/skillopt/envs/mathverse/rollout.py new file mode 100644 index 0000000..67eaf97 --- /dev/null +++ b/skillopt/envs/mathverse/rollout.py @@ -0,0 +1,415 @@ +"""MathVerse rollout — single-image multimodal math reasoning.""" +from __future__ import annotations + +import base64 +import json +import mimetypes +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +from skillopt.envs.mathverse.evaluator import evaluate_item, evaluation_mode, extract_answer +from skillopt.model import chat_student_messages, get_student_backend, is_student_exec_backend +from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec +from skillopt.prompts import load_prompt + + +def _build_system(skill_content: str) -> str: + if skill_content.strip(): + skill_section = f"## Skill\n{skill_content.strip()}\n\n" + else: + skill_section = "" + return load_prompt("rollout_system", env="mathverse").format(skill_section=skill_section) + + +def _format_choices(choices: list[dict]) -> str: + return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices) + + +def _build_user_text( + item: dict, + *, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> str: + parts = [] + if diagnostic_trace_context.strip(): + parts.append( + "## Previous Codex Trace Snapshot\n" + "This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n" + f"{diagnostic_trace_context.strip()}" + ) + question = str(item.get("question_stem") or item.get("question") or "").strip() + if question: + parts.append(f"## Question\n{question}") + else: + parts.append("## Question\nRead the full problem statement from the image.") + + if item.get("is_choice"): + choices = item.get("choices") or [] + if choices: + parts.append(f"## Choices\n{_format_choices(choices)}") + parts.append("Return only the final option label inside ....") + else: + parts.append("Return only the final mathematical answer inside ....") + + if diagnostic_mode and diagnostic_instruction.strip(): + parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}") + return "\n\n".join(parts) + + +def _image_to_data_uri(path: str) -> str: + mime = mimetypes.guess_type(path)[0] or "image/png" + with open(path, "rb") as f: + encoded = base64.b64encode(f.read()).decode("ascii") + return f"data:{mime};base64,{encoded}" + + +def _build_messages( + item: dict, + skill_content: str, + image_detail: str, + *, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> tuple[list[dict], str, str]: + system = _build_system(skill_content) + user_text = _build_user_text( + item, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + image_url = {"url": _image_to_data_uri(item["image_path"])} + if image_detail and image_detail != "auto": + image_url["detail"] = image_detail + messages = [ + {"role": "system", "content": system}, + { + "role": "user", + "content": [ + {"type": "text", "text": user_text}, + {"type": "image_url", "image_url": image_url}, + ], + }, + ] + return messages, system, user_text + + +def _build_codex_skill(skill_content: str) -> str: + return render_skill_md( + skill_content, + description="Dynamic ReflACT skill for solving the current MathVerse visual math problem.", + preamble=( + "Use this skill when solving the current MathVerse problem.\n" + "Read the image carefully and return the final answer inside ...." + ), + ) + + +def _run_codex_once( + *, + pred_dir: str, + item: dict, + skill_content: str, + model: str, + timeout: int, + image_detail: str, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", + previous_response: str = "", +) -> tuple[str, str, str, str]: + user_text = _build_user_text( + item, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + task_parts = [user_text] + if previous_response: + task_parts.append( + "## Previous Attempt\n" + f"{previous_response}\n\n" + "Re-check the diagram and the mathematical constraints. Correct the final answer if needed." + ) + task_text = "\n\n".join(task_parts) + skill_md = _build_codex_skill(skill_content) + work_dir = os.path.join(pred_dir, "codex_exec") + prepare_workspace( + work_dir=work_dir, + skill_md=skill_md, + task_text=task_text, + images=[item["image_path"]], + ) + prompt = ( + "Use the `skillopt-student` skill available in this workspace.\n" + "Read `task.md`, inspect the attached image, solve the problem, and return only the final answer inside ...." + ) + final_message, raw = run_student_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + images=[item["image_path"]], + ) + return final_message or raw, raw, skill_md, task_text + + +def process_one( + item: dict, + out_root: str, + skill_content: str, + *, + max_turns: int = 1, + image_detail: str = "auto", + judge_model: str = "gpt-5.4", + judge_max_completion_tokens: int = 256, + judge_retries: int = 5, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> dict: + item_id = str(item["id"]) + result = { + "id": item_id, + "question": item["question"], + "task_type": item.get("task_type") or item.get("question_type") or "mathverse", + "task_description": item.get("question_stem") or item["question"], + "hard": 0, + "soft": 0.0, + "predicted_answer": "", + "predicted_label": "", + "predicted_text": "", + "response": "", + "fail_reason": "", + "agent_ok": False, + "n_turns": 0, + "image_path": item["image_path"], + "question_type": item["question_type"], + "evaluation_mode": evaluation_mode(), + "judge_model": judge_model, + } + if item.get("is_choice"): + result["correct_label"] = item["correct_choice"]["label"] + result["correct_text"] = item["correct_choice"]["text"] + else: + result["gold_answers"] = item.get("gold_answers") or [item["answer"]] + + try: + pred_dir = os.path.join(out_root, "predictions", item_id) + os.makedirs(pred_dir, exist_ok=True) + + if is_student_exec_backend(): + from skillopt.model import azure_openai as _llm + + response = "" + conversation: list[dict] = [ + {"role": "user", "content": f"{item['question']}\n\n[image] {os.path.basename(item['image_path'])}"} + ] + system_prompt = "" + user_text = "" + for turn in range(max_turns): + response, raw, system_prompt, user_text = _run_codex_once( + pred_dir=pred_dir, + item=item, + skill_content=skill_content, + model=_llm.STUDENT_DEPLOYMENT, + timeout=120, + image_detail=image_detail, + diagnostic_mode=diagnostic_mode if turn == 0 else False, + diagnostic_instruction=diagnostic_instruction if turn == 0 else "", + diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "", + previous_response=response if turn > 0 else "", + ) + conversation.append({"type": "message", "turn": turn + 1, "content": response}) + if extract_answer(response): + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) - 1 + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f: + f.write(system_prompt) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f: + f.write(user_text) + else: + messages, system_prompt, user_text = _build_messages( + item, + skill_content, + image_detail, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + response = "" + conversation = [ + {"role": "user", "content": f"{user_text}\n\n[image] {os.path.basename(item['image_path'])}"} + ] + for turn in range(max_turns): + if turn == 0: + resp_text, _ = chat_student_messages( + messages=messages, + max_completion_tokens=1024, + retries=5, + stage="rollout", + ) + else: + refinement_text = ( + f"Your previous answer was:\n{response}\n\n" + "Re-check the diagram and the mathematical constraints. " + "If needed, correct your answer. Output only the final answer inside ...." + ) + refinement_messages = [ + messages[0], + messages[1], + {"role": "assistant", "content": response}, + {"role": "user", "content": refinement_text}, + ] + resp_text, _ = chat_student_messages( + messages=refinement_messages, + max_completion_tokens=768, + retries=5, + stage="rollout", + ) + response = resp_text + conversation.append({"type": "message", "turn": turn + 1, "content": resp_text}) + if extract_answer(resp_text): + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) - 1 + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f: + f.write(system_prompt) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f: + f.write(user_text) + + eval_result = evaluate_item( + item=item, + prediction_text=result["response"], + judge_model=judge_model, + max_completion_tokens=judge_max_completion_tokens, + retries=judge_retries, + ) + result["evaluation_mode"] = eval_result["evaluation_mode"] + result["judge_raw"] = eval_result.get("judge_raw", "") + result["judge_reason"] = eval_result.get("judge_reason", "") + result["matched_gold"] = eval_result.get("matched_gold", "") + + if item.get("is_choice"): + result["predicted_label"] = eval_result["predicted_label"] + result["predicted_text"] = eval_result["predicted_text"] + result["predicted_answer"] = eval_result["predicted_answer"] + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + if not result["hard"]: + result["fail_reason"] = ( + f"choice=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' " + f"but expected '{eval_result['correct_label']}'" + ) + eval_detail = ( + f"[EVALUATION RESULT]\n" + f"Question: {item['question_for_eval']}\n" + f"Predicted label: {eval_result['predicted_label']!r}\n" + f"Predicted text: {eval_result['predicted_text']!r}\n" + f"Correct label: {eval_result['correct_label']!r}\n" + f"Correct text: {eval_result['correct_text']!r}\n" + f"Exact Match: {eval_result['em']}" + ) + else: + result["predicted_answer"] = eval_result["predicted_answer"] + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + if not result["hard"]: + result["fail_reason"] = ( + f"judge=0: predicted '{eval_result['predicted_answer']}' " + f"but expected '{item['answer']}' ({eval_result.get('judge_reason', '')})" + ) + eval_detail = ( + f"[EVALUATION RESULT]\n" + f"Question: {item['question_for_eval']}\n" + f"Predicted answer: {eval_result['predicted_answer']!r}\n" + f"Gold answer: {item['answer']!r}\n" + f"Judge correct: {eval_result['em']}\n" + f"Judge reason: {eval_result.get('judge_reason', '')}\n" + f"String F1: {eval_result.get('string_f1', 0.0):.4f}" + ) + + conversation.append({"role": "system", "content": eval_detail}) + with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + except Exception as e: # noqa: BLE001 + result["fail_reason"] = f"error: {e}" + return result + + +def run_batch( + items: list[dict], + out_root: str, + skill_content: str, + *, + max_turns: int = 1, + workers: int = 32, + image_detail: str = "auto", + judge_model: str = "gpt-5.4", + judge_max_completion_tokens: int = 256, + judge_retries: int = 5, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context_by_id: dict[str, str] | None = None, +) -> list[dict]: + results_path = os.path.join(out_root, "results.jsonl") + os.makedirs(out_root, exist_ok=True) + + expected_eval_mode = evaluation_mode() + done_ids: set[str] = set() + existing: list[dict] = [] + rewrite_results = False + if os.path.exists(results_path): + with open(results_path, encoding="utf-8") as f: + for line in f: + try: + row = json.loads(line) + if row.get("evaluation_mode") != expected_eval_mode: + rewrite_results = True + continue + done_ids.add(str(row["id"])) + existing.append(row) + except Exception: + rewrite_results = True + + pending = [item for item in items if str(item["id"]) not in done_ids] + if not pending and not rewrite_results: + return existing + + results = list(existing) + file_mode = "w" if rewrite_results else "a" + with open(results_path, file_mode, encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex: + if rewrite_results: + for row in existing: + outf.write(json.dumps(row, ensure_ascii=False) + "\n") + futs = { + ex.submit( + process_one, + item, + out_root, + skill_content, + max_turns=max_turns, + image_detail=image_detail, + judge_model=judge_model, + judge_max_completion_tokens=judge_max_completion_tokens, + judge_retries=judge_retries, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""), + ): item + for item in pending + } + for fut in as_completed(futs): + row = fut.result() + results.append(row) + outf.write(json.dumps(row, ensure_ascii=False) + "\n") + outf.flush() + return results diff --git a/skillopt/envs/mathverse/skills/initial.md b/skillopt/envs/mathverse/skills/initial.md new file mode 100644 index 0000000..9a386b4 --- /dev/null +++ b/skillopt/envs/mathverse/skills/initial.md @@ -0,0 +1,15 @@ +# MathVerse Visual Math Heuristics + +## Diagram First +- Read the diagram before locking onto an equation or option. +- Recover missing labels, lengths, angles, axes, or object relations from the image when the text is abbreviated. +- If the text seems underspecified, assume the image may contain the decisive constraint. + +## Constraint Tracking +- Write down the few constraints that actually determine the answer instead of solving from vague intuition. +- Prefer geometric or functional relations that are directly supported by the figure. +- For multiple-choice questions, compare the final candidate against every option exactly. + +## Final Answer +- Use the image and the text consistently. +- Return only the final answer inside .... diff --git a/skillopt/envs/mmrb/__init__.py b/skillopt/envs/mmrb/__init__.py new file mode 100644 index 0000000..7f5deef --- /dev/null +++ b/skillopt/envs/mmrb/__init__.py @@ -0,0 +1,2 @@ +"""MMRB environment package.""" + diff --git a/skillopt/envs/mmrb/adapter.py b/skillopt/envs/mmrb/adapter.py new file mode 100644 index 0000000..3fd57e5 --- /dev/null +++ b/skillopt/envs/mmrb/adapter.py @@ -0,0 +1,283 @@ +"""MMRB environment adapter for ReflACT.""" +from __future__ import annotations + +import json +import os + +from skillopt.gradient.deep_probe import generate_deep_probe_instruction +from skillopt.datasets.base import BatchSpec +from skillopt.gradient.reflect import run_minibatch_reflect +from skillopt.envs.base import EnvAdapter +from skillopt.envs.mmrb.dataloader import MMRBDataLoader +from skillopt.envs.mmrb.rollout import run_batch +from skillopt.model import get_student_backend + + +class MMRBAdapter(EnvAdapter): + """MMRB adapter.""" + + def build_reference_text(self, item: dict) -> str: + reasoning_steps = item.get("reasoning_steps") or [] + if not reasoning_steps: + return "" + + blocks: list[str] = [] + for path_idx, path in enumerate(reasoning_steps, 1): + if not isinstance(path, list) or not path: + continue + lines = [f"### Reasoning Path {path_idx}"] + for step in path: + if not isinstance(step, dict): + continue + step_no = step.get("reasoning step", "?") + step_type = str(step.get("reasoning type") or "").strip() + rationale = str(step.get("rationale") or "").strip() + if rationale: + prefix = f"{step_no}. [{step_type}] " if step_type else f"{step_no}. " + lines.append(prefix + rationale) + if len(lines) > 1: + blocks.append("\n".join(lines)) + if not blocks: + return "" + return "## Reference Reasoning Steps\n" + "\n\n".join(blocks[:3]) + + def get_reference_metadata(self, item: dict) -> dict: + reasoning_steps = item.get("reasoning_steps") or [] + path_count = 0 + preview_parts: list[str] = [] + for path in reasoning_steps: + if not isinstance(path, list) or not path: + continue + path_count += 1 + first = path[0] if isinstance(path[0], dict) else {} + step_type = str(first.get("reasoning type") or "").strip() + rationale = str(first.get("rationale") or "").strip() + preview_parts.append(f"[path {path_count}] {step_type}: {rationale[:180]}") + if not path_count: + return {"fields": [], "preview": ""} + return { + "fields": ["reasoning_steps"], + "preview": "\n".join(preview_parts)[:500], + } + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + max_turns: int = 1, + workers: int = 16, + analyst_workers: int = 16, + failure_only: bool = False, + minibatch_size: int = 8, + edit_budget: int = 4, + seed: int = 42, + limit: int = 0, + image_detail: str = "auto", + use_deep_reflect: bool = False, + deep_reflect_failures: int = 4, + deep_reflect_successes: int = 2, + ) -> None: + self.max_turns = max_turns + self.workers = workers + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.image_detail = image_detail + self.use_deep_reflect = use_deep_reflect + self.deep_reflect_failures = deep_reflect_failures + self.deep_reflect_successes = deep_reflect_successes + self.dataloader = MMRBDataLoader( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + ) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + self.dataloader.setup(cfg) + + def get_dataloader(self): + return self.dataloader + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + return list(batch.payload or []) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout( + self, + env_manager, + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict]: + items: list[dict] = env_manager + return run_batch( + items=items, + out_root=out_dir, + skill_content=skill_content, + max_turns=self.max_turns, + workers=self.workers, + image_detail=self.image_detail, + diagnostic_mode=kwargs.get("diagnostic_mode", False), + diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), + diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"), + ) + + def reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + if not self.use_deep_reflect: + return [] + + env_manager = kwargs.get("env_manager") + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + codex_backend = get_student_backend() == "codex_exec" + selected_items = self.select_representative_items( + results, + env_manager if isinstance(env_manager, list) else None, + n_failures=self.deep_reflect_failures, + n_successes=self.deep_reflect_successes, + seed=random_seed, + ) + if not selected_items: + return [] + selected_ids = {str(item["id"]) for item in selected_items} + selected_results = [row for row in results if str(row.get("id")) in selected_ids] + selected_examples = self.attach_reference_context(selected_results, selected_items) + if codex_backend: + selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir) + + reasoning_count = 0 + selected_metadata = [] + for item in selected_items: + meta = self.get_reference_metadata(item) + if meta["fields"]: + reasoning_count += 1 + selected_metadata.append({ + "id": str(item["id"]), + "task_type": str(item.get("subtask") or item.get("task_type") or "mmrb"), + "reference_fields": meta["fields"], + "reference_preview": meta["preview"], + }) + + deep_dir = os.path.join(out_dir, "deep_reflect") + rollout_dir = os.path.join(deep_dir, "rollout") + patches_dir = os.path.join(deep_dir, "patches") + os.makedirs(deep_dir, exist_ok=True) + print( + f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} " + f"reference_fields=reasoning_steps({reasoning_count}/{len(selected_items)})" + ) + probe = generate_deep_probe_instruction( + skill_content=skill_content, + items=selected_examples, + prediction_dir=prediction_dir, + system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + ) + if not probe: + return [] + diagnostic_trace_context_by_id = None + if codex_backend: + selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target( + selected_items=selected_items, + selected_examples=selected_examples, + prediction_dir=prediction_dir, + probe=probe, + ) + probe_record = { + **probe, + "reference_summary": { + "selected_count": len(selected_items), + "field_counts": {"reasoning_steps": reasoning_count}, + }, + "selected_examples": selected_metadata, + } + with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f: + json.dump(probe_record, f, ensure_ascii=False, indent=2) + deep_results = run_batch( + items=selected_items, + out_root=rollout_dir, + skill_content=skill_content, + max_turns=self.max_turns, + workers=min(self.workers, max(len(selected_items), 1)), + image_detail=self.image_detail, + diagnostic_mode=True, + diagnostic_instruction=probe["probe_instruction"], + diagnostic_trace_context_by_id=diagnostic_trace_context_by_id, + ) + deep_results = self.attach_reference_context(deep_results, selected_items) + return run_minibatch_reflect( + results=deep_results, + skill_content=skill_content, + prediction_dir=os.path.join(rollout_dir, "predictions"), + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def get_task_types(self) -> list[str]: + return self.dataloader.get_task_types() diff --git a/skillopt/envs/mmrb/dataloader.py b/skillopt/envs/mmrb/dataloader.py new file mode 100644 index 0000000..819d89c --- /dev/null +++ b/skillopt/envs/mmrb/dataloader.py @@ -0,0 +1,146 @@ +"""MMRB task dataloader.""" +from __future__ import annotations + +import glob +import json +import os +import re +from typing import Any + +from skillopt.datasets.base import SplitDataLoader + + +# ── Raw data loading utilities (for preprocessing / standalone eval) ───── + +def _load_json(path: str) -> Any: + with open(path, encoding="utf-8") as f: + return json.load(f) + + +def _iter_data_files(data_path: str) -> list[str]: + if not data_path: + return [] + if os.path.isfile(data_path): + return [data_path] + if os.path.isdir(data_path): + nested = glob.glob(os.path.join(data_path, "**", "*_human.json"), recursive=True) + flat = glob.glob(os.path.join(data_path, "*_human.json")) + return sorted(set(nested + flat)) + return [] + + +def _normalize_space(text: str) -> str: + return re.sub(r"\s+", " ", str(text or "").strip()) + + +def _normalize_item(item: dict, row_idx: int, source_path: str) -> dict | None: + question = _normalize_space(item.get("question") or "") + answer = _normalize_space(item.get("answer") or "") + raw_image_paths = item.get("image_paths") or [] + if not question or not answer or not isinstance(raw_image_paths, list) or not raw_image_paths: + return None + + base_dir = os.path.dirname(source_path) + image_paths: list[str] = [] + for raw_path in raw_image_paths: + rel = str(raw_path or "").strip() + if not rel: + continue + abs_path = rel if os.path.isabs(rel) else os.path.abspath(os.path.join(base_dir, rel)) + if os.path.exists(abs_path): + image_paths.append(abs_path) + if not image_paths: + return None + + options_raw = item.get("options") or [] + options = [_normalize_space(opt) for opt in options_raw if _normalize_space(opt)] + source = _normalize_space(item.get("source") or "unknown") + subtask = _normalize_space(item.get("subtask") or "unknown") + item_index = item.get("index", row_idx) + item_id = f"{source}:{subtask}:{item_index}" + + return { + "id": item_id, + "source": source, + "subtask": subtask, + "task_type": subtask, + "question": question, + "answer": answer, + "options": options, + "is_choice": bool(options), + "image_paths": image_paths, + "reasoning_steps": item.get("reasoning_steps") or [], + "annotation_time": item.get("annotation_time"), + "source_path": os.path.abspath(source_path), + } + + +def load_items(data_path: str) -> list[dict]: + """Load and normalise MMRB items from JSON files.""" + files = _iter_data_files(data_path) + if not files: + raise ValueError( + "MMRB requires data_path to be a *_human.json file or a directory " + "containing extracted MMRB subtask folders." + ) + + items: list[dict] = [] + for path in files: + raw = _load_json(path) + if not isinstance(raw, list): + raise ValueError(f"Expected JSON array in {path}, got {type(raw).__name__}") + for row_idx, item in enumerate(raw): + if not isinstance(item, dict): + continue + norm = _normalize_item(item, row_idx=row_idx, source_path=path) + if norm is not None: + items.append(norm) + + if not items: + raise ValueError(f"No valid MMRB items loaded from {data_path}") + return items + + +# ── Dataloader ─────────────────────────────────────────────────────────── + +class MMRBDataLoader(SplitDataLoader): + """MMRB dataloader.""" + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + seed: int = 42, + limit: int = 0, + **kwargs, + ) -> None: + super().__init__( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + ) + self._task_types: list[str] = [] + + def load_raw_items(self, data_path: str) -> list[dict]: + return load_items(data_path) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + all_items = self.train_items + self.val_items + self.test_items + task_types = { + item.get("subtask") or item.get("task_type") or "unknown" + for item in all_items + } + self._task_types = sorted(task_types) + + def get_task_types(self) -> list[str]: + return list(self._task_types) diff --git a/skillopt/envs/mmrb/evaluator.py b/skillopt/envs/mmrb/evaluator.py new file mode 100644 index 0000000..3a92f86 --- /dev/null +++ b/skillopt/envs/mmrb/evaluator.py @@ -0,0 +1,102 @@ +"""MMRB evaluation helpers.""" +from __future__ import annotations + +import re +import string + + +_EVAL_MODE = "mmrb_exact_match_v1" + + +def normalize_text(text: str) -> str: + text = str(text or "").strip().lower() + text = "".join(ch for ch in text if ch not in string.punctuation) + return " ".join(text.split()) + + +def extract_answer(text: str | None) -> str: + raw = str(text or "").strip() + if not raw: + return "" + + answer_tags = re.findall(r"\s*(.*?)\s*", raw, re.IGNORECASE | re.DOTALL) + if answer_tags: + return answer_tags[-1].strip() + + bracket = re.findall(r"Answer\s*\[\s*(.*?)\s*\]", raw, re.IGNORECASE | re.DOTALL) + if bracket: + return bracket[-1].strip() + + boxed = re.findall(r"\\boxed\{(.*?)\}", raw, re.IGNORECASE | re.DOTALL) + if boxed: + return boxed[-1].strip() + + single = raw.strip().rstrip(".):") + if re.fullmatch(r"[A-Z]", single, re.IGNORECASE): + return single.strip() + + patterns = [ + r"final answer\s*(?:is)?\s*[::]?\s*(.+)", + r"the answer is\s*[::]?\s*(.+)", + r"answer\s*[::]?\s*(.+)$", + ] + for pattern in patterns: + match = re.search(pattern, raw, re.IGNORECASE) + if match: + return match.group(1).strip().strip("*") + + return raw + + +def evaluate_item(*, item: dict, prediction_text: str) -> dict: + predicted_answer = extract_answer(prediction_text) + gold_answer = str(item.get("answer") or "").strip() + predicted_norm = normalize_text(predicted_answer) + gold_norm = normalize_text(gold_answer) + + hard = 0.0 + matched_gold = "" + predicted_label = "" + predicted_text = predicted_answer + + if item.get("is_choice"): + predicted_label = str(predicted_answer).strip().upper().rstrip(".):") + if predicted_label == str(gold_answer).strip().upper(): + hard = 1.0 + matched_gold = gold_answer + else: + for option in item.get("options") or []: + label_match = re.match(r"\(?([A-Z])\)", option) + if not label_match: + continue + label = label_match.group(1).upper() + option_text = option[label_match.end():].strip(" .:-") + if predicted_norm and normalize_text(option_text) == predicted_norm: + predicted_label = label + predicted_text = option_text + break + if predicted_label == str(gold_answer).strip().upper(): + hard = 1.0 + matched_gold = gold_answer + else: + if predicted_norm and gold_norm and ( + predicted_norm == gold_norm or predicted_norm in gold_norm or gold_norm in predicted_norm + ): + hard = 1.0 + matched_gold = gold_answer + + return { + "evaluation_mode": _EVAL_MODE, + "predicted_answer": predicted_answer, + "predicted_label": predicted_label, + "predicted_text": predicted_text, + "em": hard, + "f1": hard, + "sub_em": hard, + "matched_gold": matched_gold, + } + + +def evaluation_mode() -> str: + return _EVAL_MODE + diff --git a/skillopt/envs/mmrb/prompts/rollout_system.md b/skillopt/envs/mmrb/prompts/rollout_system.md new file mode 100644 index 0000000..8ea51c2 --- /dev/null +++ b/skillopt/envs/mmrb/prompts/rollout_system.md @@ -0,0 +1,10 @@ +You are an expert multi-image reasoning agent. + +{skill_section}## Task Format +You will receive a question grounded in multiple images. +Use the image order exactly as presented in the prompt and compare evidence across images carefully. + +## Answer Format +- Put the final answer inside .... +- For multiple-choice questions, output only the single option letter inside .... +- For open questions, output only the short final answer inside .... diff --git a/skillopt/envs/mmrb/rollout.py b/skillopt/envs/mmrb/rollout.py new file mode 100644 index 0000000..6c78efb --- /dev/null +++ b/skillopt/envs/mmrb/rollout.py @@ -0,0 +1,439 @@ +"""MMRB rollout.""" +from __future__ import annotations + +import base64 +import json +import mimetypes +import os +import re +from concurrent.futures import ThreadPoolExecutor, as_completed + +from skillopt.envs.mmrb.evaluator import evaluate_item, evaluation_mode +from skillopt.model import chat_student_messages, get_student_backend, is_student_exec_backend +from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec +from skillopt.prompts import load_prompt + +_IMAGE_REF_RE = re.compile(r"\{image#(\d+)\}", re.IGNORECASE) + + +def _build_system(skill_content: str) -> str: + if skill_content.strip(): + skill_section = f"## Skill\n{skill_content.strip()}\n\n" + else: + skill_section = "" + return load_prompt("rollout_system", env="mmrb").format(skill_section=skill_section) + + +def _image_to_data_uri(path: str) -> str: + mime = mimetypes.guess_type(path)[0] or "image/png" + with open(path, "rb") as f: + encoded = base64.b64encode(f.read()).decode("ascii") + return f"data:{mime};base64,{encoded}" + + +def _build_user_content( + item: dict, + image_detail: str, + *, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> tuple[list[dict], str]: + raw_question = str(item["question"]) + content: list[dict] = [] + text_parts: list[str] = [] + used_indices: set[int] = set() + cursor = 0 + + if diagnostic_trace_context.strip(): + prefix = ( + "## Previous Codex Trace Snapshot\n" + "This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n" + f"{diagnostic_trace_context.strip()}\n\n" + ) + content.append({"type": "text", "text": prefix}) + text_parts.append(prefix) + + for match in _IMAGE_REF_RE.finditer(raw_question): + if match.start() > cursor: + chunk = raw_question[cursor:match.start()] + if chunk: + content.append({"type": "text", "text": chunk}) + text_parts.append(chunk) + + image_idx = int(match.group(1)) - 1 + marker = f"[Image #{image_idx + 1}]" + text_parts.append(marker) + if 0 <= image_idx < len(item["image_paths"]): + image_url = {"url": _image_to_data_uri(item["image_paths"][image_idx])} + if image_detail and image_detail != "auto": + image_url["detail"] = image_detail + content.append({"type": "image_url", "image_url": image_url}) + used_indices.add(image_idx) + else: + content.append({"type": "text", "text": marker}) + cursor = match.end() + + if cursor < len(raw_question): + tail = raw_question[cursor:] + if tail: + content.append({"type": "text", "text": tail}) + text_parts.append(tail) + + for idx, path in enumerate(item["image_paths"]): + if idx in used_indices: + continue + marker = f"\n[Additional Image #{idx + 1}]" + text_parts.append(marker) + content.append({"type": "text", "text": marker}) + image_url = {"url": _image_to_data_uri(path)} + if image_detail and image_detail != "auto": + image_url["detail"] = image_detail + content.append({"type": "image_url", "image_url": image_url}) + + answer_instruction = ( + "\n\nAnswer with the single correct option letter inside ...." + if item.get("is_choice") + else "\n\nAnswer with the short final answer inside ...." + ) + content.append({"type": "text", "text": answer_instruction}) + text_parts.append(answer_instruction) + + if diagnostic_mode and diagnostic_instruction.strip(): + diag_block = f"\n\n## Training Readout\n{diagnostic_instruction.strip()}" + content.append({"type": "text", "text": diag_block}) + text_parts.append(diag_block) + + return content, "".join(text_parts) + + +def _build_messages( + item: dict, + skill_content: str, + image_detail: str, + *, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", +) -> tuple[list[dict], str, str]: + system = _build_system(skill_content) + user_content, user_text = _build_user_content( + item, + image_detail, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + ) + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": user_content}, + ] + return messages, system, user_text + + +def _build_codex_skill(skill_content: str) -> str: + return render_skill_md( + skill_content, + description="Dynamic ReflACT skill for solving the current MMRB multi-image reasoning question.", + preamble=( + "Use this skill when solving the current multi-image reasoning task.\n" + "Inspect all attached images carefully and return the final answer inside ...." + ), + ) + + +def _run_codex_once( + *, + pred_dir: str, + item: dict, + skill_content: str, + model: str, + timeout: int, + image_detail: str, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", + previous_response: str = "", +) -> tuple[str, str, str, str]: + user_text = _build_user_content( + item, + image_detail, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + )[1] + task_parts = [user_text] + if previous_response: + task_parts.append( + "## Previous Attempt\n" + f"{previous_response}\n\n" + "Review the same images carefully and answer again." + ) + task_text = "\n\n".join(task_parts) + skill_md = _build_codex_skill(skill_content) + work_dir = os.path.join(pred_dir, "codex_exec") + prepare_workspace( + work_dir=work_dir, + skill_md=skill_md, + task_text=task_text, + images=item["image_paths"], + ) + prompt = ( + "Use the `skillopt-student` skill available in this workspace.\n" + "Read `task.md`, inspect all attached images, and answer the question.\n" + "Keep the final answer inside ...." + ) + final_message, raw = run_student_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + images=item["image_paths"], + ) + return final_message or raw, raw, skill_md, task_text + + +def process_one( + item: dict, + out_root: str, + skill_content: str, + *, + max_turns: int = 1, + image_detail: str = "auto", + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> dict: + item_id = str(item["id"]) + result = { + "id": item_id, + "question": item["question"], + "task_type": item.get("subtask") or item.get("task_type") or "mmrb", + "task_description": item["question"], + "hard": 0, + "soft": 0.0, + "predicted_answer": "", + "predicted_label": "", + "predicted_text": "", + "response": "", + "fail_reason": "", + "agent_ok": False, + "n_turns": 0, + "image_paths": item["image_paths"], + "gold_answer": item["answer"], + "evaluation_mode": evaluation_mode(), + } + + try: + pred_dir = os.path.join(out_root, "predictions", item_id) + os.makedirs(pred_dir, exist_ok=True) + + if is_student_exec_backend(): + from skillopt.model import azure_openai as _llm + + response = "" + conversation: list[dict] = [ + { + "role": "user", + "content": item["question"] + "\n\n" + "\n".join( + f"[image] {os.path.basename(path)}" for path in item["image_paths"] + ), + } + ] + system_prompt = "" + user_text = "" + for turn in range(max_turns): + response, raw, system_prompt, user_text = _run_codex_once( + pred_dir=pred_dir, + item=item, + skill_content=skill_content, + model=_llm.STUDENT_DEPLOYMENT, + timeout=120, + image_detail=image_detail, + diagnostic_mode=diagnostic_mode if turn == 0 else False, + diagnostic_instruction=diagnostic_instruction if turn == 0 else "", + diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "", + previous_response=response if turn > 0 else "", + ) + conversation.append({"type": "message", "turn": turn + 1, "content": response}) + if "" in response.lower(): + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) - 1 + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f: + f.write(system_prompt) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f: + f.write(user_text) + + eval_result = evaluate_item(item=item, prediction_text=response) + result["evaluation_mode"] = eval_result["evaluation_mode"] + result["predicted_answer"] = eval_result["predicted_answer"] + result["predicted_label"] = eval_result["predicted_label"] + result["predicted_text"] = eval_result["predicted_text"] + result["matched_gold"] = eval_result["matched_gold"] + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + if not result["hard"]: + result["fail_reason"] = ( + f"predicted '{eval_result['predicted_answer']}' but expected '{item['answer']}'" + ) + eval_detail = ( + "[EVALUATION RESULT]\n" + f"Question: {item['question']}\n" + f"Predicted answer: {eval_result['predicted_answer']!r}\n" + f"Predicted label: {eval_result['predicted_label']!r}\n" + f"Gold answer: {item['answer']!r}\n" + f"Correct: {eval_result['em']}\n" + ) + conversation.append({"role": "system", "content": eval_detail}) + with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + return result + + messages, system_prompt, user_text = _build_messages( + item, + skill_content, + image_detail, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + response = "" + conversation: list[dict] = [ + { + "role": "user", + "content": user_text + "\n\n" + "\n".join( + f"[image] {os.path.basename(path)}" for path in item["image_paths"] + ), + } + ] + + for turn in range(max_turns): + if turn == 0: + resp_text, _ = chat_student_messages( + messages=messages, + max_completion_tokens=768, + retries=5, + stage="rollout", + ) + else: + refinement_messages = [ + messages[0], + messages[1], + {"role": "assistant", "content": response}, + { + "role": "user", + "content": "Review the same images carefully and answer again. Keep the final answer inside ....", + }, + ] + resp_text, _ = chat_student_messages( + messages=refinement_messages, + max_completion_tokens=512, + retries=5, + stage="rollout", + ) + response = resp_text + conversation.append({"type": "message", "turn": turn + 1, "content": resp_text}) + if "" in resp_text.lower(): + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) - 1 + + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f: + f.write(system_prompt) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f: + f.write(user_text) + + eval_result = evaluate_item(item=item, prediction_text=response) + result["evaluation_mode"] = eval_result["evaluation_mode"] + result["predicted_answer"] = eval_result["predicted_answer"] + result["predicted_label"] = eval_result["predicted_label"] + result["predicted_text"] = eval_result["predicted_text"] + result["matched_gold"] = eval_result["matched_gold"] + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + if not result["hard"]: + result["fail_reason"] = ( + f"predicted '{eval_result['predicted_answer']}' but expected '{item['answer']}'" + ) + + eval_detail = ( + "[EVALUATION RESULT]\n" + f"Question: {item['question']}\n" + f"Predicted answer: {eval_result['predicted_answer']!r}\n" + f"Predicted label: {eval_result['predicted_label']!r}\n" + f"Gold answer: {item['answer']!r}\n" + f"Correct: {eval_result['em']}\n" + ) + conversation.append({"role": "system", "content": eval_detail}) + with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + except Exception as e: # noqa: BLE001 + result["fail_reason"] = f"error: {e}" + return result + + +def run_batch( + items: list[dict], + out_root: str, + skill_content: str, + *, + max_turns: int = 1, + workers: int = 16, + image_detail: str = "auto", + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context_by_id: dict[str, str] | None = None, +) -> list[dict]: + results_path = os.path.join(out_root, "results.jsonl") + os.makedirs(out_root, exist_ok=True) + + expected_eval_mode = evaluation_mode() + done_ids: set[str] = set() + existing: list[dict] = [] + rewrite_results = False + if os.path.exists(results_path): + with open(results_path, encoding="utf-8") as f: + for line in f: + try: + row = json.loads(line) + if row.get("evaluation_mode") != expected_eval_mode: + rewrite_results = True + continue + done_ids.add(str(row["id"])) + existing.append(row) + except Exception: + rewrite_results = True + + pending = [item for item in items if str(item["id"]) not in done_ids] + if not pending and not rewrite_results: + return existing + + results = list(existing) + file_mode = "w" if rewrite_results else "a" + with open(results_path, file_mode, encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex: + if rewrite_results: + for row in existing: + outf.write(json.dumps(row, ensure_ascii=False) + "\n") + futs = { + ex.submit( + process_one, + item, + out_root, + skill_content, + max_turns=max_turns, + image_detail=image_detail, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""), + ): item + for item in pending + } + for fut in as_completed(futs): + row = fut.result() + results.append(row) + outf.write(json.dumps(row, ensure_ascii=False) + "\n") + outf.flush() + return results diff --git a/skillopt/envs/mmrb/skills/initial.md b/skillopt/envs/mmrb/skills/initial.md new file mode 100644 index 0000000..92a8be3 --- /dev/null +++ b/skillopt/envs/mmrb/skills/initial.md @@ -0,0 +1,17 @@ +# MMRB Multi-Image Reasoning Heuristics + +## Cross-Image Alignment +- Track the role of each image by its index and compare evidence across all referenced images before deciding. +- When the question depends on sequence, correspondence, or retrieval, verify the relation between images instead of judging each image independently. + +## Option Elimination +- For multiple-choice tasks, compare all options and reject choices that match only part of the visual evidence. +- If options differ by a small visual detail, use the most discriminative cue rather than a coarse scene impression. + +## Open Answers +- For open-ended tasks, give the shortest answer that is fully supported by the combined images. +- Preserve exact entities, attributes, counts, and directions when the images support them directly. + +## Final Answer +- Output only the final answer inside .... + diff --git a/skillopt/envs/officeqa/__init__.py b/skillopt/envs/officeqa/__init__.py new file mode 100644 index 0000000..5316aaf --- /dev/null +++ b/skillopt/envs/officeqa/__init__.py @@ -0,0 +1 @@ +"""OfficeQA environment package for ReflACT.""" diff --git a/skillopt/envs/officeqa/adapter.py b/skillopt/envs/officeqa/adapter.py new file mode 100644 index 0000000..f383893 --- /dev/null +++ b/skillopt/envs/officeqa/adapter.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import os + +from skillopt.datasets.base import BatchSpec +from skillopt.envs.base import EnvAdapter +from skillopt.envs.deep_reflect import run_no_reference_deep_reflect +from skillopt.envs.officeqa.dataloader import OfficeQADataLoader +from skillopt.envs.officeqa.rollout import run_batch +from skillopt.gradient.reflect import run_minibatch_reflect + + +class OfficeQAAdapter(EnvAdapter): + def __init__( + self, + split_dir: str = "", + workers: int = 8, + analyst_workers: int = 8, + failure_only: bool = False, + minibatch_size: int = 8, + edit_budget: int = 4, + seed: int = 42, + limit: int = 0, + max_tool_turns: int = 12, + data_dirs: list[str] | str | None = None, + docs_dirs: list[str] | str | None = None, + use_deep_reflect: bool = False, + deep_reflect_failures: int = 4, + deep_reflect_successes: int = 2, + ) -> None: + self.workers = workers + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.max_tool_turns = max_tool_turns + self.data_dirs = data_dirs if data_dirs is not None else docs_dirs + self.use_deep_reflect = use_deep_reflect + self.deep_reflect_failures = deep_reflect_failures + self.deep_reflect_successes = deep_reflect_successes + self.dataloader = OfficeQADataLoader(split_dir=split_dir, seed=seed, limit=limit) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + self.dataloader.setup(cfg) + + def get_dataloader(self): + return self.dataloader + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + return list(batch.payload or []) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout(self, env_manager, skill_content: str, out_dir: str, **kwargs) -> list[dict]: + items: list[dict] = env_manager + return run_batch( + items=items, + out_root=out_dir, + skill_content=skill_content, + workers=self.workers, + max_tool_turns=self.max_tool_turns, + data_dirs=self.data_dirs, + diagnostic_mode=kwargs.get("diagnostic_mode", False), + diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), + ) + + def reflect(self, results: list[dict], skill_content: str, out_dir: str, **kwargs) -> list[dict | None]: + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + return run_no_reference_deep_reflect( + self, + results, + skill_content, + out_dir, + env_manager=kwargs.get("env_manager"), + prediction_dir=kwargs.get("prediction_dir"), + random_seed=kwargs.get("random_seed"), + step_buffer_context=kwargs.get("step_buffer_context", ""), + output_requirements=[ + "- There is no hidden reference block. Use only the question, candidate files, tool trace, student output, and evaluation result to infer what intermediate state is worth probing.", + "- The instruction must explicitly request a short ... block before the final ....", + "- The readout should focus on selected document/file, evidence span or table, extracted value, units, and any date or fiscal-period normalization.", + "- Do not ask for exhaustive copying of source text or a full chain-of-thought.", + "- The instruction text should be ready to append directly to the student's prompt.", + ], + metadata_builder=lambda item: { + "id": str(item.get("id")), + "task_type": str(item.get("task_type") or "officeqa"), + "question_preview": str(item.get("question") or "")[:200], + "source_files": item.get("source_files", []), + "source_docs": item.get("source_docs", []), + }, + ) + + def get_task_types(self) -> list[str]: + seen: list[str] = [] + for item in self.dataloader.train_items + self.dataloader.val_items + self.dataloader.test_items: + task_type = str(item.get("task_type") or "officeqa") + if task_type not in seen: + seen.append(task_type) + return seen or ["officeqa"] diff --git a/skillopt/envs/officeqa/dataloader.py b/skillopt/envs/officeqa/dataloader.py new file mode 100644 index 0000000..a9c22b4 --- /dev/null +++ b/skillopt/envs/officeqa/dataloader.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import csv +import json +import os +from pathlib import Path + +from skillopt.datasets.base import SplitDataLoader + + +def _parse_list_field(value: str | list[str] | None) -> list[str]: + if value is None: + return [] + if isinstance(value, list): + return [str(item).strip() for item in value if str(item).strip()] + text = str(value).strip() + if not text: + return [] + try: + loaded = json.loads(text) + except json.JSONDecodeError: + loaded = None + if isinstance(loaded, list): + return [str(item).strip() for item in loaded if str(item).strip()] + if "\n" in text: + return [part.strip() for part in text.splitlines() if part.strip()] + if "," in text and not text.lower().endswith(".txt"): + return [part.strip() for part in text.split(",") if part.strip()] + return [text] + + +def _normalize_row(row: dict[str, str]) -> dict: + item_id = str(row.get("uid") or row.get("id") or "").strip() + question = str(row.get("question") or "").strip() + ground_truth = str(row.get("ground_truth") or row.get("answer") or "").strip() + task_type = str(row.get("category") or row.get("difficulty") or "officeqa").strip() or "officeqa" + source_files = _parse_list_field(row.get("source_files")) + source_docs = _parse_list_field(row.get("source_docs")) + split = str(row.get("split") or "").strip() + return { + "id": item_id, + "uid": item_id, + "question": question, + "ground_truth": ground_truth, + "answers": [ground_truth] if ground_truth else [], + "task_type": task_type, + "category": task_type, + "source_files": source_files, + "source_docs": source_docs, + "split": split, + } + + +class OfficeQADataLoader(SplitDataLoader): + def load_split_items(self, split_path: str) -> list[dict]: + path = Path(split_path) + csv_files = sorted(path.glob("*.csv")) + if csv_files: + with csv_files[0].open(encoding="utf-8", newline="") as f: + reader = csv.DictReader(f) + return [_normalize_row(row) for row in reader] + + json_files = sorted(path.glob("*.json")) + if json_files: + with json_files[0].open(encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, list): + raise ValueError(f"Expected JSON array in {json_files[0]}") + return [_normalize_row(item) for item in data] + + raise FileNotFoundError(f"No .csv or .json file found in {split_path}") diff --git a/skillopt/envs/officeqa/evaluator.py b/skillopt/envs/officeqa/evaluator.py new file mode 100644 index 0000000..124d25d --- /dev/null +++ b/skillopt/envs/officeqa/evaluator.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import re +import string +from collections import Counter + + +_NUMERIC_CHARS = set("0123456789.-") + + +def normalize_answer(text: str) -> str: + text = text.lower().strip() + text = text.replace(",", "") + text = "".join(ch for ch in text if ch not in string.punctuation or ch in _NUMERIC_CHARS or ch == "%") + text = re.sub(r"\b(million|millions|billion|billions|dollars|dollar|nominal)\b", " ", text) + text = " ".join(text.split()) + return text + + +def exact_match(prediction: str, gold: str) -> float: + return 1.0 if normalize_answer(prediction) == normalize_answer(gold) else 0.0 + + +def token_f1(prediction: str, gold: str) -> float: + pred_tokens = normalize_answer(prediction).split() + gold_tokens = normalize_answer(gold).split() + if not pred_tokens or not gold_tokens: + return 1.0 if pred_tokens == gold_tokens else 0.0 + common = Counter(pred_tokens) & Counter(gold_tokens) + n_common = sum(common.values()) + if n_common == 0: + return 0.0 + precision = n_common / len(pred_tokens) + recall = n_common / len(gold_tokens) + return 2 * precision * recall / (precision + recall) + + +def evaluate(prediction: str, gold: str) -> dict: + em = exact_match(prediction, gold) + f1 = token_f1(prediction, gold) + return { + "em": em, + "f1": f1, + "predicted_answer": prediction.strip(), + "gold_answer": gold, + } diff --git a/skillopt/envs/officeqa/prompts/analyst_error.md b/skillopt/envs/officeqa/prompts/analyst_error.md new file mode 100644 index 0000000..ec9a87e --- /dev/null +++ b/skillopt/envs/officeqa/prompts/analyst_error.md @@ -0,0 +1,37 @@ +You are an expert failure-analysis agent for OfficeQA document-retrieval question answering tasks. + +You will be given MULTIPLE failed OfficeQA trajectories from a single minibatch and the current skill document. The trajectories may include local document tool calls such as file search, grep, and partial reads. + +Your job is to identify COMMON failure patterns across the batch and propose concise skill edits. + +## Failure Type Categories +- retrieval_miss: the agent searched the wrong file or failed to narrow to the right file +- evidence_miss: the agent read documents but missed the decisive evidence span +- operand_error: the agent extracted the wrong value or the wrong operands +- calculation_error: the agent identified the right evidence but computed the result incorrectly +- answer_format: the agent reached the right result but formatted it wrong +- other: none of the above + +## Rules +- Focus on patterns common across multiple trajectories. +- Prefer general retrieval and evidence-grounding rules over task-specific hacks. +- Only patch gaps in the skill; do not duplicate rules already present. +- Do not hardcode file names, years, or question-specific constants unless the pattern truly requires a reusable retrieval heuristic. + +Respond ONLY with a valid JSON object (no markdown fences, no extra text): +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +Only include edits that are needed. "edits" can be an empty list if no patch is warranted. diff --git a/skillopt/envs/officeqa/prompts/analyst_success.md b/skillopt/envs/officeqa/prompts/analyst_success.md new file mode 100644 index 0000000..4ce3da5 --- /dev/null +++ b/skillopt/envs/officeqa/prompts/analyst_success.md @@ -0,0 +1,25 @@ +You are an expert success-pattern analyst for OfficeQA document-retrieval question answering tasks. + +You will be given MULTIPLE successful OfficeQA trajectories from a single minibatch and the current skill document. Your job is to identify common retrieval, evidence-selection, and numeric-grounding behaviors worth encoding in the skill. + +## Rules +- Focus on patterns shared across multiple successful trajectories. +- Prefer reusable retrieval and extraction discipline over question-specific tips. +- Reinforce compact, high-value behaviors such as narrowing files early, reading only the relevant span, building a clean operand ledger, and copying the final answer from checked evidence. +- Only propose patches for patterns not already captured in the current skill. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +"edits" may be empty if the skill already covers all observed patterns. diff --git a/skillopt/envs/officeqa/prompts/rollout_system.md b/skillopt/envs/officeqa/prompts/rollout_system.md new file mode 100644 index 0000000..db22931 --- /dev/null +++ b/skillopt/envs/officeqa/prompts/rollout_system.md @@ -0,0 +1,15 @@ +You are an expert OfficeQA agent working over local Treasury bulletin text files. + +{skill_section}## Rules +1. Use only the provided local document tools to inspect candidate files. +2. Narrow to the most relevant file before reading long passages. +3. Prefer short targeted searches, then small reads around matching evidence. +4. Do not invent values that are not grounded in the retrieved text. +5. When the question requires arithmetic, compute only after extracting the exact operands. +6. If you have enough evidence, return the final answer inside .... + +## Tool Use +Use the provided function tools directly when you need them. Prefer searching and small reads before answering. Do not ask the user for permission to use tools; just call the tools. + +## Final Answer Format +When you are ready to answer, emit the final answer inside ... and do not request another tool. diff --git a/skillopt/envs/officeqa/rollout.py b/skillopt/envs/officeqa/rollout.py new file mode 100644 index 0000000..30a28d5 --- /dev/null +++ b/skillopt/envs/officeqa/rollout.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import json +import os +import re +from concurrent.futures import ThreadPoolExecutor, as_completed + +from skillopt.envs.officeqa.evaluator import evaluate +from skillopt.envs.officeqa.tool_runtime import resolve_candidate_files, resolve_docs_roots, run_tool +from skillopt.model import chat_student_messages, get_student_backend, is_student_exec_backend +from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec +from skillopt.prompts import load_prompt + +_TOOL_SCHEMAS = [ + { + "type": "function", + "function": { + "name": "glob", + "description": "Find candidate local document files by filename or relative-path glob pattern.", + "parameters": { + "type": "object", + "properties": {"pattern": {"type": "string"}}, + "required": ["pattern"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "read", + "description": "Read a local text document excerpt by path and line window.", + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string"}, + "start": {"type": "integer"}, + "limit": {"type": "integer"}, + }, + "required": ["path"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "grep", + "description": "Search a local text document for a literal pattern and return matching lines.", + "parameters": { + "type": "object", + "properties": { + "pattern": {"type": "string"}, + "path": {"type": "string"}, + }, + "required": ["pattern", "path"], + }, + }, + }, +] + +_FINAL_RE = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) + + +def _build_system(skill_content: str) -> str: + if skill_content.strip(): + skill_section = f"## Skill\n{skill_content.strip()}\n\n" + else: + skill_section = "" + return load_prompt("rollout_system", env="officeqa").format(skill_section=skill_section) + + +def _build_user( + item: dict, + candidate_files: list[str], + *, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + corpus_note: str = "", +) -> str: + file_block = "\n".join(f"- {path}" for path in candidate_files[:20]) or "- none resolved" + parts = [f"## Question\n{item['question']}"] + if corpus_note.strip(): + parts.append(f"## Document Corpus\n{corpus_note.strip()}") + parts.append(f"## Candidate Files\n{file_block}") + if item.get("source_docs"): + parts.append("## Source Hints\n" + "\n".join(f"- {hint}" for hint in item["source_docs"])) + if diagnostic_mode and diagnostic_instruction.strip(): + parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}") + return "\n\n".join(parts) + + +def _extract_answer(text: str) -> str: + match = _FINAL_RE.search(text) + if match: + return match.group(1).strip() + lines = [line.strip() for line in text.splitlines() if line.strip()] + return lines[-1] if lines else text.strip() + + +def _docs_link_targets(docs_roots: list[str]) -> list[tuple[str, str]]: + return [(root, os.path.join("docs", f"root_{idx}")) for idx, root in enumerate(docs_roots, start=1)] + + +def _workspace_doc_path(path: str, docs_roots: list[str]) -> str: + resolved_path = os.path.realpath(path) + for idx, root in enumerate(docs_roots, start=1): + resolved_root = os.path.realpath(root) + if resolved_path == resolved_root or resolved_path.startswith(resolved_root + os.sep): + rel_path = os.path.relpath(resolved_path, resolved_root) + return os.path.join("docs", f"root_{idx}", rel_path) + return path + + +def _build_codex_skill(skill_content: str) -> str: + return render_skill_md( + skill_content, + description="Dynamic ReflACT skill for solving the current OfficeQA local-document question.", + preamble=( + "Use this skill when answering the current OfficeQA question.\n" + "Inspect the provided local document excerpts or files, ground the answer in the evidence,\n" + "and return the final answer inside ...." + ), + ) + + +def _run_codex_once( + *, + pred_dir: str, + item: dict, + skill_content: str, + candidate_files: list[str], + docs_roots: list[str], + model: str, + timeout: int, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + previous_response: str = "", +) -> tuple[str, str, str, str]: + rel_files = [_workspace_doc_path(path, docs_roots) for path in candidate_files[:20]] + corpus_note = ( + "The full OfficeQA document corpus is available under `docs/`. " + "The candidate files below are source hints or likely starting points; search the full corpus if needed." + ) + user = _build_user( + item, + rel_files, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + corpus_note=corpus_note, + ) + task_parts = [user] + if previous_response: + task_parts.append( + "## Previous Attempt\n" + f"{previous_response}\n\n" + "Review the local documents again and correct the answer if needed." + ) + task_text = "\n\n".join(task_parts) + skill_md = _build_codex_skill(skill_content) + work_dir = os.path.join(pred_dir, "codex_exec") + prepare_workspace( + work_dir=work_dir, + skill_md=skill_md, + task_text=task_text, + link_dirs=_docs_link_targets(docs_roots), + ) + prompt = ( + "Use the `skillopt-student` skill available in this workspace.\n" + "Read `task.md`, inspect or search the full OfficeQA corpus under `docs/`, and answer the question.\n" + "Treat candidate files in `task.md` as hints, not an access limit.\n" + "Return the final answer inside ...." + ) + final_message, raw = run_student_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + data_dirs=docs_roots, + ) + return final_message or raw, raw, skill_md, task_text + + +def process_one( + item: dict, + out_root: str, + skill_content: str, + *, + max_tool_turns: int = 12, + data_dirs: list[str] | str | None = None, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", +) -> dict: + item_id = str(item["id"]) + pred_dir = os.path.join(out_root, "predictions", item_id) + os.makedirs(pred_dir, exist_ok=True) + + docs_roots = resolve_docs_roots(data_dirs) + candidate_files = resolve_candidate_files(item.get("source_files", []), docs_roots) + system = _build_system(skill_content) + user = _build_user(item, candidate_files, diagnostic_mode=diagnostic_mode, diagnostic_instruction=diagnostic_instruction) + + messages: list[dict] = [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] + conversation: list[dict] = [{"role": "user", "content": user}] + final_response = "" + final_answer = "" + fail_reason = "" + + allowed_files = [os.path.basename(path) for path in candidate_files] + + try: + if is_student_exec_backend(): + from skillopt.model import azure_openai as _llm + + response = "" + system = "" + user = "" + for turn in range(1, max_tool_turns + 1): + response, _raw, system, user = _run_codex_once( + pred_dir=pred_dir, + item=item, + skill_content=skill_content, + candidate_files=candidate_files, + docs_roots=docs_roots, + model=_llm.STUDENT_DEPLOYMENT, + timeout=180, + diagnostic_mode=diagnostic_mode if turn == 1 else False, + diagnostic_instruction=diagnostic_instruction if turn == 1 else "", + previous_response=response if turn > 1 else "", + ) + final_response = response + conversation.append({"type": "message", "turn": turn, "content": response}) + if "" in response.lower(): + final_answer = _extract_answer(response) + break + if not final_answer: + fail_reason = f"Exceeded codex turn budget ({max_tool_turns})" + system = system or _build_codex_skill(skill_content) + user = user or _build_user(item, [_workspace_doc_path(path, docs_roots) for path in candidate_files]) + else: + for turn in range(1, max_tool_turns + 1): + message, _ = chat_student_messages( + messages=messages, + max_completion_tokens=768, + retries=5, + stage="rollout", + tools=_TOOL_SCHEMAS, + tool_choice="auto", + return_message=True, + ) + response = message.content or "" + final_response = response + assistant_message = {"role": "assistant", "content": response} + if getattr(message, "tool_calls", None): + assistant_message["tool_calls"] = [tool_call.model_dump(mode="json") for tool_call in message.tool_calls] + messages.append(assistant_message) + conversation.append({"type": "message", "content": response}) + + if getattr(message, "tool_calls", None): + for tool_call in message.tool_calls: + tool_name = tool_call.function.name + arguments = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} + cmd, obs = run_tool(tool_name, arguments, allowed_roots=docs_roots, allowed_files=allowed_files) + conversation.append({"type": "tool_call", "cmd": cmd, "obs": obs}) + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": obs, + }) + continue + + if "" in response.lower(): + final_answer = _extract_answer(response) + break + if turn == max_tool_turns: + fail_reason = f"Exceeded tool-turn budget ({max_tool_turns})" + else: + fail_reason = "Model neither produced a tool request nor a final answer" + break + except Exception as e: # noqa: BLE001 + fail_reason = f"error: {e}" + + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f: + f.write(system) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f: + f.write(user) + with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + + eval_result = evaluate(final_answer, item.get("ground_truth", "")) if final_answer else {"em": 0.0, "f1": 0.0, "predicted_answer": "", "gold_answer": item.get("ground_truth", "")} + result = { + "id": item_id, + "question": item.get("question", ""), + "task_type": item.get("task_type", "officeqa"), + "task_description": item.get("question", ""), + "predicted_answer": eval_result["predicted_answer"], + "response": final_response, + "ground_truth": item.get("ground_truth", ""), + "source_files": item.get("source_files", []), + "resolved_source_paths": candidate_files, + "hard": int(eval_result["em"]), + "soft": eval_result["f1"], + "fail_reason": fail_reason or ("" if eval_result["em"] else f"predicted '{eval_result['predicted_answer']}' but expected '{item.get('ground_truth', '')}'"), + "agent_ok": not fail_reason, + "n_turns": len(conversation), + "student_system_prompt": system, + "student_user_prompt": user, + } + return result + + +def run_batch( + items: list[dict], + out_root: str, + skill_content: str, + *, + workers: int = 8, + max_tool_turns: int = 12, + data_dirs: list[str] | str | None = None, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", +) -> list[dict]: + results_path = os.path.join(out_root, "results.jsonl") + os.makedirs(out_root, exist_ok=True) + + done_ids: set[str] = set() + existing: list[dict] = [] + if os.path.exists(results_path): + with open(results_path, encoding="utf-8") as f: + for line in f: + try: + row = json.loads(line) + except json.JSONDecodeError: + continue + done_ids.add(str(row.get("id"))) + existing.append(row) + + pending = [item for item in items if str(item["id"]) not in done_ids] + if not pending: + return existing + + results = list(existing) + with open(results_path, "a", encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex: + futs = { + ex.submit( + process_one, + item, + out_root, + skill_content, + max_tool_turns=max_tool_turns, + data_dirs=data_dirs, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + ): item + for item in pending + } + for fut in as_completed(futs): + res = fut.result() + results.append(res) + outf.write(json.dumps(res, ensure_ascii=False) + "\n") + outf.flush() + return results diff --git a/skillopt/envs/officeqa/skills/initial.md b/skillopt/envs/officeqa/skills/initial.md new file mode 100644 index 0000000..530b753 --- /dev/null +++ b/skillopt/envs/officeqa/skills/initial.md @@ -0,0 +1,15 @@ +# OfficeQA Skill + +## Retrieval Discipline +- Start by narrowing to the most likely candidate file before reading long passages. +- Prefer targeted search terms that name the exact entity, period, measure, or table concept from the question. +- After a promising match, read only a small surrounding span and verify it matches the requested year, basis, and unit. + +## Evidence Discipline +- Extract the exact value from the retrieved text before doing any arithmetic. +- Keep track of each operand's period, unit, and semantic role so nearby proxy values are not mixed in. +- If the question asks for a transformed or derived quantity, compute only after confirming every operand. + +## Final Answer Discipline +- Return the final answer only after one last consistency check against the retrieved evidence. +- Copy the final answer from a checked value, not from an unverified intermediate guess. diff --git a/skillopt/envs/officeqa/tool_runtime.py b/skillopt/envs/officeqa/tool_runtime.py new file mode 100644 index 0000000..0e29bf3 --- /dev/null +++ b/skillopt/envs/officeqa/tool_runtime.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import fnmatch +import os +from pathlib import Path + +_MAX_READ_CHARS = 4000 +_MAX_GREP_MATCHES = 20 +_MAX_GLOB_MATCHES = 50 + + +def _normalize_data_dirs(data_dirs: list[str] | tuple[str, ...] | str | None, project_root: Path) -> list[str]: + if data_dirs is None: + return [] + if isinstance(data_dirs, str): + items = [part.strip() for chunk in data_dirs.split(os.pathsep) for part in chunk.split(",")] + else: + items = [str(item).strip() for item in data_dirs] + resolved: list[str] = [] + for item in items: + if not item: + continue + path = Path(item).expanduser() + if not path.is_absolute(): + path = project_root / path + resolved.append(str(path)) + return resolved + + +def resolve_docs_roots(data_dirs: list[str] | tuple[str, ...] | str | None = None) -> list[str]: + project_root = Path(__file__).resolve().parents[3] + env_value = os.environ.get("OFFICEQA_DOCS_DIR", "").strip() + candidates = _normalize_data_dirs(data_dirs, project_root) + candidates.extend(_normalize_data_dirs(env_value, project_root)) + candidates.extend([ + str(project_root / "data" / "officeqa_docs_official"), + str(project_root / "data" / "officeqa_smoke_docs"), + os.path.expanduser("~/officeqa-sparse/treasury_bulletins_parsed"), + os.path.expanduser("~/officeqa/treasury_bulletins_parsed"), + ]) + roots: list[str] = [] + seen: set[str] = set() + for candidate in candidates: + path = Path(candidate).expanduser() + if not path.is_dir(): + continue + transformed = path / "transformed" + resolved = str((transformed if transformed.is_dir() else path).resolve()) + if resolved in seen: + continue + seen.add(resolved) + roots.append(resolved) + if not roots: + raise FileNotFoundError("OfficeQA docs directory not found. Set OFFICEQA_DOCS_DIR or env.data_dirs.") + return roots + + +def _is_allowed(path: str, allowed_roots: list[str], allowed_files: list[str]) -> bool: + try: + resolved = str(Path(path).resolve()) + except FileNotFoundError: + return False + if not any(resolved.startswith(root + os.sep) or resolved == root for root in allowed_roots): + return False + if not allowed_files: + return True + base = os.path.basename(resolved) + return base in allowed_files + + +def resolve_candidate_files(source_files: list[str], allowed_roots: list[str]) -> list[str]: + resolved: list[str] = [] + seen: set[str] = set() + for root in allowed_roots: + for dirpath, _, filenames in os.walk(root): + for filename in filenames: + if source_files and filename not in source_files: + continue + full = str(Path(dirpath, filename).resolve()) + if full in seen: + continue + seen.add(full) + resolved.append(full) + return resolved + + +def run_tool(name: str, arguments: dict, *, allowed_roots: list[str], allowed_files: list[str]) -> tuple[str, str]: + if name == "glob": + pattern = str(arguments.get("pattern") or "*") + matches: list[str] = [] + for root in allowed_roots: + for dirpath, _, filenames in os.walk(root): + for filename in filenames: + if allowed_files and filename not in allowed_files: + continue + rel = os.path.relpath(os.path.join(dirpath, filename), root) + if fnmatch.fnmatch(rel, pattern) or fnmatch.fnmatch(filename, pattern): + matches.append(os.path.join(dirpath, filename)) + if len(matches) >= _MAX_GLOB_MATCHES: + break + if len(matches) >= _MAX_GLOB_MATCHES: + break + return f"glob(pattern={pattern!r})", "\n".join(matches) if matches else "[no matches]" + + if name == "read": + path = str(arguments.get("path") or "") + if not path: + return "read(path='')", "[read error: missing path]" + if not _is_allowed(path, allowed_roots, allowed_files): + return f"read(path={path!r})", "[read error: path not allowed]" + start = max(int(arguments.get("start") or 1), 1) + limit = max(int(arguments.get("limit") or 80), 1) + with open(path, encoding="utf-8") as f: + lines = f.readlines() + excerpt = "".join(lines[start - 1:start - 1 + limit]) + return f"read(path={path!r}, start={start}, limit={limit})", excerpt[:_MAX_READ_CHARS] or "[empty file]" + + if name == "grep": + pattern = str(arguments.get("pattern") or "").lower() + path = str(arguments.get("path") or "") + if not pattern or not path: + return f"grep(pattern={pattern!r}, path={path!r})", "[grep error: missing pattern or path]" + if not _is_allowed(path, allowed_roots, allowed_files): + return f"grep(pattern={pattern!r}, path={path!r})", "[grep error: path not allowed]" + matches: list[str] = [] + with open(path, encoding="utf-8") as f: + for idx, line in enumerate(f, start=1): + if pattern in line.lower(): + matches.append(f"{idx}: {line.rstrip()}") + if len(matches) >= _MAX_GREP_MATCHES: + break + return f"grep(pattern={pattern!r}, path={path!r})", "\n".join(matches) if matches else "[no matches]" + + return name, f"[tool error: unknown tool {name}]" diff --git a/skillopt/envs/sealqa/__init__.py b/skillopt/envs/sealqa/__init__.py new file mode 100644 index 0000000..4672f51 --- /dev/null +++ b/skillopt/envs/sealqa/__init__.py @@ -0,0 +1 @@ +"""SealQA environment package for ReflACT.""" diff --git a/skillopt/envs/sealqa/adapter.py b/skillopt/envs/sealqa/adapter.py new file mode 100644 index 0000000..d5b6665 --- /dev/null +++ b/skillopt/envs/sealqa/adapter.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import os + +from skillopt.datasets.base import BatchSpec +from skillopt.envs.base import EnvAdapter +from skillopt.envs.deep_reflect import run_no_reference_deep_reflect +from skillopt.envs.sealqa.dataloader import SealQADataLoader +from skillopt.envs.sealqa.rollout import run_batch +from skillopt.gradient.reflect import run_minibatch_reflect + + +class SealQAAdapter(EnvAdapter): + def __init__( + self, + split_dir: str = '', + workers: int = 4, + analyst_workers: int = 8, + failure_only: bool = False, + minibatch_size: int = 8, + edit_budget: int = 4, + seed: int = 42, + limit: int = 0, + max_tool_turns: int = 12, + use_deep_reflect: bool = False, + deep_reflect_failures: int = 4, + deep_reflect_successes: int = 2, + ) -> None: + self.workers = workers + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.max_tool_turns = max_tool_turns + self.use_deep_reflect = use_deep_reflect + self.deep_reflect_failures = deep_reflect_failures + self.deep_reflect_successes = deep_reflect_successes + self.dataloader = SealQADataLoader(split_dir=split_dir, seed=seed, limit=limit) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + self.dataloader.setup(cfg) + + def get_dataloader(self): + return self.dataloader + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + return list(batch.payload or []) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout(self, env_manager, skill_content: str, out_dir: str, **kwargs) -> list[dict]: + items: list[dict] = env_manager + return run_batch( + items=items, + out_root=out_dir, + skill_content=skill_content, + workers=self.workers, + max_tool_turns=self.max_tool_turns, + diagnostic_mode=kwargs.get('diagnostic_mode', False), + diagnostic_instruction=kwargs.get('diagnostic_instruction', ''), + ) + + def reflect(self, results: list[dict], skill_content: str, out_dir: str, **kwargs) -> list[dict | None]: + prediction_dir = kwargs.get('prediction_dir', os.path.join(out_dir, 'predictions')) + patches_dir = kwargs.get('patches_dir', os.path.join(out_dir, 'patches')) + random_seed = kwargs.get('random_seed') + step_buffer_context = kwargs.get('step_buffer_context', '') + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + return run_no_reference_deep_reflect( + self, + results, + skill_content, + out_dir, + env_manager=kwargs.get('env_manager'), + prediction_dir=kwargs.get('prediction_dir'), + random_seed=kwargs.get('random_seed'), + step_buffer_context=kwargs.get('step_buffer_context', ''), + output_requirements=[ + "- There is no hidden reference block. Use only the question, provided evidence, URL/fetch trace, student output, and evaluation result to infer what intermediate state is worth probing.", + "- The instruction must explicitly request a short ... block before the final ....", + "- The readout should focus on effective time frame, conflicting evidence, decisive source, candidate answer, and answer-finalization rule.", + "- Do not ask for exhaustive web summaries or a full chain-of-thought.", + "- The instruction text should be ready to append directly to the student's prompt.", + ], + metadata_builder=lambda item: { + "id": str(item.get('id')), + "task_type": str(item.get('task_type') or item.get('topic') or 'sealqa'), + "question_preview": str(item.get('question') or '')[:200], + "freshness": item.get('freshness', ''), + "question_types": item.get('question_types', ''), + "topic": item.get('topic', ''), + }, + ) + + def get_task_types(self) -> list[str]: + seen: list[str] = [] + for item in self.dataloader.train_items + self.dataloader.val_items + self.dataloader.test_items: + task_type = str(item.get('task_type') or 'sealqa') + if task_type not in seen: + seen.append(task_type) + return seen or ['sealqa'] diff --git a/skillopt/envs/sealqa/dataloader.py b/skillopt/envs/sealqa/dataloader.py new file mode 100644 index 0000000..ed6afd0 --- /dev/null +++ b/skillopt/envs/sealqa/dataloader.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import csv +from pathlib import Path + +from skillopt.datasets.base import SplitDataLoader + + +def _normalize_row(row: dict[str, str], index: int) -> dict: + canary = str(row.get('canary') or '').strip() + base_id = str(row.get('question_id') or row.get('id') or '').strip() + if not base_id: + base_id = f"{canary or 'sealqa'}:{index:04d}" + return { + 'id': base_id, + 'question': str(row.get('question') or '').strip(), + 'ground_truth': str(row.get('answer') or row.get('ground_truth') or '').strip(), + 'answers': [str(row.get('answer') or row.get('ground_truth') or '').strip()], + 'task_type': str(row.get('topic') or 'sealqa').strip() or 'sealqa', + 'topic': str(row.get('topic') or 'sealqa').strip() or 'sealqa', + 'urls': str(row.get('urls') or '').strip(), + 'search_results': str(row.get('search_results') or '').strip(), + 'freshness': str(row.get('freshness') or '').strip(), + 'question_types': str(row.get('question_types') or '').strip(), + 'canary': canary, + } + + +class SealQADataLoader(SplitDataLoader): + def load_split_items(self, split_path: str) -> list[dict]: + path = Path(split_path) + csv_files = sorted(path.glob('*.csv')) + if not csv_files: + raise FileNotFoundError(f'No .csv file found in {split_path}') + with csv_files[0].open(encoding='utf-8', newline='') as f: + reader = csv.DictReader(f) + return [_normalize_row(row, idx) for idx, row in enumerate(reader, start=1)] diff --git a/skillopt/envs/sealqa/evaluator.py b/skillopt/envs/sealqa/evaluator.py new file mode 100644 index 0000000..65249ab --- /dev/null +++ b/skillopt/envs/sealqa/evaluator.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import re + +from openai import AzureOpenAI, OpenAI + +GRADER_TEMPLATE = """ +Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either [\"CORRECT\", \"INCORRECT\", \"NOT_ATTEMPTED\"]. +First, I will give examples of each grade, and then you will grade a new example. + +The following are examples of CORRECT predicted answers. +``` +Question: What are the names of Barack Obama's children? +Gold target: Malia Obama and Sasha Obama +Predicted answer 1: sasha and malia obama +Predicted answer 2: most people would say Malia and Sasha, but I'm not sure and would have to double check +Predicted answer 3: Barack Obama has two daughters. Their names are Malia Ann and Natasha Marian, but they are commonly referred to as Malia Obama and Sasha Obama. Malia was born on July 4, 1998, and Sasha was born on June 10, 2001. +``` +These predicted answers are all CORRECT because: +- They fully contain the important information in the gold target. +- They do not contain any information that contradicts the gold target. +- Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter. +- Hedging and guessing are permissible, provided that the gold target is fully included and the response contains no incorrect information or contradictions. + +The following are examples of INCORRECT predicted answers. +``` +Question: What are the names of Barack Obama's children? +Gold target: Malia and Sasha +Predicted answer 1: Malia. +Predicted answer 2: Malia, Sasha, and Susan. +Predicted answer 3: Barack Obama does not have any children. +``` +These predicted answers are all INCORRECT because a factual statement in the answer contradicts the gold target. + +The following are examples of NOT_ATTEMPTED predicted answers. +``` +Question: What are the names of Barack Obama's children? +Gold target: Malia and Sasha +Predicted answer 1: I don't know. +Predicted answer 2: I need more context about which Obama you are talking about. +``` +These predicted answers are all NOT_ATTEMPTED because the important information in the gold target is not included and there is no contradiction. + +Here is a new example. Simply reply with either CORRECT, INCORRECT, NOT ATTEMPTED. Just return the letters \"A\", \"B\", or \"C\", with no text around it. +``` +Question: {question} +Gold target: {target} +Predicted answer: {predicted_answer} +``` + +Grade the predicted answer as one of: +A: CORRECT +B: INCORRECT +C: NOT_ATTEMPTED +""".strip() + + +def _build_grader_client() -> tuple[OpenAI | AzureOpenAI, str]: + import os + + endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT', '').strip() + api_version = os.environ.get('AZURE_OPENAI_API_VERSION', '').strip() or '2025-04-01-preview' + azure_key = os.environ.get('AZURE_OPENAI_API_KEY', '').strip() + openai_key = os.environ.get('OPENAI_API_KEY', '').strip() + api_key = azure_key or openai_key + if endpoint and api_version and api_key: + model = os.environ.get('SEALQA_GRADER_AZURE_MODEL', '').strip() or os.environ.get('SEALQA_GRADER_MODEL', '').strip() or os.environ.get('AZURE_MODEL_NAME', '').strip() or os.environ.get('TEACHER_DEPLOYMENT', '').strip() or 'gpt-5.4' + client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=endpoint.rstrip('/')) + return client, model + + if openai_key: + model = os.environ.get('SEALQA_GRADER_OPENAI_MODEL', '').strip() or os.environ.get('SEALQA_GRADER_MODEL', '').strip() or 'gpt-4.1-mini' + return OpenAI(api_key=openai_key), model + + raise ValueError('Missing grader credentials for SealQA scoring.') + + +def _extract_text_content(content) -> str: + if content is None: + return '' + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for part in content: + if isinstance(part, dict) and part.get('type') == 'text': + parts.append(str(part.get('text', ''))) + else: + text = getattr(part, 'text', None) + if text: + parts.append(str(text)) + return '\n'.join(parts).strip() + return str(content).strip() + + +def _normalize_text(text: str) -> str: + lowered = text.strip().lower() + lowered = re.sub(r'\s+', ' ', lowered) + lowered = re.sub(r'[^\w\s%.-]', '', lowered) + return lowered.strip() + + +def _fallback_score(ground_truth: str, predicted: str) -> float: + gold = _normalize_text(ground_truth) + pred = _normalize_text(predicted) + if not gold or not pred: + return 0.0 + if gold == pred: + return 1.0 + if gold in pred or pred in gold: + return 1.0 + return 0.0 + + +def score_sealqa(question: str, ground_truth: str, predicted: str) -> float: + try: + client, model = _build_grader_client() + except ValueError: + return _fallback_score(ground_truth, predicted) + + prompt = GRADER_TEMPLATE.format(question=question, target=ground_truth, predicted_answer=predicted) + completion = client.chat.completions.create(model=model, messages=[{'role': 'user', 'content': prompt}]) + content = _extract_text_content(completion.choices[0].message.content).strip().upper() + if content.startswith('A'): + return 1.0 + return 0.0 diff --git a/skillopt/envs/sealqa/prompts/analyst_error.md b/skillopt/envs/sealqa/prompts/analyst_error.md new file mode 100644 index 0000000..8bda1f9 --- /dev/null +++ b/skillopt/envs/sealqa/prompts/analyst_error.md @@ -0,0 +1,30 @@ +You are an expert failure-analysis agent for evidence-seeking factual question answering tasks. + +You will be given MULTIPLE failed SealQA trajectories from a single minibatch and the current skill document. The trajectories may include tool calls such as search, fetch, local reads, or evidence gathering steps. + +Your job is to identify COMMON failure patterns across the batch and propose concise skill edits. + +## Failure Type Categories +- retrieval_miss: the agent failed to gather the right evidence +- evidence_conflict: the agent saw conflicting evidence but resolved it badly +- answer_selection: the agent found evidence but chose the wrong final answer +- not_attempted: the agent never reached a grounded answer +- other: none of the above + +Respond ONLY with a valid JSON object (no markdown fences, no extra text): +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +Only include edits that are needed. "edits" can be an empty list if no patch is warranted. diff --git a/skillopt/envs/sealqa/prompts/analyst_success.md b/skillopt/envs/sealqa/prompts/analyst_success.md new file mode 100644 index 0000000..6856877 --- /dev/null +++ b/skillopt/envs/sealqa/prompts/analyst_success.md @@ -0,0 +1,19 @@ +You are an expert success-pattern analyst for evidence-seeking factual question answering tasks. + +You will be given MULTIPLE successful SealQA trajectories from a single minibatch and the current skill document. Your job is to identify common evidence-gathering and answer-selection behaviors worth encoding in the skill. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +"edits" may be empty if the skill already covers all observed patterns. diff --git a/skillopt/envs/sealqa/prompts/rollout_system.md b/skillopt/envs/sealqa/prompts/rollout_system.md new file mode 100644 index 0000000..63a95be --- /dev/null +++ b/skillopt/envs/sealqa/prompts/rollout_system.md @@ -0,0 +1,3 @@ +You are an expert research assistant. Use the provided search evidence first, and only if that is insufficient, inspect the provided URL content fetched for you. Reconcile conflicting information when necessary and return a concise final answer grounded in the evidence you found. + +{skill_section}Return the final answer inside ... when you are ready. diff --git a/skillopt/envs/sealqa/rollout.py b/skillopt/envs/sealqa/rollout.py new file mode 100644 index 0000000..7f1554e --- /dev/null +++ b/skillopt/envs/sealqa/rollout.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import json +import os +import re +from concurrent.futures import ThreadPoolExecutor, as_completed + +from skillopt.envs.sealqa.evaluator import score_sealqa +from skillopt.envs.sealqa.tool_runtime import web_fetch +from skillopt.model import chat_student, get_student_backend, is_student_exec_backend +from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec +from skillopt.prompts import load_prompt + +_FINAL_RE = re.compile(r"(.*?)", re.IGNORECASE | re.DOTALL) + + +def _build_system(skill_content: str) -> str: + if skill_content.strip(): + skill_section = f"## Skill\n{skill_content.strip()}\n\n" + else: + skill_section = "" + return load_prompt("rollout_system", env="sealqa").format(skill_section=skill_section) + + +def _build_user(item: dict, *, diagnostic_mode: bool = False, diagnostic_instruction: str = '') -> str: + parts = [f"## Question\n{item['question']}"] + if item.get('search_results'): + parts.append(f"## Search Results\n{item['search_results']}") + if item.get('urls'): + parts.append(f"## URL Hints\n{item['urls']}") + if item.get('freshness'): + parts.append(f"## Freshness\n{item['freshness']}") + if item.get('question_types'): + parts.append(f"## Question Types\n{item['question_types']}") + if diagnostic_mode and diagnostic_instruction.strip(): + parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}") + parts.append('Use the provided search evidence as your primary context. Do not rely on external tool use.') + return "\n\n".join(parts) + + +def _extract_answer(text: str) -> str: + match = _FINAL_RE.search(text) + if match: + return match.group(1).strip() + lines = [line.strip() for line in text.splitlines() if line.strip()] + return lines[-1] if lines else text.strip() + + +def _build_codex_skill(skill_content: str) -> str: + return render_skill_md( + skill_content, + description="Dynamic ReflACT skill for solving the current SealQA evidence-grounded question.", + preamble=( + "Use this skill when answering the current SealQA question.\n" + "Use the provided search evidence first, reconcile conflicts carefully,\n" + "and return the final answer inside ...." + ), + ) + + +def _run_codex_once( + *, + pred_dir: str, + skill_content: str, + task_text: str, + model: str, + timeout: int, + previous_response: str = '', +) -> tuple[str, str, str, str]: + task_parts = [task_text] + if previous_response: + task_parts.append( + "## Previous Attempt\n" + f"{previous_response}\n\n" + "Review the evidence again and correct the final answer if needed." + ) + final_task_text = "\n\n".join(task_parts) + skill_md = _build_codex_skill(skill_content) + work_dir = os.path.join(pred_dir, 'codex_exec') + prepare_workspace( + work_dir=work_dir, + skill_md=skill_md, + task_text=final_task_text, + ) + prompt = ( + "Use the `skillopt-student` skill available in this workspace.\n" + "Read `task.md`, answer the SealQA question using the provided evidence,\n" + "and return the final answer inside ...." + ) + final_message, raw = run_student_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + ) + return final_message or raw, raw, skill_md, final_task_text + + +def process_one( + item: dict, + out_root: str, + skill_content: str, + *, + max_tool_turns: int = 12, + diagnostic_mode: bool = False, + diagnostic_instruction: str = '', +) -> dict: + item_id = str(item['id']) + pred_dir = os.path.join(out_root, 'predictions', item_id) + os.makedirs(pred_dir, exist_ok=True) + + system = _build_system(skill_content) + user = _build_user( + item, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + ) + conversation: list[dict] = [{'role': 'user', 'content': user}] + final_response = '' + final_answer = '' + fail_reason = '' + + try: + if is_student_exec_backend(): + from skillopt.model import azure_openai as _llm + + response, _raw, system, user_for_save = _run_codex_once( + pred_dir=pred_dir, + skill_content=skill_content, + task_text=user, + model=_llm.STUDENT_DEPLOYMENT, + timeout=120, + ) + final_response = response + conversation.append({'type': 'message', 'content': response}) + if '' in response.lower(): + final_answer = _extract_answer(response) + else: + user = user_for_save + else: + response, _ = chat_student( + system=system, + user=user, + max_completion_tokens=768, + retries=5, + stage='rollout', + ) + final_response = response + conversation.append({'type': 'message', 'content': response}) + if '' in response.lower(): + final_answer = _extract_answer(response) + + if not final_answer: + urls_text = str(item.get('urls') or '').strip() + fetched_blocks = [] + for raw_url in re.findall(r'https?://[^\s\]\[\'\",]+', urls_text)[:2]: + try: + fetched = web_fetch(raw_url) + except Exception as fetch_error: # noqa: BLE001 + fetched = f'URL: {raw_url}\n\n[fetch error: {fetch_error}]' + fetched_blocks.append(fetched) + conversation.append({'type': 'tool_call', 'cmd': f'web_fetch({raw_url!r})', 'obs': fetched}) + if fetched_blocks: + retry_user = user + '\n\n## Fetched URL Content\n' + '\n\n'.join(fetched_blocks) + if is_student_exec_backend(): + retry_response, _raw, system, retry_user = _run_codex_once( + pred_dir=pred_dir, + skill_content=skill_content, + task_text=retry_user, + model=_llm.STUDENT_DEPLOYMENT, + timeout=120, + previous_response=final_response, + ) + else: + retry_response, _ = chat_student( + system=system, + user=retry_user, + max_completion_tokens=768, + retries=5, + stage='rollout', + ) + final_response = retry_response + conversation.append({'type': 'message', 'content': retry_response}) + if '' in retry_response.lower(): + final_answer = _extract_answer(retry_response) + else: + fail_reason = 'Model did not produce a final answer' + else: + fail_reason = 'Model did not produce a final answer' + except Exception as e: # noqa: BLE001 + fail_reason = f'error: {e}' + + with open(os.path.join(pred_dir, 'student_system_prompt.txt'), 'w', encoding='utf-8') as f: + f.write(system) + with open(os.path.join(pred_dir, 'student_user_prompt.txt'), 'w', encoding='utf-8') as f: + f.write(user) + with open(os.path.join(pred_dir, 'conversation.json'), 'w', encoding='utf-8') as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + + score = score_sealqa(item.get('question', ''), item.get('ground_truth', ''), final_answer) if final_answer else 0.0 + result = { + 'id': item_id, + 'question': item.get('question', ''), + 'task_type': item.get('task_type', 'sealqa'), + 'task_description': item.get('question', ''), + 'predicted_answer': final_answer, + 'response': final_response, + 'ground_truth': item.get('ground_truth', ''), + 'hard': int(score >= 1.0), + 'soft': float(score), + 'fail_reason': fail_reason or ('' if score >= 1.0 else f"predicted '{final_answer}' but expected '{item.get('ground_truth', '')}'"), + 'agent_ok': not fail_reason, + 'n_turns': len(conversation), + 'student_system_prompt': system, + 'student_user_prompt': user, + } + return result + + +def run_batch( + items: list[dict], + out_root: str, + skill_content: str, + *, + workers: int = 4, + max_tool_turns: int = 12, + diagnostic_mode: bool = False, + diagnostic_instruction: str = '', +) -> list[dict]: + results_path = os.path.join(out_root, 'results.jsonl') + os.makedirs(out_root, exist_ok=True) + + done_ids: set[str] = set() + existing: list[dict] = [] + if os.path.exists(results_path): + with open(results_path, encoding='utf-8') as f: + for line in f: + try: + row = json.loads(line) + except json.JSONDecodeError: + continue + done_ids.add(str(row.get('id'))) + existing.append(row) + + pending = [item for item in items if str(item['id']) not in done_ids] + if not pending: + return existing + + results = list(existing) + with open(results_path, 'a', encoding='utf-8') as outf, ThreadPoolExecutor(max_workers=workers) as ex: + futs = { + ex.submit( + process_one, + item, + out_root, + skill_content, + max_tool_turns=max_tool_turns, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + ): item + for item in pending + } + for fut in as_completed(futs): + res = fut.result() + results.append(res) + outf.write(json.dumps(res, ensure_ascii=False) + '\n') + outf.flush() + return results diff --git a/skillopt/envs/sealqa/skills/initial.md b/skillopt/envs/sealqa/skills/initial.md new file mode 100644 index 0000000..159f0bd --- /dev/null +++ b/skillopt/envs/sealqa/skills/initial.md @@ -0,0 +1,11 @@ +# SealQA Skill + +## Evidence Gathering +- Search for the most directly relevant evidence before answering. +- If multiple sources conflict, prefer the source that best matches the question's entity, date, and scope. +- Keep notes on which evidence directly answers the question versus which evidence is only contextual. + +## Final Answer Discipline +- Do not answer until the supporting evidence is specific enough. +- Choose the final answer that is best grounded in the gathered evidence. +- Keep the final answer concise. diff --git a/skillopt/envs/sealqa/tool_runtime.py b/skillopt/envs/sealqa/tool_runtime.py new file mode 100644 index 0000000..40a6b50 --- /dev/null +++ b/skillopt/envs/sealqa/tool_runtime.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import html +import re +from urllib.request import Request, urlopen + +DEFAULT_USER_AGENT = ( + 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 ' + '(KHTML, like Gecko) Chrome/135.0 Safari/537.36' +) +_MAX_FETCH_CHARS = 6000 + + +def _strip_html(raw_html: str) -> str: + cleaned = re.sub(r'(?is).*?', ' ', raw_html) + cleaned = re.sub(r'(?is).*?', ' ', cleaned) + cleaned = re.sub(r'(?is)<[^>]+>', ' ', cleaned) + cleaned = html.unescape(cleaned) + return re.sub(r'\s+', ' ', cleaned).strip() + + +def web_fetch(url: str, max_chars: int = _MAX_FETCH_CHARS) -> str: + req = Request(url, headers={'User-Agent': DEFAULT_USER_AGENT}) + with urlopen(req, timeout=20) as response: + body = response.read().decode('utf-8', errors='ignore') + text = _strip_html(body) + if len(text) > max_chars: + omitted = len(text) - max_chars + text = text[:max_chars] + f"\n\n[... {omitted} characters omitted ...]" + return f"URL: {url}\n\n{text}" diff --git a/skillopt/envs/searchqa/__init__.py b/skillopt/envs/searchqa/__init__.py new file mode 100644 index 0000000..d60fc5f --- /dev/null +++ b/skillopt/envs/searchqa/__init__.py @@ -0,0 +1 @@ +"""SearchQA environment package for ReflACT.""" diff --git a/skillopt/envs/searchqa/adapter.py b/skillopt/envs/searchqa/adapter.py new file mode 100644 index 0000000..5bd8800 --- /dev/null +++ b/skillopt/envs/searchqa/adapter.py @@ -0,0 +1,250 @@ +"""SearchQA environment adapter for ReflACT.""" +from __future__ import annotations + +import json +import os + +from skillopt.gradient.deep_probe import generate_deep_probe_instruction +from skillopt.datasets.base import BatchSpec +from skillopt.envs.base import EnvAdapter +from skillopt.envs.searchqa.dataloader import SearchQADataLoader +from skillopt.envs.searchqa.rollout import run_batch +from skillopt.gradient.reflect import run_minibatch_reflect +from skillopt.model import get_student_backend + + +class SearchQAAdapter(EnvAdapter): + """SearchQA environment adapter.""" + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + max_turns: int = 1, + exec_timeout: int = 120, + workers: int = 64, + analyst_workers: int = 16, + failure_only: bool = False, + minibatch_size: int = 8, + edit_budget: int = 4, + seed: int = 42, + limit: int = 0, + exec_timeout: int = 600, + use_deep_reflect: bool = False, + deep_reflect_failures: int = 4, + deep_reflect_successes: int = 2, + ) -> None: + self.max_turns = max_turns + self.exec_timeout = exec_timeout + self.workers = workers + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.exec_timeout = exec_timeout + self.use_deep_reflect = use_deep_reflect + self.deep_reflect_failures = deep_reflect_failures + self.deep_reflect_successes = deep_reflect_successes + self.dataloader = SearchQADataLoader( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + ) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + self.dataloader.setup(cfg) + + def get_dataloader(self): + return self.dataloader + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + return list(batch.payload or []) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout( + self, + env_manager, # actually list[dict] for SearchQA + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict]: + """Run QA agent on items. Resume-aware.""" + items: list[dict] = env_manager # type alias for clarity + return run_batch( + items=items, + out_root=out_dir, + skill_content=skill_content, + max_turns=self.max_turns, + exec_timeout=self.exec_timeout, + workers=self.workers, + diagnostic_mode=kwargs.get("diagnostic_mode", False), + diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), + diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"), + task_timeout=self.exec_timeout, + ) + + def reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + if not self.use_deep_reflect: + return [] + + env_manager = kwargs.get("env_manager") + if not isinstance(env_manager, list): + return [] + + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + codex_backend = get_student_backend() == "codex_exec" + selected_items = self.select_representative_items( + results, + env_manager, + n_failures=self.deep_reflect_failures, + n_successes=self.deep_reflect_successes, + seed=random_seed, + ) + if not selected_items: + return [] + + selected_ids = {str(item["id"]) for item in selected_items} + selected_results = [row for row in results if str(row.get("id")) in selected_ids] + selected_examples = ( + self.attach_codex_probe_context(selected_results, prediction_dir) + if codex_backend + else selected_results + ) + selected_metadata = [ + { + "id": str(item["id"]), + "question_preview": str(item.get("question") or "")[:200], + "has_context": bool(str(item.get("context") or "").strip()), + "n_gold_answers": len(item.get("answers") or []), + } + for item in selected_items + ] + + deep_dir = os.path.join(out_dir, "deep_reflect") + rollout_dir = os.path.join(deep_dir, "rollout") + patches_dir = os.path.join(deep_dir, "patches") + os.makedirs(deep_dir, exist_ok=True) + print( + f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} " + f"mode=no_reference_probe" + ) + probe = generate_deep_probe_instruction( + skill_content=skill_content, + items=selected_examples, + prediction_dir=prediction_dir, + system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + output_requirements=[ + "- There is no hidden reference block. Use only the question, provided context, the student's output, and the evaluation result to infer what intermediate state is worth probing.", + "- The instruction must explicitly request a short ... block before the final ....", + "- The readout should focus on likely evidence span, top candidate and runner-up, decisive clue, or a few short intermediate conclusions.", + "- Do not ask for exhaustive copying of the context or a full chain-of-thought.", + "- The instruction text should be ready to append directly to the student's prompt.", + ], + ) + if not probe: + return [] + diagnostic_trace_context_by_id = None + if codex_backend: + selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target( + selected_items=selected_items, + selected_examples=selected_examples, + prediction_dir=prediction_dir, + probe=probe, + ) + + with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f: + json.dump( + { + **probe, + "selected_examples": selected_metadata, + }, + f, + ensure_ascii=False, + indent=2, + ) + + deep_results = self.rollout( + selected_items, + skill_content, + rollout_dir, + diagnostic_mode=True, + diagnostic_instruction=probe["probe_instruction"], + diagnostic_trace_context_by_id=diagnostic_trace_context_by_id, + ) + return run_minibatch_reflect( + results=deep_results, + skill_content=skill_content, + prediction_dir=os.path.join(rollout_dir, "predictions"), + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def get_task_types(self) -> list[str]: + return ["qa"] diff --git a/skillopt/envs/searchqa/dataloader.py b/skillopt/envs/searchqa/dataloader.py new file mode 100644 index 0000000..2dc1c1e --- /dev/null +++ b/skillopt/envs/searchqa/dataloader.py @@ -0,0 +1,42 @@ +"""SearchQA task dataloader.""" +from __future__ import annotations + +import json + +from skillopt.datasets.base import SplitDataLoader + + +# ── Raw data loading utilities (for preprocessing / standalone eval) ───── + +def _load_items(path: str) -> list[dict]: + """Load items from JSON or JSONL file.""" + with open(path) as f: + content = f.read().strip() + try: + data = json.loads(content) + if isinstance(data, list): + return data + if isinstance(data, dict): + return data.get("data") or list(data.values()) + except json.JSONDecodeError: + pass + + items = [] + for line in content.splitlines(): + line = line.strip() + if line: + items.append(json.loads(line)) + return items + + +# ── Dataloader ─────────────────────────────────────────────────────────── + +class SearchQADataLoader(SplitDataLoader): + """SearchQA dataloader. + + Each split directory (train/, val/, test/) contains a .json file — + a JSON array of question items. + """ + + def load_raw_items(self, data_path: str) -> list[dict]: + return _load_items(data_path) diff --git a/skillopt/envs/searchqa/evaluator.py b/skillopt/envs/searchqa/evaluator.py new file mode 100644 index 0000000..8c6c488 --- /dev/null +++ b/skillopt/envs/searchqa/evaluator.py @@ -0,0 +1,100 @@ +"""SearchQA evaluation metrics: Exact Match, F1, and Substring Match. + +Normalization follows the SQuAD convention: + - lowercase + - remove punctuation + - remove articles (a, an, the) + - collapse whitespace + +Answer extraction looks for ... XML tags, +falling back to the last non-empty line of the response. +""" +from __future__ import annotations + +import re +import string +from collections import Counter + + +def normalize_answer(s: str) -> str: + """Normalize answer string (SQuAD convention).""" + s = s.lower() + s = "".join(ch for ch in s if ch not in string.punctuation) + s = re.sub(r"\b(a|an|the)\b", " ", s) + s = " ".join(s.split()) + return s.strip() + + +def extract_answer(text: str) -> str: + """Extract answer from ... tags. + + Fallback: last non-empty line, then full response stripped. + """ + matches = re.findall(r"(.*?)", text, re.DOTALL | re.IGNORECASE) + if matches: + return matches[-1].strip() + lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()] + if lines: + return lines[-1] + return text.strip() + + +def exact_match(prediction: str, gold_answers: list[str]) -> float: + norm_pred = normalize_answer(prediction) + for gold in gold_answers: + if normalize_answer(gold) == norm_pred: + return 1.0 + return 0.0 + + +def f1_score(prediction: str, gold_answers: list[str]) -> float: + """Token-level F1 (SQuAD-style), max across all gold answers.""" + norm_pred = normalize_answer(prediction) + pred_tokens = norm_pred.split() + + if not pred_tokens: + for gold in gold_answers: + if not normalize_answer(gold).split(): + return 1.0 + return 0.0 + + best_f1 = 0.0 + for gold in gold_answers: + gold_tokens = normalize_answer(gold).split() + if not gold_tokens: + continue + common = Counter(pred_tokens) & Counter(gold_tokens) + n_common = sum(common.values()) + if n_common == 0: + continue + precision = n_common / len(pred_tokens) + recall = n_common / len(gold_tokens) + f1 = 2 * precision * recall / (precision + recall) + best_f1 = max(best_f1, f1) + + return best_f1 + + +def sub_em(prediction: str, gold_answers: list[str]) -> float: + """1.0 if any normalized gold is a substring of prediction, or vice versa.""" + norm_pred = normalize_answer(prediction) + for gold in gold_answers: + norm_gold = normalize_answer(gold) + if norm_gold in norm_pred or norm_pred in norm_gold: + return 1.0 + return 0.0 + + +def evaluate(prediction_text: str, gold_answers: list[str]) -> dict: + """Evaluate a single QA prediction against gold answers. + + Returns dict with: em, f1, sub_em, predicted_answer, gold_answers. + """ + answer = extract_answer(prediction_text) + return { + "em": exact_match(answer, gold_answers), + "f1": f1_score(answer, gold_answers), + "sub_em": sub_em(answer, gold_answers), + "predicted_answer": answer, + "gold_answers": gold_answers, + } diff --git a/skillopt/envs/searchqa/prompts/analyst_error.md b/skillopt/envs/searchqa/prompts/analyst_error.md new file mode 100644 index 0000000..a60e73d --- /dev/null +++ b/skillopt/envs/searchqa/prompts/analyst_error.md @@ -0,0 +1,46 @@ +You are an expert failure-analysis agent for question answering tasks. + +You will be given MULTIPLE failed QA agent responses from a single minibatch +and the current skill document. Each trajectory includes the agent's response +and an evaluation result showing the predicted answer vs. the gold answer(s). + +Your job is to identify the most important COMMON failure patterns across +the batch and propose a concise set of skill edits. + +## Failure Type Categories +- **rule_missing**: the skill lacks a relevant rule for this type of question +- **rule_wrong**: an existing skill rule is misleading or incorrect +- **rule_ignored**: the skill has the right rule but the agent did not follow it +- **answer_format**: the agent found the right information but formatted it incorrectly +- **other**: none of the above + +## Analysis Process +1. Read ALL failed trajectories in the minibatch. +2. Carefully compare each predicted answer against the gold answer(s) — + understand exactly WHY the Exact Match failed. +3. Identify the most prevalent, systematic failure patterns across them. +4. For each pattern, classify its failure type. +5. Propose skill edits that address the COMMON patterns — not individual edge cases. +6. Edits must be generalizable; do not hardcode question-specific values. +7. Only patch gaps in the skill — do not duplicate existing content. + +You will be told the maximum number of edits (the budget L). Produce AT MOST L edits, +focusing on the highest-impact patterns. You may produce fewer if warranted. + +Respond ONLY with a valid JSON object (no markdown fences, no extra text): +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +Only include edits that are needed. "edits" can be an empty list if no patch is warranted. diff --git a/skillopt/envs/searchqa/prompts/analyst_success.md b/skillopt/envs/searchqa/prompts/analyst_success.md new file mode 100644 index 0000000..6476d94 --- /dev/null +++ b/skillopt/envs/searchqa/prompts/analyst_success.md @@ -0,0 +1,32 @@ +You are an expert success-pattern analyst for AI question answering agents. + +You will be given MULTIPLE successful QA agent responses from a single minibatch +and the current skill document. Your job is to identify generalizable behavior +patterns that are COMMON across the batch and worth encoding in the skill. + +## Rules +- Only propose patches for patterns NOT already covered in the skill. +- Focus on patterns that appear across MULTIPLE trajectories in the batch. +- Be concise. Patterns must generalize beyond specific questions. +- Prefer reinforcing existing sections over adding new top-level sections. +- If the agents' success involved a smart reading strategy or disambiguation + approach, consider reinforcing that in the patch. + +You will be told the maximum number of edits (the budget L). Produce AT MOST L edits, +focusing on the most broadly applicable patterns. You may produce fewer if warranted. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +"edits" may be empty if the skill already covers all observed patterns. diff --git a/skillopt/envs/searchqa/prompts/deep_probe.md b/skillopt/envs/searchqa/prompts/deep_probe.md new file mode 100644 index 0000000..63ab811 --- /dev/null +++ b/skillopt/envs/searchqa/prompts/deep_probe.md @@ -0,0 +1,27 @@ +You are an expert diagnostic-probe designer for retrieval-style question answering tasks. + +You will be shown representative trajectories, the current student skill, the student's prompt context, +and the evaluation result including the gold answer. There is NO hidden chain-of-thought reference. +Design one SMALL diagnostic instruction that exposes the student's intermediate reading or evidence-selection state +without materially changing the original scaffold. + +## Hard Constraints +1. Do NOT substantially change the original scaffold. +2. Do NOT prescribe a brand-new multi-step solving procedure. +3. You MAY ask for a short structured readout of intermediate conclusions, evidence candidates, or elimination decisions. +4. Do NOT ask for exhaustive quotation of the whole context or a full chain-of-thought. +5. Keep it brief and structured, and require the final answer to remain in .... +6. Use the gold answer only to target a useful probe; do not simply force the student to restate the gold answer. + +## Good Probe Targets +- the most likely supporting span or document cue +- top answer candidate and runner-up +- decisive lexical clue / entity / date / title +- why a tempting alternative was rejected +- 2-4 short intermediate conclusions that directly support the final answer + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "probe_instruction": "" +} diff --git a/skillopt/envs/searchqa/prompts/rollout_system.md b/skillopt/envs/searchqa/prompts/rollout_system.md new file mode 100644 index 0000000..1befe4e --- /dev/null +++ b/skillopt/envs/searchqa/prompts/rollout_system.md @@ -0,0 +1,13 @@ +You are an expert question answering agent. + +{skill_section}## Task Format +You will receive a CONTEXT containing document passages and a QUESTION. +Read the context carefully and answer the question based on the information provided. + +## Answer Format +Think step by step, then provide your final answer inside ... tags. +Keep your answer concise — typically a few words or a short phrase. +Do not repeat the question. Do not include unnecessary explanation in the answer tags. + +Example: +Abraham Lincoln diff --git a/skillopt/envs/searchqa/reflect.py b/skillopt/envs/searchqa/reflect.py new file mode 100644 index 0000000..7a99207 --- /dev/null +++ b/skillopt/envs/searchqa/reflect.py @@ -0,0 +1,4 @@ +"""SearchQA Reflect stage. + +Prompts are now loaded from .md files by the base adapter. +""" diff --git a/skillopt/envs/searchqa/rollout.py b/skillopt/envs/searchqa/rollout.py new file mode 100644 index 0000000..85f15d6 --- /dev/null +++ b/skillopt/envs/searchqa/rollout.py @@ -0,0 +1,455 @@ +"""SearchQA rollout — single-turn QA agent + batch execution. + +The QA agent receives a skill document, question, and context passages, +then produces an answer in ... tags. + +Public API +---------- +- :func:`process_one` — run + evaluate one QA item +- :func:`run_batch` — parallel execution of a list of items +""" +from __future__ import annotations + +import json +import os +import time +import traceback +from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait + +from skillopt.model import chat_student, get_student_backend, is_student_exec_backend +from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec +from skillopt.prompts import load_prompt +from skillopt.envs.searchqa.evaluator import evaluate + + +# ── Prompt templates ───────────────────────────────────────────────────────── + +_MAX_CONTEXT_CHARS = 6000 + + +def _truncate_context(context: str, max_chars: int = _MAX_CONTEXT_CHARS) -> str: + """Truncate context at [DOC] boundaries to stay within budget.""" + if len(context) <= max_chars: + return context + docs = context.split("[DOC]") + result = "" + for doc in docs: + candidate = result + "[DOC]" + doc if result else doc + if len(candidate) > max_chars: + break + result = candidate + if not result: + result = context[:max_chars] + "\n...[truncated]" + return result + + +def _build_system(skill_content: str) -> str: + if skill_content.strip(): + skill_section = f"## Skill\n{skill_content.strip()}\n\n" + else: + skill_section = "" + return load_prompt("rollout_system", env="searchqa").format(skill_section=skill_section) + + +def _build_user( + question: str, + context: str, + *, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> str: + context = _truncate_context(context) + parts = [ + f"## Context\n{context}", + f"## Question\n{question}", + ] + if diagnostic_trace_context.strip(): + parts.append( + "## Previous Codex Trace Snapshot\n" + "This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n" + f"{diagnostic_trace_context.strip()}" + ) + if diagnostic_mode and diagnostic_instruction.strip(): + parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}") + return "\n\n".join(parts) + + +def _build_codex_skill(skill_content: str) -> str: + return render_skill_md( + skill_content, + description="Dynamic ReflACT skill for solving the current SearchQA example.", + preamble=( + "Use this skill when solving the current SearchQA task.\n" + "Read the provided context carefully, ground the answer in that context,\n" + "and return the final answer inside ...." + ), + ) + + +def _run_codex_once( + *, + pred_dir: str, + skill_content: str, + question: str, + context: str, + model: str, + timeout: int, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", + previous_response: str = "", +) -> tuple[str, str, str, str]: + user = _build_user( + question, + context, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + task_parts = [user] + if previous_response: + task_parts.append( + "## Previous Attempt\n" + f"{previous_response}\n\n" + "Review it against the same context and question. If needed, correct it." + ) + task_text = "\n\n".join(task_parts) + skill_md = _build_codex_skill(skill_content) + work_dir = os.path.join(pred_dir, "codex_exec") + prepare_workspace( + work_dir=work_dir, + skill_md=skill_md, + task_text=task_text, + ) + prompt = ( + "Use the `skillopt-student` skill available in this workspace.\n" + "Read `task.md` and answer the SearchQA question.\n" + "Return the final answer inside ...." + ) + final_message, raw = run_student_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + ) + return final_message or raw, raw, skill_md, task_text + + +# ── Single-item execution ─────────────────────────────────────────────────── + + +def process_one( + item: dict, + out_root: str, + skill_content: str, + max_turns: int = 1, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", + exec_timeout: int = 120, +) -> dict: + """Process a single QA item: run agent + evaluate. + + Parameters + ---------- + item : dict + Must have keys: ``id``, ``question``, ``context``, ``answers``. + out_root : str + Output directory (predictions saved under ``predictions//``). + skill_content : str + Current skill document text. + max_turns : int + Max reasoning turns (1 = single-turn QA). + + Returns + ------- + dict + Result with ``hard`` (EM as int), ``soft`` (F1), etc. + """ + item_id = str(item["id"]) + question = item["question"] + context = item.get("context", "") + gold_answers = item.get("answers", []) + + result = { + "id": item_id, + "question": question, + "em": 0.0, + "f1": 0.0, + "sub_em": 0.0, + "hard": 0, + "soft": 0.0, + "predicted_answer": "", + "gold_answers": gold_answers, + "response": "", + "fail_reason": "", + "agent_ok": False, + "n_turns": 0, + } + + try: + pred_dir = os.path.join(out_root, "predictions", item_id) + os.makedirs(pred_dir, exist_ok=True) + + if is_student_exec_backend(): + from skillopt.model import azure_openai as _llm + + conversation: list[dict] = [] + response = "" + system = "" + user = "" + for turn in range(max_turns): + response, raw, system, user = _run_codex_once( + pred_dir=pred_dir, + skill_content=skill_content, + question=question, + context=context, + model=_llm.STUDENT_DEPLOYMENT, + timeout=exec_timeout, + diagnostic_mode=diagnostic_mode if turn == 0 else False, + diagnostic_instruction=diagnostic_instruction if turn == 0 else "", + diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "", + previous_response=response if turn > 0 else "", + ) + conversation.append({"type": "message", "turn": turn + 1, "content": response}) + if turn > 0 and "" in response.lower(): + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) + + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w") as f: + f.write(system) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w") as f: + f.write(user) + with open(os.path.join(pred_dir, "conversation.json"), "w") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + + eval_result = evaluate(response, gold_answers) + result["em"] = eval_result["em"] + result["f1"] = eval_result["f1"] + result["sub_em"] = eval_result["sub_em"] + result["predicted_answer"] = eval_result["predicted_answer"] + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + if eval_result["em"] < 1.0: + result["fail_reason"] = ( + f"EM=0: predicted '{eval_result['predicted_answer']}' " + f"but expected {gold_answers}" + ) + eval_detail = ( + f"[EVALUATION RESULT]\n" + f"Question: {question}\n" + f"Predicted answer: {eval_result['predicted_answer']!r}\n" + f"Gold answers: {gold_answers!r}\n" + f"Exact Match: {eval_result['em']}\n" + f"F1: {eval_result['f1']:.4f}" + ) + conversation.append({"role": "system", "content": eval_detail}) + with open(os.path.join(pred_dir, "conversation.json"), "w") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + return result + + system = _build_system(skill_content) + user = _build_user( + question, + context, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + + conversation: list[dict] = [] + response = "" + + for turn in range(max_turns): + if turn == 0: + resp_text, _ = chat_student( + system=system, user=user, + max_completion_tokens=512, + retries=5, stage="rollout", + timeout=exec_timeout, + ) + else: + refinement = ( + f"Your previous answer was:\n{response}\n\n" + f"Review it against the context and question. " + f"If correct, repeat it. If wrong, provide a corrected answer.\n" + f"Use ... tags for your final answer." + ) + resp_text, _ = chat_student( + system=system, user=refinement, + max_completion_tokens=512, + retries=5, stage="rollout", + timeout=exec_timeout, + ) + + response = resp_text + conversation.append({"type": "message", "turn": turn + 1, "content": resp_text}) + + if turn > 0 and "" in resp_text.lower(): + break + + result["response"] = response + result["agent_ok"] = True + result["n_turns"] = len(conversation) + + # Save conversation + with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w") as f: + f.write(system) + with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w") as f: + f.write(user) + with open(os.path.join(pred_dir, "conversation.json"), "w") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + + # Evaluate + eval_result = evaluate(response, gold_answers) + result["em"] = eval_result["em"] + result["f1"] = eval_result["f1"] + result["sub_em"] = eval_result["sub_em"] + result["predicted_answer"] = eval_result["predicted_answer"] + result["hard"] = int(eval_result["em"]) + result["soft"] = eval_result["f1"] + + if eval_result["em"] < 1.0: + result["fail_reason"] = ( + f"EM=0: predicted '{eval_result['predicted_answer']}' " + f"but expected {gold_answers}" + ) + + # Append eval details to conversation for the analyst + eval_detail = ( + f"[EVALUATION RESULT]\n" + f"Question: {question}\n" + f"Predicted answer: {eval_result['predicted_answer']!r}\n" + f"Gold answers: {gold_answers!r}\n" + f"Exact Match: {eval_result['em']}\n" + f"F1: {eval_result['f1']:.4f}" + ) + conversation.append({ + "role": "system", + "content": eval_detail, + }) + # Re-save enriched conversation + with open(os.path.join(pred_dir, "conversation.json"), "w") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + + except Exception as e: # noqa: BLE001 + result["fail_reason"] = f"error: {e}" + + return result + + +# ── Batch execution ────────────────────────────────────────────────────────── + + +def run_batch( + items: list[dict], + out_root: str, + skill_content: str, + max_turns: int = 1, + exec_timeout: int = 120, + workers: int = 64, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context_by_id: dict[str, str] | None = None, + task_timeout: int = 600, +) -> list[dict]: + """Run QA agent on all items with ThreadPoolExecutor. Resume-aware.""" + task_timeout = max(int(task_timeout), int(exec_timeout) + 60) + results_path = os.path.join(out_root, "results.jsonl") + os.makedirs(out_root, exist_ok=True) + + # Resume: load already-done + done_ids: set[str] = set() + existing: list[dict] = [] + if os.path.exists(results_path): + with open(results_path) as f: + for line in f: + try: + r = json.loads(line) + done_ids.add(str(r["id"])) + existing.append(r) + except Exception: + pass + + pending = [it for it in items if str(it["id"]) not in done_ids] + if not pending: + return existing + + results = list(existing) + + def _timeout_result(item: dict) -> dict: + return { + "id": str(item["id"]), + "question": item.get("question", ""), + "task_description": item.get("question", ""), + "task_type": item.get("task_type") or "searchqa", + "hard": 0, + "soft": 0.0, + "predicted_answer": "", + "response": "", + "fail_reason": f"task-timeout-{task_timeout}s", + "agent_ok": False, + "n_turns": 0, + "gold_answer": item.get("answers", []), + "phase": "timeout", + } + + def _error_result(item: dict, exc: Exception) -> dict: + row = _timeout_result(item) + row["phase"] = "error" + row["fail_reason"] = f"unexpected: {type(exc).__name__}: {exc}" + return row + + started_at: dict[str, float] = {} + + def _run_one(item: dict) -> dict: + started_at[str(item["id"])] = time.time() + return process_one( + item, + out_root, + skill_content, + max_turns, + diagnostic_mode, + diagnostic_instruction, + (diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""), + exec_timeout, + ) + + with open(results_path, "a") as outf: + ex = ThreadPoolExecutor(max_workers=workers) + try: + futs = {ex.submit(_run_one, it): it for it in pending} + pending_futs = set(futs) + while pending_futs: + done, _ = wait(pending_futs, timeout=5, return_when=FIRST_COMPLETED) + now = time.time() + timed_out = [ + fut for fut in pending_futs - done + if str(futs[fut]["id"]) in started_at + and now - started_at[str(futs[fut]["id"])] >= task_timeout + ] + for fut in done: + pending_futs.remove(fut) + item = futs[fut] + try: + res = fut.result() + except Exception as exc: # noqa: BLE001 + res = _error_result(item, exc) + results.append(res) + outf.write(json.dumps(res, ensure_ascii=False) + "\n") + outf.flush() + for fut in timed_out: + pending_futs.remove(fut) + fut.cancel() + res = _timeout_result(futs[fut]) + results.append(res) + outf.write(json.dumps(res, ensure_ascii=False) + "\n") + outf.flush() + finally: + ex.shutdown(wait=False, cancel_futures=True) + + return results diff --git a/skillopt/envs/searchqa/skills/initial.md b/skillopt/envs/searchqa/skills/initial.md new file mode 100644 index 0000000..6bc64d8 --- /dev/null +++ b/skillopt/envs/searchqa/skills/initial.md @@ -0,0 +1,3 @@ +# Question Answering Skill + +(No learned rules yet. Rules will be added through the reflection process.) diff --git a/skillopt/envs/spreadsheetbench/__init__.py b/skillopt/envs/spreadsheetbench/__init__.py new file mode 100644 index 0000000..3db374b --- /dev/null +++ b/skillopt/envs/spreadsheetbench/__init__.py @@ -0,0 +1,5 @@ +"""SpreadsheetBench environment adapter for ReflACT.""" + +from skillopt.envs.spreadsheetbench.adapter import SpreadsheetBenchAdapter + +__all__ = ["SpreadsheetBenchAdapter"] diff --git a/skillopt/envs/spreadsheetbench/adapter.py b/skillopt/envs/spreadsheetbench/adapter.py new file mode 100644 index 0000000..7f69c85 --- /dev/null +++ b/skillopt/envs/spreadsheetbench/adapter.py @@ -0,0 +1,309 @@ +"""SpreadsheetBench environment adapter for ReflACT. + +Connects the ReflACT training loop to SpreadsheetBench by implementing +:class:`~skillopt.envs.base.EnvAdapter`. +""" +from __future__ import annotations + +import json +import os + +from skillopt.gradient.deep_probe import generate_deep_probe_instruction +from skillopt.datasets.base import BatchSpec +from skillopt.envs.base import EnvAdapter +from skillopt.envs.spreadsheetbench.dataloader import SpreadsheetBenchDataLoader +from skillopt.envs.spreadsheetbench.rollout import ( + process_one, + run_spreadsheet_batch, + run_spreadsheet_batch_codegen, +) +from skillopt.gradient.reflect import run_minibatch_reflect +from skillopt.model import get_student_backend, is_student_exec_backend + + +# Task types used for per-category breakdowns +TASK_TYPES = ["cell_level", "sheet_level"] + + +class SpreadsheetBenchAdapter(EnvAdapter): + """SpreadsheetBench environment adapter.""" + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + data_root: str = "", + mode: str = "single", + max_turns: int = 30, + exec_timeout: int = 600, + workers: int = 64, + analyst_workers: int = 16, + failure_only: bool = False, + minibatch_size: int = 8, + edit_budget: int = 4, + seed: int = 42, + use_deep_reflect: bool = False, + deep_reflect_failures: int = 4, + deep_reflect_successes: int = 2, + ) -> None: + self.data_root = data_root + self.mode = mode # "single", "multi", or "react" + self.max_turns = max_turns + self.exec_timeout = exec_timeout + self.workers = workers + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.use_deep_reflect = use_deep_reflect + self.deep_reflect_failures = deep_reflect_failures + self.deep_reflect_successes = deep_reflect_successes + self.dataloader = SpreadsheetBenchDataLoader( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + data_root=data_root, + seed=seed, + ) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + if is_student_exec_backend() and self.mode != "single": + raise NotImplementedError( + "Exec student backends are currently supported only for SpreadsheetBench mode=single." + ) + self.dataloader.setup(cfg) + + def get_dataloader(self): + return self.dataloader + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + return list(batch.payload or []) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout( + self, + env_manager, + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict]: + """Run agent on all items and return results. + + Dispatches based on ``self.mode``: + - ``"single"`` / ``"multi"``: codegen agent (no tool-call) + - ``"react"``: ReAct agent with tool-call (legacy) + """ + items = env_manager # For static datasets, env_manager is a list of items + results_path = os.path.join(out_dir, "results.jsonl") + os.makedirs(out_dir, exist_ok=True) + + # Resume support + if os.path.exists(results_path): + existing: list[dict] = [] + with open(results_path) as f: + for line in f: + try: + existing.append(json.loads(line)) + except Exception: + pass + if existing: + return existing + + if self.mode in ("single", "multi"): + results = run_spreadsheet_batch_codegen( + items=items, + data_root=self.data_root, + out_root=out_dir, + skill_content=skill_content, + mode=self.mode, + max_turns=self.max_turns, + max_api_workers=self.workers, + task_timeout=self.exec_timeout, + use_eval_feedback=kwargs.get("use_eval_feedback", False), + diagnostic_mode=kwargs.get("diagnostic_mode", False), + diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), + diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"), + ) + else: + results = run_spreadsheet_batch( + items=items, + data_root=self.data_root, + out_root=out_dir, + skill_content=skill_content, + max_turns=self.max_turns, + max_api_workers=self.workers, + task_timeout=max(600, int(self.exec_timeout) + 60), + diagnostic_mode=kwargs.get("diagnostic_mode", False), + diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), + diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"), + ) + + with open(results_path, "w") as f: + for r in results: + f.write(json.dumps(r, ensure_ascii=False) + "\n") + + return results + + def reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + """Analyze rollout results and produce patches (minibatch mode).""" + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def deep_reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + if not self.use_deep_reflect: + return [] + + env_manager = kwargs.get("env_manager") + if not isinstance(env_manager, list): + return [] + + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + codex_backend = get_student_backend() == "codex_exec" + selected_items = self.select_representative_items( + results, + env_manager, + n_failures=self.deep_reflect_failures, + n_successes=self.deep_reflect_successes, + seed=random_seed, + ) + if not selected_items: + return [] + + selected_ids = {str(item["id"]) for item in selected_items} + selected_results = [row for row in results if str(row.get("id")) in selected_ids] + selected_examples = ( + self.attach_codex_probe_context(selected_results, prediction_dir) + if codex_backend + else selected_results + ) + selected_metadata = [ + { + "id": str(item["id"]), + "instruction_type": str(item.get("instruction_type") or ""), + "answer_position": str(item.get("answer_position") or ""), + } + for item in selected_items + ] + + deep_dir = os.path.join(out_dir, "deep_reflect") + rollout_dir = os.path.join(deep_dir, "rollout") + patches_dir = os.path.join(deep_dir, "patches") + os.makedirs(deep_dir, exist_ok=True) + print( + f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} " + f"mode={self.mode}" + ) + probe = generate_deep_probe_instruction( + skill_content=skill_content, + items=selected_examples, + prediction_dir=prediction_dir, + system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + output_requirements=[ + "- The instruction must ask for a short structured diagnostic readout before the student writes code or starts tool use.", + "- The readout should focus on task family, source/target region, and decisive transformation rule.", + "- The student must still complete the original spreadsheet task.", + "- Keep the readout concise and avoid exhaustive cell enumeration.", + "- The instruction text should be ready to append directly to the student's prompt.", + ], + ) + if not probe: + return [] + diagnostic_trace_context_by_id = None + if codex_backend: + selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target( + selected_items=selected_items, + selected_examples=selected_examples, + prediction_dir=prediction_dir, + probe=probe, + ) + + with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f: + json.dump( + { + **probe, + "selected_examples": selected_metadata, + }, + f, + ensure_ascii=False, + indent=2, + ) + + deep_results = self.rollout( + selected_items, + skill_content, + rollout_dir, + diagnostic_mode=True, + diagnostic_instruction=probe["probe_instruction"], + diagnostic_trace_context_by_id=diagnostic_trace_context_by_id, + ) + return run_minibatch_reflect( + results=deep_results, + skill_content=skill_content, + prediction_dir=os.path.join(rollout_dir, "predictions"), + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def get_task_types(self) -> list[str]: + return list(TASK_TYPES) diff --git a/skillopt/envs/spreadsheetbench/codegen_agent.py b/skillopt/envs/spreadsheetbench/codegen_agent.py new file mode 100644 index 0000000..8a6b48f --- /dev/null +++ b/skillopt/envs/spreadsheetbench/codegen_agent.py @@ -0,0 +1,704 @@ +"""Codegen agent for SpreadsheetBench — no tool-call, pure code generation. + +Two modes: + - **single**: One LLM call → extract ```python``` block → done. + - **multi**: Up to max_turns LLM calls; after each, execute code and + feed errors back for correction. + +This matches the official SpreadsheetBench evaluation setting (LLM generates +a Python code block, no function-calling / tool-use). +""" +from __future__ import annotations + +import json +import os +import random +import signal +import time + +import openpyxl + + +# ── Timeout helper ────────────────────────────────────────────────────────── + +class TaskTimeout(Exception): + """Raised when a task exceeds its time budget.""" + + +def _timeout_handler(signum, frame): + raise TaskTimeout("Task timed out") + +from skillopt.model.azure_openai import ( + get_reasoning_effort, + get_student_client, + _needs_responses_api, + tracker, +) +from skillopt.model import get_codex_exec_config, get_student_backend, is_student_exec_backend +from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec +from skillopt.prompts import load_prompt +from skillopt.envs.spreadsheetbench.executor import run_generated_code +from skillopt.envs.spreadsheetbench.evaluator import evaluate + + +# ── Eval feedback helper (no golden value leakage) ───────────────────────── + +def _build_eval_feedback(verify_report: str) -> str: + """Build Student feedback from a verify report, hiding expected values. + + The verify report contains lines like: + Sheet1!D2: got=None, expected=0 ✗ + Sheet1!D10: got=None, expected=None ✓ + + We strip the ``expected=...`` part so the Student sees only its own + output and whether each cell is correct or wrong. + """ + import re + lines = ["Your code executed successfully but produced incorrect results.", + "The following cells have wrong values:"] + for raw_line in verify_report.splitlines(): + raw_line = raw_line.strip() + if not raw_line: + continue + # Match enrichment lines like " Sheet1!D2: got=None, expected=0 ✗" + m = re.match( + r"(\S+!?\w+):\s*got=(.+?),\s*expected=.+?\s*(✓|✗)$", + raw_line, + ) + if m: + cell, got_val, mark = m.groups() + if mark == "✗": + lines.append(f" {cell}: your output = {got_val} (WRONG)") + else: + lines.append(f" {cell}: correct ✓") + lines.append( + "\nPlease analyze the spreadsheet data more carefully and fix the code. " + "Return a complete corrected Python script inside a ```python``` block." + ) + return "\n".join(lines) + + +# ── Workbook preview (same as official prompt.py) ──────────────────────────── + +def _preview_workbook(path: str, max_rows: int = 5, max_cols: int = 20) -> str: + """Generate a text preview of the first few rows of each sheet.""" + wb = openpyxl.load_workbook(path, data_only=False) + chunks: list[str] = [] + for sheet_name in wb.sheetnames: + ws = wb[sheet_name] + chunks.append( + f"## Sheet: {sheet_name} " + f"(dim={ws.dimensions}, max_row={ws.max_row}, max_col={ws.max_column})" + ) + for row in ws.iter_rows( + min_row=1, + max_row=min(ws.max_row, max_rows), + max_col=min(ws.max_column, max_cols), + values_only=False, + ): + cells = [] + for cell in row: + v = cell.value + if v is None: + cells.append(f"{cell.coordinate}=") + else: + s = str(v) + if len(s) > 40: + s = s[:37] + "..." + cells.append(f"{cell.coordinate}={s}") + chunks.append(" | ".join(cells)) + if ws.max_row > max_rows: + chunks.append(f"... ({ws.max_row - max_rows} more rows)") + chunks.append("") + wb.close() + return "\n".join(chunks) + + +# ── Code extraction (same as official prompt.py) ──────────────────────────── + +def extract_code(text: str) -> str: + """Extract the first ```python``` fenced code block from LLM output.""" + if "```" not in text: + return text.strip() + start = text.find("```") + nl = text.find("\n", start) + end = text.find("```", nl + 1) + if nl == -1 or end == -1: + return text.strip() + return text[nl + 1 : end].strip() + + +# ── Prompt construction (official SpreadsheetBench prompts) ───────────────── + + +def _build_system(skill_content: str) -> str: + base = load_prompt("codegen_system", env="spreadsheetbench") + if skill_content.strip(): + base += f"\n\n## Skill\n{skill_content.strip()}" + return base + + +def _build_user( + instruction: str, + input_xlsx: str, + instruction_type: str = "", + answer_position: str = "", + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> str: + try: + preview = _preview_workbook(input_xlsx) + except Exception as e: # noqa: BLE001 + preview = f"(failed to preview workbook: {e})" + extra = "" + if instruction_type: + extra += f"\nInstruction type: {instruction_type}" + if answer_position: + extra += f"\nExpected answer position: {answer_position}" + task_suffix = "Return only a ```python``` code block." + diagnostic = "" + if diagnostic_mode and diagnostic_instruction.strip(): + task_suffix = ( + "First provide a short diagnostic readout that follows the training " + "instruction below, then return a single complete ```python``` code block." + ) + diagnostic = f"\n\n# Training readout\n{diagnostic_instruction.strip()}" + prefix = "" + if diagnostic_trace_context.strip(): + prefix = ( + "# Previous Codex Trace Snapshot\n" + "This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n" + f"{diagnostic_trace_context.strip()}\n\n" + ) + return ( + f"{prefix}" + f"# Instruction\n{instruction}\n{extra}\n\n" + f"# Input spreadsheet preview\n{preview}\n\n" + "# Task\n" + "Write a Python script that reads the workbook from the variable `INPUT_PATH`, " + "applies the instruction, and writes the modified workbook to `OUTPUT_PATH`. " + "Preserve all other cells unchanged. " + "The preview may be truncated — do not hardcode row counts or assume the data ends at the last previewed row; " + "iterate over all actual rows in the workbook instead. " + f"{task_suffix}" + f"{diagnostic}" + ) + + +# ── LLM call with retry ──────────────────────────────────────────────────── + +def _llm_call_with_retry(call_fn, *, retries: int = 5, timeout: int = 120): + """Wrap an LLM API call with retry and per-call timeout.""" + last_err = None + for attempt in range(retries): + try: + return call_fn(timeout=timeout) + except Exception as e: # noqa: BLE001 + last_err = e + sleep = min(2 ** attempt + random.random(), 60) + time.sleep(sleep) + raise RuntimeError(f"LLM call failed after {retries} retries: {last_err}") + + +def _get_deployment() -> str: + from skillopt.model import azure_openai as _llm + return _llm.STUDENT_DEPLOYMENT + + +def _build_codex_skill(skill_content: str) -> str: + return render_skill_md( + skill_content, + description="Dynamic ReflACT skill for solving the current SpreadsheetBench task.", + preamble=( + "Use this skill when solving the current SpreadsheetBench task in this workspace.\n" + "Write a single self-contained Python solution to `solution.py`.\n" + "The solution must operate on the provided `INPUT_PATH` and `OUTPUT_PATH` variables.\n" + "You may inspect `input.xlsx` and run `python run_solution.py` to validate locally,\n" + "but do not hardcode values from the preview or from one specific workbook." + ), + ) + + +def _build_codex_task( + instruction: str, + input_xlsx: str, + instruction_type: str, + answer_position: str, + *, + diagnostic_mode: bool, + diagnostic_instruction: str, + diagnostic_trace_context: str, +) -> str: + prompt = _build_user( + instruction, + input_xlsx, + instruction_type, + answer_position, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + return ( + f"{prompt}\n\n" + "## Codex Harness Task\n" + "- Read `.agents/skills/skillopt-student/SKILL.md` before writing code; do not call a Skill tool.\n" + "- Read and optionally inspect `input.xlsx` in this workspace.\n" + "- Write the final Python solution to `solution.py`.\n" + "- The script should use the provided `INPUT_PATH` and `OUTPUT_PATH` variables.\n" + "- If you want to validate locally, run `python run_solution.py`.\n" + "- Do not return a code fence as the primary artifact; the source of truth is `solution.py`.\n" + ) + + +def _build_codex_driver() -> str: + return ( + "import pathlib\n" + "import re\n" + "import sys\n" + "import traceback\n\n" + 'INPUT_PATH = "input.xlsx"\n' + 'OUTPUT_PATH = "output.xlsx"\n' + "code = pathlib.Path('solution.py').read_text(encoding='utf-8')\n" + "code = re.sub(r'^\\s*(INPUT_PATH|OUTPUT_PATH)\\s*=\\s*.+$', '', code, flags=re.MULTILINE)\n" + "globals_dict = {'__name__': '__main__', 'INPUT_PATH': INPUT_PATH, 'OUTPUT_PATH': OUTPUT_PATH}\n" + "try:\n" + " exec(compile(code, 'solution.py', 'exec'), globals_dict, globals_dict)\n" + "except Exception:\n" + " traceback.print_exc()\n" + " sys.exit(2)\n" + ) + + +def _prepare_codex_workspace( + *, + instruction: str, + input_xlsx: str, + output_path: str, + instruction_type: str, + answer_position: str, + skill_content: str, + diagnostic_mode: bool, + diagnostic_instruction: str, + diagnostic_trace_context: str, + workspace_name: str = "codex_single", +) -> tuple[str, str, str, str]: + task_out_dir = os.path.dirname(output_path) + work_dir = os.path.join(task_out_dir, workspace_name) + skill_md = _build_codex_skill(skill_content) + task_md = _build_codex_task( + instruction, + input_xlsx, + instruction_type, + answer_position, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + prompt = ( + "Read `.agents/skills/skillopt-student/SKILL.md` directly; do not call a Skill tool.\n" + "Read `task.md`, inspect `input.xlsx` if useful, and write the final solution to `solution.py`.\n" + "You may run `python run_solution.py` to validate the script locally.\n" + "In your final response, briefly confirm whether `solution.py` was written and summarize the approach." + ) + prepare_workspace( + work_dir=work_dir, + skill_md=skill_md, + task_text=task_md, + extra_files={"run_solution.py": _build_codex_driver()}, + copy_files=[(input_xlsx, "input.xlsx")], + ) + + return work_dir, skill_md, task_md, prompt + + +def _run_exec_backend( + *, + work_dir: str, + prompt: str, + model: str, + timeout: int, +) -> tuple[str, str]: + return run_student_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + allow_file_edits=True, + ) + + +# ── Chat (no tools) ──────────────────────────────────────────────────────── + +def _chat_call( + client, + deployment: str, + messages: list[dict], + max_output_tokens: int, + llm_timeout: int = 120, +) -> str: + """Single LLM call, no tools. Returns raw text.""" + reasoning_effort = get_reasoning_effort() + if _needs_responses_api(deployment): + # Responses API + system = "" + api_input = [] + for m in messages: + if m["role"] == "system": + system = m["content"] + else: + api_input.append({"role": m["role"], "content": m["content"]}) + resp = _llm_call_with_retry(lambda timeout: client.responses.create( + model=deployment, + instructions=system, + input=api_input, + max_output_tokens=max_output_tokens, + **({"reasoning": {"effort": reasoning_effort}} if reasoning_effort else {}), + timeout=timeout, + ), timeout=llm_timeout) + if hasattr(resp, "usage") and resp.usage: + tracker.record( + "rollout", + getattr(resp.usage, "input_tokens", 0) or 0, + getattr(resp.usage, "output_tokens", 0) or 0, + ) + text = getattr(resp, "output_text", None) or "" + if text: + return text + for item in getattr(resp, "output", None) or []: + for part in getattr(item, "content", []): + if getattr(part, "type", "") == "output_text": + return part.text or "" + return "" + else: + # Chat Completions API — no tools + kwargs = { + "model": deployment, + "messages": messages, + "max_completion_tokens": max_output_tokens, + } + if reasoning_effort is not None: + kwargs["reasoning_effort"] = reasoning_effort + resp = _llm_call_with_retry(lambda timeout: client.chat.completions.create( + **kwargs, + timeout=timeout, + ), timeout=llm_timeout) + if resp.usage: + tracker.record( + "rollout", + resp.usage.prompt_tokens or 0, + resp.usage.completion_tokens or 0, + ) + return resp.choices[0].message.content or "" + + +# ── Public API ────────────────────────────────────────────────────────────── + +def run_single( + instruction: str, + input_xlsx: str, + output_path: str, + instruction_type: str = "", + answer_position: str = "", + skill_content: str = "", + max_output_tokens: int = 16384, + llm_timeout: int = 120, + task_timeout: int = 300, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> dict: + """Single-round code generation. One LLM call, no tools. + + Args: + llm_timeout: Per-LLM-call timeout in seconds (default 120). + task_timeout: Total task timeout in seconds (default 300). + + Returns ``{"code": str, "raw": str, "n_turns": 1}``. + """ + if is_student_exec_backend(): + deadline = time.time() + task_timeout + deployment = _get_deployment() + work_dir, skill_md, task_md, prompt = _prepare_codex_workspace( + instruction=instruction, + input_xlsx=input_xlsx, + output_path=output_path, + instruction_type=instruction_type, + answer_position=answer_position, + skill_content=skill_content, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + remaining = max(10, int(deadline - time.time())) + effective_timeout = min(task_timeout, remaining) + final_message, raw = _run_exec_backend( + work_dir=work_dir, + prompt=prompt, + model=deployment, + timeout=effective_timeout, + ) + solution_path = os.path.join(work_dir, "solution.py") + if os.path.exists(solution_path): + with open(solution_path, encoding="utf-8") as f: + code = f.read() + else: + code = extract_code(final_message or raw) + return { + "code": code, + "raw": raw or final_message, + "n_turns": 1, + "conversation": [{"role": "assistant", "content": final_message or raw}], + "student_system_prompt": skill_md, + "student_user_prompt": f"{prompt}\n\n## Task File\n\n{task_md}", + } + + deadline = time.time() + task_timeout + client = get_student_client() + deployment = _get_deployment() + system = _build_system(skill_content) + user = _build_user( + instruction, + input_xlsx, + instruction_type, + answer_position, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] + + remaining = max(10, int(deadline - time.time())) + effective_timeout = min(llm_timeout, remaining) + raw = _chat_call(client, deployment, messages, max_output_tokens, llm_timeout=effective_timeout) + time.sleep(3) # Rate-limit cooldown after successful LLM call + code = extract_code(raw) + + return { + "code": code, + "raw": raw, + "n_turns": 1, + "conversation": [{"role": "assistant", "content": raw}], + "student_system_prompt": system, + "student_user_prompt": user, + } + + +def run_multi( + instruction: str, + input_xlsx: str, + output_path: str, + instruction_type: str = "", + answer_position: str = "", + skill_content: str = "", + max_turns: int = 5, + max_output_tokens: int = 16384, + llm_timeout: int = 120, + task_timeout: int = 600, + gold_path: str = "", + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> dict: + """Multi-round code generation with execution feedback. No tools. + + Each round: LLM generates code → execute → if error, feed back and retry. + + Args: + llm_timeout: Per-LLM-call timeout in seconds (default 120). + task_timeout: Total task timeout in seconds (default 600). + gold_path: Path to golden answer xlsx for eval feedback during + training. When non-empty, a successful execution is followed + by an eval check; if the output is wrong the agent receives + cell-level feedback (without revealing expected values) and + gets another turn. Leave empty for eval/test to avoid + data leakage. + + Returns ``{"code": str, "raw": str, "n_turns": int, "conversation": [...]}``. + """ + if is_student_exec_backend(): + deadline = time.time() + task_timeout + deployment = _get_deployment() + work_dir, skill_md, task_md, initial_prompt = _prepare_codex_workspace( + instruction=instruction, + input_xlsx=input_xlsx, + output_path=output_path, + instruction_type=instruction_type, + answer_position=answer_position, + skill_content=skill_content, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + workspace_name="codex_multi", + ) + prompt = ( + f"{initial_prompt}\n\n" + "## Multi-Turn Repair Mode\n" + "- This is turn 1. Write or overwrite `solution.py`.\n" + "- After each turn, the harness will execute your `solution.py`; if it fails, you will receive feedback and may revise it.\n" + "- Keep the script general: use `INPUT_PATH` and `OUTPUT_PATH`, and do not hardcode one workbook's values." + ) + conversation: list[dict] = [] + code = "" + raw = "" + final_message = "" + solution_path = os.path.join(work_dir, "solution.py") + + for turn in range(max_turns): + remaining = deadline - time.time() + if remaining <= 10: + break + + effective_timeout = max(10, int(remaining)) + final_message, raw = _run_exec_backend( + work_dir=work_dir, + prompt=prompt, + model=deployment, + timeout=effective_timeout, + ) + conversation.append({"role": "assistant", "content": final_message or raw}) + + if os.path.exists(solution_path): + with open(solution_path, encoding="utf-8") as f: + code = f.read() + else: + code = extract_code(final_message or raw) + if code.strip(): + with open(solution_path, "w", encoding="utf-8") as f: + f.write(code) + + if not code.strip(): + feedback = ( + "No usable `solution.py` or Python code block was produced. " + "Write a complete `solution.py` that reads `INPUT_PATH` and saves `OUTPUT_PATH`." + ) + else: + ok, err = run_generated_code(code, input_xlsx, output_path) + if ok: + if gold_path and answer_position: + from skillopt.envs.spreadsheetbench.rollout import _auto_verify_output + eval_result = evaluate( + output_path, gold_path, instruction_type, answer_position, + ) + if eval_result["ok"]: + break + verify = _auto_verify_output(output_path, gold_path, answer_position) + feedback = _build_eval_feedback(verify) + else: + break + else: + feedback = ( + "The current `solution.py` raised an error during harness execution:\n\n" + f"```\n{err[:3000]}\n```\n\n" + "Revise `solution.py` to fix the error. Keep using `INPUT_PATH` and `OUTPUT_PATH`." + ) + + feedback_path = os.path.join(work_dir, f"feedback_turn_{turn + 1:02d}.md") + with open(feedback_path, "w", encoding="utf-8") as f: + f.write(feedback) + conversation.append({"role": "user", "content": feedback}) + prompt = ( + f"The previous `solution.py` was evaluated and needs another revision.\n" + f"Read `{os.path.basename(feedback_path)}` and update `solution.py` accordingly.\n" + "You may run `python run_solution.py` for a local syntax/runtime check, but the harness will run the final code separately.\n" + "Do not hardcode workbook-specific answers; preserve unrelated cells." + ) + + return { + "code": code, + "raw": raw or final_message, + "n_turns": len([m for m in conversation if m["role"] == "assistant"]), + "conversation": conversation, + "student_system_prompt": skill_md, + "student_user_prompt": f"{initial_prompt}\n\n## Task File\n\n{task_md}", + } + + deadline = time.time() + task_timeout + client = get_student_client() + deployment = _get_deployment() + system = _build_system(skill_content) + user = _build_user( + instruction, + input_xlsx, + instruction_type, + answer_position, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + + messages: list[dict] = [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] + conversation: list[dict] = [] + code = "" + raw = "" + + for turn in range(max_turns): + remaining = deadline - time.time() + if remaining <= 10: + # Not enough time for another round + break + + effective_timeout = min(llm_timeout, int(remaining)) + raw = _chat_call(client, deployment, messages, max_output_tokens, llm_timeout=effective_timeout) + time.sleep(3) # Rate-limit cooldown after successful LLM call + code = extract_code(raw) + conversation.append({"role": "assistant", "content": raw}) + messages.append({"role": "assistant", "content": raw}) + + if not code.strip(): + # No code extracted — ask again + feedback = ( + "No Python code block was found in your response. " + "Please return a complete Python script inside a ```python``` block." + ) + messages.append({"role": "user", "content": feedback}) + conversation.append({"role": "user", "content": feedback}) + continue + + # Execute the code + ok, err = run_generated_code(code, input_xlsx, output_path) + if ok: + # Execution succeeded — check correctness if gold_path available + if gold_path and answer_position: + from skillopt.envs.spreadsheetbench.rollout import _auto_verify_output + eval_result = evaluate( + output_path, gold_path, instruction_type, answer_position, + ) + if eval_result["ok"]: + break # Genuinely correct — stop + + # Output is wrong — build feedback without leaking golden values + verify = _auto_verify_output(output_path, gold_path, answer_position) + feedback = _build_eval_feedback(verify) + messages.append({"role": "user", "content": feedback}) + conversation.append({"role": "user", "content": feedback}) + continue + else: + # No gold path (eval/test) — accept execution success + break + + # Execution failed — feed error back + feedback = ( + f"The code raised an error during execution:\n\n" + f"```\n{err[:3000]}\n```\n\n" + f"Please fix the code and return a complete corrected Python script " + f"inside a ```python``` block." + ) + messages.append({"role": "user", "content": feedback}) + conversation.append({"role": "user", "content": feedback}) + + return { + "code": code, + "raw": raw, + "n_turns": turn + 1, + "conversation": conversation, + "student_system_prompt": system, + "student_user_prompt": user, + } diff --git a/skillopt/envs/spreadsheetbench/dataloader.py b/skillopt/envs/spreadsheetbench/dataloader.py new file mode 100644 index 0000000..542ecc5 --- /dev/null +++ b/skillopt/envs/spreadsheetbench/dataloader.py @@ -0,0 +1,37 @@ +"""SpreadsheetBench task dataloader.""" +from __future__ import annotations + +from skillopt.datasets.base import SplitDataLoader + + +class SpreadsheetBenchDataLoader(SplitDataLoader): + """SpreadsheetBench dataloader. + + Each split directory contains a .json file (JSON array of task items). + Spreadsheet files referenced by items live under a separate ``data_root``. + """ + + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + data_root: str = "", + seed: int = 42, + limit: int = 0, + **kwargs, + ) -> None: + super().__init__( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + ) + self.data_root = data_root diff --git a/skillopt/envs/spreadsheetbench/evaluator.py b/skillopt/envs/spreadsheetbench/evaluator.py new file mode 100644 index 0000000..3d8b84a --- /dev/null +++ b/skillopt/envs/spreadsheetbench/evaluator.py @@ -0,0 +1,158 @@ +"""Cell-value evaluator faithful to the official SpreadsheetBench +`evaluation/evaluation.py` (https://github.com/RUCKBReasoning/SpreadsheetBench). + +Key rules (copied from the official `transform_value` / `compare_cell_value`): + * numeric values (int/float and numeric strings) are compared after + ``round(float(v), 2)`` — a fixed 2-decimal quantization (NOT a tolerance); + * ``datetime.time`` is stringified and the trailing microseconds stripped; + * ``datetime.datetime`` is converted to an Excel serial day and rounded + to an integer day; + * an empty string ``""`` and ``None`` are considered equal, but otherwise + ``type(v1) != type(v2)`` fails the comparison. + +Format/style comparison is deliberately NOT performed — the official +reference evaluator also skips it (the relevant lines are commented out +in `cell_level_compare`). soft vs hard is defined at the run_bench level +across a task's multiple test cases, not here. +""" +from __future__ import annotations + +import datetime +import os +import re + +import openpyxl + + +# ---------- value transform / compare (official port) ---------- + +def _datetime_to_float(dt: datetime.datetime) -> float: + excel_start_date = datetime.datetime(1899, 12, 30) + delta = dt - excel_start_date + return delta.days + delta.seconds / 86400.0 + + +def _transform_value(v): + if isinstance(v, bool): + # openpyxl can return Python bool; official code doesn't special-case + # bools, but round(float(True), 2) == 1.0 which breaks 1 vs True. Keep + # parity with the official transform by promoting bool -> float. + return round(float(v), 2) + if isinstance(v, (int, float)): + return round(float(v), 2) + if isinstance(v, datetime.time): + return str(v)[:-3] + if isinstance(v, datetime.datetime): + return round(_datetime_to_float(v), 0) + if isinstance(v, str): + try: + return round(float(v), 2) + except ValueError: + return v + return v + + +def _compare_cell_value(v1, v2) -> bool: + v1 = _transform_value(v1) + v2 = _transform_value(v2) + if (v1 == "" and v2 is None) or (v1 is None and v2 == ""): + return True + if (v1 == "" and v2 == "") or (v1 is None and v2 is None): + return True + if type(v1) is not type(v2): + return False + return v1 == v2 + + +# ---------- range parsing (official port) ---------- + +def _col_num2name(n: int) -> str: + name = "" + while n > 0: + n, r = divmod(n - 1, 26) + name = chr(65 + r) + name + return name + + +def _col_name2num(name: str) -> int: + num = 0 + for c in name: + num = num * 26 + (ord(c) - ord("A") + 1) + return num + + +def _parse_range(range_str: str): + start_cell, end_cell = range_str.split(":") + sc = "".join(ch for ch in start_cell if ch.isalpha()) + sr = "".join(ch for ch in start_cell if ch.isdigit()) + ec = "".join(ch for ch in end_cell if ch.isalpha()) + er = "".join(ch for ch in end_cell if ch.isdigit()) + return (_col_name2num(sc), int(sr)), (_col_name2num(ec), int(er)) + + +def _generate_cell_names(range_str: str): + if ":" not in range_str: + return [range_str] + (sc, sr), (ec, er) = _parse_range(range_str) + cols = [_col_num2name(i) for i in range(sc, ec + 1)] + return [f"{c}{r}" for c in cols for r in range(sr, er + 1)] + + +def _cell_level_compare(wb_gt, wb_proc, sheet_name: str, cell_range: str): + if sheet_name not in wb_proc.sheetnames: + return False, f"worksheet not found: {sheet_name}" + ws_gt = wb_gt[sheet_name] + ws_proc = wb_proc[sheet_name] + for cn in _generate_cell_names(cell_range): + cg = ws_gt[cn] + cp = ws_proc[cn] + if not _compare_cell_value(cg.value, cp.value): + return False, f"value@{sheet_name}!{cn}: gt={cg.value!r} pred={cp.value!r}" + return True, "" + + +# ---------- public API ---------- + +def compare_workbooks(gt_file: str, proc_file: str, answer_position: str) -> tuple[bool, str]: + """Return (ok, msg). Single test-case comparison, official semantics.""" + if not os.path.exists(proc_file): + return False, "file not exist" + try: + wb_gt = openpyxl.load_workbook(filename=gt_file, data_only=True) + wb_proc = openpyxl.load_workbook(filename=proc_file, data_only=True) + except Exception as e: # noqa: BLE001 + return False, f"load error: {e}" + try: + ok_all = True + msg_first = "" + for scr in (answer_position or "").split(","): + scr = scr.strip() + if not scr: + continue + if "!" in scr: + sheet_name, cell_range = scr.split("!", 1) + sheet_name = sheet_name.strip().strip("'\"") + else: + sheet_name = wb_gt.sheetnames[0] + cell_range = scr + cell_range = cell_range.strip().strip("'\"") + ok, msg = _cell_level_compare(wb_gt, wb_proc, sheet_name, cell_range) + if not ok: + ok_all = False + if not msg_first: + msg_first = msg + return ok_all, msg_first + finally: + wb_gt.close() + wb_proc.close() + + +def evaluate(pred_path: str, gold_path: str, + instruction_type: str, answer_position: str) -> dict: + """Single test-case evaluate. soft/hard aggregation happens in run_bench.""" + ok, msg = compare_workbooks(gold_path, pred_path, answer_position) + return { + "ok": ok, + "reason": msg, + "instruction_type": instruction_type, + } diff --git a/skillopt/envs/spreadsheetbench/executor.py b/skillopt/envs/spreadsheetbench/executor.py new file mode 100644 index 0000000..518a7af --- /dev/null +++ b/skillopt/envs/spreadsheetbench/executor.py @@ -0,0 +1,67 @@ +"""Execute LLM-generated Python code against an input xlsx to produce an output xlsx.""" +from __future__ import annotations + +import os +import re +import subprocess +import sys +import tempfile +import textwrap + + +RUNNER_TEMPLATE = textwrap.dedent( + """ + import os, sys, traceback + INPUT_PATH = {input_path!r} + OUTPUT_PATH = {output_path!r} + try: + {user_code_indented} + except Exception: + traceback.print_exc() + sys.exit(2) + """ +) + +# Regex to strip user-defined INPUT_PATH / OUTPUT_PATH assignments, +# since the runner template injects the correct values. +_PATH_ASSIGN_RE = re.compile( + r'^\s*(INPUT_PATH|OUTPUT_PATH)\s*=\s*.+$', re.MULTILINE +) + + +def _strip_path_assignments(code: str) -> str: + """Remove INPUT_PATH/OUTPUT_PATH assignments from user code.""" + return _PATH_ASSIGN_RE.sub("", code) + + +def run_generated_code(code: str, input_path: str, output_path: str, timeout: int = 120) -> tuple[bool, str]: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + cleaned = _strip_path_assignments(code) + indented = textwrap.indent(cleaned, " ") + script = RUNNER_TEMPLATE.format( + input_path=input_path, + output_path=output_path, + user_code_indented=indented, + ) + with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f: + f.write(script) + tmp = f.name + try: + proc = subprocess.run( + [sys.executable, tmp], + capture_output=True, + text=True, + timeout=timeout, + ) + if proc.returncode != 0: + return False, (proc.stdout + "\n" + proc.stderr).strip() + if not os.path.exists(output_path): + return False, "output file was not created" + return True, "" + except subprocess.TimeoutExpired: + return False, f"timeout after {timeout}s" + finally: + try: + os.unlink(tmp) + except OSError: + pass diff --git a/skillopt/envs/spreadsheetbench/prompts/analyst_error.md b/skillopt/envs/spreadsheetbench/prompts/analyst_error.md new file mode 100644 index 0000000..dc7f352 --- /dev/null +++ b/skillopt/envs/spreadsheetbench/prompts/analyst_error.md @@ -0,0 +1,46 @@ +You are an expert failure-analysis agent for spreadsheet manipulation tasks. + +You will be given MULTIPLE failed agent trajectories from a single minibatch +and the current skill document. +Your job is to identify the most important COMMON failure patterns across +the batch and propose a concise set of skill edits. + +## Failure Type Categories +- **rule_missing**: the skill lacks a relevant rule for this type of task +- **rule_wrong**: an existing skill rule is misleading or incorrect +- **rule_ignored**: the skill has the right rule but the agent did not follow it +- **data_exploration**: the agent did not read enough data from the spreadsheet +- **code_error**: the agent's code has a bug unrelated to the skill +- **other**: none of the above + +## Analysis Process +1. Read ALL failed trajectories in the minibatch. +2. Identify the most prevalent, systematic failure patterns across them. +3. For each pattern, classify its failure type. +4. Propose skill edits that address the COMMON patterns — not individual edge cases. +5. Edits must be generalizable; do not hardcode task-specific values + (file paths, cell addresses, expected values). +6. Only patch gaps in the skill — do not duplicate existing content. +7. If the failure is because the agent did not read enough spreadsheet rows/columns + to understand the data, propose a patch encouraging broader data exploration. + +You will be told the maximum number of edits (the budget L). Produce AT MOST L edits, +focusing on the highest-impact patterns. You may produce fewer if warranted. + +Respond ONLY with a valid JSON object (no markdown fences, no extra text): +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +Only include edits that are needed. "edits" can be an empty list if no patch is warranted. diff --git a/skillopt/envs/spreadsheetbench/prompts/analyst_success.md b/skillopt/envs/spreadsheetbench/prompts/analyst_success.md new file mode 100644 index 0000000..7e97330 --- /dev/null +++ b/skillopt/envs/spreadsheetbench/prompts/analyst_success.md @@ -0,0 +1,32 @@ +You are an expert success-pattern analyst for AI spreadsheet agents. + +You will be given MULTIPLE successful agent trajectories from a single minibatch +and the current skill document. Your job is to identify generalizable behavior +patterns that are COMMON across the batch and worth encoding in the skill. + +## Rules +- Only propose patches for patterns NOT already covered in the skill. +- Focus on patterns that appear across MULTIPLE trajectories in the batch. +- Be concise. Patterns must generalize beyond specific tasks. +- Prefer reinforcing existing sections over adding new top-level sections. +- If the agents' success involved reading enough data rows or using a smart + exploration strategy, consider reinforcing that in the patch. + +You will be told the maximum number of edits (the budget L). Produce AT MOST L edits, +focusing on the most broadly applicable patterns. You may produce fewer if warranted. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +"edits" may be empty if the skill already covers all observed patterns. diff --git a/skillopt/envs/spreadsheetbench/prompts/codegen_system.md b/skillopt/envs/spreadsheetbench/prompts/codegen_system.md new file mode 100644 index 0000000..1ec7134 --- /dev/null +++ b/skillopt/envs/spreadsheetbench/prompts/codegen_system.md @@ -0,0 +1 @@ +You are an expert Python programmer specializing in spreadsheet manipulation. You will be given a user instruction together with a preview of an input .xlsx file. Your job is to write a single self-contained Python script that reads the input file at the path stored in the variable INPUT_PATH, performs the requested manipulation, and saves the result to OUTPUT_PATH. Use only the standard library, openpyxl, and pandas. Do not print anything. Do not use input(). Do not hardcode file paths. Return ONLY the Python code inside a single ```python ... ``` fenced block. \ No newline at end of file diff --git a/skillopt/envs/spreadsheetbench/prompts/critical_rules.md b/skillopt/envs/spreadsheetbench/prompts/critical_rules.md new file mode 100644 index 0000000..822da17 --- /dev/null +++ b/skillopt/envs/spreadsheetbench/prompts/critical_rules.md @@ -0,0 +1,9 @@ +## Critical Rules (MUST follow) +1. NEVER write Excel formulas to cells that will be graded on their displayed value. + openpyxl does NOT compute formulas -- the evaluator will see None. + Instead, compute results in Python and write literal values (numbers/strings). +2. After saving the workbook, ALWAYS reopen and verify the written values: + `wb2 = openpyxl.load_workbook(OUTPUT_PATH); print(wb2[sheet][cell].value)` +3. Use the `write_file` tool to create solution.py -- it avoids shell escaping issues. + Do NOT use `echo "..." > solution.py` for multi-line scripts. + diff --git a/skillopt/envs/spreadsheetbench/prompts/deep_probe.md b/skillopt/envs/spreadsheetbench/prompts/deep_probe.md new file mode 100644 index 0000000..a33f518 --- /dev/null +++ b/skillopt/envs/spreadsheetbench/prompts/deep_probe.md @@ -0,0 +1,35 @@ +You are an expert diagnostic-probe designer for spreadsheet manipulation tasks. + +You will design one short diagnostic instruction to append to the student's +existing SpreadsheetBench prompt for a handful of representative trajectories. + +The goal is to expose whether the student already knows the right task +decomposition, source range, target range, and transformation rule without +substantially changing the current scaffold. + +## Hard Constraints +1. Do NOT substantially change the student's current scaffold. +2. Do NOT prescribe a brand-new full algorithm. +3. Do NOT ask for exhaustive cell-by-cell enumeration. +4. Keep the diagnostic readout brief and structured. +5. The student must still complete the original spreadsheet task. +6. Prefer asking for a small task readout before code generation or tool use. +7. Never ask for hidden reference content or golden values. + +## Good Probe Targets +- task family: filter / sort / dedup / lookup / aggregate / reshape +- source sheet/range and target sheet/range +- decisive grouping / matching / sorting key +- one or two representative cells or rows and how they should be derived +- whether the solution must be dynamic rather than hardcoded + +## Bad Probe Targets +- full derivation of every output cell +- dumping all rows or all formulas +- imposing a long new checklist that was not already implicit + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "probe_instruction": "" +} diff --git a/skillopt/envs/spreadsheetbench/prompts/react_system.md b/skillopt/envs/spreadsheetbench/prompts/react_system.md new file mode 100644 index 0000000..afbbffb --- /dev/null +++ b/skillopt/envs/spreadsheetbench/prompts/react_system.md @@ -0,0 +1,21 @@ +You are an expert spreadsheet manipulation agent. + +{critical_rules}{skill_section}## Tools +You have two tools: +- `bash` -- execute any shell command and receive its output. +- `write_file` -- write content to a file (path, content). Use this for solution.py. + +## Protocol +1. Explore the input spreadsheet to understand its structure (sheets, headers, row count). +2. Use the `write_file` tool to create `solution.py` in the current directory. + solution.py MUST start with: + INPUT_PATH = "" + OUTPUT_PATH = "" + Then perform the manipulation and save the result to OUTPUT_PATH. + Use only: standard library, openpyxl, pandas. +3. Run `python solution.py` via `bash` and verify the output was created. +4. Fix any errors and re-run until the output is correct. +5. Once OUTPUT_PATH exists and is correct, stop calling tools. + +Do NOT use any libraries other than standard library, openpyxl, and pandas. +Do NOT hardcode cell values from the preview -- iterate over actual rows. diff --git a/skillopt/envs/spreadsheetbench/react_agent.py b/skillopt/envs/spreadsheetbench/react_agent.py new file mode 100644 index 0000000..3aa7481 --- /dev/null +++ b/skillopt/envs/spreadsheetbench/react_agent.py @@ -0,0 +1,395 @@ +"""ReAct agent with bash tool for SpreadsheetBench evaluation. + +Adapted from workspace-yqh/refleAct/spreadsheetbench/src/react_agent.py. + +Uses the unified ``skillopt.model`` router so SpreadsheetBench follows the same +backend selection as the rest of the framework. +""" +from __future__ import annotations + +import json +import os +import subprocess + +from skillopt.model import chat_student_messages +from skillopt.prompts import load_prompt + +# ── Tool schemas ───────────────────────────────────────────────────────────── + +BASH_TOOL_CHAT = { + "type": "function", + "function": { + "name": "bash", + "description": ( + "Execute a bash command and receive stdout+stderr (truncated to 4000 chars). " + "Use Python to read / write Excel files." + ), + "parameters": { + "type": "object", + "properties": { + "cmd": {"type": "string", "description": "Bash command to execute."} + }, + "required": ["cmd"], + }, + }, +} + +BASH_TOOL_RESPONSES = { + "type": "function", + "name": "bash", + "description": ( + "Execute a bash command and receive stdout+stderr (truncated to 4000 chars). " + "Use Python to read / write Excel files." + ), + "parameters": { + "type": "object", + "properties": { + "cmd": {"type": "string", "description": "Bash command to execute."} + }, + "required": ["cmd"], + }, +} + +WRITE_FILE_TOOL_CHAT = { + "type": "function", + "function": { + "name": "write_file", + "description": ( + "Write content to a file. Use this instead of echo/cat for multi-line " + "Python scripts to avoid shell escaping issues." + ), + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path to write (relative to working directory).", + }, + "content": { + "type": "string", + "description": "File content to write.", + }, + }, + "required": ["path", "content"], + }, + }, +} + +WRITE_FILE_TOOL_RESPONSES = { + "type": "function", + "name": "write_file", + "description": ( + "Write content to a file. Use this instead of echo/cat for multi-line " + "Python scripts to avoid shell escaping issues." + ), + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "File path to write (relative to working directory).", + }, + "content": { + "type": "string", + "description": "File content to write.", + }, + }, + "required": ["path", "content"], + }, +} + +# ── System prompt ───────────────────────────────────────────────────────────── + + +def _build_system(skill_content: str) -> str: + if skill_content.strip(): + skill_section = f"## Skill\n{skill_content.strip()}\n\n" + else: + skill_section = "" + return load_prompt("react_system", env="spreadsheetbench").format( + critical_rules=load_prompt("critical_rules", env="spreadsheetbench"), + skill_section=skill_section, + ) + + +def _build_user( + instruction: str, + input_path: str, + output_path: str, + instruction_type: str, + answer_position: str, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> str: + parts = [] + if diagnostic_trace_context.strip(): + parts.append( + "# Previous Codex Trace Snapshot\n" + "This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n" + f"{diagnostic_trace_context.strip()}" + ) + parts.extend([ + f"# Instruction\n{instruction}", + f"# Input file\n{input_path}", + f"# Output file\n{output_path}", + ]) + if instruction_type: + parts.append(f"# Instruction type\n{instruction_type}") + if answer_position: + parts.append(f"# Answer position\n{answer_position}") + if diagnostic_mode and diagnostic_instruction.strip(): + parts.append(f"# Training readout\n{diagnostic_instruction.strip()}") + parts.append( + "Manipulate the input spreadsheet according to the instruction " + "and save the result to the output file." + ) + return "\n\n".join(parts) + + +# ── File write (bypass shell escaping) ──────────────────────────────────────── + +def _write_file(path: str, content: str, work_dir: str) -> str: + """Write content to a file, bypassing shell escaping issues.""" + try: + full_path = os.path.join(work_dir, path) if not os.path.isabs(path) else path + parent = os.path.dirname(full_path) + if parent: + os.makedirs(parent, exist_ok=True) + with open(full_path, "w") as f: + f.write(content) + return f"File written: {full_path} ({len(content)} chars)" + except Exception as e: # noqa: BLE001 + return f"[write_file error: {e}]" + + +# ── Auto-verification ───────────────────────────────────────────────────────── + +def _auto_verify(work_dir: str) -> str: + """Auto-verify output xlsx after solution.py runs.""" + import glob as _glob + + sol_path = os.path.join(work_dir, "solution.py") + output_path = None + if os.path.exists(sol_path): + with open(sol_path) as f: + for line in f: + stripped = line.strip() + if stripped.startswith("OUTPUT_PATH"): + try: + val = stripped.split("=", 1)[1].strip() + output_path = val.strip("'\"").strip() + except Exception: # noqa: BLE001 + pass + break + + if not output_path or not os.path.exists(output_path): + xlsx_files = [ + f for f in _glob.glob(os.path.join(work_dir, "*.xlsx")) + if "_pred" in os.path.basename(f) + ] + if xlsx_files: + output_path = xlsx_files[0] + + if not output_path or not os.path.exists(output_path): + return ( + "\n\n[AUTO-VERIFY] WARNING: Output file not found! " + "Make sure OUTPUT_PATH is correct and wb.save(OUTPUT_PATH) is called." + ) + + try: + import openpyxl + + wb_formula = openpyxl.load_workbook(output_path, data_only=False) + wb_value = openpyxl.load_workbook(output_path, data_only=True) + lines = [f"\n\n[AUTO-VERIFY] Output file exists: {output_path}"] + + sn = wb_formula.sheetnames[0] + ws_f = wb_formula[sn] + ws_v = wb_value[sn] + lines.append(f" Sheet '{sn}': {ws_f.dimensions}") + + for row in ws_v.iter_rows( + min_row=1, max_row=min(5, ws_v.max_row), values_only=True, + ): + lines.append(f" {list(row)}") + + none_cells: list[str] = [] + for row_f, row_v in zip( + ws_f.iter_rows(min_row=1, max_row=min(30, ws_f.max_row)), + ws_v.iter_rows(min_row=1, max_row=min(30, ws_v.max_row)), + ): + for cf, cv in zip(row_f, row_v): + formula_val = cf.value + cached_val = cv.value + if ( + isinstance(formula_val, str) + and formula_val.startswith("=") + and cached_val is None + ): + none_cells.append(cf.coordinate) + + if none_cells: + lines.append( + f" WARNING: {len(none_cells)} cells have formulas but NO cached " + f"value -- evaluator will see None: {none_cells[:10]}" + ) + lines.append( + " FIX: Compute values in Python and write literal " + "numbers/strings instead of formulas." + ) + else: + lines.append(" All cells have concrete values. Looks good.") + + wb_formula.close() + wb_value.close() + return "\n".join(lines) + except Exception as e: # noqa: BLE001 + return f"\n\n[AUTO-VERIFY] Could not inspect output: {e}" + + +# ── Bash execution ──────────────────────────────────────────────────────────── + +def _run_bash(cmd: str, work_dir: str, timeout: int = 60) -> str: + try: + proc = subprocess.run( + cmd, + shell=True, + capture_output=True, + text=True, + timeout=timeout, + cwd=work_dir, + ) + out = (proc.stdout + proc.stderr).strip() + except subprocess.TimeoutExpired: + return f"[timeout after {timeout}s]" + except Exception as e: # noqa: BLE001 + return f"[error: {e}]" + if len(out) > 4000: + out = out[:3800] + f"\n...[truncated, {len(out)} total chars]" + result = out or "(no output)" + + if "solution.py" in cmd and "python" in cmd.lower(): + result += _auto_verify(work_dir) + + return result + + +def _assistant_tool_calls(message) -> list[dict]: + tool_calls = getattr(message, "tool_calls", None) or [] + return [ + tool_call.model_dump() if hasattr(tool_call, "model_dump") else dict(tool_call) + for tool_call in tool_calls + ] + + +def _react_loop( + system: str, + user: str, + work_dir: str, + max_turns: int, + max_output_tokens: int, +) -> dict: + messages: list[dict] = [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] + conversation: list[dict] = [] + n_turns = 0 + + for _ in range(max_turns): + message, _ = chat_student_messages( + messages=messages, + tools=[BASH_TOOL_CHAT, WRITE_FILE_TOOL_CHAT], + tool_choice="auto", + max_completion_tokens=max_output_tokens, + retries=5, + stage="rollout", + return_message=True, + ) + + assistant_text = str(getattr(message, "content", "") or "") + tool_calls = _assistant_tool_calls(message) + assistant_payload: dict = {"role": "assistant", "content": assistant_text} + if tool_calls: + assistant_payload["tool_calls"] = tool_calls + messages.append(assistant_payload) + + if not tool_calls: + conversation.append({"type": "message", "content": assistant_text}) + break + + for tool_call in tool_calls: + n_turns += 1 + function = tool_call.get("function", {}) or {} + try: + args = json.loads(str(function.get("arguments", "{}") or "{}")) + except json.JSONDecodeError: + args = {} + + if function.get("name") == "write_file": + obs = _write_file( + args.get("path", ""), + args.get("content", ""), + work_dir, + ) + conversation.append({ + "type": "tool_call", + "cmd": f"[write_file] {args.get('path', '')}", + "obs": obs, + }) + else: + cmd = args.get("cmd", "") + obs = _run_bash(cmd, work_dir) + conversation.append({"type": "tool_call", "cmd": cmd, "obs": obs}) + + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.get("id", ""), + "content": obs, + } + ) + + return {"conversation": conversation, "n_turns": n_turns} + + +# ── Public API ──────────────────────────────────────────────────────────────── + +def run_react( + instruction: str, + input_path: str, + output_path: str, + work_dir: str, + instruction_type: str = "", + answer_position: str = "", + skill_content: str = "", + max_turns: int = 30, + max_output_tokens: int = 4096, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> dict: + """Run the ReAct agent for one task. + + Returns: + { + "conversation": [...], # list of {type, cmd/content, obs?} + "n_turns": int, # number of bash tool calls made + } + """ + system = _build_system(skill_content) + user = _build_user( + instruction, + input_path, + output_path, + instruction_type, + answer_position, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + result = _react_loop(system, user, work_dir, max_turns, max_output_tokens) + result["student_system_prompt"] = system + result["student_user_prompt"] = user + return result diff --git a/skillopt/envs/spreadsheetbench/reflect.py b/skillopt/envs/spreadsheetbench/reflect.py new file mode 100644 index 0000000..cdfeaf6 --- /dev/null +++ b/skillopt/envs/spreadsheetbench/reflect.py @@ -0,0 +1,4 @@ +"""SpreadsheetBench Reflect stage. + +Prompts are now loaded from .md files by the base adapter. +""" diff --git a/skillopt/envs/spreadsheetbench/rollout.py b/skillopt/envs/spreadsheetbench/rollout.py new file mode 100644 index 0000000..7cc616d --- /dev/null +++ b/skillopt/envs/spreadsheetbench/rollout.py @@ -0,0 +1,921 @@ +"""SpreadsheetBench rollout — codegen & ReAct batch execution. + +Provides: + - process_one_codegen(): single/multi-round code generation (no tool-call) + - run_spreadsheet_batch_codegen(): batch wrapper for codegen + - process_one(): ReAct agent with tool-call (legacy) + - run_spreadsheet_batch(): batch wrapper for ReAct (legacy) + - load_items(): load benchmark .json/.jsonl files +""" +from __future__ import annotations + +import glob as _glob +import json +import os +import shutil +import tempfile +import time +import traceback +from concurrent.futures import ( + FIRST_COMPLETED, + ThreadPoolExecutor, + wait, + TimeoutError as FuturesTimeoutError, +) + +import openpyxl + +from skillopt.envs.spreadsheetbench.react_agent import run_react +from skillopt.envs.spreadsheetbench.evaluator import evaluate, _generate_cell_names +from skillopt.envs.spreadsheetbench.executor import run_generated_code + + +# ── Data loading ───────────────────────────────────────────────────────────── + + +def load_items(path: str) -> list[dict]: + """Load a benchmark file. Supports both .jsonl and .json (list of dicts).""" + if path.endswith(".json"): + with open(path) as f: + data = json.load(f) + if isinstance(data, dict): + data = data.get("data") or list(data.values()) + return list(data) + items = [] + with open(path) as f: + for line in f: + line = line.strip() + if line: + items.append(json.loads(line)) + return items + + +# ── Test case discovery ────────────────────────────────────────────────────── + + +def _find_test_cases(task_dir: str) -> list[tuple[str, str, str]]: + """Return [(case_no, input_path, answer_path), ...] sorted by case_no. + + Supports naming conventions used by SpreadsheetBench releases: + * ``{no}_{id}_input.xlsx`` + ``{no}_{id}_answer.xlsx`` (original) + * ``{no}_{id}_init.xlsx`` + ``{no}_{id}_golden.xlsx`` (verified_400) + * ``initial.xlsx`` + ``golden.xlsx`` (verified_400, no prefix) + """ + cases: list[tuple[str, str, str]] = [] + inputs = sorted(_glob.glob(os.path.join(task_dir, "*_input.xlsx"))) + for ip in inputs: + no = os.path.basename(ip).split("_", 1)[0] + ap = ip.replace("_input.xlsx", "_answer.xlsx") + if os.path.exists(ap): + cases.append((no, ip, ap)) + inits = sorted(_glob.glob(os.path.join(task_dir, "*_init.xlsx"))) + for ip in inits: + no = os.path.basename(ip).split("_", 1)[0] + ap = ip.replace("_init.xlsx", "_golden.xlsx") + if os.path.exists(ap): + cases.append((no, ip, ap)) + + # Fallback: bare initial.xlsx + golden.xlsx (no numbered prefix) + if not cases: + bare_init = os.path.join(task_dir, "initial.xlsx") + bare_gold = os.path.join(task_dir, "golden.xlsx") + if os.path.exists(bare_init) and os.path.exists(bare_gold): + cases.append(("1", bare_init, bare_gold)) + + return cases + + +# ── Auto-verify helper ────────────────────────────────────────────────────── + + +def _auto_verify_output( + pred_path: str, + gold_path: str, + answer_position: str, +) -> str: + """Reopen the predicted xlsx and compare cells at answer_position with gold. + + Returns a human-readable verification report that can be appended to the + trajectory so the error analyst can see exactly what went wrong (e.g. + ``cell A1: got=None, expected=420``). + """ + if not os.path.exists(pred_path): + return "Verification: output file does not exist." + try: + wb_pred = openpyxl.load_workbook(pred_path, data_only=True) + wb_gold = openpyxl.load_workbook(gold_path, data_only=True) + except Exception as e: + return f"Verification: could not open workbooks: {e}" + + lines = ["## Output Verification"] + try: + for scr in (answer_position or "").split(","): + scr = scr.strip() + if not scr: + continue + if "!" in scr: + sheet_name, cell_range = scr.split("!", 1) + sheet_name = sheet_name.strip().strip("'\"") + else: + sheet_name = wb_gold.sheetnames[0] + cell_range = scr + cell_range = cell_range.strip().strip("'\"") + + cell_names = _generate_cell_names(cell_range) + ws_pred = wb_pred[sheet_name] if sheet_name in wb_pred.sheetnames else None + ws_gold = wb_gold[sheet_name] if sheet_name in wb_gold.sheetnames else None + + if ws_pred is None: + lines.append(f" Sheet '{sheet_name}' NOT FOUND in output.") + continue + + for cn in cell_names: + gv = ws_gold[cn].value if ws_gold else "N/A" + pv = ws_pred[cn].value + match = "✓" if repr(gv) == repr(pv) else "✗" + lines.append(f" {sheet_name}!{cn}: got={pv!r}, expected={gv!r} {match}") + + # Also check if any cells in the output contain formula strings + formula_cells = [] + for sn in wb_pred.sheetnames: + ws = wb_pred[sn] + for row in ws.iter_rows(max_row=min(ws.max_row, 200), values_only=False): + for cell in row: + if isinstance(cell.value, str) and cell.value.startswith("="): + formula_cells.append(f"{sn}!{cell.coordinate}={cell.value}") + if len(formula_cells) >= 10: + break + if len(formula_cells) >= 10: + break + if len(formula_cells) >= 10: + break + if formula_cells: + lines.append(f"\n WARNING: {len(formula_cells)} cells contain Excel formulas (openpyxl cannot evaluate them):") + for fc in formula_cells[:5]: + lines.append(f" {fc}") + if len(formula_cells) > 5: + lines.append(f" ... and {len(formula_cells) - 5} more") + finally: + wb_pred.close() + wb_gold.close() + + return "\n".join(lines) + + +# ── Per-task worker ────────────────────────────────────────────────────────── + + +def process_one( + item: dict, + data_root: str, + out_root: str, + skill_content: str, + max_turns: int, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> dict: + """Run the ReAct agent on a single SpreadsheetBench task. + + Returns a result dict compatible with ``compute_score()``. + """ + task_id = str(item["id"]) + instruction = item["instruction"] + instruction_type = item.get("instruction_type", "") + answer_position = item.get("answer_position", "") + answer_sheet = item.get("answer_sheet", "") + if answer_position and answer_sheet and "!" not in answer_position: + answer_position_eval = f"{answer_sheet}!{answer_position}" + else: + answer_position_eval = answer_position + + # Determine task_type from instruction_type + itype_lower = (instruction_type or "").lower() + if "cell" in itype_lower: + task_type = "cell_level" + elif "sheet" in itype_lower: + task_type = "sheet_level" + else: + task_type = "other" + + sp = item.get("spreadsheet_path", f"spreadsheet/{task_id}") + task_dir = sp if os.path.isabs(sp) else os.path.join(data_root, sp) + + result = { + "id": task_id, + "ok": False, + "instruction_type": instruction_type, + "task_type": task_type, + "task_description": instruction, + "phase": "setup", + "fail_reason": "", + "agent_ok": False, + "exec_ok": False, + "n_cases": 0, + "n_exec_pass": 0, + "n_pass": 0, + "soft": 0.0, + "hard": 0, + "n_turns": 0, + "cases": [], + "error": "", + } + + try: + cases = _find_test_cases(task_dir) + result["n_cases"] = len(cases) + if not cases: + result["fail_reason"] = "no-test-cases" + return result + + task_out_dir = os.path.join(out_root, "predictions", task_id) + os.makedirs(task_out_dir, exist_ok=True) + + no1, ip1, _ = cases[0] + pred_path_1 = os.path.join(task_out_dir, f"{no1}_pred.xlsx") + student_prompt_parts = [ + f"# Instruction\n{instruction}", + f"# Input file\n{ip1}", + f"# Output file\n{pred_path_1}", + ] + if instruction_type: + student_prompt_parts.append(f"# Instruction type\n{instruction_type}") + if answer_position_eval: + student_prompt_parts.append(f"# Answer position\n{answer_position_eval}") + if diagnostic_trace_context.strip(): + student_prompt_parts.insert( + 0, + "# Previous Codex Trace Snapshot\n" + "This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n" + f"{diagnostic_trace_context.strip()}", + ) + if diagnostic_mode and diagnostic_instruction.strip(): + student_prompt_parts.append(f"# Training readout\n{diagnostic_instruction.strip()}") + student_user_prompt = "\n\n".join(student_prompt_parts) + try: + from skillopt.envs.spreadsheetbench.react_agent import _build_system + student_system_prompt = _build_system(skill_content) + except Exception: + student_system_prompt = "" + if student_system_prompt: + with open(os.path.join(task_out_dir, "student_system_prompt.txt"), "w") as f: + f.write(student_system_prompt) + result["student_system_prompt"] = student_system_prompt + with open(os.path.join(task_out_dir, "student_user_prompt.txt"), "w") as f: + f.write(student_user_prompt) + result["student_user_prompt"] = student_user_prompt + + # ── Stage 1: run ReAct agent on test case 1 ───────────────────── + result["phase"] = "agent" + + work_dir = tempfile.mkdtemp(prefix=f"react_{task_id}_") + try: + # Copy input so agent works in an isolated directory + work_input = os.path.join(work_dir, os.path.basename(ip1)) + shutil.copy2(ip1, work_input) + + agent_result = run_react( + instruction=instruction, + input_path=work_input, + output_path=pred_path_1, + work_dir=work_dir, + instruction_type=instruction_type, + answer_position=answer_position_eval, + skill_content=skill_content, + max_turns=max_turns, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + result["n_turns"] = agent_result.get("n_turns", 0) + if agent_result.get("student_system_prompt"): + with open(os.path.join(task_out_dir, "student_system_prompt.txt"), "w") as f: + f.write(agent_result["student_system_prompt"]) + result["student_system_prompt"] = agent_result["student_system_prompt"] + if agent_result.get("student_user_prompt"): + with open(os.path.join(task_out_dir, "student_user_prompt.txt"), "w") as f: + f.write(agent_result["student_user_prompt"]) + result["student_user_prompt"] = agent_result["student_user_prompt"] + + # Save conversation log + with open(os.path.join(task_out_dir, "conversation.json"), "w") as f: + json.dump( + agent_result.get("conversation", []), + f, ensure_ascii=False, indent=2, + ) + + # Copy solution.py if the agent wrote one + solution_src = os.path.join(work_dir, "solution.py") + solution_dst = os.path.join(task_out_dir, "solution.py") + if os.path.exists(solution_src): + shutil.copy2(solution_src, solution_dst) + + except Exception as e: + result["fail_reason"] = f"agent-error: {type(e).__name__}: {e}" + result["error"] = traceback.format_exc() + return result + finally: + shutil.rmtree(work_dir, ignore_errors=True) + + result["agent_ok"] = True + + # ── Stage 2: evaluate all test cases ───────────────────────────── + result["phase"] = "eval" + solution_path = os.path.join(task_out_dir, "solution.py") + all_exec = True + + for i, (no, ip, ap) in enumerate(cases): + pred_path = os.path.join(task_out_dir, f"{no}_pred.xlsx") + + if i > 0: + # Re-apply solution.py to subsequent test cases + if not os.path.exists(solution_path): + all_exec = False + result["cases"].append( + {"no": no, "stage": "exec", "ok": False, "error": "no-solution-py"} + ) + if not result["fail_reason"]: + result["fail_reason"] = "no-solution-py-for-other-cases" + continue + + with open(solution_path) as f: + code = f.read() + + # Prepend new INPUT_PATH / OUTPUT_PATH + preamble = ( + f"INPUT_PATH = {ip!r}\n" + f"OUTPUT_PATH = {pred_path!r}\n" + ) + full_code = preamble + code + + ok_exec, err = run_generated_code(full_code, ip, pred_path) + if not ok_exec: + all_exec = False + result["cases"].append( + {"no": no, "stage": "exec", "ok": False, "error": err[:500]} + ) + if not result["fail_reason"]: + tail = err.strip().splitlines()[-1][:200] if err.strip() else "unknown" + result["fail_reason"] = f"exec-error: {tail}" + continue + + # ── Evaluate ───────────────────────────────────────────────── + if not os.path.exists(pred_path): + all_exec = False + result["cases"].append( + {"no": no, "stage": "exec", "ok": False, "error": "output-not-found"} + ) + if not result["fail_reason"]: + result["fail_reason"] = "output-not-found" + continue + + result["n_exec_pass"] += 1 + try: + ev = evaluate(pred_path, ap, instruction_type, answer_position_eval) + except Exception as e: # noqa: BLE001 + ev = {"ok": False, "reason": f"eval-exception: {type(e).__name__}: {e}"} + + if ev["ok"]: + result["n_pass"] += 1 + else: + if not result["fail_reason"]: + result["fail_reason"] = f"eval-mismatch: {ev['reason'][:200]}" + result["cases"].append( + {"no": no, "stage": "eval", "ok": ev["ok"], "reason": ev.get("reason", "")} + ) + + result["exec_ok"] = all_exec + n_cases = result["n_cases"] + n_pass = result["n_pass"] + result["soft"] = (n_pass / n_cases) if n_cases else 0.0 + result["hard"] = 1 if (n_cases > 0 and n_pass == n_cases) else 0 + result["ok"] = bool(result["hard"]) + if result["ok"]: + result["fail_reason"] = "" + return result + + except Exception as e: # noqa: BLE001 + result["fail_reason"] = f"unexpected: {type(e).__name__}: {e}" + result["error"] = traceback.format_exc() + return result + + +# ── Batch runner ───────────────────────────────────────────────────────────── + + +def run_spreadsheet_batch( + items: list[dict], + data_root: str, + out_root: str, + skill_content: str, + max_turns: int = 30, + max_api_workers: int = 64, + task_timeout: int = 600, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context_by_id: dict[str, str] | None = None, +) -> list[dict]: + """Run the ReAct agent on all items with ThreadPoolExecutor. + + Returns list of result dicts compatible with ``compute_score()``. + """ + os.makedirs(out_root, exist_ok=True) + + # Check for already-done items (resume support) + results_path = os.path.join(out_root, "results.jsonl") + done_ids: set[str] = set() + existing: list[dict] = [] + if os.path.exists(results_path): + with open(results_path) as f: + for line in f: + try: + r = json.loads(line) + done_ids.add(str(r["id"])) + existing.append(r) + except Exception: + pass + + pending = [it for it in items if str(it["id"]) not in done_ids] + print( + f" [spreadsheet rollout] total={len(items)} done={len(done_ids)} " + f"pending={len(pending)} workers={max_api_workers} task_timeout={task_timeout}s" + ) + + if not pending: + return existing + + t0 = time.time() + results = list(existing) + started_at: dict[str, float] = {} + + def _timeout_result(item: dict) -> dict: + return { + "id": str(item["id"]), + "ok": False, + "phase": "timeout", + "fail_reason": f"task-timeout-{task_timeout}s", + "n_cases": 0, "n_pass": 0, "soft": 0.0, "hard": 0, + "n_turns": 0, "cases": [], "error": "timeout", + } + + def _error_result(item: dict, exc: Exception) -> dict: + return { + "id": str(item["id"]), + "ok": False, + "phase": "error", + "fail_reason": f"unexpected: {type(exc).__name__}: {exc}", + "n_cases": 0, "n_pass": 0, "soft": 0.0, "hard": 0, + "n_turns": 0, "cases": [], "error": str(exc), + } + + def _run_one(it: dict) -> dict: + started_at[str(it["id"])] = time.time() + return process_one( + it, + data_root, + out_root, + skill_content, + max_turns, + diagnostic_mode, + diagnostic_instruction, + (diagnostic_trace_context_by_id or {}).get(str(it["id"]), ""), + ) + + ex = ThreadPoolExecutor(max_workers=max_api_workers) + try: + futs = {ex.submit(_run_one, it): it for it in pending} + pending_futs = set(futs) + finished = 0 + while pending_futs: + done, _ = wait(pending_futs, timeout=5, return_when=FIRST_COMPLETED) + now = time.time() + timed_out = [ + fut for fut in pending_futs - done + if str(futs[fut]["id"]) in started_at + and now - started_at[str(futs[fut]["id"])] >= task_timeout + ] + for fut in done: + pending_futs.remove(fut) + item = futs[fut] + try: + res = fut.result() + except FuturesTimeoutError: + res = _timeout_result(item) + except Exception as e: # noqa: BLE001 + res = _error_result(item, e) + results.append(res) + finished += 1 + status = "PASS" if res.get("hard") else ("TIMEOUT" if res.get("phase") == "timeout" else "FAIL") + dt = time.time() - t0 + print( + f" {finished}/{len(pending)} id={res['id']:<10} {status} " + f"turns={res.get('n_turns', 0):<3} " + f"cases={res.get('n_pass', 0)}/{res.get('n_cases', 0)} " + f"dt={dt:.0f}s" + ) + for fut in timed_out: + pending_futs.remove(fut) + res = _timeout_result(futs[fut]) + results.append(res) + finished += 1 + status = "TIMEOUT" + dt = time.time() - t0 + print( + f" {finished}/{len(pending)} id={res['id']:<10} {status} " + f"turns={res.get('n_turns', 0):<3} " + f"cases={res.get('n_pass', 0)}/{res.get('n_cases', 0)} " + f"dt={dt:.0f}s" + ) + finally: + ex.shutdown(wait=False, cancel_futures=True) + + return results + + +# ── Codegen per-task worker (no tool-call) ────────────────────────────────── + + +def process_one_codegen( + item: dict, + data_root: str, + out_root: str, + skill_content: str, + mode: str = "single", + max_turns: int = 5, + use_eval_feedback: bool = False, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context: str = "", +) -> dict: + """Run codegen agent (single or multi-round) on one SpreadsheetBench task. + + This matches the official evaluation setting: LLM generates a Python code + block, no function-calling / tool-use. + """ + from skillopt.envs.spreadsheetbench.codegen_agent import run_single, run_multi + + task_id = str(item["id"]) + instruction = item["instruction"] + instruction_type = item.get("instruction_type", "") + answer_position = item.get("answer_position", "") + answer_sheet = item.get("answer_sheet", "") + if answer_position and answer_sheet and "!" not in answer_position: + answer_position_eval = f"{answer_sheet}!{answer_position}" + else: + answer_position_eval = answer_position + + itype_lower = (instruction_type or "").lower() + if "cell" in itype_lower: + task_type = "cell_level" + elif "sheet" in itype_lower: + task_type = "sheet_level" + else: + task_type = "other" + + sp = item.get("spreadsheet_path", f"spreadsheet/{task_id}") + task_dir = sp if os.path.isabs(sp) else os.path.join(data_root, sp) + + result = { + "id": task_id, + "ok": False, + "instruction_type": instruction_type, + "task_type": task_type, + "task_description": instruction, + "phase": "setup", + "fail_reason": "", + "llm_ok": False, + "code_ok": False, + "exec_ok": False, + "n_cases": 0, + "n_exec_pass": 0, + "n_pass": 0, + "soft": 0.0, + "hard": 0, + "n_turns": 0, + "cases": [], + "error": "", + } + + try: + cases = _find_test_cases(task_dir) + result["n_cases"] = len(cases) + if not cases: + result["fail_reason"] = "no-test-cases" + return result + + task_out_dir = os.path.join(out_root, "predictions", task_id) + os.makedirs(task_out_dir, exist_ok=True) + + # ── Save context for Teacher (Reflect stage) ────────────────── + from skillopt.envs.spreadsheetbench.codegen_agent import ( + _preview_workbook, _build_system, _build_user, + ) + first_input_for_preview = cases[0][1] + try: + preview_text = _preview_workbook(first_input_for_preview) + except Exception: + preview_text = "(preview failed)" + student_system = _build_system(skill_content) + student_user = _build_user( + instruction, + first_input_for_preview, + instruction_type, + answer_position_eval, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + + with open(os.path.join(task_out_dir, "spreadsheet_preview.txt"), "w") as f: + f.write(preview_text) + with open(os.path.join(task_out_dir, "student_system_prompt.txt"), "w") as f: + f.write(student_system) + with open(os.path.join(task_out_dir, "student_user_prompt.txt"), "w") as f: + f.write(student_user) + + result["spreadsheet_preview"] = preview_text + result["student_system_prompt"] = student_system + result["student_user_prompt"] = student_user + + # ── LLM phase ────────────────────────────────────────────────── + result["phase"] = "llm" + first_input = cases[0][1] + first_gold = cases[0][2] + first_pred = os.path.join(task_out_dir, f"{cases[0][0]}_pred.xlsx") + + try: + if mode == "multi": + agent_result = run_multi( + instruction=instruction, + input_xlsx=first_input, + output_path=first_pred, + instruction_type=instruction_type, + answer_position=answer_position_eval, + skill_content=skill_content, + max_turns=max_turns, + gold_path=first_gold if use_eval_feedback else "", + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + else: + agent_result = run_single( + instruction=instruction, + input_xlsx=first_input, + output_path=first_pred, + instruction_type=instruction_type, + answer_position=answer_position_eval, + skill_content=skill_content, + diagnostic_mode=diagnostic_mode, + diagnostic_instruction=diagnostic_instruction, + diagnostic_trace_context=diagnostic_trace_context, + ) + except Exception as e: # noqa: BLE001 + result["fail_reason"] = f"llm-call-failed: {type(e).__name__}: {e}" + result["error"] = traceback.format_exc() + return result + + result["llm_ok"] = True + result["n_turns"] = agent_result.get("n_turns", 1) + code = agent_result.get("code", "") + raw = agent_result.get("raw", "") + + # Save artifacts + with open(os.path.join(task_out_dir, "code.py"), "w") as f: + f.write(code) + with open(os.path.join(task_out_dir, "raw.txt"), "w") as f: + f.write(raw) + if agent_result.get("conversation"): + with open(os.path.join(task_out_dir, "conversation.json"), "w") as f: + json.dump(agent_result["conversation"], f, ensure_ascii=False, indent=2) + + if not code.strip(): + result["phase"] = "extract" + result["fail_reason"] = "empty-code-block" + return result + result["code_ok"] = True + + # ── Exec + eval per test case ────────────────────────────────── + result["phase"] = "exec" + all_exec = True + # Collect enrichment info for the conversation/trajectory + enrichment_parts: list[str] = [] + + for no, ip, ap in cases: + pred_path = os.path.join(task_out_dir, f"{no}_pred.xlsx") + + # For multi mode, the first case may already be produced + if not os.path.exists(pred_path): + ok_exec, err = run_generated_code(code, ip, pred_path) + if not ok_exec: + all_exec = False + result["cases"].append( + {"no": no, "stage": "exec", "ok": False, "error": err[:500]} + ) + if not result["fail_reason"]: + tail = err.strip().splitlines()[-1][:200] if err.strip() else "unknown" + result["fail_reason"] = f"exec-error: {tail}" + enrichment_parts.append( + f"## Execution (case {no})\nERROR: {err[:500]}" + ) + continue + + if not os.path.exists(pred_path): + all_exec = False + result["cases"].append( + {"no": no, "stage": "exec", "ok": False, "error": "output-not-found"} + ) + if not result["fail_reason"]: + result["fail_reason"] = "output-not-found" + continue + + result["n_exec_pass"] += 1 + try: + ev = evaluate(pred_path, ap, instruction_type, answer_position_eval) + except Exception as e: # noqa: BLE001 + ev = {"ok": False, "reason": f"eval-exception: {type(e).__name__}: {e}"} + + if ev["ok"]: + result["n_pass"] += 1 + else: + if not result["fail_reason"]: + result["fail_reason"] = f"eval-mismatch: {ev['reason'][:200]}" + result["cases"].append( + {"no": no, "stage": "eval", "ok": ev["ok"], "reason": ev.get("reason", "")} + ) + + # Auto-verify: reopen output and compare cells at answer_position + if answer_position_eval: + verify_report = _auto_verify_output(pred_path, ap, answer_position_eval) + enrichment_parts.append( + f"## Eval Result (case {no}): {'PASS' if ev['ok'] else 'FAIL'}\n" + f"{ev.get('reason', '')}\n\n{verify_report}" + ) + + result["exec_ok"] = all_exec + + # ── Enrich conversation with eval details ────────────────────── + if enrichment_parts: + enrichment_msg = "\n\n---\n\n".join(enrichment_parts) + conversation = agent_result.get("conversation", []) + conversation.append({ + "role": "system", + "content": f"[POST-EXECUTION VERIFICATION]\n\n{enrichment_msg}", + }) + # Re-save the enriched conversation + with open(os.path.join(task_out_dir, "conversation.json"), "w") as f: + json.dump(conversation, f, ensure_ascii=False, indent=2) + n_cases = result["n_cases"] + n_pass = result["n_pass"] + result["soft"] = (n_pass / n_cases) if n_cases else 0.0 + result["hard"] = 1 if (n_cases > 0 and n_pass == n_cases) else 0 + result["ok"] = bool(result["hard"]) + if result["ok"]: + result["fail_reason"] = "" + return result + + except Exception as e: # noqa: BLE001 + result["fail_reason"] = f"unexpected: {type(e).__name__}: {e}" + result["error"] = traceback.format_exc() + return result + + +# ── Codegen batch runner ──────────────────────────────────────────────────── + + +def run_spreadsheet_batch_codegen( + items: list[dict], + data_root: str, + out_root: str, + skill_content: str, + mode: str = "single", + max_turns: int = 5, + max_api_workers: int = 32, + task_timeout: int = 0, + use_eval_feedback: bool = False, + diagnostic_mode: bool = False, + diagnostic_instruction: str = "", + diagnostic_trace_context_by_id: dict[str, str] | None = None, +) -> list[dict]: + """Run codegen agent on all items (no tool-call). + + Args: + mode: "single" or "multi". + task_timeout: Hard per-task timeout in seconds at the future level. + 0 = auto (single: 300s, multi: 600s). + """ + if task_timeout <= 0: + task_timeout = 300 if mode == "single" else 600 + + os.makedirs(out_root, exist_ok=True) + + results_path = os.path.join(out_root, "results.jsonl") + done_ids: set[str] = set() + existing: list[dict] = [] + if os.path.exists(results_path): + with open(results_path) as f: + for line in f: + try: + r = json.loads(line) + done_ids.add(str(r["id"])) + existing.append(r) + except Exception: + pass + + pending = [it for it in items if str(it["id"]) not in done_ids] + print( + f" [spreadsheet codegen-{mode}] total={len(items)} done={len(done_ids)} " + f"pending={len(pending)} workers={max_api_workers} task_timeout={task_timeout}s" + ) + + if not pending: + return existing + + t0 = time.time() + results = list(existing) + + started_at: dict[str, float] = {} + + def _run_one(it: dict) -> dict: + started_at[str(it["id"])] = time.time() + return process_one_codegen( + it, + data_root, + out_root, + skill_content, + mode, + max_turns, + use_eval_feedback, + diagnostic_mode, + diagnostic_instruction, + (diagnostic_trace_context_by_id or {}).get(str(it["id"]), ""), + ) + + def _timeout_result(item: dict) -> dict: + return { + "id": str(item["id"]), + "ok": False, + "instruction_type": item.get("instruction_type", ""), + "task_type": "other", + "phase": "timeout", + "fail_reason": f"task-timeout-{task_timeout}s", + "n_cases": 0, "n_pass": 0, "soft": 0.0, "hard": 0, + "n_turns": 0, "cases": [], "error": "timeout", + } + + def _error_result(item: dict, e: Exception) -> dict: + return { + "id": str(item["id"]), + "ok": False, + "instruction_type": item.get("instruction_type", ""), + "task_type": "other", + "phase": "error", + "fail_reason": f"unexpected: {type(e).__name__}: {e}", + "n_cases": 0, "n_pass": 0, "soft": 0.0, "hard": 0, + "n_turns": 0, "cases": [], "error": str(e), + } + + def _record(res: dict, i: int) -> None: + results.append(res) + status = "PASS" if res.get("hard") else ("TIMEOUT" if res.get("phase") == "timeout" else "FAIL") + dt = time.time() - t0 + print( + f" {i}/{len(pending)} id={res['id']:<10} {status} " + f"turns={res.get('n_turns', 0):<3} " + f"cases={res.get('n_pass', 0)}/{res.get('n_cases', 0)} " + f"dt={dt:.0f}s" + ) + + ex = ThreadPoolExecutor(max_workers=max_api_workers) + try: + futs = {ex.submit(_run_one, it): it for it in pending} + pending_futs = set(futs) + finished = 0 + while pending_futs: + done, _ = wait(pending_futs, timeout=5, return_when=FIRST_COMPLETED) + now = time.time() + timed_out = [ + fut for fut in pending_futs - done + if str(futs[fut]["id"]) in started_at + and now - started_at[str(futs[fut]["id"])] >= task_timeout + ] + for fut in done: + pending_futs.remove(fut) + item = futs[fut] + try: + res = fut.result() + except FuturesTimeoutError: + res = _timeout_result(item) + except Exception as e: # noqa: BLE001 + res = _error_result(item, e) + finished += 1 + _record(res, finished) + for fut in timed_out: + pending_futs.remove(fut) + fut.cancel() + finished += 1 + _record(_timeout_result(futs[fut]), finished) + finally: + ex.shutdown(wait=False, cancel_futures=True) + + return results diff --git a/skillopt/envs/spreadsheetbench/skills/initial.md b/skillopt/envs/spreadsheetbench/skills/initial.md new file mode 100644 index 0000000..17eb7cd --- /dev/null +++ b/skillopt/envs/spreadsheetbench/skills/initial.md @@ -0,0 +1,56 @@ +# Spreadsheet Manipulation Skill (xlsx) + +## Overview +This skill guides agents in manipulating Excel (.xlsx) spreadsheets using Python. + +**Primary libraries**: `openpyxl` (structure-preserving read/write), `pandas` (data transformation). +Never use any other third-party libraries. + +--- + +## Common Workflow + +1. **Explore** the input file: list sheets, inspect headers, check dimensions. +2. **Write `solution.py`** with `INPUT_PATH` and `OUTPUT_PATH` defined at the top. +3. **Execute** `python solution.py` and verify the output file was created. +4. **Confirm** the target cells/range contain the expected values. + +--- + +## Library Selection + +| Use case | Library | +|----------|---------| +| Preserve formulas, formatting, named ranges | `openpyxl` | +| Bulk data transformation, aggregation, sorting | `pandas` → write back with `openpyxl` | +| Simple cell read/write | `openpyxl` | + +**Warning**: `pandas.to_excel()` silently destroys existing formulas and named ranges. +When writing back to a spreadsheet that contains formulas, always use `openpyxl.save()`. + +--- + +## solution.py Template + +```python +import openpyxl +import pandas as pd + +INPUT_PATH = "..." # set to the actual input path +OUTPUT_PATH = "..." # set to the actual output path + +wb = openpyxl.load_workbook(INPUT_PATH) +ws = wb.active # or wb["SheetName"] + +# --- perform manipulation --- + +wb.save(OUTPUT_PATH) +``` + +--- + +## Output Requirements + +- Save the result to `OUTPUT_PATH`. +- Do not hardcode row counts or column letters — iterate over actual rows in the workbook. +- Preserve sheets and cells not mentioned in the instruction. diff --git a/skillopt/envs/spreadsheetbench/skills/xlsx_initial.md b/skillopt/envs/spreadsheetbench/skills/xlsx_initial.md new file mode 100644 index 0000000..17eb7cd --- /dev/null +++ b/skillopt/envs/spreadsheetbench/skills/xlsx_initial.md @@ -0,0 +1,56 @@ +# Spreadsheet Manipulation Skill (xlsx) + +## Overview +This skill guides agents in manipulating Excel (.xlsx) spreadsheets using Python. + +**Primary libraries**: `openpyxl` (structure-preserving read/write), `pandas` (data transformation). +Never use any other third-party libraries. + +--- + +## Common Workflow + +1. **Explore** the input file: list sheets, inspect headers, check dimensions. +2. **Write `solution.py`** with `INPUT_PATH` and `OUTPUT_PATH` defined at the top. +3. **Execute** `python solution.py` and verify the output file was created. +4. **Confirm** the target cells/range contain the expected values. + +--- + +## Library Selection + +| Use case | Library | +|----------|---------| +| Preserve formulas, formatting, named ranges | `openpyxl` | +| Bulk data transformation, aggregation, sorting | `pandas` → write back with `openpyxl` | +| Simple cell read/write | `openpyxl` | + +**Warning**: `pandas.to_excel()` silently destroys existing formulas and named ranges. +When writing back to a spreadsheet that contains formulas, always use `openpyxl.save()`. + +--- + +## solution.py Template + +```python +import openpyxl +import pandas as pd + +INPUT_PATH = "..." # set to the actual input path +OUTPUT_PATH = "..." # set to the actual output path + +wb = openpyxl.load_workbook(INPUT_PATH) +ws = wb.active # or wb["SheetName"] + +# --- perform manipulation --- + +wb.save(OUTPUT_PATH) +``` + +--- + +## Output Requirements + +- Save the result to `OUTPUT_PATH`. +- Do not hardcode row counts or column letters — iterate over actual rows in the workbook. +- Preserve sheets and cells not mentioned in the instruction. diff --git a/skillopt/envs/spreadsheetbench/skills/xlsx_skill0.md b/skillopt/envs/spreadsheetbench/skills/xlsx_skill0.md new file mode 100644 index 0000000..80ab2a1 --- /dev/null +++ b/skillopt/envs/spreadsheetbench/skills/xlsx_skill0.md @@ -0,0 +1,4 @@ +# No Skill + +This is a placeholder. No domain skill is loaded for this run. +The agent relies solely on its parametric knowledge. diff --git a/skillopt/envs/swebench/__init__.py b/skillopt/envs/swebench/__init__.py new file mode 100644 index 0000000..6eb4467 --- /dev/null +++ b/skillopt/envs/swebench/__init__.py @@ -0,0 +1 @@ +"""SWEBench environment for ReflACT.""" diff --git a/skillopt/envs/swebench/adapter.py b/skillopt/envs/swebench/adapter.py new file mode 100644 index 0000000..9728748 --- /dev/null +++ b/skillopt/envs/swebench/adapter.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import os + +from skillopt.datasets.base import BatchSpec +from skillopt.envs.base import EnvAdapter +from skillopt.envs.swebench.dataloader import SWEBenchDataLoader +from skillopt.envs.swebench.rollout import run_batch +from skillopt.gradient.reflect import run_minibatch_reflect + + +class SWEBenchAdapter(EnvAdapter): + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + dataset_name: str = "lite", + hf_split: str = "test", + workers: int = 8, + eval_workers: int = 8, + analyst_workers: int = 16, + failure_only: bool = False, + minibatch_size: int = 4, + edit_budget: int = 4, + seed: int = 42, + limit: int = 0, + step_limit: int = 50, + cost_limit: float = 3.0, + timeout_per_instance: int = 600, + student_model: str = "", + ) -> None: + self.dataset_name = dataset_name + self.hf_split = hf_split + self.workers = workers + self.eval_workers = eval_workers + self.analyst_workers = analyst_workers + self.failure_only = failure_only + self.minibatch_size = minibatch_size + self.edit_budget = edit_budget + self.step_limit = step_limit + self.cost_limit = cost_limit + self.timeout_per_instance = timeout_per_instance + self.student_model = student_model + self.dataloader = SWEBenchDataLoader( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + dataset_name=dataset_name, + hf_split=hf_split, + ) + + def setup(self, cfg: dict) -> None: + super().setup(cfg) + self.student_model = str(self.student_model or cfg.get("student_model") or "gpt-5.4").strip() + self.dataset_name = str(self.dataset_name or cfg.get("dataset_name") or "lite").strip() + self.hf_split = str(self.hf_split or cfg.get("hf_split") or "test").strip() + self.dataloader.setup(cfg) + + def get_dataloader(self): + return self.dataloader + + def build_env_from_batch(self, batch: BatchSpec, **kwargs): + return list(batch.payload or []) + + def build_train_env(self, batch_size: int, seed: int, **kwargs): + batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs): + batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs) + return self.build_env_from_batch(batch, **kwargs) + + def rollout(self, env_manager, skill_content: str, out_dir: str, **kwargs) -> list[dict]: + items: list[dict] = env_manager + return run_batch( + items=items, + out_root=out_dir, + skill_content=skill_content, + student_model=self.student_model, + dataset_name=self.dataset_name, + hf_split=self.hf_split, + workers=self.workers, + eval_workers=self.eval_workers, + step_limit=self.step_limit, + cost_limit=self.cost_limit, + timeout_per_instance=self.timeout_per_instance, + ) + + def reflect( + self, + results: list[dict], + skill_content: str, + out_dir: str, + **kwargs, + ) -> list[dict | None]: + prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions")) + patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches")) + random_seed = kwargs.get("random_seed") + step_buffer_context = kwargs.get("step_buffer_context", "") + meta_skill_context = kwargs.get("meta_skill_context", "") + return run_minibatch_reflect( + results=results, + skill_content=skill_content, + prediction_dir=prediction_dir, + patches_dir=patches_dir, + workers=self.analyst_workers, + failure_only=self.failure_only, + minibatch_size=self.minibatch_size, + edit_budget=self.edit_budget, + random_seed=random_seed, + error_system=self.get_error_minibatch_prompt(), + success_system=self.get_success_minibatch_prompt(), + step_buffer_context=step_buffer_context, + meta_skill_context=meta_skill_context, + update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"), + ) + + def get_task_types(self) -> list[str]: + repos = { + str(item.get("repo") or "").strip() + for item in ( + self.dataloader.train_items + + self.dataloader.val_items + + self.dataloader.test_items + ) + if str(item.get("repo") or "").strip() + } + return sorted(repos) or ["swebench"] diff --git a/skillopt/envs/swebench/dataloader.py b/skillopt/envs/swebench/dataloader.py new file mode 100644 index 0000000..f4ba582 --- /dev/null +++ b/skillopt/envs/swebench/dataloader.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import json +import os +import random +from collections import defaultdict + +from skillopt.datasets.base import SplitDataLoader, _parse_split_ratio + + +_DATASET_ALIASES = { + "lite": "princeton-nlp/SWE-Bench_Lite", + "verified": "princeton-nlp/SWE-Bench_Verified", + "full": "princeton-nlp/SWE-Bench", +} + + +def _normalize_dataset_name(name: str) -> str: + key = str(name or "").strip() + return _DATASET_ALIASES.get(key.lower(), key or _DATASET_ALIASES["lite"]) + + +class SWEBenchDataLoader(SplitDataLoader): + def __init__( + self, + split_dir: str = "", + data_path: str = "", + split_mode: str = "ratio", + split_ratio: str = "2:1:7", + split_seed: int = 42, + split_output_dir: str = "", + seed: int = 42, + limit: int = 0, + dataset_name: str = "lite", + hf_split: str = "test", + **kwargs, + ) -> None: + super().__init__( + split_dir=split_dir, + data_path=data_path, + split_mode=split_mode, + split_ratio=split_ratio, + split_seed=split_seed, + split_output_dir=split_output_dir, + seed=seed, + limit=limit, + ) + self.dataset_name = dataset_name + self.hf_split = hf_split + + def setup(self, cfg: dict) -> None: + self.dataset_name = str( + self.dataset_name or cfg.get("dataset_name") or "lite" + ).strip() + self.hf_split = str(self.hf_split or cfg.get("hf_split") or "test").strip() + super().setup(cfg) + + def load_raw_items(self, data_path: str) -> list[dict]: + dataset_ref = str(data_path or "").strip() + if dataset_ref and (os.path.exists(dataset_ref) or dataset_ref.endswith(".json") or dataset_ref.endswith(".jsonl")): + return super().load_raw_items(dataset_ref) + + dataset_name = _normalize_dataset_name(dataset_ref or self.dataset_name) + from datasets import load_dataset + + ds = load_dataset(dataset_name, split=self.hf_split) + return [dict(item) for item in ds] + + def _materialize_ratio_split(self, cfg: dict) -> str: + dataset_ref = os.path.abspath(str(self.data_path or "").strip()) if str(self.data_path or "").strip() and os.path.exists(str(self.data_path or "").strip()) else str(self.data_path or "").strip() + if not dataset_ref: + dataset_ref = _normalize_dataset_name(self.dataset_name) + + items = self.load_raw_items(dataset_ref) + if not isinstance(items, list) or not items: + raise ValueError(f"No SWE-bench items available from {dataset_ref!r}") + + ratio = _parse_split_ratio(self.split_ratio) + parts = list(ratio) + total_parts = sum(parts) + rng = random.Random(self.split_seed) + + by_repo: dict[str, list[dict]] = defaultdict(list) + for item in items: + repo = str(item.get("repo") or "unknown").strip() or "unknown" + by_repo[repo].append(dict(item)) + + train_items: list[dict] = [] + val_items: list[dict] = [] + test_items: list[dict] = [] + + for repo in sorted(by_repo): + group = list(by_repo[repo]) + rng.shuffle(group) + n = len(group) + n_train = round(n * parts[0] / total_parts) + n_val = round(n * parts[1] / total_parts) + + if n >= 3: + n_train = max(1, n_train) + n_val = max(1, n_val) + elif n == 2: + n_train, n_val = 1, 0 + else: + n_train, n_val = 0, 0 + + while n_train + n_val >= n and n >= 2: + if n_val > 1: + n_val -= 1 + elif n_train > 1: + n_train -= 1 + else: + break + + train_items.extend(group[:n_train]) + val_items.extend(group[n_train:n_train + n_val]) + test_items.extend(group[n_train + n_val:]) + + rng2 = random.Random(self.split_seed + 1) + rng2.shuffle(train_items) + rng2.shuffle(val_items) + rng2.shuffle(test_items) + + split_dir = self._resolve_split_output_dir(cfg) + os.makedirs(split_dir, exist_ok=True) + self.write_split_items(os.path.join(split_dir, "train"), train_items) + self.write_split_items(os.path.join(split_dir, "val"), val_items) + self.write_split_items(os.path.join(split_dir, "test"), test_items) + + manifest = { + "source_data_path": dataset_ref, + "dataset_name": _normalize_dataset_name(self.dataset_name), + "hf_split": self.hf_split, + "split_mode": "ratio", + "split_ratio": self.split_ratio, + "split_seed": self.split_seed, + "strategy": "stratified_by_repo", + "counts": { + "train": len(train_items), + "val": len(val_items), + "test": len(test_items), + }, + } + with open(os.path.join(split_dir, "split_manifest.json"), "w", encoding="utf-8") as f: + json.dump(manifest, f, ensure_ascii=False, indent=2) + print( + f" [SWEBenchDataLoader] generated repo-stratified split {self.split_ratio} " + f"at {split_dir} from {dataset_ref}" + ) + return split_dir + diff --git a/skillopt/envs/swebench/rollout.py b/skillopt/envs/swebench/rollout.py new file mode 100644 index 0000000..db26b20 --- /dev/null +++ b/skillopt/envs/swebench/rollout.py @@ -0,0 +1,346 @@ +from __future__ import annotations + +import json +import os +import shutil +import subprocess +import sys +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + + +_DATASET_ALIASES = { + "lite": ("princeton-nlp/SWE-Bench_Lite", "SWE-bench/SWE-bench_Lite"), + "verified": ("princeton-nlp/SWE-Bench_Verified", "SWE-bench/SWE-bench_Verified"), + "full": ("princeton-nlp/SWE-Bench", "SWE-bench/SWE-bench"), +} + + +def _normalize_dataset_names(dataset_name: str) -> tuple[str, str]: + key = str(dataset_name or "lite").strip() + pair = _DATASET_ALIASES.get(key.lower()) + if pair: + return pair + return key, key + + +def _setup_litellm_env() -> None: + mapping = { + "AZURE_API_KEY": os.environ.get("AZURE_API_KEY") or os.environ.get("AZURE_OPENAI_API_KEY", ""), + "AZURE_API_BASE": os.environ.get("AZURE_API_BASE") or os.environ.get("AZURE_OPENAI_ENDPOINT", ""), + "AZURE_API_VERSION": os.environ.get("AZURE_API_VERSION") or os.environ.get("AZURE_OPENAI_API_VERSION", ""), + } + for key, value in mapping.items(): + if value and not os.environ.get(key): + os.environ[key] = value + + +def _normalize_student_model(student_model: str) -> str: + model = str(student_model or "").strip() + if not model: + return "azure/gpt-5.4" + if "/" in model: + return model + if os.environ.get("AZURE_OPENAI_ENDPOINT"): + return f"azure/{model}" + return model + + +def _load_json(path: str) -> dict | list | None: + if not os.path.exists(path): + return None + with open(path, encoding="utf-8") as f: + return json.load(f) + + +def _build_agent_config( + *, + skill_content: str, + student_model: str, + step_limit: int, + cost_limit: float, +) -> tuple[dict, str]: + try: + from minisweagent.config import get_config_from_spec + from minisweagent.utils.serialize import recursive_merge + except ImportError as exc: + raise ImportError( + "SWEBench rollout requires minisweagent. Install the mini-swe-agent environment first." + ) from exc + + base_config = get_config_from_spec("swebench.yaml") + system_template = base_config.get("agent", {}).get("system_template", "") + rendered_system = system_template + if skill_content.strip(): + rendered_system = ( + system_template.rstrip() + + "\n\n## Skill Document\n" + + "The following skill contains learned guidance for SWE-bench style bug-fixing tasks.\n\n" + + skill_content.strip() + + "\n" + ) + + agent_override = { + "agent": { + "system_template": rendered_system, + "step_limit": int(step_limit), + "cost_limit": float(cost_limit), + }, + "model": { + "model_name": _normalize_student_model(student_model), + "cost_tracking": "ignore_errors", + }, + } + return recursive_merge(base_config, agent_override), rendered_system + + +def _load_messages_from_traj(traj_path: Path) -> list[dict]: + traj_data = _load_json(str(traj_path)) + if not isinstance(traj_data, dict): + return [] + messages = traj_data.get("messages") + if not isinstance(messages, list): + return [] + return [msg for msg in messages if isinstance(msg, dict) and msg.get("role") != "system"] + + +def _load_exit_status(traj_path: Path) -> str: + traj_data = _load_json(str(traj_path)) + if not isinstance(traj_data, dict): + return "missing_traj" + info = traj_data.get("info") + if isinstance(info, dict): + return str(info.get("exit_status") or "unknown") + return "unknown" + + +def _run_rollout( + *, + items: list[dict], + predictions_dir: str, + skill_content: str, + student_model: str, + workers: int, + step_limit: int, + cost_limit: float, +) -> tuple[list[dict], str]: + try: + from minisweagent.run.benchmarks.swebench import process_instance + from minisweagent.run.benchmarks.utils.batch_progress import RunBatchProgressManager + except ImportError as exc: + raise ImportError( + "SWEBench rollout requires minisweagent with swebench benchmark support." + ) from exc + + _setup_litellm_env() + config, system_prompt = _build_agent_config( + skill_content=skill_content, + student_model=student_model, + step_limit=step_limit, + cost_limit=cost_limit, + ) + + out_path = Path(predictions_dir) + out_path.mkdir(parents=True, exist_ok=True) + preds_path = out_path / "preds.json" + done_ids: set[str] = set() + if preds_path.exists(): + data = _load_json(str(preds_path)) + if isinstance(data, dict): + done_ids = set(data.keys()) + + pending = [item for item in items if str(item.get("instance_id")) not in done_ids] + progress_manager = RunBatchProgressManager( + len(pending), + out_path / f"exit_statuses_{int(time.time())}.yaml", + ) + + task_errors: dict[str, str] = {} + + def _process(instance: dict) -> None: + process_instance(instance, out_path, config, progress_manager) + + with ThreadPoolExecutor(max_workers=max(int(workers), 1)) as executor: + futures = { + executor.submit(_process, item): str(item.get("instance_id")) + for item in pending + } + for fut in as_completed(futures): + iid = futures[fut] + try: + fut.result() + except Exception as exc: # noqa: BLE001 + task_errors[iid] = str(exc) + + preds_data = _load_json(str(preds_path)) + preds_dict = preds_data if isinstance(preds_data, dict) else {} + results: list[dict] = [] + + for item in items: + iid = str(item.get("instance_id")) + pred = preds_dict.get(iid, {}) if isinstance(preds_dict, dict) else {} + traj_path = out_path / iid / f"{iid}.traj.json" + messages = _load_messages_from_traj(traj_path) + task_dir = out_path / iid + task_dir.mkdir(parents=True, exist_ok=True) + user_prompt = ( + f"Repository: {item.get('repo', '')}\n\n" + f"Issue:\n{item.get('problem_statement', '').strip()}" + ).strip() + with open(task_dir / "conversation.json", "w", encoding="utf-8") as f: + json.dump(messages, f, ensure_ascii=False, indent=2) + with open(task_dir / "student_system_prompt.txt", "w", encoding="utf-8") as f: + f.write(system_prompt) + with open(task_dir / "student_user_prompt.txt", "w", encoding="utf-8") as f: + f.write(user_prompt) + + results.append( + { + "id": iid, + "instance_id": iid, + "repo": str(item.get("repo") or "").strip(), + "task_type": str(item.get("repo") or "swebench").strip() or "swebench", + "task_description": str(item.get("problem_statement") or "").strip(), + "instruction": str(item.get("problem_statement") or "").strip(), + "hard": 0, + "soft": 0.0, + "response": str(pred.get("model_patch") or ""), + "submission": str(pred.get("model_patch") or ""), + "predicted_patch": str(pred.get("model_patch") or ""), + "agent_ok": bool(messages), + "n_turns": sum(1 for msg in messages if msg.get("role") == "assistant"), + "fail_reason": task_errors.get(iid, ""), + "exit_status": _load_exit_status(traj_path), + } + ) + + return results, str(preds_path) + + +def _run_evaluation( + *, + preds_path: str, + dataset_name: str, + split: str, + run_id: str, + eval_workers: int, + report_dir: str, + instance_ids: list[str], +) -> dict: + _, eval_dataset = _normalize_dataset_names(dataset_name) + os.makedirs(report_dir, exist_ok=True) + + preds_data = _load_json(preds_path) + model_name = "unknown" + if isinstance(preds_data, dict) and preds_data: + first_pred = next(iter(preds_data.values())) + if isinstance(first_pred, dict): + model_name = str(first_pred.get("model_name_or_path") or "unknown") + expected_report = os.path.join(report_dir, f"{model_name.replace('/', '__')}.{run_id}.json") + if os.path.exists(expected_report): + cached = _load_json(expected_report) + return cached if isinstance(cached, dict) else {} + + cmd = [ + sys.executable, + "-m", + "swebench.harness.run_evaluation", + "--dataset_name", + eval_dataset, + "--split", + split, + "--predictions_path", + preds_path, + "--max_workers", + str(max(int(eval_workers), 1)), + "--run_id", + run_id, + ] + if instance_ids: + cmd.extend(["--instance_ids"] + instance_ids) + + subprocess.run( + cmd, + cwd=report_dir, + capture_output=True, + text=True, + timeout=7200, + check=False, + ) + + if os.path.exists(expected_report): + report = _load_json(expected_report) + return report if isinstance(report, dict) else {} + + for name in sorted(os.listdir(report_dir)): + if name.endswith(".json") and run_id in name: + report = _load_json(os.path.join(report_dir, name)) + if isinstance(report, dict): + if os.path.join(report_dir, name) != expected_report: + shutil.move(os.path.join(report_dir, name), expected_report) + return report + return {"resolved_ids": [], "total_instances": len(instance_ids), "resolved_instances": 0} + + +def run_batch( + *, + items: list[dict], + out_root: str, + skill_content: str, + student_model: str, + dataset_name: str, + hf_split: str, + workers: int, + eval_workers: int, + step_limit: int, + cost_limit: float, + timeout_per_instance: int, +) -> list[dict]: + os.makedirs(out_root, exist_ok=True) + results_path = os.path.join(out_root, "results.jsonl") + if os.path.exists(results_path): + cached: list[dict] = [] + with open(results_path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + cached.append(json.loads(line)) + if cached: + return cached + + predictions_dir = os.path.join(out_root, "predictions") + results, preds_path = _run_rollout( + items=items, + predictions_dir=predictions_dir, + skill_content=skill_content, + student_model=student_model, + workers=workers, + step_limit=step_limit, + cost_limit=cost_limit, + ) + eval_report = _run_evaluation( + preds_path=preds_path, + dataset_name=dataset_name, + split=hf_split, + run_id=f"skillopt_{int(time.time())}", + eval_workers=eval_workers, + report_dir=os.path.join(out_root, "evaluation"), + instance_ids=[str(item.get("instance_id")) for item in items], + ) + resolved_ids = set(str(i) for i in eval_report.get("resolved_ids", [])) + for row in results: + resolved = str(row["instance_id"]) in resolved_ids + row["hard"] = int(resolved) + row["soft"] = float(int(resolved)) + if not resolved: + status = row.get("exit_status") or "not_resolved" + base_reason = str(row.get("fail_reason") or "").strip() + unresolved = f"swebench unresolved ({status})" + row["fail_reason"] = f"{base_reason}; {unresolved}" if base_reason else unresolved + row["timeout_per_instance"] = int(timeout_per_instance) + + with open(results_path, "w", encoding="utf-8") as f: + for row in results: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + return results diff --git a/skillopt/envs/swebench/skills/initial.md b/skillopt/envs/swebench/skills/initial.md new file mode 100644 index 0000000..8411917 --- /dev/null +++ b/skillopt/envs/swebench/skills/initial.md @@ -0,0 +1,23 @@ +# SWE-bench Bug Fixing Skill + +## Overview +This skill guides agents in resolving real-world GitHub issues by producing correct patches. + +**Goal**: Given a repository and an issue description, produce a minimal, correct `git diff` patch that resolves the issue without modifying test files. + +## Workflow + +1. Understand the issue. Read the problem statement carefully and restate the expected behavior before editing code. +2. Locate relevant code. Use targeted search to identify the files, functions, and tests that encode the buggy behavior. +3. Reproduce the issue. Build a small, local reproduction before changing source files when feasible. +4. Implement the fix. Make the smallest source change that addresses the root cause. +5. Verify the fix. Re-run the reproduction and any focused checks needed to confirm the change. +6. Submit the patch. Generate a clean unified diff of only the source files you modified. + +## Key Rules + +- Keep changes minimal and directly tied to the bug. +- Do not modify tests, fixtures, or unrelated configuration unless the issue explicitly requires it. +- Prefer understanding the code path before patching. +- Verify behavior after editing instead of relying on intuition. +- The final submission must be a valid unified diff. diff --git a/skillopt/evaluation/__init__.py b/skillopt/evaluation/__init__.py new file mode 100644 index 0000000..87e0e1f --- /dev/null +++ b/skillopt/evaluation/__init__.py @@ -0,0 +1,7 @@ +"""ReflACT Evaluation -- candidate skill validation and model selection. + +Analogous to validation-based early stopping and model selection in neural +network training: evaluates candidate skills on held-out selection sets and +decides whether to accept or reject proposed updates. +""" +from skillopt.evaluation.gate import evaluate_gate, GateAction, GateResult # noqa: F401 diff --git a/skillopt/evaluation/gate.py b/skillopt/evaluation/gate.py new file mode 100644 index 0000000..f4f2c40 --- /dev/null +++ b/skillopt/evaluation/gate.py @@ -0,0 +1,73 @@ +"""Validation gate — accept / reject candidate skills. + +Analogous to validation-based early stopping and model selection in neural +network training: compares the candidate's score against the current and +best scores, then returns an accept/reject decision. + +The trainer owns side-effects (cache lookup, rollout, printing, state +mutation). This module is the pure decision function. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal + + +GateAction = Literal["accept_new_best", "accept", "reject"] + + +@dataclass(frozen=True) +class GateResult: + """Immutable outcome of the validation gate.""" + + action: GateAction + current_skill: str + current_score: float + best_skill: str + best_score: float + best_step: int + + +def evaluate_gate( + candidate_skill: str, + cand_hard: float, + current_skill: str, + current_score: float, + best_skill: str, + best_score: float, + best_step: int, + global_step: int, +) -> GateResult: + """Pure gate decision: compare candidate score to current/best. + + Returns a *GateResult* with updated state; the caller decides what + to do with it (print, mutate trainer state, log, etc.). + """ + if cand_hard > current_score: + new_current_skill = candidate_skill + new_current_score = cand_hard + if cand_hard > best_score: + return GateResult( + action="accept_new_best", + current_skill=new_current_skill, + current_score=new_current_score, + best_skill=candidate_skill, + best_score=cand_hard, + best_step=global_step, + ) + return GateResult( + action="accept", + current_skill=new_current_skill, + current_score=new_current_score, + best_skill=best_skill, + best_score=best_score, + best_step=best_step, + ) + return GateResult( + action="reject", + current_skill=current_skill, + current_score=current_score, + best_skill=best_skill, + best_score=best_score, + best_step=best_step, + ) diff --git a/skillopt/gradient/__init__.py b/skillopt/gradient/__init__.py new file mode 100644 index 0000000..65b5416 --- /dev/null +++ b/skillopt/gradient/__init__.py @@ -0,0 +1,17 @@ +"""ReflACT Gradient -- trajectory analysis and patch generation. + +Analogous to gradient computation in neural network training: analyzes +minibatch rollout trajectories to produce skill-edit patches (the "gradient" +that drives skill updates). + +Modules +------- +- reflect: minibatch trajectory analysis (gradient computation) +- aggregate: hierarchical patch merging (gradient aggregation) +- deep_probe: diagnostic probe generation (gradient probing) +""" +from skillopt.gradient.reflect import ( # noqa: F401 + run_minibatch_reflect, +) +from skillopt.gradient.aggregate import merge_patches # noqa: F401 +from skillopt.gradient.deep_probe import generate_deep_probe_instruction # noqa: F401 diff --git a/skillopt/gradient/aggregate.py b/skillopt/gradient/aggregate.py new file mode 100644 index 0000000..43ef74c --- /dev/null +++ b/skillopt/gradient/aggregate.py @@ -0,0 +1,253 @@ +"""ReflACT Aggregate stage — hierarchical patch merging. + +The Aggregate stage takes independently-generated patches from the Reflect +stage and merges them into a single coherent patch via hierarchical LLM calls. +Failure-driven patches take priority over success-driven ones. +""" +from __future__ import annotations + +import json +from concurrent.futures import ThreadPoolExecutor, as_completed + +from skillopt.model import chat_teacher +from skillopt.optimizer.meta_skill import format_meta_skill_context +from skillopt.optimizer.update_modes import ( + get_payload_items, + is_full_rewrite_minibatch_mode, + is_rewrite_mode, + normalize_update_mode, + payload_key, + payload_label, +) +from skillopt.prompts import load_prompt +from skillopt.utils import extract_json + + +# ── Internal helpers ────────────────────────────────────────────────────────── + +def _merge_batch( + skill_content: str, + patches: list[dict], + system_prompt: str, + update_mode: str, + meta_skill_context: str = "", + level: int = 1, +) -> dict: + """Call teacher LLM to merge a batch of patches into one.""" + patches_text = json.dumps(patches, ensure_ascii=False, indent=2) + user = ( + f"## Current Skill\n{skill_content}\n\n" + f"## Patches to merge ({len(patches)} total, merge level {level})\n{patches_text}" + ) + teacher_ctx = format_meta_skill_context(meta_skill_context) + if teacher_ctx: + user = f"{teacher_ctx}\n\n{user}" + try: + response, _ = chat_teacher( + system=system_prompt, + user=user, + max_completion_tokens=64000 if is_full_rewrite_minibatch_mode(update_mode) else 4096, + retries=3, + stage="merge", + ) + merged = extract_json(response) + key = payload_key(update_mode) + if merged and key in merged: + for e in merged.get(key, []): + e["merge_level"] = level + return merged + except Exception: # noqa: BLE001 + pass + # Fallback: concatenate all edits + all_edits = [] + for p in patches: + for e in get_payload_items(p, update_mode): + e.setdefault("merge_level", level) + all_edits.append(e) + return {"reasoning": "fallback concatenation", payload_key(update_mode): all_edits} + + +def _hierarchical_merge( + skill_content: str, + patches: list[dict], + system_prompt: str, + update_mode: str, + batch_size: int, + verbose: bool, + label: str = "", + workers: int = 16, + meta_skill_context: str = "", +) -> dict: + """Hierarchically merge N patches using the given system prompt. + + Same-level batches are executed in PARALLEL via ThreadPoolExecutor. + """ + if not patches: + return {"reasoning": "no patches", payload_key(update_mode): []} + if len(patches) == 1: + return patches[0] + + current = list(patches) + level = 0 + while len(current) > 1: + level += 1 + batches: list[tuple[int, list[dict]]] = [] + for i in range(0, len(current), batch_size): + batch = current[i : i + batch_size] + batches.append((i, batch)) + + if verbose: + print( + f" [aggregate {label}] level={level} " + f"{len(current)} patches → {len(batches)} batches " + f"(parallel, batch_size={batch_size})" + ) + + next_level: list[dict | None] = [None] * len(batches) + + to_merge: list[tuple[int, list[dict]]] = [] + for idx, (i, batch) in enumerate(batches): + if len(batch) == 1: + next_level[idx] = batch[0] + else: + to_merge.append((idx, batch)) + + if to_merge: + with ThreadPoolExecutor(max_workers=workers) as ex: + futs = { + ex.submit( + _merge_batch, skill_content, batch, system_prompt, update_mode, + meta_skill_context, level, + ): idx + for idx, batch in to_merge + } + for fut in as_completed(futs): + idx = futs[fut] + next_level[idx] = fut.result() + if verbose: + batch_i, batch_data = batches[idx] + n_edits = len(get_payload_items(next_level[idx], update_mode)) + print( + f" [aggregate {label}] level={level} " + f"batch [{batch_i}:{batch_i+len(batch_data)}] " + f"→ 1 patch ({n_edits} {payload_label(update_mode)})" + ) + + current = [x for x in next_level if x is not None] + + return current[0] + + +# ── Public API ──────────────────────────────────────────────────────────────── + +def merge_patches( + skill_content: str, + failure_patches: list[dict], + success_patches: list[dict], + batch_size: int = 8, + verbose: bool = True, + workers: int = 16, + update_mode: str = "patch", + meta_skill_context: str = "", +) -> dict: + """Failure-first hierarchical merge with support count tracking. + + 1. Merge failure patches independently (parallel) + 2. Merge success patches independently (parallel) + 3. Final merge: combine both groups with failure priority + + Returns a merged :class:`~skillopt.types.Patch` dict (``edits`` + ``reasoning``). + """ + if verbose: + print( + f" [3/6 AGGREGATE] " + f"failure={len(failure_patches)} success={len(success_patches)} " + f"(parallel, workers={workers})" + ) + + update_mode = normalize_update_mode(update_mode) + if is_full_rewrite_minibatch_mode(update_mode): + merge_failure_prompt = load_prompt("merge_failure_full_rewrite") + merge_success_prompt = load_prompt("merge_success_full_rewrite") + merge_final_prompt = load_prompt("merge_final_full_rewrite") + elif is_rewrite_mode(update_mode): + merge_failure_prompt = load_prompt("merge_failure_rewrite") + merge_success_prompt = load_prompt("merge_success_rewrite") + merge_final_prompt = load_prompt("merge_final_rewrite") + else: + merge_failure_prompt = load_prompt("merge_failure") + merge_success_prompt = load_prompt("merge_success") + merge_final_prompt = load_prompt("merge_final") + + failure_merged = _hierarchical_merge( + skill_content, failure_patches, merge_failure_prompt, update_mode, + batch_size, verbose, label="failure", workers=workers, + meta_skill_context=meta_skill_context, + ) + + success_merged = _hierarchical_merge( + skill_content, success_patches, merge_success_prompt, update_mode, + batch_size, verbose, label="success", workers=workers, + meta_skill_context=meta_skill_context, + ) + + f_edits = get_payload_items(failure_merged, update_mode) + s_edits = get_payload_items(success_merged, update_mode) + + if not f_edits and not s_edits: + return {"reasoning": "no updates from either group", payload_key(update_mode): []} + if not s_edits: + return failure_merged + if not f_edits: + return success_merged + + combined_patches = [failure_merged, success_merged] + combined_text = json.dumps(combined_patches, ensure_ascii=False, indent=2) + if is_full_rewrite_minibatch_mode(update_mode): + item_label = payload_label(update_mode) + user = ( + f"## Current Skill\n{skill_content}\n\n" + f"## Two pre-merged candidate groups to combine\n" + f"Group 1 (from failed trajectories): " + f"{len(f_edits)} {item_label}\n" + f"Group 2 (from successful trajectories): " + f"{len(s_edits)} {item_label}\n\n" + f"{combined_text}" + ) + else: + user = ( + f"## Current Skill\n{skill_content}\n\n" + f"## Two pre-merged patch groups to combine\n" + f"Group 1 (failure-driven, HIGH priority): " + f"{len(f_edits)} edits\n" + f"Group 2 (success-driven, lower priority): " + f"{len(s_edits)} edits\n\n" + f"{combined_text}" + ) + teacher_ctx = format_meta_skill_context(meta_skill_context) + if teacher_ctx: + user = f"{teacher_ctx}\n\n{user}" + try: + response, _ = chat_teacher( + system=merge_final_prompt, + user=user, + max_completion_tokens=64000 if is_full_rewrite_minibatch_mode(update_mode) else 4096, + retries=3, + stage="merge", + ) + final = extract_json(response) + key = payload_key(update_mode) + if final and key in final: + if verbose: + print( + f" [aggregate final] " + f"{len(f_edits)}+{len(s_edits)} → {len(final[key])} {payload_label(update_mode)}" + ) + return final + except Exception: # noqa: BLE001 + pass + + return { + "reasoning": "fallback: failure first, then success", + payload_key(update_mode): f_edits + s_edits, + } diff --git a/skillopt/gradient/deep_probe.py b/skillopt/gradient/deep_probe.py new file mode 100644 index 0000000..e732272 --- /dev/null +++ b/skillopt/gradient/deep_probe.py @@ -0,0 +1,77 @@ +"""Teacher-written diagnostic probe generation for deep reflection.""" +from __future__ import annotations + +from skillopt.gradient.reflect import fmt_minibatch_trajectories +from skillopt.model import chat_teacher +from skillopt.optimizer.meta_skill import format_meta_skill_context +from skillopt.prompts import load_prompt +from skillopt.utils import extract_json + + +def generate_deep_probe_instruction( + skill_content: str, + items: list[dict], + prediction_dir: str, + *, + system_prompt: str | None = None, + step_buffer_context: str = "", + output_requirements: list[str] | None = None, + meta_skill_context: str = "", +) -> dict | None: + """Generate one minimally-perturbing diagnostic probe instruction.""" + trajectories_text = fmt_minibatch_trajectories(items, prediction_dir) + if not trajectories_text.strip(): + return None + + actual_system = system_prompt or load_prompt("deep_probe") + user = ( + f"## Current Skill\n{skill_content}\n\n" + "## Probe Design Goal\n" + "Design one short diagnostic instruction to append to the student prompt.\n" + "The instruction should expose the student's current intermediate judgment\n" + "without materially changing the original scaffold.\n\n" + ) + if step_buffer_context.strip(): + user += f"## Previous Steps in This Epoch\n{step_buffer_context}\n\n" + teacher_ctx = format_meta_skill_context(meta_skill_context) + if teacher_ctx: + user += teacher_ctx + "\n\n" + requirements = output_requirements or [ + "- Some trajectories may include a hidden Reference block. Use it to identify what intermediate conclusion matters, but do not reveal or paraphrase that reference directly to the student.", + "- The instruction must explicitly request a short ... block before the final ....", + "- Keep the readout concise and structured.", + "- Do not ask for exhaustive listing, full derivation, or a new solving protocol.", + "- The instruction text should be ready to append directly to the student's prompt.", + ] + user += ( + f"## Representative Trajectories ({len(items)} total)\n{trajectories_text}\n\n" + "## Output Requirements\n" + + "\n".join(requirements) + + "\n" + ) + + try: + response, _ = chat_teacher( + system=actual_system, + user=user, + max_completion_tokens=1024, + retries=3, + stage="deep_probe", + ) + result = extract_json(response) + if result and str(result.get("probe_instruction", "")).strip(): + parsed = { + "reasoning": str(result.get("reasoning", "")).strip(), + "probe_instruction": str(result.get("probe_instruction", "")).strip(), + } + if str(result.get("probe_target_id", "")).strip(): + parsed["probe_target_id"] = str(result.get("probe_target_id", "")).strip() + try: + if result.get("probe_after_step") is not None: + parsed["probe_after_step"] = int(result.get("probe_after_step")) + except Exception: # noqa: BLE001 + pass + return parsed + except Exception: # noqa: BLE001 + return None + return None diff --git a/skillopt/gradient/reflect.py b/skillopt/gradient/reflect.py new file mode 100644 index 0000000..4579aed --- /dev/null +++ b/skillopt/gradient/reflect.py @@ -0,0 +1,588 @@ +"""ReflACT core Reflect engine -- minibatch trajectory analysis. + +Provides environment-agnostic minibatch trajectory analysis: instead of +analyzing each trajectory independently, trajectories are grouped into +minibatches of size M and analyzed together -- analogous to minibatch SGD +vs per-sample SGD in neural network training. + +Two-level prompt priority system: + +1. **Custom prompt** (adapter returns non-None) -- used as-is. +2. **Generic default prompt** (adapter returns None) -- built-in defaults + that work for any environment without configuration. + +Public API +---------- +- :func:`fmt_trajectory` -- format one conversation into text +- :func:`fmt_minibatch_trajectories` -- format multiple trajectories for batch analysis +- :func:`run_error_analyst_minibatch` -- one teacher call for a group of failures +- :func:`run_success_analyst_minibatch` -- one teacher call for a group of successes +- :func:`run_minibatch_reflect` -- full reflect stage dispatcher +""" +from __future__ import annotations + +import json +import os +import random +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed + +from skillopt.model import chat_teacher +from skillopt.optimizer.meta_skill import format_meta_skill_context +from skillopt.optimizer.update_modes import ( + get_payload_items, + is_full_rewrite_minibatch_mode, + normalize_update_mode, + payload_key, + payload_label, + truncate_payload, +) +from skillopt.prompts import load_prompt +from skillopt.utils import extract_json + + +# ── Trajectory formatting ──────────────────────────────────────────────────── + +_MAX_TRAJ_CHARS = 12_000 + + +def _clip_text(value, limit: int) -> str: + """Render optional trajectory fields safely before truncation.""" + if value is None: + return "" + return str(value)[:limit] + + +def fmt_trajectory( + conversation: list[dict], + max_chars: int = _MAX_TRAJ_CHARS, +) -> str: + """Format a conversation list into analyst-readable text. + + Accepts two common formats: + + 1. Tool-call records: ``{"type": "tool_call", "cmd": ..., "obs": ...}`` + 2. Step records: ``{"step": N, "action": ..., "env_feedback": ..., "reasoning": ...}`` + + Any other dict is rendered via its ``"content"`` key. + """ + lines: list[str] = [] + for item in conversation: + if not isinstance(item, dict): + lines.append(f"[agent] {_clip_text(item, 500)}") + continue + if item.get("type") == "tool_call": + cmd = _clip_text(item.get("cmd"), 500) + obs = _clip_text(item.get("obs"), 800) + lines.append(f"[action] {cmd}") + lines.append(f"[obs] {obs}") + elif "action" in item and "env_feedback" in item: + step = item.get("step", "?") + reasoning = _clip_text(item.get("reasoning"), 300) + action = _clip_text(item.get("action"), 200) + feedback = _clip_text(item.get("env_feedback"), 500) + if reasoning: + lines.append(f"[step {step} think] {reasoning}") + lines.append(f"[step {step} action] {action}") + lines.append(f"[step {step} obs] {feedback}") + elif item.get("role") == "system": + # Post-execution verification / enrichment info + msg = _clip_text(item.get("content"), 2000) + lines.append(f"[verification] {msg}") + else: + msg = _clip_text(item.get("content"), 500) + role = item.get("role", "agent") + lines.append(f"[{role}] {msg}") + + text = "\n".join(lines) + if len(text) > max_chars: + head = text[: max_chars // 2] + tail = text[-max_chars // 2 :] + text = head + "\n...[middle truncated]...\n" + tail + return text + + +# ── Minibatch trajectory formatting ────────────────────────────────────────── + + +def fmt_minibatch_trajectories( + items: list[dict], + prediction_dir: str, +) -> str: + """Format multiple trajectories for minibatch analyst consumption. + + Each item is a rollout result dict with ``"id"``, ``"task_description"``, + ``"task_type"``, ``"fail_reason"``, etc. Reads ``conversation.json`` + for each and formats them together with trajectory headers. + + If available, includes the spreadsheet preview and student system prompt + so the analyst can see what the agent saw. + + Parameters + ---------- + items : list[dict] + Rollout result dicts belonging to one minibatch. + prediction_dir : str + Path to ``predictions/`` directory containing per-task + ``/conversation.json`` files. + + Returns + ------- + str + Formatted text with all trajectories separated by ``---``. + """ + parts: list[str] = [] + for idx, item in enumerate(items, 1): + tid = str(item["id"]) + conv_path = os.path.join(prediction_dir, tid, "conversation.json") + if not os.path.exists(conv_path): + continue + with open(conv_path) as f: + conversation = json.load(f) + if not conversation: + continue + + traj_text = fmt_trajectory(conversation) + header = ( + f"### Trajectory {idx} (id={tid})\n" + f"Task: {item.get('task_description', item.get('instruction', ''))}\n" + f"Task type: {item.get('task_type', item.get('instruction_type', ''))}\n" + ) + fail_reason = item.get("fail_reason", "") + if fail_reason: + header += f"Failure reason: {fail_reason}\n" + header += f"Steps: {item.get('n_turns', '?')}\n" + + reference_text = str(item.get("reference_text") or "").strip() + if reference_text: + header += ( + f"\n#### Hidden Reference\n" + f"{reference_text[:4000]}\n" + ) + + # ── Append student context (what the agent saw) ────────────── + student_prompt = item.get("student_system_prompt", "") + if not student_prompt: + prompt_path = os.path.join(prediction_dir, tid, "student_system_prompt.txt") + if os.path.exists(prompt_path): + with open(prompt_path) as f: + student_prompt = f.read() + if student_prompt: + header += ( + f"\n#### Student System Prompt\n" + f"{student_prompt[:3000]}\n" + ) + + user_prompt = item.get("student_user_prompt", "") + if not user_prompt: + user_prompt_path = os.path.join(prediction_dir, tid, "student_user_prompt.txt") + if os.path.exists(user_prompt_path): + with open(user_prompt_path) as f: + user_prompt = f.read() + if user_prompt: + header += ( + f"\n#### Student User Prompt\n" + f"{user_prompt[:3000]}\n" + ) + + if os.environ.get("REFLACT_CODEX_TRACE_TO_TEACHER", "0") == "1": + codex_trace_summary = item.get("codex_trace_summary", "") + if not codex_trace_summary: + codex_trace_summary_path = os.path.join(prediction_dir, tid, "codex_trace_summary.txt") + if os.path.exists(codex_trace_summary_path): + with open(codex_trace_summary_path) as f: + codex_trace_summary = f.read() + if codex_trace_summary: + header += ( + f"\n#### Codex Trace Summary\n" + f"{codex_trace_summary}\n" + ) + + codex_probe_trace_steps = str(item.get("codex_probe_trace_steps") or "").strip() + if codex_probe_trace_steps: + header += ( + f"\n#### Codex Trace Steps\n" + f"{codex_probe_trace_steps}\n" + ) + + preview = item.get("spreadsheet_preview", "") + if not preview: + preview_path = os.path.join(prediction_dir, tid, "spreadsheet_preview.txt") + if os.path.exists(preview_path): + with open(preview_path) as f: + preview = f.read() + if preview: + header += ( + f"\n#### Spreadsheet Preview\n" + f"{preview[:3000]}\n" + ) + + parts.append(header + "\n" + traj_text) + + return "\n\n---\n\n".join(parts) + + +# ── Prompt resolution ─────────────────────────────────────────────────────── + + +def _resolve_prompt(custom: str | None, default_name: str, update_mode: str = "patch") -> str: + """Return *custom* if provided (non-None), otherwise load from file.""" + if custom is not None: + return custom + mode = normalize_update_mode(update_mode) + actual_name = default_name + if is_full_rewrite_minibatch_mode(mode): + full_name = f"{default_name}_full_rewrite" + try: + return load_prompt(full_name) + except FileNotFoundError: + actual_name = default_name + elif mode == "rewrite_from_suggestions": + rewrite_name = f"{default_name}_rewrite" + try: + return load_prompt(rewrite_name) + except FileNotFoundError: + actual_name = default_name + return load_prompt(actual_name) + + +# ── Minibatch analysts ────────────────────────────────────────────────────── + + +def run_error_analyst_minibatch( + skill_content: str, + items: list[dict], + prediction_dir: str, + edit_budget: int = 4, + *, + system_prompt: str | None = None, + rejection_context: str = "", + trajectory_memory_context: str = "", + step_buffer_context: str = "", + meta_skill_context: str = "", + update_mode: str = "patch", +) -> dict | None: + """Analyze a minibatch of failed trajectories in one teacher call. + + Parameters + ---------- + skill_content : str + Current skill document text. + items : list[dict] + Rollout result dicts (all should have ``hard=0``). + prediction_dir : str + Path to ``predictions/`` directory. + edit_budget : int + Maximum number of edits (L). + system_prompt : str | None + Custom system prompt. ``None`` = use generic default. + rejection_context : str + *Deprecated* — use ``step_buffer_context``. + trajectory_memory_context : str + *Deprecated* — use ``step_buffer_context``. + step_buffer_context : str + Unified summary of previous steps (failure patterns + rejected edits). + + Returns + ------- + dict | None + Patch dict with ``source_type="failure"``, or ``None`` on error. + """ + mode = normalize_update_mode(update_mode) + actual_system = _resolve_prompt(system_prompt, "analyst_error", mode) + + trajectories_text = fmt_minibatch_trajectories(items, prediction_dir) + if not trajectories_text.strip(): + return None + + user = ( + f"## Current Skill\n{skill_content}\n\n" + ) + if is_full_rewrite_minibatch_mode(mode): + user += ( + f"## Update Format\n" + f"Produce one complete replacement skill candidate for this minibatch. " + f"Do not output edits, patches, or revise suggestions.\n\n" + ) + else: + user += ( + f"## {payload_label(mode, title=True)} Budget\n" + f"Produce at most L={edit_budget} {payload_label(mode)}.\n\n" + ) + # Unified step buffer context (preferred) + ctx = step_buffer_context or rejection_context or "" + if trajectory_memory_context: + ctx = f"{ctx}\n{trajectory_memory_context}" if ctx else trajectory_memory_context + if ctx.strip(): + user += f"## Previous Steps in This Epoch\n{ctx}\n\n" + teacher_ctx = format_meta_skill_context(meta_skill_context) + if teacher_ctx: + user += teacher_ctx + "\n\n" + user += f"## Failed Trajectories ({len(items)} total)\n{trajectories_text}" + + try: + response, _ = chat_teacher( + system=actual_system, user=user, + max_completion_tokens=64000 if is_full_rewrite_minibatch_mode(mode) else 4096, + retries=3, + stage="analyst", + ) + result = extract_json(response) + if result and "patch" in result: + result["source_type"] = "failure" + if not is_full_rewrite_minibatch_mode(mode): + truncate_payload(result["patch"], edit_budget, mode) + return result + except Exception: # noqa: BLE001 + traceback.print_exc() + return None + + +def run_success_analyst_minibatch( + skill_content: str, + items: list[dict], + prediction_dir: str, + edit_budget: int = 4, + *, + system_prompt: str | None = None, + trajectory_memory_context: str = "", + step_buffer_context: str = "", + meta_skill_context: str = "", + update_mode: str = "patch", +) -> dict | None: + """Analyze a minibatch of successful trajectories in one teacher call. + + Parameters + ---------- + system_prompt : str | None + Custom system prompt. ``None`` = use generic default. + trajectory_memory_context : str + *Deprecated* — use ``step_buffer_context``. + step_buffer_context : str + Unified summary of previous steps (failure patterns + rejected edits). + + Returns + ------- + dict | None + Patch dict with ``source_type="success"``, or ``None`` on error. + """ + mode = normalize_update_mode(update_mode) + actual_system = _resolve_prompt(system_prompt, "analyst_success", mode) + + trajectories_text = fmt_minibatch_trajectories(items, prediction_dir) + if not trajectories_text.strip(): + return None + + user = ( + f"## Current Skill\n{skill_content}\n\n" + ) + if is_full_rewrite_minibatch_mode(mode): + user += ( + f"## Update Format\n" + f"Produce one complete replacement skill candidate for this minibatch. " + f"Do not output edits, patches, or revise suggestions.\n\n" + ) + else: + user += ( + f"## {payload_label(mode, title=True)} Budget\n" + f"Produce at most L={edit_budget} {payload_label(mode)}.\n\n" + ) + ctx = step_buffer_context or trajectory_memory_context or "" + if ctx.strip(): + user += f"## Previous Steps in This Epoch\n{ctx}\n\n" + teacher_ctx = format_meta_skill_context(meta_skill_context) + if teacher_ctx: + user += teacher_ctx + "\n\n" + user += f"## Successful Trajectories ({len(items)} total)\n{trajectories_text}" + + try: + response, _ = chat_teacher( + system=actual_system, user=user, + max_completion_tokens=64000 if is_full_rewrite_minibatch_mode(mode) else 4096, + retries=3, + stage="analyst", + ) + result = extract_json(response) + if result and "patch" in result: + result["source_type"] = "success" + if not is_full_rewrite_minibatch_mode(mode): + truncate_payload(result["patch"], edit_budget, mode) + return result + except Exception: # noqa: BLE001 + traceback.print_exc() + return None + + +# ── Minibatch reflect dispatcher ──────────────────────────────────────────── + + +def _split_minibatches(items: list, batch_size: int) -> list[list]: + """Split items into minibatches of at most *batch_size*.""" + return [items[i : i + batch_size] for i in range(0, len(items), batch_size)] + + +def _shuffle_for_minibatch(items: list, seed: int | None) -> list: + """Return items in minibatch order. + + Uses a deterministic shuffle when a seed is provided so resume runs keep + the same minibatch composition. Falls back to input order when no seed is + available. + """ + ordered = list(items) + if seed is None: + return ordered + random.Random(seed).shuffle(ordered) + return ordered + + +def run_minibatch_reflect( + results: list[dict], + skill_content: str, + prediction_dir: str, + patches_dir: str, + workers: int, + failure_only: bool, + minibatch_size: int = 8, + edit_budget: int = 4, + random_seed: int | None = None, + *, + error_system: str | None = None, + success_system: str | None = None, + rejection_context: str = "", + trajectory_memory_context: str = "", + step_buffer_context: str = "", + meta_skill_context: str = "", + update_mode: str = "patch", +) -> list[dict | None]: + """Full minibatch reflect stage: group → parallel teacher calls → patches. + + Separates failure and success trajectories, splits each into minibatches + of size M, runs all minibatches in parallel, and saves patch files. + + Parameters + ---------- + results : list[dict] + Rollout result dicts; see :class:`~skillopt.types.RolloutResult`. + skill_content : str + Current skill document. + prediction_dir : str + Path to ``predictions/`` with ``conversation.json`` files. + patches_dir : str + Path to save per-minibatch patch JSON files. + workers : int + Max parallel teacher calls. + failure_only : bool + If True, skip success trajectories. + minibatch_size : int + Trajectories per group (M). + edit_budget : int + Max edits per minibatch (L). + random_seed : int | None + Optional seed used to shuffle trajectories before minibatch splitting. + error_system, success_system : str | None + Optional custom prompts. ``None`` = use generic defaults. + + Returns + ------- + list[dict | None] + Patch dicts (with ``source_type`` "failure" or "success"). + """ + os.makedirs(patches_dir, exist_ok=True) + + # Separate failure / success + failures = [r for r in results if not r.get("hard")] + successes = [r for r in results if r.get("hard")] if not failure_only else [] + + failures = _shuffle_for_minibatch(failures, random_seed) + successes = _shuffle_for_minibatch(successes, None if random_seed is None else random_seed + 1) + + # Split into minibatches + fail_batches = _split_minibatches(failures, minibatch_size) + succ_batches = _split_minibatches(successes, minibatch_size) + + n_fail_batches = len(fail_batches) + n_succ_batches = len(succ_batches) + print( + f" [2/6 REFLECT minibatch] " + f"failure={len(failures)}→{n_fail_batches} groups " + f"success={len(successes)}→{n_succ_batches} groups " + f"(M={minibatch_size}, L={edit_budget}, workers={workers})" + ) + + raw_patches: list[dict | None] = [] + + # Resume support: check for already-done minibatch patches + pending_fail: list[tuple[int, list[dict]]] = [] + for idx, batch in enumerate(fail_batches): + path = os.path.join(patches_dir, f"minibatch_fail_{idx:03d}.json") + if os.path.exists(path): + with open(path) as f: + raw_patches.append(json.load(f)) + else: + pending_fail.append((idx, batch)) + + pending_succ: list[tuple[int, list[dict]]] = [] + for idx, batch in enumerate(succ_batches): + path = os.path.join(patches_dir, f"minibatch_succ_{idx:03d}.json") + if os.path.exists(path): + with open(path) as f: + raw_patches.append(json.load(f)) + else: + pending_succ.append((idx, batch)) + + # ── Worker functions ────────────────────────────────────────────────── + def _do_fail(idx: int, batch: list[dict]) -> tuple[str, dict | None]: + patch = run_error_analyst_minibatch( + skill_content, batch, prediction_dir, + edit_budget=edit_budget, + system_prompt=error_system, + step_buffer_context=step_buffer_context, + # backward compat fallback + rejection_context=rejection_context, + trajectory_memory_context=trajectory_memory_context, + meta_skill_context=meta_skill_context, + update_mode=update_mode, + ) + return f"minibatch_fail_{idx:03d}", patch + + def _do_succ(idx: int, batch: list[dict]) -> tuple[str, dict | None]: + patch = run_success_analyst_minibatch( + skill_content, batch, prediction_dir, + edit_budget=edit_budget, + system_prompt=success_system, + step_buffer_context=step_buffer_context, + trajectory_memory_context=trajectory_memory_context, + meta_skill_context=meta_skill_context, + update_mode=update_mode, + ) + return f"minibatch_succ_{idx:03d}", patch + + # Run all pending minibatches in parallel + all_pending = ( + [("fail", idx, batch) for idx, batch in pending_fail] + + [("succ", idx, batch) for idx, batch in pending_succ] + ) + + with ThreadPoolExecutor(max_workers=workers) as ex: + futs = {} + for kind, idx, batch in all_pending: + if kind == "fail": + futs[ex.submit(_do_fail, idx, batch)] = (kind, idx, len(batch)) + else: + futs[ex.submit(_do_succ, idx, batch)] = (kind, idx, len(batch)) + + for i, fut in enumerate(as_completed(futs), 1): + kind, idx, batch_len = futs[fut] + tag, patch = fut.result() + if patch: + path = os.path.join(patches_dir, f"{tag}.json") + with open(path, "w") as f: + json.dump(patch, f, ensure_ascii=False, indent=2) + raw_patches.append(patch) + n_edits = len(get_payload_items(patch.get("patch", {}) if patch else {}, update_mode)) + print( + f" [analyst] {i}/{len(all_pending)} {tag} " + f"({batch_len} trajs) → {n_edits} {payload_label(update_mode)}" + ) + + return raw_patches diff --git a/skillopt/model/__init__.py b/skillopt/model/__init__.py new file mode 100644 index 0000000..cca4b25 --- /dev/null +++ b/skillopt/model/__init__.py @@ -0,0 +1,343 @@ +"""ReflACT model API with runtime backend selection for the student path.""" + +from __future__ import annotations + +from typing import Any + +from skillopt.model import azure_openai as _openai +from skillopt.model import claude_backend as _claude +from skillopt.model.backend_config import ( # noqa: F401 + configure_claude_code_exec, + configure_codex_exec, + get_claude_code_exec_config, + get_codex_exec_config, + get_student_backend, + get_teacher_backend, + is_student_chat_backend, + is_student_exec_backend, + is_teacher_chat_backend, + set_student_backend, + set_teacher_backend, +) + + +def set_backend(name: str | None) -> str: + """Backward-compatible global backend setter. + + Historically the codebase used one shared backend for both teacher and + student. Keep that entry point so older scripts continue to work, while + mapping it onto the split teacher/student backend model. + """ + normalized = str(name or "azure_openai").strip().lower() + if normalized in {"azure_openai", "openai_chat", "azure", "azure-openai"}: + set_teacher_backend("openai_chat") + set_student_backend("openai_chat") + return "azure_openai" + if normalized in {"claude", "claude_chat", "anthropic"}: + set_teacher_backend("claude_chat") + set_student_backend("claude_chat") + return "claude_chat" + if normalized == "codex": + set_teacher_backend("openai_chat") + set_student_backend("codex_exec") + return "codex" + if normalized in {"codex_exec", "claude_code_exec"}: + set_teacher_backend("openai_chat") + set_student_backend(normalized) + return normalized + raise ValueError(f"Unsupported legacy backend: {name!r}") + + +def get_backend_name() -> str: + """Best-effort backward-compatible backend summary.""" + teacher = get_teacher_backend() + student = get_student_backend() + if teacher == "claude_chat" and student == "claude_chat": + return "claude_chat" + if teacher == "openai_chat" and student == "openai_chat": + return "azure_openai" + if teacher == "openai_chat" and student == "codex_exec": + return "codex" + return f"{teacher}+{student}" + + +def chat_teacher( + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "teacher", + reasoning_effort: str | None = None, + timeout: int | None = None, +) -> tuple[str, dict]: + if get_teacher_backend() == "claude_chat": + return _claude.chat_teacher( + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + timeout=timeout, + ) + return _openai.chat_teacher( + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + reasoning_effort=reasoning_effort, + timeout=timeout, + ) + + +def chat_student( + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "student", + reasoning_effort: str | None = None, + timeout: int | None = None, +) -> tuple[str, dict]: + if get_student_backend() == "claude_chat": + return _claude.chat_student( + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + timeout=timeout, + ) + if not is_student_chat_backend(): + raise NotImplementedError( + "chat_student is only supported with student_backend=openai_chat or claude_chat. " + "Exec backends are handled in environment-specific rollout code." + ) + return _openai.chat_student( + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + reasoning_effort=reasoning_effort, + timeout=timeout, + ) + + +def chat_teacher_messages( + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "teacher", + reasoning_effort: str | None = None, + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict]: + if get_teacher_backend() == "claude_chat": + return _claude.chat_teacher_messages( + messages=messages, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + return _openai.chat_teacher_messages( + messages=messages, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + reasoning_effort=reasoning_effort, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def chat_student_messages( + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "student", + reasoning_effort: str | None = None, + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict]: + if get_student_backend() == "claude_chat": + return _claude.chat_student_messages( + messages=messages, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + if not is_student_chat_backend(): + raise NotImplementedError( + "chat_student_messages is only supported with student_backend=openai_chat or claude_chat. " + "Exec backends are handled in environment-specific rollout code." + ) + return _openai.chat_student_messages( + messages=messages, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + reasoning_effort=reasoning_effort, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def chat_messages_with_deployment( + deployment: str, + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "custom", + reasoning_effort: str | None = None, + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict]: + return _openai.chat_messages_with_deployment( + deployment=deployment, + messages=messages, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + reasoning_effort=reasoning_effort, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def chat_with_deployment( + deployment: str, + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "custom", + reasoning_effort: str | None = None, + timeout: int | None = None, +) -> tuple[str, dict]: + return _openai.chat_with_deployment( + deployment=deployment, + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + reasoning_effort=reasoning_effort, + timeout=timeout, + ) + + +def get_token_summary() -> dict: + summary = _openai.get_token_summary() + claude_summary = _claude.get_token_summary() + for stage, values in claude_summary.items(): + if stage == "_total": + continue + if stage not in summary: + summary[stage] = values + continue + summary[stage]["calls"] += values["calls"] + summary[stage]["prompt_tokens"] += values["prompt_tokens"] + summary[stage]["completion_tokens"] += values["completion_tokens"] + summary[stage]["total_tokens"] += values["total_tokens"] + total = { + "calls": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + for stage, values in summary.items(): + if stage == "_total": + continue + total["calls"] += values["calls"] + total["prompt_tokens"] += values["prompt_tokens"] + total["completion_tokens"] += values["completion_tokens"] + total["total_tokens"] += values["total_tokens"] + summary["_total"] = total + return summary + + +def reset_token_tracker() -> None: + _openai.reset_token_tracker() + _claude.reset_token_tracker() + + +def configure_azure_openai( + *, + endpoint: str | None = None, + api_version: str | None = None, + api_key: str | None = None, + auth_mode: str | None = None, + ad_scope: str | None = None, + managed_identity_client_id: str | None = None, + teacher_endpoint: str | None = None, + teacher_api_version: str | None = None, + teacher_api_key: str | None = None, + teacher_auth_mode: str | None = None, + teacher_ad_scope: str | None = None, + teacher_managed_identity_client_id: str | None = None, + student_endpoint: str | None = None, + student_api_version: str | None = None, + student_api_key: str | None = None, + student_auth_mode: str | None = None, + student_ad_scope: str | None = None, + student_managed_identity_client_id: str | None = None, +) -> None: + _openai.configure_azure_openai( + endpoint=endpoint, + api_version=api_version, + api_key=api_key, + auth_mode=auth_mode, + ad_scope=ad_scope, + managed_identity_client_id=managed_identity_client_id, + teacher_endpoint=teacher_endpoint, + teacher_api_version=teacher_api_version, + teacher_api_key=teacher_api_key, + teacher_auth_mode=teacher_auth_mode, + teacher_ad_scope=teacher_ad_scope, + teacher_managed_identity_client_id=teacher_managed_identity_client_id, + student_endpoint=student_endpoint, + student_api_version=student_api_version, + student_api_key=student_api_key, + student_auth_mode=student_auth_mode, + student_ad_scope=student_ad_scope, + student_managed_identity_client_id=student_managed_identity_client_id, + ) + + +def set_reasoning_effort(effort: str | None) -> None: + _openai.set_reasoning_effort(effort) + _claude.set_reasoning_effort(effort) + + +def set_student_deployment(deployment: str) -> None: + _openai.set_student_deployment(deployment) + _claude.set_student_deployment(deployment) + + +def set_teacher_deployment(deployment: str) -> None: + _openai.set_teacher_deployment(deployment) + _claude.set_teacher_deployment(deployment) diff --git a/skillopt/model/azure_openai.py b/skillopt/model/azure_openai.py new file mode 100644 index 0000000..763c35f --- /dev/null +++ b/skillopt/model/azure_openai.py @@ -0,0 +1,871 @@ +"""ReflACT Model backend — Azure OpenAI wrapper with token tracking. + +Provides teacher/student dual-deployment chat functions and a global +TokenTracker for per-stage cost accounting. Previously llm/azure_openai.py. +""" +from __future__ import annotations + +import json +import os +import subprocess +import threading +import time +from types import SimpleNamespace +from typing import Any +from openai import AzureOpenAI + +# ── Configuration ───────────────────────────────────────────────────────────── + +ENDPOINT = os.environ.get( + "AZURE_OPENAI_ENDPOINT", + "https://t2vgoaigpt4o3.openai.azure.com/", +) +API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION", "2024-12-01-preview") +API_KEY = os.environ.get( + "AZURE_OPENAI_API_KEY", + "", +) +AUTH_MODE = os.environ.get("AZURE_OPENAI_AUTH_MODE", "azure_cli").strip().lower() +AD_SCOPE = os.environ.get( + "AZURE_OPENAI_AD_SCOPE", + "https://cognitiveservices.azure.com/.default", +) +MANAGED_IDENTITY_CLIENT_ID = os.environ.get( + "AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID", + "", +).strip() + +TEACHER_ENDPOINT = ( + os.environ.get("TEACHER_AZURE_OPENAI_ENDPOINT") + or os.environ.get("AZURE_OPENAI_TEACHER_ENDPOINT") + or ENDPOINT +) +STUDENT_ENDPOINT = ( + os.environ.get("STUDENT_AZURE_OPENAI_ENDPOINT") + or os.environ.get("AZURE_OPENAI_STUDENT_ENDPOINT") + or ENDPOINT +) +TEACHER_API_VERSION = ( + os.environ.get("TEACHER_AZURE_OPENAI_API_VERSION") + or os.environ.get("AZURE_OPENAI_TEACHER_API_VERSION") + or API_VERSION +) +STUDENT_API_VERSION = ( + os.environ.get("STUDENT_AZURE_OPENAI_API_VERSION") + or os.environ.get("AZURE_OPENAI_STUDENT_API_VERSION") + or API_VERSION +) +TEACHER_API_KEY = ( + os.environ.get("TEACHER_AZURE_OPENAI_API_KEY") + or os.environ.get("AZURE_OPENAI_TEACHER_API_KEY") + or API_KEY +) +STUDENT_API_KEY = ( + os.environ.get("STUDENT_AZURE_OPENAI_API_KEY") + or os.environ.get("AZURE_OPENAI_STUDENT_API_KEY") + or API_KEY +) +TEACHER_AUTH_MODE = ( + os.environ.get("TEACHER_AZURE_OPENAI_AUTH_MODE") + or os.environ.get("AZURE_OPENAI_TEACHER_AUTH_MODE") + or AUTH_MODE +).strip().lower() +STUDENT_AUTH_MODE = ( + os.environ.get("STUDENT_AZURE_OPENAI_AUTH_MODE") + or os.environ.get("AZURE_OPENAI_STUDENT_AUTH_MODE") + or AUTH_MODE +).strip().lower() +TEACHER_AD_SCOPE = ( + os.environ.get("TEACHER_AZURE_OPENAI_AD_SCOPE") + or os.environ.get("AZURE_OPENAI_TEACHER_AD_SCOPE") + or AD_SCOPE +) +STUDENT_AD_SCOPE = ( + os.environ.get("STUDENT_AZURE_OPENAI_AD_SCOPE") + or os.environ.get("AZURE_OPENAI_STUDENT_AD_SCOPE") + or AD_SCOPE +) +TEACHER_MANAGED_IDENTITY_CLIENT_ID = ( + os.environ.get("TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID") + or os.environ.get("AZURE_OPENAI_TEACHER_MANAGED_IDENTITY_CLIENT_ID") + or MANAGED_IDENTITY_CLIENT_ID +).strip() +STUDENT_MANAGED_IDENTITY_CLIENT_ID = ( + os.environ.get("STUDENT_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID") + or os.environ.get("AZURE_OPENAI_STUDENT_MANAGED_IDENTITY_CLIENT_ID") + or MANAGED_IDENTITY_CLIENT_ID +).strip() + +TEACHER_DEPLOYMENT = os.environ.get("TEACHER_DEPLOYMENT", "gpt-5.5") +STUDENT_DEPLOYMENT = os.environ.get("STUDENT_DEPLOYMENT", "gpt-5.5") + +REASONING_EFFORT: str | None = None + +_AZ_CLI_TOKEN_CACHE: dict[str, dict[str, Any]] = {} + +# Deployments that require Responses API +_RESPONSES_API_MODELS = { + "gpt-5.3-codex", "gpt-5.1-codex", "gpt-5.2-codex", + "gpt-5-codex", "codex-mini", "gpt-5.4-pro", +} + + +# ── Token Tracker ───────────────────────────────────────────────────────────── + +class TokenTracker: + """Thread-safe per-stage token counter.""" + + def __init__(self) -> None: + self._lock = threading.Lock() + self._data: dict[str, dict] = {} + + def record( + self, stage: str, prompt_tokens: int, completion_tokens: int, + ) -> None: + with self._lock: + if stage not in self._data: + self._data[stage] = { + "calls": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + } + d = self._data[stage] + d["calls"] += 1 + d["prompt_tokens"] += prompt_tokens + d["completion_tokens"] += completion_tokens + + def summary(self) -> dict: + with self._lock: + out: dict = {} + total_p = total_c = total_calls = 0 + for stage, d in sorted(self._data.items()): + out[stage] = { + "calls": d["calls"], + "prompt_tokens": d["prompt_tokens"], + "completion_tokens": d["completion_tokens"], + "total_tokens": d["prompt_tokens"] + d["completion_tokens"], + } + total_p += d["prompt_tokens"] + total_c += d["completion_tokens"] + total_calls += d["calls"] + out["_total"] = { + "calls": total_calls, + "prompt_tokens": total_p, + "completion_tokens": total_c, + "total_tokens": total_p + total_c, + } + return out + + def reset(self) -> None: + with self._lock: + self._data.clear() + + def stage_snapshot(self, stage: str) -> dict: + """Return a copy of one stage's counters (or zeros if not tracked yet).""" + with self._lock: + d = self._data.get(stage, {}) + return { + "calls": d.get("calls", 0), + "prompt_tokens": d.get("prompt_tokens", 0), + "completion_tokens": d.get("completion_tokens", 0), + "total_tokens": d.get("prompt_tokens", 0) + d.get("completion_tokens", 0), + } + + +tracker = TokenTracker() + + +# ── Client management ───────────────────────────────────────────────────────── + +_teacher_client: AzureOpenAI | None = None +_student_client: AzureOpenAI | None = None +_teacher_lock = threading.Lock() +_student_lock = threading.Lock() + + +def _role_config(role: str) -> dict[str, str]: + if role == "teacher": + return { + "endpoint": TEACHER_ENDPOINT, + "api_version": TEACHER_API_VERSION, + "api_key": TEACHER_API_KEY, + "auth_mode": TEACHER_AUTH_MODE, + "ad_scope": TEACHER_AD_SCOPE, + "managed_identity_client_id": TEACHER_MANAGED_IDENTITY_CLIENT_ID, + } + if role == "student": + return { + "endpoint": STUDENT_ENDPOINT, + "api_version": STUDENT_API_VERSION, + "api_key": STUDENT_API_KEY, + "auth_mode": STUDENT_AUTH_MODE, + "ad_scope": STUDENT_AD_SCOPE, + "managed_identity_client_id": STUDENT_MANAGED_IDENTITY_CLIENT_ID, + } + raise ValueError(f"Unknown Azure OpenAI client role: {role!r}") + + +def _make_token_provider( + auth_mode: str, + ad_scope: str, + managed_identity_client_id: str, +): + try: + from azure.identity import ( # type: ignore[import-not-found] + AzureCliCredential, + ManagedIdentityCredential, + get_bearer_token_provider, + ) + except ImportError as e: + if auth_mode == "azure_cli": + return _make_azure_cli_token_provider(ad_scope) + raise ImportError( + "Azure AD auth requires azure-identity. Install it with `pip install azure-identity`." + ) from e + + if auth_mode in {"managed_identity", "aad", "azure_ad"}: + if managed_identity_client_id: + credential = ManagedIdentityCredential(client_id=managed_identity_client_id) + else: + credential = ManagedIdentityCredential() + elif auth_mode == "azure_cli": + credential = AzureCliCredential() + else: + raise ValueError( + "Unsupported Azure OpenAI auth mode " + f"{auth_mode!r}; expected api_key, managed_identity, azure_ad, aad, or azure_cli." + ) + return get_bearer_token_provider(credential, ad_scope) + + +def _make_azure_cli_token_provider(ad_scope: str): + """Return an Azure CLI token provider compatible with AzureOpenAI. + + This fallback avoids requiring azure-identity in environments where `az` + is already logged in. The SDK calls this provider whenever it needs a + bearer token. + """ + + resource = ad_scope.removesuffix("/.default") + + def _provider() -> str: + now = int(time.time()) + cache = _AZ_CLI_TOKEN_CACHE.setdefault(resource, {"token": "", "expires_on": 0}) + cached = str(cache.get("token") or "") + expires_on = int(cache.get("expires_on") or 0) + if cached and expires_on - now > 300: + return cached + + raw = subprocess.check_output( + [ + "az", + "account", + "get-access-token", + "--resource", + resource, + "-o", + "json", + ], + text=True, + stderr=subprocess.STDOUT, + ) + payload = json.loads(raw) + token = str(payload["accessToken"]) + cache["token"] = token + cache["expires_on"] = int(payload.get("expires_on") or now + 3000) + return token + + return _provider + + +def _make_client(role: str) -> AzureOpenAI: + cfg = _role_config(role) + auth_mode = cfg["auth_mode"] + if auth_mode in {"api_key", "key"}: + if not cfg["api_key"]: + raise ValueError( + f"Azure OpenAI API key is not configured for {role}. " + "Set model.azure_openai_api_key in the config or export AZURE_OPENAI_API_KEY." + ) + return AzureOpenAI( + api_version=cfg["api_version"], + azure_endpoint=cfg["endpoint"], + api_key=cfg["api_key"], + ) + return AzureOpenAI( + api_version=cfg["api_version"], + azure_endpoint=cfg["endpoint"], + azure_ad_token_provider=_make_token_provider( + auth_mode, + cfg["ad_scope"], + cfg["managed_identity_client_id"], + ), + ) + + +def get_teacher_client() -> AzureOpenAI: + global _teacher_client + with _teacher_lock: + if _teacher_client is None: + _teacher_client = _make_client("teacher") + return _teacher_client + + +def get_student_client() -> AzureOpenAI: + global _student_client + with _student_lock: + if _student_client is None: + _student_client = _make_client("student") + return _student_client + + +def _needs_responses_api(deployment: str) -> bool: + dep = deployment.lower() + return any(dep == m or dep.startswith(m + "-") for m in _RESPONSES_API_MODELS) + + +# ── Core chat function ──────────────────────────────────────────────────────── + +def _chat_impl( + client: AzureOpenAI, + deployment: str, + system: str, + user: str, + max_completion_tokens: int, + retries: int, + stage: str, + reasoning_effort: str | None = None, + timeout: int | None = None, +) -> tuple[str, dict]: + """Call LLM, track tokens, return (text, usage_dict).""" + last_err = None + usage_info = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + + for attempt in range(retries): + try: + if _needs_responses_api(deployment): + kwargs: dict[str, Any] = { + "model": deployment, + "instructions": system, + "input": [{"role": "user", "content": user}], + "max_output_tokens": max_completion_tokens, + } + actual_effort = reasoning_effort or REASONING_EFFORT + if actual_effort: + kwargs["reasoning"] = {"effort": actual_effort} + if timeout is not None: + kwargs["timeout"] = timeout + resp = client.responses.create(**kwargs) + text = getattr(resp, "output_text", None) or "" + if not text: + for item in getattr(resp, "output", None) or []: + for part in getattr(item, "content", []): + if getattr(part, "type", "") == "output_text": + text = part.text or "" + if hasattr(resp, "usage") and resp.usage: + usage_info = { + "prompt_tokens": getattr(resp.usage, "input_tokens", 0) or 0, + "completion_tokens": getattr(resp.usage, "output_tokens", 0) or 0, + "total_tokens": ( + (getattr(resp.usage, "input_tokens", 0) or 0) + + (getattr(resp.usage, "output_tokens", 0) or 0) + ), + } + else: + kwargs: dict[str, Any] = dict( + model=deployment, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + max_completion_tokens=max_completion_tokens, + ) + actual_effort = reasoning_effort or REASONING_EFFORT + if actual_effort is not None: + kwargs["reasoning_effort"] = actual_effort + if timeout is not None: + kwargs["timeout"] = timeout + resp = client.chat.completions.create(**kwargs) + text = resp.choices[0].message.content or "" + if resp.usage: + usage_info = { + "prompt_tokens": resp.usage.prompt_tokens or 0, + "completion_tokens": resp.usage.completion_tokens or 0, + "total_tokens": resp.usage.total_tokens or 0, + } + + tracker.record( + stage, + usage_info["prompt_tokens"], + usage_info["completion_tokens"], + ) + return text, usage_info + + except Exception as e: # noqa: BLE001 + last_err = e + sleep = min(2 ** attempt, 30) + time.sleep(sleep) + + raise RuntimeError(f"LLM call failed after {retries} retries: {last_err}") + + +def _chat_messages_impl( + client: AzureOpenAI, + deployment: str, + messages: list[dict[str, Any]], + max_completion_tokens: int, + retries: int, + stage: str, + reasoning_effort: str | None = None, + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict]: + """Call the model with a pre-built message list.""" + last_err = None + usage_info = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + + for attempt in range(retries): + try: + if _needs_responses_api(deployment): + input_items, instructions = _messages_to_responses_input(messages) + kwargs: dict[str, Any] = { + "model": deployment, + "input": input_items, + "max_output_tokens": max_completion_tokens, + } + if instructions: + kwargs["instructions"] = instructions + actual_effort = reasoning_effort or REASONING_EFFORT + if actual_effort: + kwargs["reasoning"] = {"effort": actual_effort} + if tools: + kwargs["tools"] = [_chat_tool_to_responses_tool(tool) for tool in tools] + if tool_choice is not None: + kwargs["tool_choice"] = tool_choice + if timeout is not None: + kwargs["timeout"] = timeout + resp = client.responses.create(**kwargs) + message, text = _responses_to_chat_message(resp) + if hasattr(resp, "usage") and resp.usage: + usage_info = { + "prompt_tokens": getattr(resp.usage, "input_tokens", 0) or 0, + "completion_tokens": getattr(resp.usage, "output_tokens", 0) or 0, + "total_tokens": ( + (getattr(resp.usage, "input_tokens", 0) or 0) + + (getattr(resp.usage, "output_tokens", 0) or 0) + ), + } + else: + kwargs = dict( + model=deployment, + messages=messages, + max_completion_tokens=max_completion_tokens, + ) + actual_effort = reasoning_effort or REASONING_EFFORT + if actual_effort is not None: + kwargs["reasoning_effort"] = actual_effort + if tools: + kwargs["tools"] = tools + if tool_choice is not None: + kwargs["tool_choice"] = tool_choice + if timeout is not None: + kwargs["timeout"] = timeout + resp = client.chat.completions.create(**kwargs) + message = resp.choices[0].message + text = message.content or "" + if resp.usage: + usage_info = { + "prompt_tokens": resp.usage.prompt_tokens or 0, + "completion_tokens": resp.usage.completion_tokens or 0, + "total_tokens": resp.usage.total_tokens or 0, + } + tracker.record( + stage, + usage_info["prompt_tokens"], + usage_info["completion_tokens"], + ) + return (message if return_message else text), usage_info + except Exception as e: # noqa: BLE001 + last_err = e + sleep = min(2 ** attempt, 30) + time.sleep(sleep) + + raise RuntimeError(f"LLM message call failed after {retries} retries: {last_err}") + + +def _chat_tool_to_responses_tool(tool: dict[str, Any]) -> dict[str, Any]: + """Convert a Chat Completions function tool to Responses API format.""" + if tool.get("type") == "function" and isinstance(tool.get("function"), dict): + fn = tool["function"] + return { + "type": "function", + "name": fn.get("name", ""), + "description": fn.get("description", ""), + "parameters": fn.get("parameters", {"type": "object", "properties": {}}), + } + return tool + + +def _messages_to_responses_input(messages: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], str]: + """Convert chat-style messages, including tool results, to Responses input.""" + instructions: list[str] = [] + input_items: list[dict[str, Any]] = [] + for message in messages: + role = message.get("role") + content = message.get("content") or "" + if role == "system": + if content: + instructions.append(str(content)) + continue + if role == "tool": + input_items.append({ + "type": "function_call_output", + "call_id": str(message.get("tool_call_id", "")), + "output": str(content), + }) + continue + if role == "assistant": + if content: + input_items.append({"role": "assistant", "content": str(content)}) + for tool_call in message.get("tool_calls") or []: + function = tool_call.get("function", {}) or {} + input_items.append({ + "type": "function_call", + "call_id": str(tool_call.get("id", "")), + "name": str(function.get("name", "")), + "arguments": str(function.get("arguments", "{}") or "{}"), + }) + continue + if role in {"user", "developer"}: + input_items.append({"role": "user", "content": str(content)}) + return input_items, "\n\n".join(instructions) + + +def _responses_to_chat_message(resp: Any) -> tuple[Any, str]: + """Convert Responses output into the subset of Chat message API we use.""" + text = getattr(resp, "output_text", None) or "" + tool_calls: list[dict[str, Any]] = [] + for item in getattr(resp, "output", None) or []: + item_type = getattr(item, "type", "") + if item_type == "function_call": + tool_calls.append({ + "id": getattr(item, "call_id", "") or getattr(item, "id", ""), + "type": "function", + "function": { + "name": getattr(item, "name", ""), + "arguments": getattr(item, "arguments", "") or "{}", + }, + }) + elif item_type == "message" and not text: + content_parts = getattr(item, "content", []) or [] + for part in content_parts: + if getattr(part, "type", "") == "output_text": + text += getattr(part, "text", "") or "" + return SimpleNamespace(content=text, tool_calls=tool_calls), text + + +# ── Public API ──────────────────────────────────────────────────────────────── + +def configure_azure_openai( + *, + endpoint: str | None = None, + api_version: str | None = None, + api_key: str | None = None, + auth_mode: str | None = None, + ad_scope: str | None = None, + managed_identity_client_id: str | None = None, + teacher_endpoint: str | None = None, + teacher_api_version: str | None = None, + teacher_api_key: str | None = None, + teacher_auth_mode: str | None = None, + teacher_ad_scope: str | None = None, + teacher_managed_identity_client_id: str | None = None, + student_endpoint: str | None = None, + student_api_version: str | None = None, + student_api_key: str | None = None, + student_auth_mode: str | None = None, + student_ad_scope: str | None = None, + student_managed_identity_client_id: str | None = None, +) -> None: + global ENDPOINT, API_VERSION, API_KEY, AUTH_MODE, AD_SCOPE, MANAGED_IDENTITY_CLIENT_ID + global TEACHER_ENDPOINT, TEACHER_API_VERSION, TEACHER_API_KEY, TEACHER_AUTH_MODE + global TEACHER_AD_SCOPE, TEACHER_MANAGED_IDENTITY_CLIENT_ID + global STUDENT_ENDPOINT, STUDENT_API_VERSION, STUDENT_API_KEY, STUDENT_AUTH_MODE + global STUDENT_AD_SCOPE, STUDENT_MANAGED_IDENTITY_CLIENT_ID + global _teacher_client, _student_client + + def _clean(value: str | None, *, lower: bool = False) -> str | None: + if value is None: + return None + str_value = str(value).strip() + if not str_value: + return None + if lower: + str_value = str_value.lower() + return str_value + + def _set(global_name: str, value: str | None, env_key: str) -> None: + if value is None: + return + globals()[global_name] = value + os.environ[env_key] = value + + shared_endpoint = _clean(endpoint) + shared_api_version = _clean(api_version) + shared_api_key = _clean(api_key) + shared_auth_mode = _clean(auth_mode, lower=True) + shared_ad_scope = _clean(ad_scope) + shared_managed_identity_client_id = _clean(managed_identity_client_id) + + _set("ENDPOINT", shared_endpoint, "AZURE_OPENAI_ENDPOINT") + _set("API_VERSION", shared_api_version, "AZURE_OPENAI_API_VERSION") + _set("API_KEY", shared_api_key, "AZURE_OPENAI_API_KEY") + _set("AUTH_MODE", shared_auth_mode, "AZURE_OPENAI_AUTH_MODE") + _set("AD_SCOPE", shared_ad_scope, "AZURE_OPENAI_AD_SCOPE") + _set( + "MANAGED_IDENTITY_CLIENT_ID", + shared_managed_identity_client_id, + "AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID", + ) + + resolved_teacher_endpoint = _clean(teacher_endpoint) or shared_endpoint + resolved_teacher_api_version = _clean(teacher_api_version) or shared_api_version + resolved_teacher_api_key = _clean(teacher_api_key) or shared_api_key + resolved_teacher_auth_mode = _clean(teacher_auth_mode, lower=True) or shared_auth_mode + resolved_teacher_ad_scope = _clean(teacher_ad_scope) or shared_ad_scope + resolved_teacher_mi = ( + _clean(teacher_managed_identity_client_id) + or shared_managed_identity_client_id + ) + resolved_student_endpoint = _clean(student_endpoint) or shared_endpoint + resolved_student_api_version = _clean(student_api_version) or shared_api_version + resolved_student_api_key = _clean(student_api_key) or shared_api_key + resolved_student_auth_mode = _clean(student_auth_mode, lower=True) or shared_auth_mode + resolved_student_ad_scope = _clean(student_ad_scope) or shared_ad_scope + resolved_student_mi = ( + _clean(student_managed_identity_client_id) + or shared_managed_identity_client_id + ) + + _set("TEACHER_ENDPOINT", resolved_teacher_endpoint, "TEACHER_AZURE_OPENAI_ENDPOINT") + _set( + "TEACHER_API_VERSION", + resolved_teacher_api_version, + "TEACHER_AZURE_OPENAI_API_VERSION", + ) + _set("TEACHER_API_KEY", resolved_teacher_api_key, "TEACHER_AZURE_OPENAI_API_KEY") + _set("TEACHER_AUTH_MODE", resolved_teacher_auth_mode, "TEACHER_AZURE_OPENAI_AUTH_MODE") + _set("TEACHER_AD_SCOPE", resolved_teacher_ad_scope, "TEACHER_AZURE_OPENAI_AD_SCOPE") + _set( + "TEACHER_MANAGED_IDENTITY_CLIENT_ID", + resolved_teacher_mi, + "TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID", + ) + _set("STUDENT_ENDPOINT", resolved_student_endpoint, "STUDENT_AZURE_OPENAI_ENDPOINT") + _set( + "STUDENT_API_VERSION", + resolved_student_api_version, + "STUDENT_AZURE_OPENAI_API_VERSION", + ) + _set("STUDENT_API_KEY", resolved_student_api_key, "STUDENT_AZURE_OPENAI_API_KEY") + _set("STUDENT_AUTH_MODE", resolved_student_auth_mode, "STUDENT_AZURE_OPENAI_AUTH_MODE") + _set("STUDENT_AD_SCOPE", resolved_student_ad_scope, "STUDENT_AZURE_OPENAI_AD_SCOPE") + _set( + "STUDENT_MANAGED_IDENTITY_CLIENT_ID", + resolved_student_mi, + "STUDENT_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID", + ) + + with _teacher_lock: + _teacher_client = None + with _student_lock: + _student_client = None + + +def chat_teacher( + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "teacher", + reasoning_effort: str | None = None, + timeout: int | None = None, +) -> tuple[str, dict]: + """Call the teacher model. Returns (response_text, usage_dict).""" + return _chat_impl( + get_teacher_client(), TEACHER_DEPLOYMENT, + system, user, max_completion_tokens, retries, stage, reasoning_effort, timeout, + ) + + +def chat_with_deployment( + deployment: str, + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "custom", + reasoning_effort: str | None = None, + timeout: int | None = None, +) -> tuple[str, dict]: + """Call an arbitrary deployment using the shared Azure client.""" + return _chat_impl( + get_teacher_client(), + deployment, + system, + user, + max_completion_tokens, + retries, + stage, + reasoning_effort, + timeout, + ) + + +def chat_student( + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "student", + reasoning_effort: str | None = None, + timeout: int | None = None, +) -> tuple[str, dict]: + """Call the student model. Returns (response_text, usage_dict).""" + return _chat_impl( + get_student_client(), STUDENT_DEPLOYMENT, + system, user, max_completion_tokens, retries, stage, reasoning_effort, timeout, + ) + + +def chat_teacher_messages( + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "teacher", + reasoning_effort: str | None = None, + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict]: + """Call the teacher model with a pre-built chat message list.""" + return _chat_messages_impl( + get_teacher_client(), + TEACHER_DEPLOYMENT, + messages, + max_completion_tokens, + retries, + stage, + reasoning_effort, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def chat_messages_with_deployment( + deployment: str, + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "custom", + reasoning_effort: str | None = None, + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict]: + """Call an arbitrary deployment with a pre-built chat message list.""" + return _chat_messages_impl( + get_teacher_client(), + deployment, + messages, + max_completion_tokens, + retries, + stage, + reasoning_effort, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def chat_student_messages( + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "student", + reasoning_effort: str | None = None, + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict]: + """Call the student model with a pre-built chat message list.""" + return _chat_messages_impl( + get_student_client(), + STUDENT_DEPLOYMENT, + messages, + max_completion_tokens, + retries, + stage, + reasoning_effort, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def get_token_summary() -> dict: + """Return per-stage and total token usage.""" + return tracker.summary() + + +def reset_token_tracker() -> None: + tracker.reset() + + +def set_student_deployment(deployment: str) -> None: + """Change student deployment at runtime.""" + global _student_client, STUDENT_DEPLOYMENT + STUDENT_DEPLOYMENT = deployment + os.environ["STUDENT_DEPLOYMENT"] = deployment + os.environ["AZURE_OPENAI_DEPLOYMENT"] = deployment + with _student_lock: + _student_client = None + try: + import llm_client as _legacy + _legacy.DEPLOYMENT = deployment + _legacy._client = None + except Exception: + pass + + +def set_reasoning_effort(effort: str | None) -> None: + """Set reasoning effort for all LLM calls. None = off.""" + global REASONING_EFFORT + REASONING_EFFORT = effort if effort else None + + +def get_reasoning_effort() -> str | None: + """Return the process-wide reasoning effort for direct Azure client users.""" + return REASONING_EFFORT + + +def set_teacher_deployment(deployment: str) -> None: + """Change teacher deployment at runtime.""" + global _teacher_client, TEACHER_DEPLOYMENT + TEACHER_DEPLOYMENT = deployment + os.environ["TEACHER_DEPLOYMENT"] = deployment + with _teacher_lock: + _teacher_client = None diff --git a/skillopt/model/backend_config.py b/skillopt/model/backend_config.py new file mode 100644 index 0000000..4d87386 --- /dev/null +++ b/skillopt/model/backend_config.py @@ -0,0 +1,185 @@ +"""Runtime backend configuration for teacher/student model calls.""" +from __future__ import annotations + +import os + +from skillopt.model.common import default_model_for_backend, normalize_backend_name + + +def _parse_bool(value: str | None, default: bool) -> bool: + if value is None: + return default + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + +TEACHER_BACKEND = normalize_backend_name(os.environ.get("TEACHER_BACKEND", "openai_chat")) +STUDENT_BACKEND = normalize_backend_name(os.environ.get("STUDENT_BACKEND", "openai_chat")) + +CODEX_EXEC_PATH = os.environ.get("CODEX_EXEC_PATH", "codex") +CODEX_EXEC_SANDBOX = os.environ.get("CODEX_EXEC_SANDBOX", "workspace-write") +CODEX_EXEC_PROFILE = os.environ.get("CODEX_EXEC_PROFILE", "") +CODEX_EXEC_FULL_AUTO = _parse_bool(os.environ.get("CODEX_EXEC_FULL_AUTO"), True) +CODEX_EXEC_REASONING_EFFORT = os.environ.get("CODEX_EXEC_REASONING_EFFORT", "none") +CODEX_EXEC_USE_SDK = os.environ.get("CODEX_EXEC_USE_SDK", "auto") +CODEX_EXEC_NETWORK_ACCESS = _parse_bool(os.environ.get("CODEX_EXEC_NETWORK_ACCESS"), False) +CODEX_EXEC_WEB_SEARCH = _parse_bool(os.environ.get("CODEX_EXEC_WEB_SEARCH"), False) +CODEX_EXEC_APPROVAL_POLICY = os.environ.get("CODEX_EXEC_APPROVAL_POLICY", "never") +CLAUDE_CODE_EXEC_PATH = os.environ.get("CLAUDE_CODE_EXEC_PATH", "claude") +CLAUDE_CODE_EXEC_PROFILE = os.environ.get("CLAUDE_CODE_EXEC_PROFILE", "") +CLAUDE_CODE_EXEC_USE_SDK = os.environ.get("CLAUDE_CODE_EXEC_USE_SDK", "auto") +CLAUDE_CODE_EXEC_EFFORT = os.environ.get("CLAUDE_CODE_EXEC_EFFORT", "medium") + + +def _parse_int(value: str | None, default: int) -> int: + if value is None: + return default + try: + return int(str(value).strip()) + except ValueError: + return default + + +EXEC_EMPTY_RESPONSE_RETRIES = max(0, _parse_int(os.environ.get("EXEC_EMPTY_RESPONSE_RETRIES"), 1)) +CLAUDE_CODE_EXEC_MAX_THINKING_TOKENS = max( + 0, + _parse_int(os.environ.get("CLAUDE_CODE_EXEC_MAX_THINKING_TOKENS"), 16384), +) + + +def set_teacher_backend(backend: str) -> None: + global TEACHER_BACKEND + TEACHER_BACKEND = normalize_backend_name(backend or "openai_chat") + if TEACHER_BACKEND not in {"openai_chat", "claude_chat"}: + raise ValueError( + f"Unsupported teacher backend: {TEACHER_BACKEND!r}. " + "Supported values are 'openai_chat' and 'claude_chat'." + ) + os.environ["TEACHER_BACKEND"] = TEACHER_BACKEND + + +def get_teacher_backend() -> str: + return TEACHER_BACKEND + + +def set_student_backend(backend: str) -> None: + global STUDENT_BACKEND + STUDENT_BACKEND = normalize_backend_name(backend or "openai_chat") + if STUDENT_BACKEND not in {"openai_chat", "claude_chat", "codex_exec", "claude_code_exec"}: + raise ValueError( + f"Unsupported student backend: {STUDENT_BACKEND!r}. " + "Supported values are 'openai_chat', 'claude_chat', 'codex_exec', and 'claude_code_exec'." + ) + os.environ["STUDENT_BACKEND"] = STUDENT_BACKEND + + +def get_student_backend() -> str: + return STUDENT_BACKEND + + +def is_student_exec_backend() -> bool: + return STUDENT_BACKEND in {"codex_exec", "claude_code_exec"} + + +def is_teacher_chat_backend() -> bool: + return TEACHER_BACKEND in {"openai_chat", "claude_chat"} + + +def is_student_chat_backend() -> bool: + return STUDENT_BACKEND in {"openai_chat", "claude_chat"} + + +def configure_codex_exec( + *, + path: str | None = None, + sandbox: str | None = None, + profile: str | None = None, + full_auto: bool | None = None, + reasoning_effort: str | None = None, + use_sdk: str | None = None, + network_access: bool | None = None, + web_search: bool | None = None, + approval_policy: str | None = None, +) -> None: + global CODEX_EXEC_PATH, CODEX_EXEC_SANDBOX, CODEX_EXEC_PROFILE, CODEX_EXEC_FULL_AUTO, CODEX_EXEC_REASONING_EFFORT, CODEX_EXEC_USE_SDK, CODEX_EXEC_NETWORK_ACCESS, CODEX_EXEC_WEB_SEARCH, CODEX_EXEC_APPROVAL_POLICY + if path is not None: + CODEX_EXEC_PATH = str(path).strip() or "codex" + os.environ["CODEX_EXEC_PATH"] = CODEX_EXEC_PATH + if sandbox is not None: + CODEX_EXEC_SANDBOX = str(sandbox).strip() or "workspace-write" + os.environ["CODEX_EXEC_SANDBOX"] = CODEX_EXEC_SANDBOX + if profile is not None: + CODEX_EXEC_PROFILE = str(profile).strip() + os.environ["CODEX_EXEC_PROFILE"] = CODEX_EXEC_PROFILE + if full_auto is not None: + CODEX_EXEC_FULL_AUTO = bool(full_auto) + os.environ["CODEX_EXEC_FULL_AUTO"] = "true" if CODEX_EXEC_FULL_AUTO else "false" + if reasoning_effort is not None: + CODEX_EXEC_REASONING_EFFORT = str(reasoning_effort).strip() or "none" + os.environ["CODEX_EXEC_REASONING_EFFORT"] = CODEX_EXEC_REASONING_EFFORT + if use_sdk is not None: + CODEX_EXEC_USE_SDK = str(use_sdk).strip().lower() or "auto" + os.environ["CODEX_EXEC_USE_SDK"] = CODEX_EXEC_USE_SDK + if network_access is not None: + CODEX_EXEC_NETWORK_ACCESS = bool(network_access) + os.environ["CODEX_EXEC_NETWORK_ACCESS"] = "true" if CODEX_EXEC_NETWORK_ACCESS else "false" + if web_search is not None: + CODEX_EXEC_WEB_SEARCH = bool(web_search) + os.environ["CODEX_EXEC_WEB_SEARCH"] = "true" if CODEX_EXEC_WEB_SEARCH else "false" + if approval_policy is not None: + CODEX_EXEC_APPROVAL_POLICY = str(approval_policy).strip() or "never" + os.environ["CODEX_EXEC_APPROVAL_POLICY"] = CODEX_EXEC_APPROVAL_POLICY + + +def get_codex_exec_config() -> dict[str, str | bool | int]: + return { + "path": CODEX_EXEC_PATH, + "sandbox": CODEX_EXEC_SANDBOX, + "profile": CODEX_EXEC_PROFILE, + "full_auto": CODEX_EXEC_FULL_AUTO, + "reasoning_effort": CODEX_EXEC_REASONING_EFFORT, + "use_sdk": CODEX_EXEC_USE_SDK, + "network_access": CODEX_EXEC_NETWORK_ACCESS, + "web_search": CODEX_EXEC_WEB_SEARCH, + "approval_policy": CODEX_EXEC_APPROVAL_POLICY, + "empty_response_retries": EXEC_EMPTY_RESPONSE_RETRIES, + } + + +def configure_claude_code_exec( + *, + path: str | None = None, + profile: str | None = None, + use_sdk: str | None = None, + effort: str | None = None, + max_thinking_tokens: int | str | None = None, +) -> None: + global CLAUDE_CODE_EXEC_PATH, CLAUDE_CODE_EXEC_PROFILE, CLAUDE_CODE_EXEC_USE_SDK, CLAUDE_CODE_EXEC_EFFORT, CLAUDE_CODE_EXEC_MAX_THINKING_TOKENS + if path is not None: + CLAUDE_CODE_EXEC_PATH = str(path).strip() or "claude" + os.environ["CLAUDE_CODE_EXEC_PATH"] = CLAUDE_CODE_EXEC_PATH + if profile is not None: + CLAUDE_CODE_EXEC_PROFILE = str(profile).strip() + os.environ["CLAUDE_CODE_EXEC_PROFILE"] = CLAUDE_CODE_EXEC_PROFILE + if use_sdk is not None: + CLAUDE_CODE_EXEC_USE_SDK = str(use_sdk).strip().lower() or "auto" + os.environ["CLAUDE_CODE_EXEC_USE_SDK"] = CLAUDE_CODE_EXEC_USE_SDK + if effort is not None: + CLAUDE_CODE_EXEC_EFFORT = str(effort).strip().lower() or "medium" + os.environ["CLAUDE_CODE_EXEC_EFFORT"] = CLAUDE_CODE_EXEC_EFFORT + if max_thinking_tokens is not None: + CLAUDE_CODE_EXEC_MAX_THINKING_TOKENS = max( + 0, + _parse_int(str(max_thinking_tokens), 16384), + ) + os.environ["CLAUDE_CODE_EXEC_MAX_THINKING_TOKENS"] = str(CLAUDE_CODE_EXEC_MAX_THINKING_TOKENS) + + +def get_claude_code_exec_config() -> dict[str, str | int]: + return { + "path": CLAUDE_CODE_EXEC_PATH, + "profile": CLAUDE_CODE_EXEC_PROFILE, + "use_sdk": CLAUDE_CODE_EXEC_USE_SDK, + "effort": CLAUDE_CODE_EXEC_EFFORT, + "max_thinking_tokens": CLAUDE_CODE_EXEC_MAX_THINKING_TOKENS, + "empty_response_retries": EXEC_EMPTY_RESPONSE_RETRIES, + } diff --git a/skillopt/model/claude_backend.py b/skillopt/model/claude_backend.py new file mode 100644 index 0000000..a4ca3a1 --- /dev/null +++ b/skillopt/model/claude_backend.py @@ -0,0 +1,359 @@ +"""Claude CLI chat backend for ReflACT.""" +from __future__ import annotations + +import base64 +import json +import mimetypes +import os +import shutil +import subprocess +import tempfile +import time +from typing import Any +from urllib.parse import unquote, urlparse + +from skillopt.model.common import CompatAssistantMessage, CompatToolCall, CompatToolFunction, default_model_for_backend, tracker + +CLAUDE_BIN = os.environ.get("CLAUDE_CLI_BIN", "claude") +CLAUDE_PERMISSION_MODE = os.environ.get("CLAUDE_PERMISSION_MODE", "dontAsk") +CLAUDE_SETTING_SOURCES = os.environ.get("CLAUDE_SETTING_SOURCES", "user,project") +CLAUDE_ALLOW_ATTACHMENT_READ = os.environ.get("CLAUDE_ALLOW_ATTACHMENT_READ", "1").strip().lower() not in {"0", "false", "no"} + +TEACHER_DEPLOYMENT = os.environ.get("TEACHER_DEPLOYMENT", "claude-sonnet-4-6") +STUDENT_DEPLOYMENT = os.environ.get("STUDENT_DEPLOYMENT", "claude-sonnet-4-6") +REASONING_EFFORT: str | None = None +_VALID_EFFORTS = {"low", "medium", "high", "xhigh", "max"} + + +def _parse_data_uri(url: str) -> tuple[bytes, str]: + header, data = url.split(",", 1) + mime = header[5:].split(";", 1)[0] or "image/png" + return base64.b64decode(data), mime + + +def _content_to_text(content: Any, attachments: list[dict[str, Any]], *, image_counter: int) -> tuple[str, int]: + if isinstance(content, str): + return content, image_counter + if not isinstance(content, list): + return str(content), image_counter + parts: list[str] = [] + for item in content: + if not isinstance(item, dict): + continue + item_type = item.get("type") + if item_type == "text": + parts.append(str(item.get("text", ""))) + continue + if item_type != "image_url": + continue + image_counter += 1 + label = f"[Attached image {image_counter}]" + parts.append(label) + image_url = item.get("image_url", {}) or {} + url = str(image_url.get("url", "") or "") + if not url: + continue + if url.startswith("data:") and ";base64," in url: + data, mime = _parse_data_uri(url) + attachments.append({"bytes": data, "mime": mime, "label": label}) + continue + if url.startswith("file://"): + parsed = urlparse(url) + path = unquote(parsed.path) + if path: + attachments.append({"path": path, "label": label}) + continue + if os.path.exists(url): + attachments.append({"path": url, "label": label}) + return "".join(parts), image_counter + + +def _simplify_tool_schemas(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]]: + simplified: list[dict[str, Any]] = [] + for tool in tools or []: + function = tool.get("function", tool) + simplified.append({ + "name": function.get("name", ""), + "description": function.get("description", ""), + "parameters": function.get("parameters", {}), + }) + return simplified + + +def _build_prompt_from_messages(messages: list[dict[str, Any]], *, tools: list[dict[str, Any]] | None = None, tool_choice: str | dict[str, Any] | None = None, structured_output: bool = False) -> tuple[str, str, list[dict[str, Any]]]: + system_parts: list[str] = [] + history_parts: list[str] = [] + attachments: list[dict[str, Any]] = [] + image_counter = 0 + + def _history_line(label: str, body: str) -> str: + stripped = body.strip() + if not stripped: + return f"- {label}:" + indented = stripped.replace("\n", "\n ") + return f"- {label}: {indented}" + + for message in messages: + role = str(message.get("role", "user")) + text, image_counter = _content_to_text(message.get("content", ""), attachments, image_counter=image_counter) + if role == "system": + if text.strip(): + system_parts.append(text.strip()) + continue + if role == "assistant": + block = _history_line("Assistant", text) + tool_calls = message.get("tool_calls") or [] + if tool_calls: + simplified_calls = [] + for tool_call in tool_calls: + function = tool_call.get("function", {}) or {} + simplified_calls.append({ + "name": function.get("name", ""), + "arguments": function.get("arguments", "{}"), + }) + block += "\n Compatibility tool requests:\n" + json.dumps(simplified_calls, ensure_ascii=False, indent=2) + history_parts.append(block) + continue + if role == "tool": + tool_call_id = str(message.get("tool_call_id", "") or "") + history_parts.append(_history_line(f"Tool result (tool_call_id={tool_call_id})", text)) + continue + history_parts.append(_history_line(role.capitalize(), text)) + + prompt_parts: list[str] = [] + if tools: + simplified_tools = _simplify_tool_schemas(tools) + prompt_parts.append("Available compatibility tools:\n" + json.dumps(simplified_tools, ensure_ascii=False, indent=2)) + prompt_parts.append("Do not execute these compatibility tools yourself. If you need one, request it in `tool_calls`. Each `arguments` field must be a JSON string.") + if tool_choice == "required": + prompt_parts.append("Tool choice policy: you must request at least one compatibility tool.") + elif isinstance(tool_choice, dict) and tool_choice.get("type") == "function": + function = tool_choice.get("function", {}) or {} + prompt_parts.append(f"Tool choice policy: you must request the compatibility tool `{function.get('name', '')}`.") + history_text = "\n".join(part for part in history_parts if part).strip() + if history_text: + prompt_parts.append("History:\n" + history_text) + if structured_output: + prompt_parts.append("Return only JSON matching the provided schema.") + if tools: + prompt_parts.append("Set `content` to the assistant-visible reply. Set `tool_calls` to an empty array when no compatibility tool is needed.") + else: + prompt_parts.append("Answer the latest user request.") + return "\n\n".join(part for part in system_parts if part).strip(), "\n\n".join(prompt_parts), attachments + + +def _copy_attachments_to_temp(attachments: list[dict[str, Any]], temp_dir: str) -> list[dict[str, str]]: + copied: list[dict[str, str]] = [] + for index, attachment in enumerate(attachments, 1): + source_path = attachment.get("path") + if source_path: + source_path = str(source_path) + source_suffix = os.path.splitext(source_path)[1] + target_path = os.path.join(temp_dir, f"image_{index}{source_suffix or '.bin'}") + shutil.copyfile(source_path, target_path) + copied.append({"path": target_path, "label": str(attachment.get("label", ""))}) + continue + mime = str(attachment.get("mime", "image/png")) + suffix = mimetypes.guess_extension(mime) or ".png" + target_path = os.path.join(temp_dir, f"image_{index}{suffix}") + with open(target_path, "wb") as f: + f.write(attachment.get("bytes", b"") or b"") + copied.append({"path": target_path, "label": str(attachment.get("label", ""))}) + return copied + + +def _append_attachment_instructions(prompt: str, copied_attachments: list[dict[str, str]]) -> str: + if not copied_attachments or not CLAUDE_ALLOW_ATTACHMENT_READ: + return prompt + lines = [ + "Attached image files:", + *[f"- {item['label'] or f'Attached image {index}'}: {item['path']}" for index, item in enumerate(copied_attachments, 1)], + "If you need to inspect an attached image, you may use the built-in `Read` tool on those listed paths only. Do not use built-in tools for any other purpose.", + ] + return prompt.rstrip() + "\n\n" + "\n".join(lines) + + +def _usage_from_result(result_event: dict[str, Any] | None) -> dict[str, int]: + usage = (result_event or {}).get("usage", {}) or {} + input_tokens = int(usage.get("input_tokens", 0) or 0) + output_tokens = int(usage.get("output_tokens", 0) or 0) + return { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + } + + +def _extract_result(event_stream: list[dict[str, Any]]) -> tuple[str, dict[str, Any] | None]: + result_event = None + for event in reversed(event_stream): + if event.get("type") == "result": + result_event = event + break + if result_event is None: + raise RuntimeError("Claude backend did not return a result event.") + content = result_event.get("result") or result_event.get("content") or "" + return str(content), result_event + + +def _check_claude_error(stderr_text: str, model: str) -> None: + lowered = stderr_text.lower() + if "invalid api key" in lowered or "authentication" in lowered or "login" in lowered: + raise RuntimeError("Claude CLI is not logged in. Run `claude auth login` (or start `claude` and use `/login`) first.") + if "unknown model" in lowered or "not available" in lowered or "invalid model" in lowered: + default_model = default_model_for_backend("claude") + raise RuntimeError(f"Claude backend tried to use model {model!r}, but your current Claude CLI/account rejected it. Try an available Claude model such as {default_model!r}.") + + +def _normalize_reasoning_effort(effort: str | None) -> str | None: + normalized = str(effort or "").strip().lower() + if not normalized or normalized == "off": + return None + if normalized in _VALID_EFFORTS: + return normalized + return None + + +def _assistant_message_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "content": {"type": "string"}, + "tool_calls": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": {"type": "string"}, + }, + "required": ["name", "arguments"], + "additionalProperties": False, + }, + }, + }, + "required": ["content", "tool_calls"], + "additionalProperties": False, + } + + +def _assistant_message_schema_wrapper() -> str: + return json.dumps(_assistant_message_schema(), ensure_ascii=False) + + +def _run_claude_print(*, system: str, prompt: str, model: str, tools: list[dict[str, Any]] | None, tool_choice: str | dict[str, Any] | None, return_message: bool, timeout: int | None, attachments: list[dict[str, Any]] | None = None) -> tuple[str, dict[str, Any], dict[str, int]]: + effort = _normalize_reasoning_effort(REASONING_EFFORT) + with tempfile.TemporaryDirectory(prefix="skillopt_claude_") as temp_dir: + copied_attachments = _copy_attachments_to_temp(attachments or [], temp_dir) + prompt_for_cli = _append_attachment_instructions(prompt, copied_attachments) + cmd = [CLAUDE_BIN, "-p", "--output-format", "json", "--permission-mode", CLAUDE_PERMISSION_MODE, "--add-dir", temp_dir] + if model: + cmd.extend(["--model", model]) + if CLAUDE_SETTING_SOURCES: + cmd.extend(["--setting-sources", CLAUDE_SETTING_SOURCES]) + if system: + cmd.extend(["--append-system-prompt", system]) + if effort: + cmd.extend(["--thinking", effort]) + structured_output = bool(return_message) + if structured_output: + cmd.extend(["--schema", _assistant_message_schema_wrapper()]) + proc = subprocess.run(cmd + [prompt_for_cli], capture_output=True, text=True, timeout=timeout or 300, cwd=temp_dir) + stderr_text = (proc.stderr or "").strip() + if proc.returncode != 0: + _check_claude_error(stderr_text, model) + raise RuntimeError(stderr_text or f"Claude CLI exited with code {proc.returncode}") + stream = [] + for raw_line in (proc.stdout or "").splitlines(): + raw_line = raw_line.strip() + if not raw_line: + continue + try: + stream.append(json.loads(raw_line)) + except json.JSONDecodeError: + continue + raw_text, result_event = _extract_result(stream) + usage_info = _usage_from_result(result_event) + return raw_text, result_event or {}, usage_info + + +def _compat_message_from_payload(payload: Any) -> CompatAssistantMessage: + if not isinstance(payload, dict): + return CompatAssistantMessage(content=str(payload or ""), tool_calls=[]) + content = str(payload.get("content", "") or "") + tool_calls: list[CompatToolCall] = [] + for index, tool_call in enumerate(payload.get("tool_calls", []) or [], start=1): + name = str(tool_call.get("name", "") or "") + arguments = str(tool_call.get("arguments", "{}") or "{}") + tool_calls.append(CompatToolCall(id=f"claude_tool_{index}", function=CompatToolFunction(name=name, arguments=arguments))) + return CompatAssistantMessage(content=content, tool_calls=tool_calls) + + +def _call_messages(messages: list[dict[str, Any]], max_completion_tokens: int, retries: int, stage: str, *, tools: list[dict[str, Any]] | None = None, tool_choice: str | dict[str, Any] | None = None, return_message: bool = False, deployment: str | None = None, timeout: int | None = None) -> tuple[Any, dict[str, int]]: + del max_completion_tokens + system, prompt, attachments = _build_prompt_from_messages(messages, tools=tools, tool_choice=tool_choice, structured_output=return_message) + model = deployment or STUDENT_DEPLOYMENT + last_err = None + for attempt in range(retries): + try: + raw_text, payload, usage_info = _run_claude_print(system=system, prompt=prompt, model=model, tools=tools, tool_choice=tool_choice, return_message=return_message, timeout=timeout, attachments=attachments) + tracker.record(stage, usage_info["prompt_tokens"], usage_info["completion_tokens"]) + if return_message: + return _compat_message_from_payload(payload.get("result", payload)), usage_info + return raw_text, usage_info + except Exception as e: # noqa: BLE001 + last_err = e + time.sleep(min(2 ** attempt, 15)) + raise RuntimeError(f"Claude backend failed after {retries} retries: {last_err}") + + +def chat_teacher(system: str, user: str, max_completion_tokens: int = 16384, retries: int = 5, stage: str = "teacher", timeout: int | None = None) -> tuple[str, dict[str, int]]: + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + return _call_messages(messages, max_completion_tokens, retries, stage, deployment=TEACHER_DEPLOYMENT, timeout=timeout) + + +def chat_student(system: str, user: str, max_completion_tokens: int = 16384, retries: int = 5, stage: str = "student", timeout: int | None = None) -> tuple[str, dict[str, int]]: + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + return _call_messages(messages, max_completion_tokens, retries, stage, deployment=STUDENT_DEPLOYMENT, timeout=timeout) + + +def chat_with_deployment(deployment: str, system: str, user: str, max_completion_tokens: int = 16384, retries: int = 5, stage: str = "custom", timeout: int | None = None) -> tuple[str, dict[str, int]]: + messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] + return _call_messages(messages, max_completion_tokens, retries, stage, deployment=deployment, timeout=timeout) + + +def chat_teacher_messages(messages: list[dict[str, Any]], max_completion_tokens: int = 16384, retries: int = 5, stage: str = "teacher", *, tools: list[dict[str, Any]] | None = None, tool_choice: str | dict[str, Any] | None = None, return_message: bool = False, timeout: int | None = None) -> tuple[Any, dict[str, int]]: + return _call_messages(messages, max_completion_tokens, retries, stage, tools=tools, tool_choice=tool_choice, return_message=return_message, deployment=TEACHER_DEPLOYMENT, timeout=timeout) + + +def chat_student_messages(messages: list[dict[str, Any]], max_completion_tokens: int = 16384, retries: int = 5, stage: str = "student", *, tools: list[dict[str, Any]] | None = None, tool_choice: str | dict[str, Any] | None = None, return_message: bool = False, timeout: int | None = None) -> tuple[Any, dict[str, int]]: + return _call_messages(messages, max_completion_tokens, retries, stage, tools=tools, tool_choice=tool_choice, return_message=return_message, deployment=STUDENT_DEPLOYMENT, timeout=timeout) + + +def chat_messages_with_deployment(deployment: str, messages: list[dict[str, Any]], max_completion_tokens: int = 16384, retries: int = 5, stage: str = "custom", *, tools: list[dict[str, Any]] | None = None, tool_choice: str | dict[str, Any] | None = None, return_message: bool = False, timeout: int | None = None) -> tuple[Any, dict[str, int]]: + return _call_messages(messages, max_completion_tokens, retries, stage, tools=tools, tool_choice=tool_choice, return_message=return_message, deployment=deployment, timeout=timeout) + + +def get_token_summary() -> dict[str, dict[str, int]]: + return tracker.summary() + + +def reset_token_tracker() -> None: + tracker.reset() + + +def set_reasoning_effort(effort: str | None) -> None: + global REASONING_EFFORT + REASONING_EFFORT = effort if effort else None + + +def set_student_deployment(deployment: str) -> None: + global STUDENT_DEPLOYMENT + STUDENT_DEPLOYMENT = deployment or default_model_for_backend("claude") + os.environ["STUDENT_DEPLOYMENT"] = STUDENT_DEPLOYMENT + + +def set_teacher_deployment(deployment: str) -> None: + global TEACHER_DEPLOYMENT + TEACHER_DEPLOYMENT = deployment or default_model_for_backend("claude") + os.environ["TEACHER_DEPLOYMENT"] = TEACHER_DEPLOYMENT diff --git a/skillopt/model/codex_backend.py b/skillopt/model/codex_backend.py new file mode 100644 index 0000000..c69599f --- /dev/null +++ b/skillopt/model/codex_backend.py @@ -0,0 +1,664 @@ +"""Codex CLI backend for ReflACT.""" +from __future__ import annotations + +import base64 +import json +import mimetypes +import os +import subprocess +import tempfile +import time +import uuid +from typing import Any +from urllib.parse import unquote, urlparse + +from skillopt.model.common import ( + CompatAssistantMessage, + CompatToolCall, + CompatToolFunction, + tracker, +) + + +CODEX_BIN = os.environ.get("CODEX_CLI_BIN", "codex") +CODEX_PROFILE = os.environ.get("CODEX_PROFILE", "review") +CODEX_SANDBOX_MODE = os.environ.get("CODEX_SANDBOX_MODE", "read-only") + +TEACHER_DEPLOYMENT = os.environ.get("TEACHER_DEPLOYMENT", "gpt-5.5") +STUDENT_DEPLOYMENT = os.environ.get("STUDENT_DEPLOYMENT", "gpt-5.5") + +REASONING_EFFORT: str | None = None + + +def _default_working_directory() -> str: + return os.environ.get("CODEX_WORKING_DIRECTORY", os.getcwd()) + + +def _parse_data_uri(url: str) -> tuple[bytes, str]: + header, data = url.split(",", 1) + mime = header[5:].split(";", 1)[0] or "image/png" + return base64.b64decode(data), mime + + +def _content_to_text( + content: Any, + attachments: list[dict[str, Any]], + *, + image_counter: int, +) -> tuple[str, int]: + if isinstance(content, str): + return content, image_counter + + if not isinstance(content, list): + return str(content), image_counter + + parts: list[str] = [] + for item in content: + if not isinstance(item, dict): + continue + item_type = item.get("type") + if item_type == "text": + parts.append(str(item.get("text", ""))) + continue + if item_type != "image_url": + continue + + image_counter += 1 + label = f"[Attached image {image_counter}]" + parts.append(label) + + image_url = item.get("image_url", {}) or {} + url = str(image_url.get("url", "") or "") + if not url: + continue + if url.startswith("data:") and ";base64," in url: + data, mime = _parse_data_uri(url) + attachments.append({"bytes": data, "mime": mime}) + continue + if url.startswith("file://"): + parsed = urlparse(url) + path = unquote(parsed.path) + if path: + attachments.append({"path": path}) + continue + if os.path.exists(url): + attachments.append({"path": url}) + + return "".join(parts), image_counter + + +def _simplify_tool_schemas(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]]: + simplified: list[dict[str, Any]] = [] + for tool in tools or []: + function = tool.get("function", tool) + simplified.append( + { + "name": function.get("name", ""), + "description": function.get("description", ""), + "parameters": function.get("parameters", {}), + } + ) + return simplified + + +def _build_prompt_from_messages( + messages: list[dict[str, Any]], + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + structured_output: bool = False, +) -> tuple[str, list[dict[str, Any]]]: + system_parts: list[str] = [] + history_parts: list[str] = [] + attachments: list[dict[str, Any]] = [] + image_counter = 0 + + def _history_line(label: str, body: str) -> str: + stripped = body.strip() + if not stripped: + return f"- {label}:" + indented = stripped.replace("\n", "\n ") + return f"- {label}: {indented}" + + for message in messages: + role = str(message.get("role", "user")) + text, image_counter = _content_to_text( + message.get("content", ""), + attachments, + image_counter=image_counter, + ) + + if role == "system": + if text.strip(): + system_parts.append(text.strip()) + continue + + if role == "assistant": + block = _history_line("Assistant", text) + tool_calls = message.get("tool_calls") or [] + if tool_calls: + simplified_calls = [] + for tool_call in tool_calls: + function = tool_call.get("function", {}) or {} + simplified_calls.append( + { + "name": function.get("name", ""), + "arguments": function.get("arguments", "{}"), + } + ) + block += ( + "\n Compatibility tool requests:\n" + + json.dumps(simplified_calls, ensure_ascii=False, indent=2) + ) + history_parts.append(block) + continue + + if role == "tool": + tool_call_id = str(message.get("tool_call_id", "") or "") + label = f"Tool result (tool_call_id={tool_call_id})" + history_parts.append(_history_line(label, text)) + continue + + history_parts.append(_history_line(role.capitalize(), text)) + + prompt_parts: list[str] = [] + + system_text = "\n\n".join(part for part in system_parts if part).strip() + if system_text: + prompt_parts.append(system_text) + + if tools: + simplified_tools = _simplify_tool_schemas(tools) + prompt_parts.append( + "Available compatibility tools:\n" + + json.dumps(simplified_tools, ensure_ascii=False, indent=2) + ) + prompt_parts.append( + "Do not execute these tools yourself. If you need one, request it in " + "`tool_calls`. Each `arguments` field must be a JSON string." + ) + + if tool_choice == "required": + prompt_parts.append( + "Tool choice policy: you must request at least one compatibility tool." + ) + elif isinstance(tool_choice, dict) and tool_choice.get("type") == "function": + function = tool_choice.get("function", {}) or {} + prompt_parts.append( + "Tool choice policy: you must request the compatibility tool " + f"`{function.get('name', '')}`." + ) + + history_text = "\n".join(part for part in history_parts if part).strip() + if history_text: + prompt_parts.append("History:\n" + history_text) + + if structured_output: + prompt_parts.append("Return only JSON matching the provided schema.") + if tools: + prompt_parts.append( + "Set `content` to the assistant-visible reply. Set `tool_calls` to " + "an empty array when no tool is needed." + ) + else: + prompt_parts.append("Answer the latest user request.") + + return "\n\n".join(prompt_parts), attachments + + +def _assistant_message_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "content": {"type": "string"}, + "tool_calls": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": {"type": "string"}, + }, + "required": ["name", "arguments"], + "additionalProperties": False, + }, + }, + }, + "required": ["content", "tool_calls"], + "additionalProperties": False, + } + + +def _materialize_attachments( + attachments: list[dict[str, Any]], + temp_dir: str, +) -> list[str]: + image_paths: list[str] = [] + for index, attachment in enumerate(attachments, 1): + path = attachment.get("path") + if path: + image_paths.append(str(path)) + continue + + mime = str(attachment.get("mime", "image/png")) + suffix = mimetypes.guess_extension(mime) or ".png" + image_path = os.path.join(temp_dir, f"image_{index}{suffix}") + with open(image_path, "wb") as f: + f.write(attachment.get("bytes", b"")) + image_paths.append(image_path) + return image_paths + + +def _usage_from_event(usage: dict[str, Any] | None) -> dict[str, int]: + usage = usage or {} + prompt_tokens = int(usage.get("input_tokens", 0) or 0) + completion_tokens = int(usage.get("output_tokens", 0) or 0) + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + + +def _extract_error(stdout: str, stderr: str) -> str: + for raw_line in reversed(stdout.splitlines()): + line = raw_line.strip() + if not line: + continue + try: + payload = json.loads(line) + except json.JSONDecodeError: + continue + if payload.get("type") == "turn.failed": + error = payload.get("error", {}) or {} + return str(error.get("message", "") or "Codex turn failed") + if payload.get("type") == "error": + return str(payload.get("message", "") or "Codex execution failed") + return stderr.strip() or stdout.strip() or "Codex execution failed" + + +def _run_codex_exec( + *, + model: str, + prompt: str, + attachments: list[dict[str, Any]], + output_schema: dict[str, Any] | None, + timeout: int | None, +) -> tuple[str, dict[str, int]]: + with tempfile.TemporaryDirectory(prefix="skillopt_codex_") as temp_dir: + output_path = os.path.join(temp_dir, "last_message.txt") + image_paths = _materialize_attachments(attachments, temp_dir) + + command = [ + CODEX_BIN, + "exec", + "--json", + "--ephemeral", + "--profile", + CODEX_PROFILE, + "-c", + "approval_policy=\"never\"", + "--sandbox", + CODEX_SANDBOX_MODE, + "--skip-git-repo-check", + "--cd", + _default_working_directory(), + "--model", + model, + "--output-last-message", + output_path, + ] + + if REASONING_EFFORT: + command.extend(["-c", f"model_reasoning_effort={json.dumps(REASONING_EFFORT)}"]) + + schema_path = None + if output_schema is not None: + schema_path = os.path.join(temp_dir, "schema.json") + with open(schema_path, "w", encoding="utf-8") as f: + json.dump(output_schema, f, ensure_ascii=False) + command.extend(["--output-schema", schema_path]) + + for image_path in image_paths: + command.extend(["--image", image_path]) + + command.append("-") + + proc = subprocess.run( + command, + input=prompt, + text=True, + capture_output=True, + timeout=timeout, + check=False, + ) + + usage_info = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + fallback_text = "" + for raw_line in proc.stdout.splitlines(): + line = raw_line.strip() + if not line: + continue + try: + payload = json.loads(line) + except json.JSONDecodeError: + continue + if payload.get("type") == "item.completed": + item = payload.get("item", {}) or {} + if item.get("type") == "agent_message": + fallback_text = str(item.get("text", "") or fallback_text) + if payload.get("type") == "turn.completed": + usage_info = _usage_from_event(payload.get("usage")) + + last_message = "" + if os.path.exists(output_path): + with open(output_path, encoding="utf-8") as f: + last_message = f.read().strip() + if not last_message: + last_message = fallback_text.strip() + + if proc.returncode != 0: + raise RuntimeError(_extract_error(proc.stdout, proc.stderr)) + if not last_message: + raise RuntimeError("Codex returned an empty final message") + return last_message, usage_info + + +def _tool_name_from_choice(tool_choice: str | dict[str, Any] | None) -> str | None: + if not isinstance(tool_choice, dict): + return None + if tool_choice.get("type") != "function": + return None + function = tool_choice.get("function", {}) or {} + return str(function.get("name", "") or "") or None + + +def _compat_message_from_payload( + payload: dict[str, Any], + *, + tool_choice: str | dict[str, Any] | None = None, +) -> CompatAssistantMessage: + content = str(payload.get("content", "") or "") + tool_calls: list[CompatToolCall] = [] + for index, raw_tool_call in enumerate(payload.get("tool_calls", []) or [], 1): + if not isinstance(raw_tool_call, dict): + continue + name = str(raw_tool_call.get("name", "") or "") + arguments = raw_tool_call.get("arguments", "{}") + if not isinstance(arguments, str): + arguments = json.dumps(arguments, ensure_ascii=False) + tool_calls.append( + CompatToolCall( + id=f"tool_{index}_{uuid.uuid4().hex[:12]}", + function=CompatToolFunction(name=name, arguments=arguments), + ) + ) + + if tool_choice == "required" and not tool_calls: + raise RuntimeError("Codex response did not request a tool under tool_choice='required'") + + required_name = _tool_name_from_choice(tool_choice) + if required_name and all( + tool_call.function.name != required_name for tool_call in tool_calls + ): + raise RuntimeError( + f"Codex response did not request the required tool {required_name!r}" + ) + + return CompatAssistantMessage(content=content, tool_calls=tool_calls) + + +def _chat_messages_impl( + model: str, + messages: list[dict[str, Any]], + max_completion_tokens: int, + retries: int, + stage: str, + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict[str, int]]: + del max_completion_tokens + last_err = None + structured_output = bool(tools) or return_message + + for attempt in range(retries): + try: + prompt, attachments = _build_prompt_from_messages( + messages, + tools=tools, + tool_choice=tool_choice, + structured_output=structured_output, + ) + raw_text, usage_info = _run_codex_exec( + model=model, + prompt=prompt, + attachments=attachments, + output_schema=_assistant_message_schema() if structured_output else None, + timeout=timeout, + ) + tracker.record( + stage, + usage_info["prompt_tokens"], + usage_info["completion_tokens"], + ) + + if not structured_output: + return raw_text, usage_info + + payload = json.loads(raw_text) + compat = _compat_message_from_payload(payload, tool_choice=tool_choice) + return (compat if return_message else compat.content), usage_info + except subprocess.TimeoutExpired as exc: + last_err = RuntimeError(f"Codex CLI timed out after {timeout}s") if timeout else exc + except Exception as exc: # noqa: BLE001 + last_err = exc + time.sleep(min(2 ** attempt, 30)) + + raise RuntimeError(f"Codex call failed after {retries} retries: {last_err}") + + +def chat_with_model( + model: str, + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "custom", + timeout: int | None = None, +) -> tuple[str, dict[str, int]]: + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ] + return _chat_messages_impl( + model, + messages, + max_completion_tokens, + retries, + stage, + timeout=timeout, + ) + + +def chat_messages_with_model( + model: str, + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "custom", + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict[str, int]]: + return _chat_messages_impl( + model, + messages, + max_completion_tokens, + retries, + stage, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def chat_teacher( + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "teacher", + timeout: int | None = None, +) -> tuple[str, dict[str, int]]: + return chat_with_model( + model=TEACHER_DEPLOYMENT, + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + timeout=timeout, + ) + + +def chat_with_deployment( + deployment: str, + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "custom", + timeout: int | None = None, +) -> tuple[str, dict[str, int]]: + return chat_with_model( + model=deployment, + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + timeout=timeout, + ) + + +def chat_student( + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "student", + timeout: int | None = None, +) -> tuple[str, dict[str, int]]: + return chat_with_model( + model=STUDENT_DEPLOYMENT, + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + timeout=timeout, + ) + + +def chat_teacher_messages( + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "teacher", + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict[str, int]]: + return _chat_messages_impl( + TEACHER_DEPLOYMENT, + messages, + max_completion_tokens, + retries, + stage, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def chat_messages_with_deployment( + deployment: str, + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "custom", + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict[str, int]]: + return _chat_messages_impl( + deployment, + messages, + max_completion_tokens, + retries, + stage, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def chat_student_messages( + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "student", + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict[str, int]]: + return _chat_messages_impl( + STUDENT_DEPLOYMENT, + messages, + max_completion_tokens, + retries, + stage, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def get_token_summary() -> dict[str, dict[str, int]]: + return tracker.summary() + + +def reset_token_tracker() -> None: + tracker.reset() + + +def set_student_deployment(deployment: str) -> None: + global STUDENT_DEPLOYMENT + STUDENT_DEPLOYMENT = deployment + os.environ["STUDENT_DEPLOYMENT"] = deployment + + +def set_reasoning_effort(effort: str | None) -> None: + global REASONING_EFFORT + REASONING_EFFORT = effort if effort else None + + +def set_teacher_deployment(deployment: str) -> None: + global TEACHER_DEPLOYMENT + TEACHER_DEPLOYMENT = deployment + os.environ["TEACHER_DEPLOYMENT"] = deployment diff --git a/skillopt/model/codex_harness.py b/skillopt/model/codex_harness.py new file mode 100644 index 0000000..c5486f2 --- /dev/null +++ b/skillopt/model/codex_harness.py @@ -0,0 +1,1057 @@ +"""Helpers for running exec backends as the student harness.""" +from __future__ import annotations + +import asyncio +import json +import os +import re +import shutil +import subprocess +import threading +import traceback +from typing import Any + +from skillopt.model.backend_config import ( + get_claude_code_exec_config, + get_codex_exec_config, + get_student_backend, +) + + +ANSWER_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "final_response": { + "type": "string", + "description": "The exact final answer text to return, preserving required ... tags.", + }, + "final_answer": { + "type": "string", + "description": "The concise answer value without explanation, if separable.", + }, + }, + "required": ["final_response", "final_answer"], + "additionalProperties": False, +} + + +def render_skill_md( + skill_content: str, + *, + name: str = "skillopt-student", + description: str = "Dynamic ReflACT skill for the current benchmark task.", + preamble: str = "", +) -> str: + body = skill_content.strip() or "No additional dynamic guidance was provided for this task." + chunks = [ + "---", + f'name: "{name}"', + f'description: "{description}"', + "---", + "", + "# ReflACT Student Skill", + "", + ] + if preamble.strip(): + chunks.append(preamble.strip()) + chunks.append("") + chunks.extend([ + "## Dynamic Guidance", + "", + body, + "", + ]) + return "\n".join(chunks) + + +def prepare_workspace( + *, + work_dir: str, + skill_md: str, + task_text: str = "", + task_filename: str = "task.md", + images: list[str] | None = None, + extra_files: dict[str, str] | None = None, + copy_files: list[tuple[str, str]] | None = None, + link_dirs: list[tuple[str, str]] | None = None, +) -> tuple[str, str]: + if os.path.exists(work_dir): + shutil.rmtree(work_dir) + os.makedirs(os.path.join(work_dir, ".agents", "skills", "skillopt-student"), exist_ok=True) + + skill_path = os.path.join(work_dir, ".agents", "skills", "skillopt-student", "SKILL.md") + with open(skill_path, "w", encoding="utf-8") as f: + f.write(skill_md) + + task_path = os.path.join(work_dir, task_filename) + if task_text: + with open(task_path, "w", encoding="utf-8") as f: + f.write(task_text) + + if extra_files: + for rel_path, content in extra_files.items(): + full_path = os.path.join(work_dir, rel_path) + parent = os.path.dirname(full_path) + if parent: + os.makedirs(parent, exist_ok=True) + with open(full_path, "w", encoding="utf-8") as f: + f.write(content) + + if copy_files: + for src, rel_dst in copy_files: + dst = os.path.join(work_dir, rel_dst) + parent = os.path.dirname(dst) + if parent: + os.makedirs(parent, exist_ok=True) + shutil.copy2(src, dst) + + if link_dirs: + for src, rel_dst in link_dirs: + dst = os.path.join(work_dir, rel_dst) + parent = os.path.dirname(dst) + if parent: + os.makedirs(parent, exist_ok=True) + os.symlink(os.path.abspath(src), dst) + + attachment_lines: list[str] = [] + if images: + attachments_dir = os.path.join(work_dir, "attachments") + os.makedirs(attachments_dir, exist_ok=True) + for index, image in enumerate(images, 1): + if not os.path.exists(image): + raise FileNotFoundError(image) + src = os.path.abspath(image) + base = os.path.basename(src) or f"image_{index}" + dst_name = f"{index:02d}_{base}" + dst = os.path.join(attachments_dir, dst_name) + if os.path.abspath(src) != os.path.abspath(dst): + shutil.copy2(src, dst) + rel_dst = os.path.relpath(dst, work_dir) + attachment_lines.append(f"- `{rel_dst}` (source: `{src}`)") + + if attachment_lines: + with open(os.path.join(work_dir, "ATTACHMENTS.md"), "w", encoding="utf-8") as f: + f.write( + "# Attachments\n\n" + "Use these local files when the task refers to attached images or documents.\n\n" + + "\n".join(attachment_lines) + + "\n" + ) + + return skill_path, task_path + + +def _build_codex_trace_summary(raw: str, response: str) -> str: + lines = [ln.rstrip() for ln in (raw or "").splitlines()] + + def _find(prefix: str) -> str: + for ln in lines: + if ln.startswith(prefix): + return ln[len(prefix):].strip() + return "" + + sandbox = _find("sandbox: ") + reasoning = _find("reasoning effort: ") + task_read = "unknown" + skill_read = "unknown" + exec_errors: list[str] = [] + tokens_used = "" + + for idx, ln in enumerate(lines): + if ln.startswith("exec"): + cmd = lines[idx + 1] if idx + 1 < len(lines) else "" + outcome = lines[idx + 2] if idx + 2 < len(lines) else "" + joined = f"{cmd}\n{outcome}" + if "task.md" in joined: + if "succeeded" in outcome: + task_read = "success" + elif "failed" in outcome or "ERROR" in outcome: + task_read = "failed" + if "SKILL.md" in joined: + if "succeeded" in outcome: + skill_read = "success" + elif "failed" in outcome or "ERROR" in outcome: + skill_read = "failed" + if ln.startswith("ERROR:"): + exec_errors.append(ln[len("ERROR:"):].strip()) + if ln == "tokens used" and idx + 1 < len(lines): + tokens_used = lines[idx + 1].strip() + + match = re.search(r"\s*([A-E])\s*", response or "", re.IGNORECASE) + if match: + answer_format = "well_formed" + answer_label = match.group(1).upper() + elif "" in (response or "").lower(): + answer_format = "tagged_nonlabel" + answer_label = "" + elif (response or "").strip(): + answer_format = "plain_text" + answer_label = "" + else: + answer_format = "missing" + answer_label = "" + + parts = ["Codex Trace Summary"] + if sandbox: + parts.append(f"- sandbox: {sandbox}") + if reasoning: + parts.append(f"- reasoning: {reasoning}") + parts.append(f"- read task.md: {task_read}") + parts.append(f"- read SKILL.md: {skill_read}") + if exec_errors: + parts.append(f"- shell/tool errors: {' | '.join(exec_errors[:3])}") + else: + parts.append("- shell/tool errors: none") + parts.append(f"- final answer format: {answer_format}") + parts.append(f"- final answer label: {answer_label or '(none)'}") + if tokens_used: + parts.append(f"- tokens used: {tokens_used}") + return "\n".join(parts) + + +def _build_claude_trace_summary(raw: str, response: str) -> str: + answer_format = "missing" + if "" in (response or "").lower(): + answer_format = "tagged" + elif (response or "").strip(): + answer_format = "plain_text" + errors: list[str] = [] + for ln in (raw or "").splitlines(): + if "error" in ln.lower() or "traceback" in ln.lower(): + errors.append(ln.strip()) + if len(errors) >= 3: + break + parts = ["Claude Code Trace Summary", f"- final answer format: {answer_format}"] + parts.append(f"- final response chars: {len(response or '')}") + parts.append(f"- errors: {' | '.join(errors) if errors else 'none'}") + return "\n".join(parts) + + +def _persist_artifacts( + *, + work_dir: str, + raw: str, + response: str, + prefix: str, + summary_builder, +) -> None: + pred_dir = os.path.dirname(work_dir.rstrip(os.sep)) + raw_path = os.path.join(pred_dir, f"{prefix}_raw.txt") + summary_path = os.path.join(pred_dir, f"{prefix}_trace_summary.txt") + + combined_raw = raw + if os.path.exists(raw_path): + with open(raw_path, encoding="utf-8") as f: + prev = f.read() + combined_raw = f"{prev}\n\n===== TURN BREAK =====\n\n{raw}" if prev.strip() else raw + + with open(raw_path, "w", encoding="utf-8") as f: + f.write(combined_raw) + with open(summary_path, "w", encoding="utf-8") as f: + f.write(summary_builder(combined_raw, response)) + + +def _persist_codex_artifacts(work_dir: str, raw: str, response: str) -> None: + _persist_artifacts( + work_dir=work_dir, + raw=raw, + response=response, + prefix="codex", + summary_builder=_build_codex_trace_summary, + ) + + +def _persist_claude_artifacts(work_dir: str, raw: str, response: str) -> None: + _persist_artifacts( + work_dir=work_dir, + raw=raw, + response=response, + prefix="claude", + summary_builder=_build_claude_trace_summary, + ) + + +def parse_codex_raw(raw: str) -> dict: + """Parse raw Codex CLI output into step sections. + + Returns a dict with: + - ``steps``: ordered sections beginning at the first ``user/codex/exec`` marker + - ``trace_body``: raw trace starting at the first marker + """ + lines = (raw or "").splitlines() + markers = {"user", "codex", "exec"} + first_step_line: int | None = None + for idx, line in enumerate(lines): + if line in markers: + first_step_line = idx + break + if first_step_line is None: + return {"steps": [], "trace_body": ""} + + steps: list[dict] = [] + current: dict | None = None + for idx in range(first_step_line, len(lines)): + line = lines[idx] + if line in markers: + if current is not None: + current["end_line"] = idx + current["content"] = "\n".join(current["content_lines"]).strip() + current.pop("content_lines", None) + steps.append(current) + current = { + "index": len(steps) + 1, + "type": line, + "start_line": idx, + "content_lines": [], + } + continue + if current is not None: + current["content_lines"].append(line) + if current is not None: + current["end_line"] = len(lines) + current["content"] = "\n".join(current["content_lines"]).strip() + current.pop("content_lines", None) + steps.append(current) + + trace_body = "\n".join(lines[first_step_line:]).strip() + return {"steps": steps, "trace_body": trace_body} + + +def format_codex_trace_steps(raw: str, *, max_chars: int = 4000) -> str: + """Render parsed Codex trace into numbered compact steps for teacher prompts.""" + parsed = parse_codex_raw(raw) + steps = parsed["steps"] + if not steps: + return "" + + rendered: list[str] = [] + for step in steps: + summary = "" + content = str(step.get("content") or "").strip() + if step["type"] == "exec": + body_lines = [ln.strip() for ln in content.splitlines() if ln.strip()] + cmd = body_lines[0] if body_lines else "" + status = "" + for ln in body_lines[1:]: + low = ln.lower() + if "succeeded in" in low or "failed in" in low or "timed out" in low or low.startswith("error"): + status = ln + break + summary = cmd + if status: + summary = f"{summary} | {status}" if summary else status + else: + summary = " ".join(content.splitlines()) + summary = summary[:500] if summary else "(empty)" + rendered.append(f"[{step['index']}] {step['type']}: {summary}") + + text = "\n".join(rendered) + if len(text) > max_chars: + text = text[:max_chars] + "\n...[trace steps truncated]..." + return text + + +def extract_codex_trace_prefix(raw: str, *, after_step: int) -> str: + """Return raw trace body up to and including ``after_step``. + + ``after_step <= 0`` yields an empty string. + """ + if after_step <= 0: + return "" + parsed = parse_codex_raw(raw) + steps = parsed["steps"] + if not steps: + return "" + clamped = min(after_step, len(steps)) + lines = parsed["trace_body"].splitlines() + end_line = int(steps[clamped - 1]["end_line"]) - int(steps[0]["start_line"]) + return "\n".join(lines[:end_line]).strip() + + +_DENIED_DATA_DIR_NAMES = {"officeqa_split", "sealqa_split"} + + +def _normalize_tools(allowed_tools: list[str] | str | None) -> str: + if allowed_tools is None: + return "" + if isinstance(allowed_tools, str): + return ",".join(part.strip() for part in allowed_tools.split(",") if part.strip()) + return ",".join(str(tool).strip() for tool in allowed_tools if str(tool).strip()) + + +def _tools_list(allowed_tools: list[str] | str | None) -> list[str]: + tools = _normalize_tools(allowed_tools) + return [part.strip() for part in tools.split(",") if part.strip()] + + +def _validate_exec_path(path: str) -> str: + resolved = os.path.realpath(os.path.abspath(path)) + parts = set(resolved.split(os.sep)) + denied = parts & _DENIED_DATA_DIR_NAMES + if denied: + raise ValueError(f"Refusing to expose denied data directory to exec backend: {', '.join(sorted(denied))}") + return resolved + + +def _validated_add_dirs(work_dir: str, data_dirs: list[str] | None, images: list[str] | None) -> list[str]: + add_dirs = [_validate_exec_path(work_dir)] + for data_dir in data_dirs or []: + add_dirs.append(_validate_exec_path(data_dir)) + for image in images or []: + add_dirs.append(_validate_exec_path(os.path.dirname(image) or work_dir)) + deduped: list[str] = [] + for path in add_dirs: + if path not in deduped: + deduped.append(path) + return deduped + + +def _sdk_mode(value: Any) -> str: + mode = str(value or "auto").strip().lower() + if mode in {"1", "true", "yes", "on", "sdk"}: + return "sdk" + if mode in {"0", "false", "no", "off", "cli"}: + return "cli" + return "auto" + + +def _claude_effort(value: Any) -> str: + effort = str(value or "medium").strip().lower() + if effort in {"", "none", "off"}: + return "" + if effort == "xhigh": + return "max" + if effort not in {"low", "medium", "high", "max"}: + return "medium" + return effort + + +def _json_default(obj: Any) -> Any: + if isinstance(obj, (str, int, float, bool)) or obj is None: + return obj + if isinstance(obj, (list, tuple)): + return list(obj) + if isinstance(obj, dict): + return obj + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + if hasattr(obj, "__dict__"): + return {k: v for k, v in vars(obj).items() if not k.startswith("_")} + return str(obj) + + +def _json_dumps(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, indent=2, default=_json_default) + + +def _run_async(coro): + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + box: dict[str, Any] = {} + + def _target() -> None: + try: + box["result"] = asyncio.run(coro) + except BaseException as exc: # noqa: BLE001 + box["exception"] = exc + + thread = threading.Thread(target=_target, daemon=True) + thread.start() + thread.join() + if "exception" in box: + raise box["exception"] + return box.get("result") + + +def _exec_prompt(prompt: str, *, allow_file_edits: bool = False) -> str: + edit_instruction = ( + "You may modify files in the workspace when the task asks you to create an artifact. " + if allow_file_edits + else "Do not modify files. " + ) + return ( + "Use the workspace files to solve the task. Read task.md and the skill at " + ".agents/skills/skillopt-student/SKILL.md before answering. " + "If ATTACHMENTS.md exists, read it and inspect the listed local files. " + "Do not call a Skill tool; the ReflACT guidance is a local markdown file. " + f"Do not ask for permission. {edit_instruction}" + "Return only the final answer text, keeping any required ... tags exactly.\n\n" + f"{_normalize_student_exec_prompt(prompt)}" + ) + + +def _retry_prompt(prompt: str, attempt: int) -> str: + if attempt <= 0: + return prompt + return ( + f"{prompt}\n\n" + "Previous execution returned an empty final response. Re-read task.md and " + ".agents/skills/skillopt-student/SKILL.md. If ATTACHMENTS.md exists, use the listed files. " + "Then produce the final answer inside ...." + ) + + +def _normalize_student_exec_prompt(prompt: str) -> str: + """Avoid wording that makes Claude Code call an unregistered Skill tool.""" + text = prompt or "" + replacements = { + "Use the `skillopt-student` skill available in this workspace.": ( + "Read `.agents/skills/skillopt-student/SKILL.md` directly; do not call a Skill tool." + ), + "- Use the local `skillopt-student` skill before writing code.": ( + "- Read `.agents/skills/skillopt-student/SKILL.md` before writing code; do not call a Skill tool." + ), + } + for old, new in replacements.items(): + text = text.replace(old, new) + return text + + +def _strict_schema(schema: dict[str, Any]) -> dict[str, Any]: + strict = json.loads(json.dumps(schema)) + strict["additionalProperties"] = False + properties = strict.get("properties") or {} + strict["required"] = list(properties.keys()) + return strict + + +def _structured_response(data: Any) -> tuple[str, str]: + if not isinstance(data, dict): + return "", f"Structured output was not an object: {type(data).__name__}" + final_response = str(data.get("final_response") or "").strip() + final_answer = str(data.get("final_answer") or "").strip() + if final_response: + return final_response, "" + if final_answer: + if "" in final_answer.lower(): + return final_answer, "" + return f"{final_answer}", "" + return "", "Structured output did not contain a final response." + + +def _extract_claude_structured_output(messages: list[Any]) -> Any: + """Claude Code SDK can finish with error_during_execution after StructuredOutput.""" + for msg in reversed(messages): + structured = getattr(msg, "structured_output", None) + if isinstance(structured, dict): + return structured + + content = getattr(msg, "content", None) + if content is None and isinstance(msg, dict): + content = msg.get("content") + if not isinstance(content, list): + continue + + for item in reversed(content): + name = getattr(item, "name", None) + payload = getattr(item, "input", None) + if isinstance(item, dict): + name = item.get("name", name) + payload = item.get("input", payload) + if name == "StructuredOutput" and isinstance(payload, dict): + return payload + return None + + +def _raw_exception(label: str, exc: BaseException) -> str: + return _json_dumps({ + "backend": label, + "is_error": True, + "error_type": type(exc).__name__, + "error": str(exc), + "traceback": traceback.format_exc(), + }) + + +def _run_claude_code_sdk_exec( + *, + work_dir: str, + prompt: str, + model: str, + timeout: int, + images: list[str] | None = None, + data_dirs: list[str] | None = None, + allowed_tools: list[str] | str | None = None, + permission_mode: str | None = None, + allow_file_edits: bool = False, +) -> tuple[str, str]: + from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient + + async def _query() -> tuple[str, str]: + system_prompt: dict[str, Any] = { + "type": "preset", + "preset": "claude_code", + "append": ( + "Use the workspace files to solve the task. Read task.md and the skill at " + ".agents/skills/skillopt-student/SKILL.md before answering. " + "If ATTACHMENTS.md exists, read it and inspect the listed local files. " + "Do not call a Skill tool; the ReflACT guidance is a local markdown file. " + + ( + "You may modify files in the workspace when the task asks you to create an artifact. " + if allow_file_edits + else "Do not modify files. " + ) + + "Return structured output whose final_response preserves required ... tags." + ), + } + kwargs: dict[str, Any] = { + "system_prompt": system_prompt, + "output_format": {"type": "json_schema", "schema": ANSWER_SCHEMA}, + "allowed_tools": _tools_list(allowed_tools) or ["Read", "Bash"], + "cwd": str(work_dir), + "permission_mode": permission_mode or "bypassPermissions", + "add_dirs": _validated_add_dirs(work_dir, data_dirs, images), + "max_buffer_size": 8 * 1024 * 1024, + } + config = get_claude_code_exec_config() + effort = _claude_effort(config.get("effort")) + if effort: + kwargs["effort"] = effort + max_thinking_tokens = int(config.get("max_thinking_tokens", 0) or 0) + if max_thinking_tokens > 0: + kwargs["max_thinking_tokens"] = max_thinking_tokens + options = ClaudeAgentOptions(**kwargs) + if model: + options.model = model.split("/", 1)[1] if model.startswith("anthropic/") else model + + messages = [] + async with ClaudeSDKClient(options) as client: + await client.query(_normalize_student_exec_prompt(prompt)) + messages = [msg async for msg in client.receive_response()] + last = messages[-1] if messages else None + raw_structured_output = _extract_claude_structured_output(messages) + response, parse_error = _structured_response(raw_structured_output) + first = messages[0] if messages else None + first_data = getattr(first, "data", {}) if first is not None else {} + terminal_is_error = bool(getattr(last, "is_error", False)) if last is not None else False + raw = _json_dumps({ + "backend": "claude_code_sdk", + "uuid": first_data.get("uuid", "") if isinstance(first_data, dict) else "", + "session_id": getattr(last, "session_id", "") if last is not None else "", + "model": first_data.get("model", model) if isinstance(first_data, dict) else model, + "tools": first_data.get("tools", _tools_list(allowed_tools)) if isinstance(first_data, dict) else _tools_list(allowed_tools), + "duration_ms": getattr(last, "duration_ms", 0) if last is not None else 0, + "total_cost_usd": getattr(last, "total_cost_usd", 0.0) if last is not None else 0.0, + "num_turns": getattr(last, "num_turns", 0) if last is not None else 0, + "usage": getattr(last, "usage", {}) if last is not None else {}, + "result": getattr(last, "result", "") if last is not None else "", + "is_error": bool(parse_error) or (terminal_is_error and not response.strip()), + "terminal_is_error": terminal_is_error, + "parse_error": parse_error, + "raw_structured_output": raw_structured_output, + "messages": messages, + }) + return response, raw + + return _run_async(asyncio.wait_for(_query(), timeout=timeout)) + + +def _run_claude_code_cli_exec( + *, + work_dir: str, + prompt: str, + model: str, + timeout: int, + images: list[str] | None = None, + data_dirs: list[str] | None = None, + allowed_tools: list[str] | str | None = None, + permission_mode: str | None = None, + allow_file_edits: bool = False, +) -> tuple[str, str]: + config = get_claude_code_exec_config() + tools = "Read,Bash" if allowed_tools is None else _normalize_tools(allowed_tools) + cmd = [ + str(config["path"]), + "-p", + "--output-format", + "text", + "--permission-mode", + permission_mode or "bypassPermissions", + "--add-dir", + work_dir, + "--tools", + tools, + "--allowedTools", + tools, + ] + if config.get("profile"): + cmd.extend(["--settings", '{"env":{"CLAUDE_CODE_USE_BEDROCK":"0"}}']) + cmd.extend(["--append-system-prompt", f"Profile: {config['profile']}"]) + if model: + cmd.extend(["--model", model]) + effort = _claude_effort(config.get("effort")) + if effort: + cmd.extend(["--effort", effort]) + max_thinking_tokens = int(config.get("max_thinking_tokens", 0) or 0) + if max_thinking_tokens > 0: + cmd.extend(["--max-thinking-tokens", str(max_thinking_tokens)]) + for data_dir in data_dirs or []: + cmd.extend(["--add-dir", _validate_exec_path(data_dir)]) + if images: + for image in images: + cmd.extend(["--add-dir", _validate_exec_path(os.path.dirname(image) or work_dir)]) + cmd.extend(["--", _exec_prompt(prompt, allow_file_edits=allow_file_edits)]) + + try: + proc = subprocess.run( + cmd, + cwd=work_dir, + capture_output=True, + text=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired as exc: + stdout = exc.stdout or "" + stderr = exc.stderr or "" + raw = stdout + if stderr: + raw = f"{raw}\n[stderr]\n{stderr}" if raw else stderr + return "", raw + + stdout = proc.stdout or "" + stderr = proc.stderr or "" + raw = stdout + if stderr: + raw = f"{raw}\n[stderr]\n{stderr}" if raw else stderr + response = stdout.strip() + if proc.returncode != 0 and not response: + return "", raw + return response, raw + + +def run_claude_code_exec( + *, + work_dir: str, + prompt: str, + model: str, + timeout: int, + images: list[str] | None = None, + data_dirs: list[str] | None = None, + allowed_tools: list[str] | str | None = None, + permission_mode: str | None = None, + allow_file_edits: bool = False, +) -> tuple[str, str]: + config = get_claude_code_exec_config() + mode = _sdk_mode(config.get("use_sdk")) + retries = int(config.get("empty_response_retries", 0) or 0) + last_response = "" + all_raw: list[str] = [] + + for attempt in range(retries + 1): + attempt_prompt = _retry_prompt(prompt, attempt) + if mode != "cli": + try: + response, raw = _run_claude_code_sdk_exec( + work_dir=work_dir, + prompt=attempt_prompt, + model=model, + timeout=timeout, + images=images, + data_dirs=data_dirs, + allowed_tools=allowed_tools, + permission_mode=permission_mode, + allow_file_edits=allow_file_edits, + ) + all_raw.append(f"===== CLAUDE SDK ATTEMPT {attempt + 1} =====\n{raw}") + if response.strip(): + combined = "\n\n".join(all_raw) + _persist_claude_artifacts(work_dir, combined, response) + return response, combined + except (ImportError, ModuleNotFoundError) as exc: + raw = _raw_exception("claude_code_sdk", exc) + all_raw.append(f"===== CLAUDE SDK ATTEMPT {attempt + 1} =====\n{raw}") + if mode == "sdk": + _persist_claude_artifacts(work_dir, "\n\n".join(all_raw), "") + raise + except Exception as exc: # noqa: BLE001 + raw = _raw_exception("claude_code_sdk", exc) + all_raw.append(f"===== CLAUDE SDK ATTEMPT {attempt + 1} =====\n{raw}") + if mode == "sdk" and attempt >= retries: + _persist_claude_artifacts(work_dir, "\n\n".join(all_raw), "") + raise + if mode != "sdk": + response, raw = _run_claude_code_cli_exec( + work_dir=work_dir, + prompt=attempt_prompt, + model=model, + timeout=timeout, + images=images, + data_dirs=data_dirs, + allowed_tools=allowed_tools, + permission_mode=permission_mode, + allow_file_edits=allow_file_edits, + ) + all_raw.append(f"===== CLAUDE CLI ATTEMPT {attempt + 1} =====\n{raw}") + last_response = response + if response.strip(): + combined = "\n\n".join(all_raw) + _persist_claude_artifacts(work_dir, combined, response) + return response, combined + + combined = "\n\n".join(all_raw) + _persist_claude_artifacts(work_dir, combined, last_response) + return last_response, combined + + +def _run_codex_sdk_exec( + *, + work_dir: str, + prompt: str, + model: str, + timeout: int, + images: list[str] | None = None, + data_dirs: list[str] | None = None, +) -> tuple[str, str]: + from openai_codex_sdk import Codex + + for data_dir in data_dirs or []: + _validate_exec_path(data_dir) + for image in images or []: + _validate_exec_path(os.path.dirname(image) or work_dir) + + async def _query() -> tuple[str, str]: + config = get_codex_exec_config() + reasoning_effort = str(config.get("reasoning_effort", "") or "").strip() + thread_options: dict[str, Any] = { + "working_directory": work_dir, + "skip_git_repo_check": True, + "sandbox_mode": str(config.get("sandbox") or "workspace-write"), + "network_access_enabled": bool(config.get("network_access", False)), + "web_search_enabled": bool(config.get("web_search", False)), + "approval_policy": str(config.get("approval_policy") or "never"), + } + if model: + thread_options["model"] = model + if data_dirs: + thread_options["additional_directories"] = data_dirs + if reasoning_effort and reasoning_effort != "none": + thread_options["model_reasoning_effort"] = reasoning_effort + + codex_options: dict[str, Any] = {"env": os.environ.copy()} + codex_path = str(config.get("path") or "").strip() + if codex_path: + codex_options["codexPathOverride"] = codex_path + codex = Codex(codex_options) + thread = codex.start_thread(thread_options) + turn = await thread.run(prompt, {"output_schema": _strict_schema(ANSWER_SCHEMA)}) + result_text = str(getattr(turn, "final_response", "") or "") + parsed: Any = None + parse_error = "" + response = "" + if result_text.strip(): + try: + parsed = json.loads(result_text) + response, parse_error = _structured_response(parsed) + except Exception as exc: # noqa: BLE001 + parse_error = f"{type(exc).__name__}: {exc}" + else: + parse_error = "No response from Codex SDK (final_response is empty)." + raw = _json_dumps({ + "backend": "codex_sdk", + "id": getattr(turn, "id", ""), + "thread_id": getattr(turn, "thread_id", ""), + "model": model, + "thread_options": thread_options, + "final_response": result_text, + "raw_structured_output": parsed, + "parse_error": parse_error, + "is_error": bool(parse_error), + "items": getattr(turn, "items", []), + }) + return response, raw + + return _run_async(asyncio.wait_for(_query(), timeout=timeout)) + + +def _run_codex_cli_exec( + *, + work_dir: str, + prompt: str, + model: str, + timeout: int, + images: list[str] | None = None, + data_dirs: list[str] | None = None, + sandbox: str | None = None, + full_auto: bool | None = None, +) -> tuple[str, str]: + config = get_codex_exec_config() + last_message_path = os.path.join(work_dir, "codex_last_message.txt") + cmd = [ + str(config["path"]), + "exec", + "--skip-git-repo-check", + "--color", + "never", + "-C", + work_dir, + ] + if config.get("profile"): + cmd.extend(["-p", str(config["profile"])]) + reasoning_effort = str(config.get("reasoning_effort", "")).strip() + if reasoning_effort: + cmd.extend(["-c", f'model_reasoning_effort="{reasoning_effort}"']) + actual_full_auto = bool(config.get("full_auto", True)) if full_auto is None else bool(full_auto) + actual_sandbox = str(sandbox or config["sandbox"]) + if actual_full_auto: + cmd.append("--full-auto") + else: + cmd.extend(["--sandbox", actual_sandbox]) + if model: + cmd.extend(["-m", model]) + for data_dir in data_dirs or []: + _validate_exec_path(data_dir) + for image in images or []: + _validate_exec_path(os.path.dirname(image) or work_dir) + cmd.extend(["-i", image]) + cmd.extend(["--output-last-message", last_message_path, prompt]) + + try: + proc = subprocess.run( + cmd, + cwd=work_dir, + capture_output=True, + text=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired as exc: + stdout = exc.stdout or "" + stderr = exc.stderr or "" + raw = stdout + if stderr: + raw = f"{raw}\n[stderr]\n{stderr}" if raw else stderr + _persist_codex_artifacts(work_dir, raw, "") + raise + try: + from skillopt.model import azure_openai as _openai + _openai.tracker.record("rollout", 0, 0) + except Exception: + pass + stdout = proc.stdout or "" + stderr = proc.stderr or "" + last_message = "" + if os.path.exists(last_message_path): + with open(last_message_path, encoding="utf-8") as f: + last_message = f.read() + raw = stdout + if stderr: + raw = f"{raw}\n[stderr]\n{stderr}" if raw else stderr + if proc.returncode != 0: + _persist_codex_artifacts(work_dir, raw, last_message) + detail = (stderr or stdout).strip() + raise RuntimeError( + f"codex exec failed with exit code {proc.returncode}: {detail[:4000]}" + ) + return last_message, raw + + +def run_codex_exec( + *, + work_dir: str, + prompt: str, + model: str, + timeout: int, + images: list[str] | None = None, + data_dirs: list[str] | None = None, + sandbox: str | None = None, + full_auto: bool | None = None, +) -> tuple[str, str]: + config = get_codex_exec_config() + mode = _sdk_mode(config.get("use_sdk")) + retries = int(config.get("empty_response_retries", 0) or 0) + last_response = "" + all_raw: list[str] = [] + + for attempt in range(retries + 1): + attempt_prompt = _retry_prompt(prompt, attempt) + if mode != "cli": + try: + response, raw = _run_codex_sdk_exec( + work_dir=work_dir, + prompt=attempt_prompt, + model=model, + timeout=timeout, + images=images, + data_dirs=data_dirs, + ) + all_raw.append(f"===== CODEX SDK ATTEMPT {attempt + 1} =====\n{raw}") + if response.strip(): + combined = "\n\n".join(all_raw) + _persist_codex_artifacts(work_dir, combined, response) + return response, combined + except (ImportError, ModuleNotFoundError) as exc: + raw = _raw_exception("codex_sdk", exc) + all_raw.append(f"===== CODEX SDK ATTEMPT {attempt + 1} =====\n{raw}") + if mode == "sdk": + _persist_codex_artifacts(work_dir, "\n\n".join(all_raw), "") + raise + except Exception as exc: # noqa: BLE001 + raw = _raw_exception("codex_sdk", exc) + all_raw.append(f"===== CODEX SDK ATTEMPT {attempt + 1} =====\n{raw}") + if mode == "sdk" and attempt >= retries: + _persist_codex_artifacts(work_dir, "\n\n".join(all_raw), "") + raise + if mode != "sdk": + response, raw = _run_codex_cli_exec( + work_dir=work_dir, + prompt=attempt_prompt, + model=model, + timeout=timeout, + images=images, + data_dirs=data_dirs, + sandbox=sandbox, + full_auto=full_auto, + ) + all_raw.append(f"===== CODEX CLI ATTEMPT {attempt + 1} =====\n{raw}") + last_response = response + if response.strip(): + combined = "\n\n".join(all_raw) + _persist_codex_artifacts(work_dir, combined, response) + return response, combined + + combined = "\n\n".join(all_raw) + _persist_codex_artifacts(work_dir, combined, last_response) + return last_response, combined + + +def run_student_exec( + *, + work_dir: str, + prompt: str, + model: str, + timeout: int, + images: list[str] | None = None, + data_dirs: list[str] | None = None, + allowed_tools: list[str] | str | None = None, + permission_mode: str | None = None, + sandbox: str | None = None, + full_auto: bool | None = None, + allow_file_edits: bool = False, +) -> tuple[str, str]: + backend = get_student_backend() + if backend == "codex_exec": + return run_codex_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + images=images, + data_dirs=data_dirs, + sandbox=sandbox, + full_auto=full_auto, + ) + if backend == "claude_code_exec": + return run_claude_code_exec( + work_dir=work_dir, + prompt=prompt, + model=model, + timeout=timeout, + images=images, + data_dirs=data_dirs, + allowed_tools=allowed_tools, + permission_mode=permission_mode, + allow_file_edits=allow_file_edits, + ) + raise ValueError(f"Unsupported exec backend: {backend}") diff --git a/skillopt/model/common.py b/skillopt/model/common.py new file mode 100644 index 0000000..9ab3c91 --- /dev/null +++ b/skillopt/model/common.py @@ -0,0 +1,222 @@ +"""Shared model utilities for ReflACT backends.""" +from __future__ import annotations + +import json +import threading +from dataclasses import dataclass, field +from typing import Any + + +_RESPONSES_API_MODELS = { + "gpt-5.3-codex", + "gpt-5.1-codex", + "gpt-5.2-codex", + "gpt-5-codex", + "codex-mini", + "gpt-5.4-pro", +} + +_BACKEND_DEFAULT_MODELS = { + "azure_openai": "gpt-5.5", + "openai_chat": "gpt-5.5", + "codex": "gpt-5.5", + "codex_exec": "gpt-5.5", + "claude": "claude-sonnet-4-6", + "claude_chat": "claude-sonnet-4-6", + "claude_code_exec": "claude-sonnet-4-6", +} + +_BACKEND_ALIASES = { + "azure": "azure_openai", + "azure_openai": "azure_openai", + "azure-openai": "azure_openai", + "openai_chat": "openai_chat", + "openai": "codex", + "codex": "codex", + "codex_exec": "codex_exec", + "claude": "claude_chat", + "claude_chat": "claude_chat", + "claude_code_exec": "claude_code_exec", + "anthropic": "claude_chat", +} + + +def normalize_backend_name(name: str | None) -> str: + normalized = str(name or "").strip().lower() + return _BACKEND_ALIASES.get(normalized, normalized or "azure_openai") + + +def default_model_for_backend(backend: str | None) -> str: + return _BACKEND_DEFAULT_MODELS.get( + normalize_backend_name(backend), + _BACKEND_DEFAULT_MODELS["azure_openai"], + ) + + +def needs_responses_api(model: str) -> bool: + normalized = str(model or "").strip().lower() + return any( + normalized == prefix or normalized.startswith(prefix + "-") + for prefix in _RESPONSES_API_MODELS + ) + + +class TokenTracker: + def __init__(self) -> None: + self._lock = threading.Lock() + self._data: dict[str, dict[str, int]] = {} + + def record(self, stage: str, prompt_tokens: int, completion_tokens: int) -> None: + with self._lock: + if stage not in self._data: + self._data[stage] = { + "calls": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + } + entry = self._data[stage] + entry["calls"] += 1 + entry["prompt_tokens"] += prompt_tokens + entry["completion_tokens"] += completion_tokens + + def summary(self) -> dict[str, dict[str, int]]: + with self._lock: + out: dict[str, dict[str, int]] = {} + total_prompt = total_completion = total_calls = 0 + for stage, entry in sorted(self._data.items()): + prompt_tokens = entry["prompt_tokens"] + completion_tokens = entry["completion_tokens"] + out[stage] = { + "calls": entry["calls"], + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + total_prompt += prompt_tokens + total_completion += completion_tokens + total_calls += entry["calls"] + out["_total"] = { + "calls": total_calls, + "prompt_tokens": total_prompt, + "completion_tokens": total_completion, + "total_tokens": total_prompt + total_completion, + } + return out + + def reset(self) -> None: + with self._lock: + self._data.clear() + + +tracker = TokenTracker() + + +@dataclass +class CompatToolFunction: + name: str + arguments: str + + def model_dump(self, mode: str = "json") -> dict[str, str]: + del mode + return { + "name": self.name, + "arguments": self.arguments, + } + + +@dataclass +class CompatToolCall: + id: str + function: CompatToolFunction + type: str = "function" + + def model_dump(self, mode: str = "json") -> dict[str, Any]: + del mode + return { + "id": self.id, + "type": self.type, + "function": self.function.model_dump(), + } + + +@dataclass +class CompatAssistantMessage: + content: str + tool_calls: list[CompatToolCall] = field(default_factory=list) + + def model_dump(self, mode: str = "json") -> dict[str, Any]: + del mode + data: dict[str, Any] = {"role": "assistant", "content": self.content} + if self.tool_calls: + data["tool_calls"] = [tool_call.model_dump() for tool_call in self.tool_calls] + return data + + +def usage_from_openai_usage(usage: Any) -> dict[str, int]: + if not usage: + return {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 + completion_tokens = getattr(usage, "completion_tokens", 0) or 0 + total_tokens = getattr(usage, "total_tokens", 0) or (prompt_tokens + completion_tokens) + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + + +def usage_from_responses_usage(usage: Any) -> dict[str, int]: + if not usage: + return {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + prompt_tokens = getattr(usage, "input_tokens", 0) or 0 + completion_tokens = getattr(usage, "output_tokens", 0) or 0 + return { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + + +def compat_message_from_chat_message(message: Any) -> CompatAssistantMessage: + content = getattr(message, "content", "") or "" + tool_calls = [] + for tool_call in getattr(message, "tool_calls", None) or []: + function = getattr(tool_call, "function", None) + tool_calls.append( + CompatToolCall( + id=getattr(tool_call, "id", "") or "", + function=CompatToolFunction( + name=getattr(function, "name", "") or "", + arguments=getattr(function, "arguments", "") or "{}", + ), + ) + ) + return CompatAssistantMessage(content=content, tool_calls=tool_calls) + + +def compat_message_from_responses_output(output: list[Any]) -> CompatAssistantMessage: + text_parts: list[str] = [] + tool_calls: list[CompatToolCall] = [] + for item in output: + item_type = getattr(item, "type", "") or "" + if item_type == "function_call": + raw_arguments = getattr(item, "arguments", None) + if raw_arguments is None: + raw_arguments = json.dumps(getattr(item, "input", {}) or {}) + tool_calls.append( + CompatToolCall( + id=getattr(item, "call_id", "") or getattr(item, "id", "") or "", + function=CompatToolFunction( + name=getattr(item, "name", "") or "", + arguments=str(raw_arguments or "{}"), + ), + ) + ) + continue + if item_type != "message": + continue + for part in getattr(item, "content", []) or []: + part_type = getattr(part, "type", "") or "" + if part_type in {"output_text", "text"}: + text_parts.append(getattr(part, "text", "") or "") + return CompatAssistantMessage(content="".join(text_parts), tool_calls=tool_calls) diff --git a/skillopt/model/router.py b/skillopt/model/router.py new file mode 100644 index 0000000..cf8179c --- /dev/null +++ b/skillopt/model/router.py @@ -0,0 +1,236 @@ +"""Runtime backend router for ReflACT model calls.""" +from __future__ import annotations + +import os +from typing import Any + +from . import azure_openai, claude_backend, codex_backend +from .common import normalize_backend_name + + +_ACTIVE_BACKEND = normalize_backend_name( + os.environ.get("REFLACT_MODEL_BACKEND", "azure_openai") +) + + +def _backend_module(name: str): + if name == "azure_openai": + return azure_openai + if name == "codex": + return codex_backend + if name == "claude": + return claude_backend + raise ValueError(f"Unknown backend: {name!r}") + + +def _all_backend_modules() -> list[Any]: + return [azure_openai, codex_backend, claude_backend] + + +def set_backend(name: str | None) -> str: + """Select the active model backend for subsequent calls.""" + global _ACTIVE_BACKEND + normalized = normalize_backend_name(name) + if normalized not in {"azure_openai", "codex", "claude"}: + valid = ", ".join(sorted({"azure_openai", "codex", "claude"})) + raise ValueError(f"Unknown backend {name!r}. Expected one of: {valid}") + _ACTIVE_BACKEND = normalized + os.environ["REFLACT_MODEL_BACKEND"] = normalized + return _ACTIVE_BACKEND + + +def get_backend_name() -> str: + return _ACTIVE_BACKEND + + +def chat_teacher( + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "teacher", + timeout: int | None = None, +) -> tuple[str, dict[str, int]]: + return _backend_module(_ACTIVE_BACKEND).chat_teacher( + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + timeout=timeout, + ) + + +def chat_student( + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "student", + timeout: int | None = None, +) -> tuple[str, dict[str, int]]: + return _backend_module(_ACTIVE_BACKEND).chat_student( + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + timeout=timeout, + ) + + +def chat_with_deployment( + deployment: str, + system: str, + user: str, + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "custom", + timeout: int | None = None, +) -> tuple[str, dict[str, int]]: + return _backend_module(_ACTIVE_BACKEND).chat_with_deployment( + deployment=deployment, + system=system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + timeout=timeout, + ) + + +def chat_teacher_messages( + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "teacher", + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict[str, int]]: + return _backend_module(_ACTIVE_BACKEND).chat_teacher_messages( + messages=messages, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def chat_student_messages( + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "student", + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict[str, int]]: + return _backend_module(_ACTIVE_BACKEND).chat_student_messages( + messages=messages, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def chat_messages_with_deployment( + deployment: str, + messages: list[dict[str, Any]], + max_completion_tokens: int = 16384, + retries: int = 5, + stage: str = "custom", + *, + tools: list[dict[str, Any]] | None = None, + tool_choice: str | dict[str, Any] | None = None, + return_message: bool = False, + timeout: int | None = None, +) -> tuple[Any, dict[str, int]]: + return _backend_module(_ACTIVE_BACKEND).chat_messages_with_deployment( + deployment=deployment, + messages=messages, + max_completion_tokens=max_completion_tokens, + retries=retries, + stage=stage, + tools=tools, + tool_choice=tool_choice, + return_message=return_message, + timeout=timeout, + ) + + +def get_token_summary() -> dict[str, dict[str, int]]: + return _backend_module(_ACTIVE_BACKEND).get_token_summary() + + +def reset_token_tracker() -> None: + _backend_module(_ACTIVE_BACKEND).reset_token_tracker() + + +def set_reasoning_effort(effort: str | None) -> None: + for module in _all_backend_modules(): + module.set_reasoning_effort(effort) + + +def set_student_deployment(deployment: str) -> None: + for module in _all_backend_modules(): + module.set_student_deployment(deployment) + + +def set_teacher_deployment(deployment: str) -> None: + for module in _all_backend_modules(): + module.set_teacher_deployment(deployment) + + +def configure_azure_openai( + *, + endpoint: str | None = None, + api_version: str | None = None, + api_key: str | None = None, + auth_mode: str | None = None, + ad_scope: str | None = None, + managed_identity_client_id: str | None = None, + teacher_endpoint: str | None = None, + teacher_api_version: str | None = None, + teacher_api_key: str | None = None, + teacher_auth_mode: str | None = None, + teacher_ad_scope: str | None = None, + teacher_managed_identity_client_id: str | None = None, + student_endpoint: str | None = None, + student_api_version: str | None = None, + student_api_key: str | None = None, + student_auth_mode: str | None = None, + student_ad_scope: str | None = None, + student_managed_identity_client_id: str | None = None, +) -> None: + azure_openai.configure_azure_openai( + endpoint=endpoint, + api_version=api_version, + api_key=api_key, + auth_mode=auth_mode, + ad_scope=ad_scope, + managed_identity_client_id=managed_identity_client_id, + teacher_endpoint=teacher_endpoint, + teacher_api_version=teacher_api_version, + teacher_api_key=teacher_api_key, + teacher_auth_mode=teacher_auth_mode, + teacher_ad_scope=teacher_ad_scope, + teacher_managed_identity_client_id=teacher_managed_identity_client_id, + student_endpoint=student_endpoint, + student_api_version=student_api_version, + student_api_key=student_api_key, + student_auth_mode=student_auth_mode, + student_ad_scope=student_ad_scope, + student_managed_identity_client_id=student_managed_identity_client_id, + ) diff --git a/skillopt/optimizer/__init__.py b/skillopt/optimizer/__init__.py new file mode 100644 index 0000000..aadc376 --- /dev/null +++ b/skillopt/optimizer/__init__.py @@ -0,0 +1,15 @@ +"""ReflACT Optimizer -- skill update operations. + +Analogous to the optimizer in neural network training: applies the computed +"gradient" (patches) to the current skill document to produce an updated +candidate skill. + +Modules +------- +- skill: edit application (optimizer.step() / parameter update) +- clip: edit ranking and selection (gradient clipping) +- meta_reflect: epoch-level macro refinement (momentum) +- slow_update: longitudinal comparison and guidance (EMA / regularization) +""" +from skillopt.optimizer.skill import apply_edit, apply_patch # noqa: F401 +from skillopt.optimizer.clip import rank_and_select # noqa: F401 diff --git a/skillopt/optimizer/clip.py b/skillopt/optimizer/clip.py new file mode 100644 index 0000000..6ecab30 --- /dev/null +++ b/skillopt/optimizer/clip.py @@ -0,0 +1,109 @@ +"""ReflACT gradient clipping — LLM-driven edit ranking and selection. + +Analogous to gradient clipping in neural network training: ranks candidate +edits by importance and selects the top-L to apply, controlling the +effective step size. Previously core/select.py. +""" +from __future__ import annotations + +from skillopt.model import chat_teacher +from skillopt.optimizer.meta_skill import format_meta_skill_context +from skillopt.optimizer.update_modes import ( + describe_item, + get_payload_items, + is_rewrite_mode, + normalize_update_mode, + payload_key, + payload_label, +) +from skillopt.prompts import load_prompt +from skillopt.utils import extract_json + + +# ── Public API ──────────────────────────────────────────────────────────────── + +def rank_and_select( + skill_content: str, + patch: dict, + max_edits: int, + meta_skill_context: str = "", + update_mode: str = "patch", +) -> dict: + """Use a teacher LLM to rank edits by importance, then keep top-L. + + If the edit pool is within budget, returns the patch unchanged. + Otherwise, calls the teacher to rank and select the most impactful edits. + + Parameters + ---------- + skill_content : str + Current skill document. + patch : dict + Merged :class:`~skillopt.types.Patch` dict with ``edits`` list. + max_edits : int + Maximum number of edits to keep (the "edit budget"). + + Returns + ------- + dict + :class:`~skillopt.types.Patch` dict with selected edits and + optional ``ranking_details``. + """ + update_mode = normalize_update_mode(update_mode) + edits = get_payload_items(patch, update_mode) + if len(edits) <= max_edits: + return patch + + # Build the edit pool description for the teacher + edits_desc = [] + for i, edit in enumerate(edits): + edits_desc.append(f"[{i}] {describe_item(edit, update_mode, max_chars=500)}") + + user = ( + f"## Current Skill\n{skill_content}\n\n" + f"## {payload_label(update_mode, title=True)} Pool ({len(edits)} {payload_label(update_mode)}, budget={max_edits})\n" + + "\n".join(edits_desc) + + f"\n\nSelect the {max_edits} most important {payload_label(update_mode)}. " + f"Return their 0-based indices in priority order." + ) + teacher_ctx = format_meta_skill_context(meta_skill_context) + if teacher_ctx: + user = f"{teacher_ctx}\n\n{user}" + prompt_name = "ranking_rewrite" if is_rewrite_mode(update_mode) else "ranking" + + try: + response, _ = chat_teacher( + system=load_prompt(prompt_name), user=user, + max_completion_tokens=2048, retries=3, stage="ranking", + ) + result = extract_json(response) + if result and "selected_indices" in result: + indices = result["selected_indices"] + selected = [] + seen: set[int] = set() + for idx in indices: + if ( + isinstance(idx, int) + and 0 <= idx < len(edits) + and idx not in seen + ): + selected.append(edits[idx]) + seen.add(idx) + if len(selected) >= max_edits: + break + if selected: + return { + "reasoning": patch.get("reasoning", "") + + f" [teacher-ranked: selected {len(selected)}/{len(edits)} {payload_label(update_mode)}]", + payload_key(update_mode): selected, + "ranking_details": result, + } + except Exception: # noqa: BLE001 + pass + + # Fallback: simple truncation + return { + "reasoning": patch.get("reasoning", "") + + f" [fallback truncated {len(edits)}->{max_edits} {payload_label(update_mode)}]", + payload_key(update_mode): edits[:max_edits], + } diff --git a/skillopt/optimizer/lr_autonomous.py b/skillopt/optimizer/lr_autonomous.py new file mode 100644 index 0000000..f8045e7 --- /dev/null +++ b/skillopt/optimizer/lr_autonomous.py @@ -0,0 +1,108 @@ +"""Teacher-driven autonomous update-size decisions.""" +from __future__ import annotations + +import json +import re +from typing import Any + +from skillopt.model import chat_teacher +from skillopt.optimizer.meta_skill import format_meta_skill_context +from skillopt.optimizer.update_modes import describe_item, get_payload_items, payload_label +from skillopt.prompts import load_prompt +from skillopt.utils import extract_json + + +def _coerce_nonnegative_int(value: Any) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return max(0, value) + if isinstance(value, float) and value.is_integer(): + return max(0, int(value)) + text = str(value or "").strip() + if not text: + return None + match = re.search(r"-?\d+", text) + if not match: + return None + return max(0, int(match.group(0))) + + +def decide_autonomous_learning_rate( + *, + skill_content: str, + merged_patch: dict, + update_mode: str, + rollout_hard: float, + rollout_soft: float, + rollout_n: int, + step_buffer_context: str = "", + meta_skill_context: str = "", +) -> dict: + """Ask the teacher to choose the number of update items for this step. + + The prompt intentionally avoids default budgets, candidate budget lists, or + scheduler history. The only hard post-processing is validity: the returned + integer is clamped to the available item count. + """ + items = get_payload_items(merged_patch, update_mode) + available = len(items) + item_lines = [ + f"[{idx}] {describe_item(item, update_mode, max_chars=700)}" + for idx, item in enumerate(items) + ] + user = ( + f"## Current Skill\n{skill_content}\n\n" + f"## Current Step Evidence\n" + f"rollout_n={rollout_n}\n" + f"rollout_hard={rollout_hard:.6f}\n" + f"rollout_soft={rollout_soft:.6f}\n" + f"proposed_update_items={available}\n" + f"update_item_type={payload_label(update_mode)}\n\n" + f"## Proposed Update Items\n" + + "\n".join(item_lines) + + "\n\nDecide how many proposed update items should be applied now." + ) + if step_buffer_context.strip(): + user += f"\n\n## Previous Steps in This Epoch\n{step_buffer_context}" + teacher_ctx = format_meta_skill_context(meta_skill_context) + if teacher_ctx: + user = f"{teacher_ctx}\n\n{user}" + + response = "" + parsed: dict | None = None + decision: int | None = None + try: + response, _ = chat_teacher( + system=load_prompt("lr_autonomous"), + user=user, + max_completion_tokens=2048, + retries=3, + stage="lr_autonomous", + ) + parsed = extract_json(response) + if parsed: + decision = _coerce_nonnegative_int(parsed.get("learning_rate")) + except Exception as exc: # noqa: BLE001 + parsed = {"error": str(exc)} + + fallback = False + if decision is None: + decision = 0 + fallback = True + + chosen = min(decision, available) + record = { + "learning_rate": chosen, + "raw_learning_rate": decision, + "available_update_items": available, + "clamped": chosen != decision, + "fallback": fallback, + "reasoning": (parsed or {}).get("reasoning", ""), + "confidence": (parsed or {}).get("confidence", ""), + "risk_notes": (parsed or {}).get("risk_notes", []), + "raw_response": response, + } + if parsed and "error" in parsed: + record["error"] = parsed["error"] + return record diff --git a/skillopt/optimizer/meta_reflect.py b/skillopt/optimizer/meta_reflect.py new file mode 100644 index 0000000..7c483e2 --- /dev/null +++ b/skillopt/optimizer/meta_reflect.py @@ -0,0 +1,198 @@ +"""ReflACT Meta-Reflect — epoch-level skill refinement with momentum. + +After each epoch, the meta-reflect stage reviews the epoch's step history +(applied edits + gate scores) and performs high-level skill edits: +merging redundant rules, removing ineffective ones, and distilling +cross-step strategic patterns. + +This is analogous to momentum in neural network optimization: +- Fast update (per step): analyst edits fix local issues from current batch +- Slow update (per epoch): meta-reflect refines the skill based on what + worked and what didn't across the full epoch + +The meta-reflect also maintains a ``meta_summary`` — a compact memory +passed between epochs that captures directional insights (which editing +directions are effective, which are not). This is the "momentum buffer". + +Public API +---------- +- :func:`build_epoch_history` — format an epoch's step records for meta-reflect +- :func:`run_meta_reflect` — one teacher call to produce high-level edits + meta_summary +""" +from __future__ import annotations + +import json +import os +import traceback + +from skillopt.model import chat_teacher +from skillopt.optimizer.update_modes import ( + describe_item, + get_payload_items, + normalize_update_mode, + payload_label, + truncate_payload, +) +from skillopt.prompts import load_prompt +from skillopt.utils import extract_json + + +# ── Epoch history formatting ───────────────────────────────────────────────── + + +def build_epoch_history( + epoch_step_records: list[dict], + out_root: str, + *, + update_mode: str = "patch", +) -> str: + """Format an epoch's step records into text for the meta-reflect teacher. + + For each step, includes the exact edits applied (read from + ``ranked_edits.json``) and the gate evaluation result. + + Parameters + ---------- + epoch_step_records : list[dict] + Step record dicts from ``history.json`` belonging to this epoch. + out_root : str + Training output root directory (to locate ``ranked_edits.json``). + + Returns + ------- + str + Formatted epoch history text. + """ + update_mode = normalize_update_mode(update_mode) + parts: list[str] = [] + for rec in epoch_step_records: + step = rec["step"] + action = rec.get("action", "unknown") + gate_score = rec.get("selection_hard", rec.get("current_score", "?")) + best_score = rec.get("best_score", "?") + + header = ( + f"### Step {step} — " + f"gate: {gate_score}, {action.upper()}, " + f"best_so_far: {best_score}" + ) + + # Read the actual applied edits + ranked_path = os.path.join( + out_root, "steps", f"step_{step:04d}", "ranked_edits.json", + ) + edits_text = "" + if os.path.exists(ranked_path): + try: + with open(ranked_path) as f: + ranked = json.load(f) + edits = get_payload_items(ranked, update_mode) + if edits: + lines = [f"Selected {payload_label(update_mode)}:"] + for i, edit in enumerate(edits, 1): + lines.append(f" {i}. {describe_item(edit, update_mode, max_chars=220)}") + edits_text = "\n".join(lines) + else: + edits_text = f"Selected {payload_label(update_mode)}: (none)" + except Exception: + edits_text = f"Selected {payload_label(update_mode)}: (could not read)" + else: + # Step may have been skipped + if "skip" in action: + edits_text = f"Selected {payload_label(update_mode)}: (skipped)" + else: + edits_text = f"Selected {payload_label(update_mode)}: (file not found)" + + parts.append(f"{header}\n{edits_text}") + + # Append trajectory failure digest if available + digest_path = os.path.join( + out_root, "steps", f"step_{step:04d}", "trajectory_digest.json", + ) + if os.path.exists(digest_path): + try: + with open(digest_path) as f: + digest = json.load(f) + patterns = digest.get("failure_patterns", []) + if patterns: + n_fail = digest.get("n_fail", "?") + n_total = digest.get("n_total", "?") + lines = [f"Failure patterns ({n_fail}/{n_total} tasks failed):"] + for p in patterns: + lines.append( + f' - "{p["pattern"]}" (×{p["count"]})' + ) + parts[-1] += "\n" + "\n".join(lines) + except Exception: + pass + + return "\n\n".join(parts) + + +# ── Meta-reflect teacher call ──────────────────────────────────────────────── + + +def run_meta_reflect( + skill_content: str, + epoch_history_text: str, + prev_meta_summary: str, + meta_edit_budget: int = 4, + *, + system_prompt: str | None = None, + update_mode: str = "patch", +) -> dict | None: + """Run one meta-reflect teacher call for an epoch. + + Parameters + ---------- + skill_content : str + Current skill document (after the epoch's fast updates). + epoch_history_text : str + Formatted epoch history from :func:`build_epoch_history`. + prev_meta_summary : str + Meta summary from the previous epoch ("" if first epoch). + meta_edit_budget : int + Maximum number of high-level edits. + system_prompt : str | None + Custom system prompt. ``None`` = use generic default. + + Returns + ------- + dict | None + Conforms to :class:`~skillopt.types.MetaReflectResult`: + ``"meta_summary"`` (str) and ``"patch"`` (:class:`~skillopt.types.Patch` + dict), or ``None`` on failure. + """ + mode = normalize_update_mode(update_mode) + actual_system = system_prompt if system_prompt is not None else load_prompt( + "meta_reflect_rewrite" if mode == "rewrite_from_suggestions" else "meta_reflect" + ) + + prev_section = prev_meta_summary.strip() if prev_meta_summary else "(First epoch — no previous summary)" + + user = ( + f"## Previous Meta Summary\n{prev_section}\n\n" + f"## Current Skill Document\n{skill_content}\n\n" + f"## {payload_label(mode, title=True)} Budget\n" + f"Produce at most {meta_edit_budget} high-level {payload_label(mode)}.\n\n" + f"## This Epoch's Step History\n{epoch_history_text}" + ) + + try: + response, _ = chat_teacher( + system=actual_system, + user=user, + max_completion_tokens=4096, + retries=3, + stage="meta_reflect", + ) + result = extract_json(response) + if result and "patch" in result: + truncate_payload(result["patch"], meta_edit_budget, mode) + if "meta_summary" not in result: + result["meta_summary"] = "" + return result + except Exception: # noqa: BLE001 + traceback.print_exc() + + return None diff --git a/skillopt/optimizer/meta_skill.py b/skillopt/optimizer/meta_skill.py new file mode 100644 index 0000000..04494c7 --- /dev/null +++ b/skillopt/optimizer/meta_skill.py @@ -0,0 +1,87 @@ +"""Teacher-side meta skill memory for cross-epoch optimization guidance. + +This module maintains a compact teacher-facing memory distilled from +adjacent-epoch skill comparisons. Unlike ``slow_update``, it does not +modify the student skill document. Instead, it produces guidance meant to +improve future teacher behavior when proposing, merging, and ranking edits. +""" +from __future__ import annotations + +import traceback + +from skillopt.model import chat_teacher +from skillopt.optimizer.slow_update import format_comparison_text +from skillopt.prompts import load_prompt +from skillopt.utils import extract_json + + +def format_meta_skill_context(meta_skill_content: str) -> str: + """Render teacher memory into a prompt-ready context block.""" + content = (meta_skill_content or "").strip() + if not content: + return "" + return ( + "## Teacher Meta Skill\n" + "This is teacher-side memory distilled from prior epoch transitions in " + "this environment. Use it to improve how you propose, merge, and rank " + "skill edits. Prefer it when the current evidence is ambiguous, but do " + "not force it if the current trajectories clearly contradict it.\n\n" + f"{content}" + ) + + +def run_meta_skill( + prev_skill: str, + curr_skill: str, + comparison_pairs: list[dict], + *, + prev_meta_skill_content: str = "", + system_prompt: str | None = None, +) -> dict | None: + """Produce updated teacher-side meta skill from adjacent epochs.""" + actual_system = system_prompt if system_prompt is not None else load_prompt("meta_skill") + + prev_skill_display = prev_skill + if len(prev_skill_display) > 6000: + prev_skill_display = prev_skill_display[:6000] + "\n...[truncated]..." + + curr_skill_display = curr_skill + if len(curr_skill_display) > 6000: + curr_skill_display = curr_skill_display[:6000] + "\n...[truncated]..." + + prev_meta_section = ( + prev_meta_skill_content.strip() + if prev_meta_skill_content and prev_meta_skill_content.strip() + else "(No previous teacher meta skill — this is the first update.)" + ) + + comparison_text = format_comparison_text(comparison_pairs) + user = ( + f"## Previous Epoch Last-Step Skill\n{prev_skill_display}\n\n" + f"## Current Epoch Last-Step Skill\n{curr_skill_display}\n\n" + f"## Previous Teacher Meta Skill\n" + f"The following teacher memory was available during the current epoch. " + f"Reflect on whether it improved or harmed the quality of edits.\n\n" + f"{prev_meta_section}\n\n" + f"## Longitudinal Comparison (same tasks, two last-step skills)\n" + f"{comparison_text}" + ) + + try: + response, _ = chat_teacher( + system=actual_system, + user=user, + max_completion_tokens=3072, + retries=3, + stage="meta_skill", + ) + result = extract_json(response) + if result and result.get("meta_skill_content"): + return { + "reasoning": str(result.get("reasoning", "")).strip(), + "meta_skill_content": str(result["meta_skill_content"]).strip(), + } + except Exception: # noqa: BLE001 + traceback.print_exc() + + return None diff --git a/skillopt/optimizer/rewrite.py b/skillopt/optimizer/rewrite.py new file mode 100644 index 0000000..23bf075 --- /dev/null +++ b/skillopt/optimizer/rewrite.py @@ -0,0 +1,59 @@ +"""Teacher-driven full skill rewrite from selected revise_suggestions.""" +from __future__ import annotations + +import json + +from skillopt.model import chat_teacher +from skillopt.prompts import load_prompt +from skillopt.optimizer.update_modes import get_payload_items +from skillopt.utils import extract_json + + +def rewrite_skill_from_suggestions( + skill_content: str, + patch: dict, + *, + system_prompt: str | None = None, + step_buffer_context: str = "", + env: str | None = None, + reasoning_effort: str | None = "high", + max_completion_tokens: int = 64000, +) -> dict | None: + suggestions = get_payload_items(patch, "rewrite_from_suggestions") + if not suggestions: + return None + + user = ( + f"## Current Skill\n{skill_content}\n\n" + f"## Selected Revise Suggestions ({len(suggestions)} total)\n" + f"{json.dumps(suggestions, ensure_ascii=False, indent=2)}\n\n" + ) + if step_buffer_context.strip(): + user += f"## Previous Steps in This Epoch\n{step_buffer_context}\n\n" + user += ( + "Rewrite the full skill document so it integrates the selected suggestions. " + "Return the complete new skill in `new_skill`." + ) + + actual_system = system_prompt if system_prompt is not None else load_prompt( + "rewrite_skill", env=env, + ) + + try: + response, _ = chat_teacher( + system=actual_system, + user=user, + max_completion_tokens=max_completion_tokens, + retries=3, + stage="rewrite", + reasoning_effort=reasoning_effort, + ) + result = extract_json(response) + if result and str(result.get("new_skill", "")).strip(): + result["new_skill"] = str(result["new_skill"]).rstrip() + "\n" + if "change_summary" not in result or not isinstance(result["change_summary"], list): + result["change_summary"] = [] + return result + except Exception: # noqa: BLE001 + return None + return None diff --git a/skillopt/optimizer/scheduler.py b/skillopt/optimizer/scheduler.py new file mode 100644 index 0000000..63b944e --- /dev/null +++ b/skillopt/optimizer/scheduler.py @@ -0,0 +1,127 @@ +"""Learning-rate (edit budget) schedulers for ReflACT. + +The "learning rate" in ReflACT is the maximum number of skill edits allowed +per optimization step. A scheduler controls how this budget changes over +the course of training. + +Supported modes +--------------- +- ``constant`` : Fixed budget throughout training. +- ``linear`` : Linear decay from ``max_lr`` to ``min_lr``. +- ``cosine`` : Cosine annealing from ``max_lr`` to ``min_lr``. +- ``autonomous`` : No limit — the model decides how many edits to make. + +Usage:: + + scheduler = build_scheduler(cfg) + for step in range(1, total_steps + 1): + lr = scheduler.step() # returns edit budget for this step + # ... use lr as max_edits ... +""" +from __future__ import annotations + +import math +from abc import ABC, abstractmethod + + +class LRScheduler(ABC): + """Base class for edit-budget schedulers.""" + + def __init__(self, max_lr: int, min_lr: int, total_steps: int) -> None: + self.max_lr = max_lr + self.min_lr = min_lr + self.total_steps = total_steps + self._current_step = 0 + + @abstractmethod + def _compute_lr(self, step: int) -> int: + """Return the edit budget for the given 1-indexed step.""" + + def step(self) -> int: + """Advance one step and return the edit budget.""" + self._current_step += 1 + return self._compute_lr(self._current_step) + + def get_lr(self, step: int) -> int: + """Return the edit budget for an arbitrary step (1-indexed).""" + return self._compute_lr(step) + + def state_dict(self) -> dict: + return {"current_step": self._current_step} + + def load_state_dict(self, state: dict) -> None: + self._current_step = state.get("current_step", 0) + + +class ConstantScheduler(LRScheduler): + """Fixed edit budget throughout training.""" + + def _compute_lr(self, step: int) -> int: + return self.max_lr + + +class LinearScheduler(LRScheduler): + """Linear decay from ``max_lr`` to ``min_lr`` over ``total_steps``.""" + + def _compute_lr(self, step: int) -> int: + if self.total_steps <= 1: + return self.max_lr + t = min(step, self.total_steps) / self.total_steps + lr = self.max_lr + (self.min_lr - self.max_lr) * t + return max(self.min_lr, round(lr)) + + +class CosineScheduler(LRScheduler): + """Cosine annealing from ``max_lr`` to ``min_lr`` over ``total_steps``.""" + + def _compute_lr(self, step: int) -> int: + if self.total_steps <= 1: + return self.max_lr + t = min(step, self.total_steps) / self.total_steps + lr = self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * t)) + return max(self.min_lr, round(lr)) + + +class AutonomousScheduler(LRScheduler): + """No edit limit — the model decides freely.""" + + NO_LIMIT = 999 + + def _compute_lr(self, step: int) -> int: + return self.NO_LIMIT + + +# ── Factory ────────────────────────────────────────────────────────────── + +_REGISTRY: dict[str, type[LRScheduler]] = { + "constant": ConstantScheduler, + "linear": LinearScheduler, + "cosine": CosineScheduler, + "autonomous": AutonomousScheduler, +} + + +def build_scheduler( + mode: str = "constant", + max_lr: int = 8, + min_lr: int = 2, + total_steps: int = 8, +) -> LRScheduler: + """Build a scheduler from config parameters. + + Parameters + ---------- + mode : str + One of ``constant``, ``linear``, ``cosine``, ``autonomous``. + max_lr : int + Initial / maximum edit budget. + min_lr : int + Minimum edit budget (for decay modes). + total_steps : int + Total number of optimization steps in training. + """ + if mode not in _REGISTRY: + raise ValueError( + f"Unknown scheduler mode '{mode}'. Available: {list(_REGISTRY.keys())}" + ) + return _REGISTRY[mode](max_lr=max_lr, min_lr=min_lr, total_steps=total_steps) diff --git a/skillopt/optimizer/select.py b/skillopt/optimizer/select.py new file mode 100644 index 0000000..fc49eeb --- /dev/null +++ b/skillopt/optimizer/select.py @@ -0,0 +1,4 @@ +"""Backward-compat stub — moved to skillopt.optimizer.clip.""" +from skillopt.optimizer.clip import rank_and_select # noqa: F401 + +__all__ = ["rank_and_select"] diff --git a/skillopt/optimizer/skill.py b/skillopt/optimizer/skill.py new file mode 100644 index 0000000..5230dda --- /dev/null +++ b/skillopt/optimizer/skill.py @@ -0,0 +1,154 @@ +"""ReflACT skill operations — edit application and patch processing. + +The Update stage (⑤) of the ReflACT pipeline: apply a ranked set of +edits to the current skill document, producing an updated candidate. +Analogous to optimizer.step() in neural network training. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from skillopt.types import Edit as EditType, Patch as PatchType + +SLOW_UPDATE_START = "" +SLOW_UPDATE_END = "" + + +def _is_in_slow_update_region(skill: str, target: str) -> bool: + """Check if *target* text falls within the protected slow update region.""" + start_idx = skill.find(SLOW_UPDATE_START) + end_idx = skill.find(SLOW_UPDATE_END) + if start_idx == -1 or end_idx == -1: + return False + target_idx = skill.find(target) + if target_idx == -1: + return False + region_end = end_idx + len(SLOW_UPDATE_END) + return start_idx <= target_idx < region_end + + +def _edit_fields(edit: EditType | dict) -> tuple[str, str, str]: + op = edit.op if hasattr(edit, "op") else edit.get("op", "") + content = (edit.content if hasattr(edit, "content") else edit.get("content", "")).strip() + target = edit.target if hasattr(edit, "target") else edit.get("target", "") + return op, content, target + + +def _apply_edit_with_report(skill: str, edit: EditType | dict) -> tuple[str, dict]: + op, content, target = _edit_fields(edit) + report = { + "op": op, + "target": target[:200], + "content_preview": content[:200], + "status": "unknown", + } + + if target and _is_in_slow_update_region(skill, target): + report["status"] = "skipped_protected_slow_update_region" + return skill, report + + if op == "append": + su_start = skill.find(SLOW_UPDATE_START) + if su_start != -1: + before = skill[:su_start].rstrip() + after = skill[su_start:] + report["status"] = "applied_append_before_slow_update" + return before + "\n\n" + content + "\n\n" + after, report + report["status"] = "applied_append" + return skill.rstrip() + "\n\n" + content + "\n", report + + if op == "insert_after": + if not target or target not in skill: + su_start = skill.find(SLOW_UPDATE_START) + if su_start != -1: + before = skill[:su_start].rstrip() + after = skill[su_start:] + report["status"] = "applied_insert_after_fallback_before_slow_update" + return before + "\n\n" + content + "\n\n" + after, report + report["status"] = "applied_insert_after_fallback_append" + return skill.rstrip() + "\n\n" + content + "\n", report + idx = skill.index(target) + len(target) + newline = skill.find("\n", idx) + insert_at = newline + 1 if newline != -1 else len(skill) + report["status"] = "applied_insert_after" + return skill[:insert_at] + "\n" + content + "\n" + skill[insert_at:], report + + if op == "replace": + if not target: + report["status"] = "skipped_replace_missing_target" + return skill, report + if target not in skill: + report["status"] = "skipped_replace_target_not_found" + return skill, report + report["status"] = "applied_replace" + return skill.replace(target, content, 1), report + + if op == "delete": + if not target: + report["status"] = "skipped_delete_missing_target" + return skill, report + if target not in skill: + report["status"] = "skipped_delete_target_not_found" + return skill, report + report["status"] = "applied_delete" + return skill.replace(target, "", 1), report + + report["status"] = "skipped_unknown_op" + return skill, report + + +def apply_edit(skill: str, edit: EditType | dict) -> str: + """Apply a single edit operation to the skill document. + + Parameters + ---------- + skill : str + Current skill document content. + edit : Edit | dict + An :class:`~skillopt.types.Edit` instance or a plain dict with + keys ``op``, ``content``, ``target``. + + Edits targeting the protected slow-update region are silently skipped. + """ + updated_skill, _ = _apply_edit_with_report(skill, edit) + return updated_skill + + +def apply_patch_with_report( + skill: str, + patch: PatchType | dict, +) -> tuple[str, list[dict]]: + """Apply a patch and return a per-edit report for observability.""" + edits = patch.edits if hasattr(patch, "edits") else patch.get("edits", []) + reports: list[dict] = [] + for idx, edit in enumerate(edits, 1): + try: + skill, report = _apply_edit_with_report(skill, edit) + report["index"] = idx + except Exception as exc: # noqa: BLE001 + report = { + "index": idx, + "op": "", + "target": "", + "content_preview": "", + "status": "error", + "error": str(exc), + } + reports.append(report) + return skill, reports + + +def apply_patch(skill: str, patch: PatchType | dict) -> str: + """Apply a patch (list of edits) to the skill document sequentially. + + Parameters + ---------- + skill : str + Current skill document content. + patch : Patch | dict + A :class:`~skillopt.types.Patch` instance or a plain dict with + key ``edits`` containing a list of edit operations. + """ + updated_skill, _ = apply_patch_with_report(skill, patch) + return updated_skill diff --git a/skillopt/optimizer/slow_update.py b/skillopt/optimizer/slow_update.py new file mode 100644 index 0000000..f28140b --- /dev/null +++ b/skillopt/optimizer/slow_update.py @@ -0,0 +1,374 @@ +"""ReflACT Slow Update — epoch-level longitudinal skill refinement. + +At the end of each epoch, the slow update compares rollout performance of the +same sample set under the previous epoch's skill vs. the current epoch's skill +(Markov: only adjacent epochs). A teacher analyzes regressions, improvements, +and persistent failures, then writes a free-form guidance block into a +**protected** section of the skill document. This section cannot be modified by +step-level analyst edits — only the slow update process overwrites it. + +Public API +---------- +- :func:`inject_empty_slow_update_field` — add empty placeholder (epoch 1) +- :func:`extract_slow_update_field` — read current content +- :func:`replace_slow_update_field` — overwrite content +- :func:`has_slow_update_field` — check if markers are present +- :func:`build_comparison_text` — format side-by-side rollout results +- :func:`run_slow_update` — teacher call to produce guidance +""" +from __future__ import annotations + +import json +import os +import traceback + +from skillopt.model import chat_teacher +from skillopt.prompts import load_prompt +from skillopt.utils import extract_json + +# ── Protected field markers ───────────────────────────────────────────────── + +SLOW_UPDATE_START = "" +SLOW_UPDATE_END = "" + +# ── Field manipulation helpers ────────────────────────────────────────────── + + +def has_slow_update_field(skill: str) -> bool: + return SLOW_UPDATE_START in skill and SLOW_UPDATE_END in skill + + +def inject_empty_slow_update_field(skill: str) -> str: + if has_slow_update_field(skill): + return skill + block = ( + f"\n\n{SLOW_UPDATE_START}\n" + f"{SLOW_UPDATE_END}\n" + ) + return skill.rstrip() + block + + +def extract_slow_update_field(skill: str) -> str: + start = skill.find(SLOW_UPDATE_START) + end = skill.find(SLOW_UPDATE_END) + if start == -1 or end == -1: + return "" + inner_start = start + len(SLOW_UPDATE_START) + return skill[inner_start:end].strip() + + +def replace_slow_update_field(skill: str, new_content: str) -> str: + start = skill.find(SLOW_UPDATE_START) + end = skill.find(SLOW_UPDATE_END) + if start == -1 or end == -1: + skill = inject_empty_slow_update_field(skill) + start = skill.find(SLOW_UPDATE_START) + end = skill.find(SLOW_UPDATE_END) + before = skill[:start + len(SLOW_UPDATE_START)] + after = skill[end:] + return before + "\n" + new_content.strip() + "\n" + after + + +# ── Comparison text builder ───────────────────────────────────────────────── + + +_MAX_TRAJ_CHARS = 3000 + + +def _clip_text(value, limit: int) -> str: + if value is None: + return "" + return str(value)[:limit] + + +def _read_trajectory(rollout_dir: str, task_id: str) -> str: + """Read and format a single trajectory from a rollout directory.""" + conv_path = os.path.join(rollout_dir, "predictions", task_id, "conversation.json") + if not os.path.exists(conv_path): + return "(trajectory not available)" + try: + with open(conv_path) as f: + conversation = json.load(f) + except Exception: + return "(trajectory read error)" + if not conversation: + return "(empty trajectory)" + + lines: list[str] = [] + for entry in conversation: + if not isinstance(entry, dict): + continue + if entry.get("type") == "tool_call": + cmd = _clip_text(entry.get("cmd"), 500) + obs = _clip_text(entry.get("obs"), 800) + lines.append(f"[action] {cmd}") + lines.append(f"[obs] {obs}") + elif "action" in entry and "env_feedback" in entry: + step = entry.get("step", "?") + reasoning = _clip_text(entry.get("reasoning"), 300) + action = _clip_text(entry.get("action"), 200) + feedback = _clip_text(entry.get("env_feedback"), 500) + if reasoning: + lines.append(f"[step {step} think] {reasoning}") + lines.append(f"[step {step} action] {action}") + lines.append(f"[step {step} obs] {feedback}") + elif entry.get("role") == "system": + msg = _clip_text(entry.get("content"), 1000) + lines.append(f"[verification] {msg}") + else: + msg = _clip_text(entry.get("content"), 500) + role = entry.get("role", "agent") + lines.append(f"[{role}] {msg}") + + text = "\n".join(lines) + if len(text) > _MAX_TRAJ_CHARS: + half = _MAX_TRAJ_CHARS // 2 + text = text[:half] + "\n...[truncated]...\n" + text[-half:] + return text + + +# ── Structured comparison pairs ───────────────────────────────────────────── + + +def build_comparison_pairs( + results_prev: list[dict], + results_curr: list[dict], + items: list[dict], + prev_rollout_dir: str = "", + curr_rollout_dir: str = "", +) -> list[dict]: + """Build a structured list of per-sample comparison entries. + + Each entry bundles the original item, both rollout results, the change + category, and both trajectories into one dict — the single source of + truth for this sample's longitudinal comparison. + + Returns + ------- + list[dict] + One dict per sample with keys: + ``id, task, category, prev, curr, prev_trajectory, curr_trajectory`` + """ + prev_by_id = {str(r["id"]): r for r in results_prev} + curr_by_id = {str(r["id"]): r for r in results_curr} + + pairs: list[dict] = [] + for item in items: + tid = str(item.get("id", "")) + prev = prev_by_id.get(tid, {}) + curr = curr_by_id.get(tid, {}) + prev_ok = bool(prev.get("hard", 0)) + curr_ok = bool(curr.get("hard", 0)) + + if not prev_ok and curr_ok: + category = "improved" + elif prev_ok and not curr_ok: + category = "regressed" + elif not prev_ok and not curr_ok: + category = "persistent_fail" + else: + category = "stable_success" + + pairs.append({ + "id": tid, + "task": item.get("question", item.get("task_description", item.get("instruction", tid))), + "category": category, + "prev": { + "hard": int(prev_ok), + "soft": float(prev.get("soft", 0.0)), + "predicted_answer": prev.get("predicted_answer", prev.get("answer", "N/A")), + "fail_reason": prev.get("fail_reason", ""), + }, + "curr": { + "hard": int(curr_ok), + "soft": float(curr.get("soft", 0.0)), + "predicted_answer": curr.get("predicted_answer", curr.get("answer", "N/A")), + "fail_reason": curr.get("fail_reason", ""), + }, + "prev_trajectory": ( + _read_trajectory(prev_rollout_dir, tid) if prev_rollout_dir else "" + ), + "curr_trajectory": ( + _read_trajectory(curr_rollout_dir, tid) if curr_rollout_dir else "" + ), + }) + + return pairs + + +def save_comparison_pairs(pairs: list[dict], out_path: str) -> None: + """Persist comparison pairs to JSON (without trajectory text to save space).""" + slim = [] + for p in pairs: + slim.append({ + "id": p["id"], + "task": p["task"][:300], + "category": p["category"], + "prev": p["prev"], + "curr": p["curr"], + }) + with open(out_path, "w") as f: + json.dump(slim, f, ensure_ascii=False, indent=2) + + +def format_comparison_text(pairs: list[dict]) -> str: + """Format structured comparison pairs into teacher-readable text.""" + by_cat: dict[str, list[dict]] = { + "regressed": [], + "persistent_fail": [], + "improved": [], + "stable_success": [], + } + for p in pairs: + by_cat.setdefault(p["category"], []).append(p) + + total = len(pairs) + parts = [ + f"## Longitudinal Comparison Summary\n" + f"Total samples: {total}\n" + f"- Improved (wrong→right): {len(by_cat['improved'])}\n" + f"- Regressed (right→wrong): {len(by_cat['regressed'])}\n" + f"- Persistent failures (wrong→wrong): {len(by_cat['persistent_fail'])}\n" + f"- Stable successes (right→right): {len(by_cat['stable_success'])}\n" + ] + + categories = [ + ("regressed", "Regressions (right→wrong) — HIGHEST PRIORITY", True), + ("persistent_fail", "Persistent Failures (wrong→wrong)", True), + ("improved", "Improvements (wrong→right)", True), + ("stable_success", "Stable Successes (right→right)", False), + ] + + for cat_key, label, show_traj in categories: + entries = by_cat[cat_key] + if not entries: + parts.append(f"### {label}\n(none)\n") + continue + + lines = [f"### {label}"] + for e in entries: + prev = e["prev"] + curr = e["curr"] + lines.append( + f"\n#### Task {e['id']}: {e['task'][:300]}\n" + f"- Prev epoch: {'PASS' if prev['hard'] else 'FAIL'} " + f"(soft={prev['soft']:.2f}) — answer: {str(prev['predicted_answer'])[:200]}\n" + f"- Curr epoch: {'PASS' if curr['hard'] else 'FAIL'} " + f"(soft={curr['soft']:.2f}) — answer: {str(curr['predicted_answer'])[:200]}" + ) + if curr.get("fail_reason"): + lines.append(f"- Curr fail reason: {curr['fail_reason'][:300]}") + if prev.get("fail_reason") and not prev["hard"]: + lines.append(f"- Prev fail reason: {prev['fail_reason'][:300]}") + + if show_traj: + if e.get("prev_trajectory"): + lines.append( + f"\n**Previous epoch trajectory:**\n```\n{e['prev_trajectory']}\n```" + ) + if e.get("curr_trajectory"): + lines.append( + f"\n**Current epoch trajectory:**\n```\n{e['curr_trajectory']}\n```" + ) + + parts.append("\n".join(lines)) + + return "\n\n".join(parts) + + + +# ── Teacher call ──────────────────────────────────────────────────────────── + + +def run_slow_update( + skill_content: str, + results_prev: list[dict], + results_curr: list[dict], + items: list[dict], + *, + prev_skill: str = "", + prev_slow_update_content: str = "", + prev_rollout_dir: str = "", + curr_rollout_dir: str = "", + comparison_pairs: list[dict] | None = None, + system_prompt: str | None = None, +) -> dict | None: + """Run the slow update teacher call for one epoch boundary. + + Parameters + ---------- + skill_content : str + Current epoch's skill (after fast updates). + results_prev : list[dict] + Rollout results of the 20 samples under previous epoch's skill. + results_curr : list[dict] + Rollout results of the 20 samples under current epoch's skill. + items : list[dict] + The 20 sample items used for comparison. + prev_skill : str + Previous epoch's skill content. + prev_slow_update_content : str + The slow update guidance from the previous epoch (to reflect on). + prev_rollout_dir : str + Path to previous epoch rollout output (contains predictions/). + curr_rollout_dir : str + Path to current epoch rollout output (contains predictions/). + system_prompt : str | None + Custom system prompt override. + + Returns + ------- + dict | None + Conforms to :class:`~skillopt.types.SlowUpdateResult`: + ``{"reasoning": str, "slow_update_content": str}`` or ``None``. + """ + actual_system = system_prompt if system_prompt is not None else load_prompt("slow_update") + + pairs = comparison_pairs + if pairs is None: + pairs = build_comparison_pairs( + results_prev, results_curr, items, + prev_rollout_dir=prev_rollout_dir, + curr_rollout_dir=curr_rollout_dir, + ) + comparison_text = format_comparison_text(pairs) + + prev_skill_display = prev_skill + if len(prev_skill_display) > 6000: + prev_skill_display = prev_skill_display[:6000] + "\n...[truncated]..." + + prev_guidance_section = ( + prev_slow_update_content.strip() + if prev_slow_update_content and prev_slow_update_content.strip() + else "(No previous guidance — this is the first slow update.)" + ) + + user = ( + f"## Previous Epoch's Skill\n{prev_skill_display}\n\n" + f"## Current Epoch's Skill\n{skill_content}\n\n" + f"## Previous Slow Update Guidance\n" + f"The following guidance was active during the current epoch. " + f"Reflect on its effectiveness before writing the new version.\n\n" + f"{prev_guidance_section}\n\n" + f"## Longitudinal Comparison (same 20 tasks, two skill versions)\n" + f"{comparison_text}" + ) + + try: + response, _ = chat_teacher( + system=actual_system, + user=user, + max_completion_tokens=4096, + retries=3, + stage="slow_update", + ) + result = extract_json(response) + if result and result.get("slow_update_content"): + return { + "reasoning": str(result.get("reasoning", "")).strip(), + "slow_update_content": str(result["slow_update_content"]).strip(), + } + except Exception: # noqa: BLE001 + traceback.print_exc() + + return None diff --git a/skillopt/optimizer/update_modes.py b/skillopt/optimizer/update_modes.py new file mode 100644 index 0000000..59dddda --- /dev/null +++ b/skillopt/optimizer/update_modes.py @@ -0,0 +1,136 @@ +"""Helpers for switching between patch edits and rewrite-from-suggestions.""" +from __future__ import annotations + +from typing import Any + +PATCH_MODE = "patch" +REWRITE_MODE = "rewrite_from_suggestions" +FULL_REWRITE_MINIBATCH_MODE = "full_rewrite_minibatch" + + +def normalize_update_mode(mode: str | None) -> str: + raw = str(mode or PATCH_MODE).strip().lower() + aliases = { + "patch": PATCH_MODE, + "edits": PATCH_MODE, + "rewrite": REWRITE_MODE, + "rewrite_from_suggestions": REWRITE_MODE, + "suggestions": REWRITE_MODE, + "rewrite_suggestions": REWRITE_MODE, + "full_rewrite": FULL_REWRITE_MINIBATCH_MODE, + "full_rewrite_minibatch": FULL_REWRITE_MINIBATCH_MODE, + "minibatch_full_rewrite": FULL_REWRITE_MINIBATCH_MODE, + "skill_rewrite_minibatch": FULL_REWRITE_MINIBATCH_MODE, + } + return aliases.get(raw, PATCH_MODE) + + +def is_rewrite_mode(mode: str | None) -> bool: + return normalize_update_mode(mode) == REWRITE_MODE + + +def is_full_rewrite_minibatch_mode(mode: str | None) -> bool: + return normalize_update_mode(mode) == FULL_REWRITE_MINIBATCH_MODE + + +def payload_key(mode: str | None) -> str: + if is_full_rewrite_minibatch_mode(mode): + return "skill_candidates" + return "revise_suggestions" if is_rewrite_mode(mode) else "edits" + + +def payload_label(mode: str | None, *, singular: bool = False, title: bool = False) -> str: + if is_full_rewrite_minibatch_mode(mode): + word = "skill candidate" if singular else "skill candidates" + elif is_rewrite_mode(mode): + word = "suggestion" if singular else "suggestions" + else: + word = "edit" if singular else "edits" + return word.title() if title else word + + +def get_payload_items(container: dict | None, mode: str | None) -> list[dict]: + if not isinstance(container, dict): + return [] + items = container.get(payload_key(mode), []) + return items if isinstance(items, list) else [] + + +def set_payload_items(container: dict, items: list[dict], mode: str | None) -> dict: + container[payload_key(mode)] = items + return container + + +def truncate_payload(container: dict, max_items: int, mode: str | None) -> dict: + if max_items < 0: + return container + items = get_payload_items(container, mode) + if len(items) > max_items: + set_payload_items(container, items[:max_items], mode) + return container + + +def describe_item(item: dict, mode: str | None, *, max_chars: int = 240) -> str: + if not isinstance(item, dict): + return "" + if is_full_rewrite_minibatch_mode(mode): + parts = [ + f"title={item.get('title', '')!r}", + f"change_summary={item.get('change_summary', [])!r}", + ] + if item.get("source_type"): + parts.append(f"source={item.get('source_type')}") + if item.get("support_count") is not None: + parts.append(f"support={item.get('support_count')}") + new_skill = str(item.get("new_skill", "")).strip() + if new_skill: + parts.append(f"new_skill_preview={new_skill[:120]!r}") + text = " ".join(parts) + elif is_rewrite_mode(mode): + parts = [ + f"type={item.get('type', '?')}", + f"title={item.get('title', '')!r}", + f"instruction={item.get('instruction', '')!r}", + ] + if item.get("priority_hint"): + parts.append(f"priority={item.get('priority_hint')}") + if item.get("support_count") is not None: + parts.append(f"support={item.get('support_count')}") + text = " ".join(parts) + else: + op = item.get("op", "?") + target = item.get("target", "") + content = item.get("content", "") + parts = [f"op={op}"] + if target: + parts.append(f"target={target!r}") + if content: + parts.append(f"content={content!r}") + if item.get("support_count") is not None: + parts.append(f"support={item.get('support_count')}") + text = " ".join(parts) + if len(text) <= max_chars: + return text + return text[: max_chars - 3].rstrip() + "..." + + +def short_item_summary(item: dict, mode: str | None, *, max_chars: int = 200) -> dict[str, Any]: + if is_full_rewrite_minibatch_mode(mode): + return { + "title": str(item.get("title", ""))[:max_chars], + "change_summary": [ + str(x)[:max_chars] for x in item.get("change_summary", [])[:3] + ] if isinstance(item.get("change_summary"), list) else [], + "source_type": item.get("source_type", ""), + } + if is_rewrite_mode(mode): + return { + "type": item.get("type", "?"), + "title": str(item.get("title", ""))[:max_chars], + "instruction": str(item.get("instruction", ""))[:max_chars], + } + return { + "op": item.get("op", "?"), + "content": str(item.get("content", ""))[:max_chars], + "target": item.get("target", ""), + } diff --git a/skillopt/prompts/__init__.py b/skillopt/prompts/__init__.py new file mode 100644 index 0000000..af1bc58 --- /dev/null +++ b/skillopt/prompts/__init__.py @@ -0,0 +1,63 @@ +"""Prompt loading utilities for ReflACT. + +Prompts are stored as ``.md`` files and loaded at runtime: + +- **Generic** prompts live in ``skillopt/prompts/*.md`` +- **Env-specific** prompts live in ``skillopt/envs//prompts/*.md`` + +``load_prompt(name, env)`` tries the env-specific path first, then falls +back to the generic default. +""" +from __future__ import annotations + +import os + +_PROMPTS_DIR = os.path.dirname(os.path.abspath(__file__)) +_REFLACT_DIR = os.path.dirname(_PROMPTS_DIR) + +_cache: dict[str, str] = {} + + +def _read_file(path: str) -> str | None: + if path in _cache: + return _cache[path] + if not os.path.isfile(path): + return None + with open(path, encoding="utf-8") as f: + content = f.read() + _cache[path] = content + return content + + +def load_prompt(name: str, env: str | None = None) -> str: + """Load a prompt by name with env-specific override and generic fallback. + + Lookup order: + 1. ``skillopt/envs/{env}/prompts/{name}.md`` (if *env* given) + 2. ``skillopt/prompts/{name}.md`` (generic default) + + Raises ``FileNotFoundError`` if neither path exists. + """ + if env is not None: + env_path = os.path.join(_REFLACT_DIR, "envs", env, "prompts", f"{name}.md") + content = _read_file(env_path) + if content is not None: + return content + + generic_path = os.path.join(_PROMPTS_DIR, f"{name}.md") + content = _read_file(generic_path) + if content is not None: + return content + + searched = [] + if env is not None: + searched.append(os.path.join("skillopt/envs", env, "prompts", f"{name}.md")) + searched.append(f"skillopt/prompts/{name}.md") + raise FileNotFoundError( + f"Prompt '{name}' not found. Searched: {', '.join(searched)}" + ) + + +def clear_cache() -> None: + """Clear the prompt file cache (useful for testing).""" + _cache.clear() diff --git a/skillopt/prompts/analyst_error.md b/skillopt/prompts/analyst_error.md new file mode 100644 index 0000000..af1c0c5 --- /dev/null +++ b/skillopt/prompts/analyst_error.md @@ -0,0 +1,41 @@ +You are an expert failure-analysis agent for AI agent tasks. + +You will be given MULTIPLE failed agent trajectories from a single minibatch +and the current skill document. +Your job is to identify the most important COMMON failure patterns across +the batch and propose a concise set of skill edits. + +## Analysis Process +1. Read ALL trajectories in the minibatch. +2. Identify the most prevalent, systematic failure patterns across them. +3. For each pattern, classify its failure type. +4. Propose skill edits that address the COMMON patterns — not individual edge cases. +5. Edits must be generalizable; do not hardcode task-specific values. +6. Only patch gaps in the skill — do not duplicate existing content. + +You will be told the maximum number of edits (the budget L). Produce AT MOST L edits, +focusing on the highest-impact patterns. You may produce fewer if warranted. + +Respond ONLY with a valid JSON object (no markdown fences, no extra text): +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +Only include edits that are needed. "edits" can be an empty list if no patch is warranted. + +IMPORTANT: The skill document may contain a section between + and markers. +This is a PROTECTED section managed by a separate slow-update process. +Do NOT propose any edits that target, modify, or delete content within +these markers. diff --git a/skillopt/prompts/analyst_error_full_rewrite.md b/skillopt/prompts/analyst_error_full_rewrite.md new file mode 100644 index 0000000..5d7e2c5 --- /dev/null +++ b/skillopt/prompts/analyst_error_full_rewrite.md @@ -0,0 +1,32 @@ +You will be given several failed agent trajectories from one minibatch and the current skill document. + +Summarize the lessons from these trajectories into one complete replacement skill document. + +When rewriting from a minibatch, use the current trajectories as the primary +evidence for updates. Preserve essential task-format instructions, but avoid mechanically carrying over +stale, redundant, or conflicting rules. Prefer a concise, coherent replacement +skill over a long document with weakly supported guidance. + +Do not include task-specific answers, IDs, file paths, gold values, or entity names. +If the skill contains a protected block between and +, keep that block unchanged. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "skill_candidates": [ + { + "title": "", + "change_summary": ["", ""], + "new_skill": "" + } + ] + } +} + +Return exactly one item in "skill_candidates". diff --git a/skillopt/prompts/analyst_error_rewrite.md b/skillopt/prompts/analyst_error_rewrite.md new file mode 100644 index 0000000..d8c9f21 --- /dev/null +++ b/skillopt/prompts/analyst_error_rewrite.md @@ -0,0 +1,44 @@ +You are an expert failure-analysis agent for AI agent tasks. + +You will be given MULTIPLE failed agent trajectories from a single minibatch +and the current skill document. +Your job is to identify the most important COMMON failure patterns across +the batch and propose a concise set of skill-revision suggestions. + +## Analysis Process +1. Read ALL trajectories in the minibatch. +2. Identify the most prevalent, systematic failure patterns across them. +3. For each pattern, classify its failure type. +4. Propose revision suggestions that address the COMMON patterns, not individual edge cases. +5. Suggestions must be generalizable and should help a later teacher rewrite the full skill document. +6. Do not hardcode task-specific values. + +You will be told the maximum number of suggestions (the budget L). Produce AT MOST L suggestions, +focusing on the highest-impact patterns. You may produce fewer if warranted. + +Respond ONLY with a valid JSON object (no markdown fences, no extra text): +{ + "batch_size": , + "failure_summary": [ + {"failure_type": "", "count": , "description": ""} + ], + "patch": { + "reasoning": "", + "revise_suggestions": [ + { + "type": "add_rule|remove_rule|merge_rules|reorganize|compress|clarify", + "title": "", + "motivation": "", + "instruction": "", + "priority_hint": "high|medium|low" + } + ] + } +} +"revise_suggestions" may be an empty list if no revision is warranted. + +IMPORTANT: The skill document may contain a section between + and markers. +This is a PROTECTED section managed by a separate slow-update process. +Do NOT propose suggestions that target, modify, or delete content within +these markers. diff --git a/skillopt/prompts/analyst_success.md b/skillopt/prompts/analyst_success.md new file mode 100644 index 0000000..f79336d --- /dev/null +++ b/skillopt/prompts/analyst_success.md @@ -0,0 +1,36 @@ +You are an expert success-pattern analyst for AI agents. + +You will be given MULTIPLE successful agent trajectories from a single minibatch +and the current skill document. Your job is to identify generalizable behavior +patterns that are COMMON across the batch and worth encoding in the skill. + +## Rules +- Only propose patches for patterns NOT already covered in the skill. +- Focus on patterns that appear across MULTIPLE trajectories in the batch. +- Be concise. Patterns must generalize beyond specific tasks. +- Prefer reinforcing existing sections over adding new top-level sections. + +You will be told the maximum number of edits (the budget L). Produce AT MOST L edits, +focusing on the most broadly applicable patterns. You may produce fewer if warranted. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +"edits" may be empty if the skill already covers all observed patterns. + +IMPORTANT: The skill document may contain a section between + and markers. +This is a PROTECTED section managed by a separate slow-update process. +Do NOT propose any edits that target, modify, or delete content within +these markers. diff --git a/skillopt/prompts/analyst_success_full_rewrite.md b/skillopt/prompts/analyst_success_full_rewrite.md new file mode 100644 index 0000000..eabfcf5 --- /dev/null +++ b/skillopt/prompts/analyst_success_full_rewrite.md @@ -0,0 +1,30 @@ +You will be given several successful agent trajectories from one minibatch and the current skill document. + +Summarize any useful lessons from these trajectories into one complete replacement skill document. + +When rewriting from a minibatch, use the current trajectories as the primary +evidence for updates. Preserve essential task-format instructions, but avoid mechanically carrying over +stale, redundant, or conflicting rules. Prefer a concise, coherent replacement +skill over a long document with weakly supported guidance. + +Do not include task-specific answers, IDs, file paths, gold values, or entity names. +If the skill contains a protected block between and +, keep that block unchanged. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "skill_candidates": [ + { + "title": "", + "change_summary": ["", ""], + "new_skill": "" + } + ] + } +} + +Return exactly one item in "skill_candidates". diff --git a/skillopt/prompts/analyst_success_rewrite.md b/skillopt/prompts/analyst_success_rewrite.md new file mode 100644 index 0000000..2bf7245 --- /dev/null +++ b/skillopt/prompts/analyst_success_rewrite.md @@ -0,0 +1,33 @@ +You are an expert success-pattern analyst for AI agent tasks. + +You will be given MULTIPLE successful agent trajectories from a single minibatch +and the current skill document. Your job is to identify broadly useful patterns +worth preserving in a later full-skill rewrite. + +## Rules +- Only propose revise_suggestions for patterns NOT already covered in the skill. +- Focus on patterns that appear across MULTIPLE trajectories in the batch. +- Keep suggestions general, concise, and rewrite-friendly. +- Prefer guidance that improves organization, clarity, or reusable behavior. + +You will be told the maximum number of suggestions (the budget L). Produce AT MOST L suggestions, +focusing on the most broadly applicable patterns. You may produce fewer if warranted. + +Respond ONLY with a valid JSON object: +{ + "batch_size": , + "success_patterns": ["", ""], + "patch": { + "reasoning": "", + "revise_suggestions": [ + { + "type": "add_rule|remove_rule|merge_rules|reorganize|compress|clarify", + "title": "", + "motivation": "", + "instruction": "", + "priority_hint": "high|medium|low" + } + ] + } +} +"revise_suggestions" may be empty if the skill already captures all useful patterns. diff --git a/skillopt/prompts/deep_probe.md b/skillopt/prompts/deep_probe.md new file mode 100644 index 0000000..a9d5fb2 --- /dev/null +++ b/skillopt/prompts/deep_probe.md @@ -0,0 +1,34 @@ +You are an expert diagnostic-probe designer for reflective skill learning. + +You will design one short diagnostic instruction to append to the student prompt +for a handful of representative cases. + +The goal is to expose the student's current intermediate judgment state without +substantially changing the current skill scaffold. + +## Hard Constraints +1. Do NOT substantially change the student's existing scaffold. +2. Do NOT prescribe a new multi-step solving procedure. +3. Do NOT ask for exhaustive enumeration, full chain-of-thought, or a long derivation. +4. Ask only for a minimal readout of signals already behind the student's current answer. +5. Keep the diagnostic block brief and structured. +6. The final answer must still be produced in .... +7. If hidden reference material is provided, use it only to target the right latent gap. +8. Never copy hidden reference content into the student-facing probe. + +## Good Probe Targets +- top candidate and runner-up +- decisive cue / decisive constraint +- why a runner-up was rejected +- counted unit / suspicious region / compared objects + +## Bad Probe Targets +- full proof or full chain-of-thought +- dumping every object, cell, or possibility +- imposing a brand-new solving algorithm + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "probe_instruction": "" +} diff --git a/skillopt/prompts/deep_probe_codex.md b/skillopt/prompts/deep_probe_codex.md new file mode 100644 index 0000000..b44aff7 --- /dev/null +++ b/skillopt/prompts/deep_probe_codex.md @@ -0,0 +1,35 @@ +You are an expert diagnostic-probe designer for codex-executed student trajectories. + +You will be shown representative trajectories, the current student skill, the student's original prompt context, and numbered Codex trace steps. +Some trajectories may also include a hidden Reference block. Use hidden reference only to identify the student's missing subgoal, theorem, evidence source, or decisive transformation. Do not reveal or paraphrase that reference directly to the student. + +Choose exactly one trajectory and one probe point. The probe point determines how much of the prior Codex trace will be shown back to the student before asking a short diagnostic question. + +## Hard Constraints +1. Do NOT reveal or paraphrase hidden reference content to the student. +2. Do NOT prescribe a new full solving procedure. +3. Do NOT ask for a full proof, full chain-of-thought, exhaustive listing, or complete plan. +4. Ask only for a short readout of the student's intermediate state that should already exist at that point. +5. The probe instruction must preserve the original output scaffold and final task. +6. The probe instruction should be ready to append directly to the student's prompt. + +## Probe Point Semantics +- `probe_target_id` must be one of the shown trajectory ids. +- `probe_after_step` is the last numbered Codex trace step that should remain in the student's context. +- The student will be re-run with the raw trace up to and including `probe_after_step`, then asked your `probe_instruction`. +- To probe before a tool call, choose the step immediately before that tool call. + +## Good Probe Targets +- next theorem / subgoal / evidence source +- strongest-vs-runner-up option distinction +- decisive constraint or transformation +- why a tempting alternative is being rejected +- what code region / spreadsheet region / image cue / passage evidence matters next + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "probe_target_id": "", + "probe_after_step": , + "probe_instruction": "" +} diff --git a/skillopt/prompts/lr_autonomous.md b/skillopt/prompts/lr_autonomous.md new file mode 100644 index 0000000..81d1bc0 --- /dev/null +++ b/skillopt/prompts/lr_autonomous.md @@ -0,0 +1,20 @@ +You are an update-size controller for a skill-learning system. + +You will receive: +1. The current skill document. +2. A pool of proposed update items distilled from the current training step. +3. Brief evidence about the current rollout and training step. + +Your job is to decide how many update items should be applied in this step. +Use only the evidence shown in the prompt. Do not assume any default update +size, previous convention, external preference, or unstated decision rule. + +Do not rank the update items. Only decide the count. + +Respond ONLY with a valid JSON object: +{ + "learning_rate": , + "reasoning": "", + "confidence": "low|medium|high", + "risk_notes": ["", "..."] +} diff --git a/skillopt/prompts/merge_failure.md b/skillopt/prompts/merge_failure.md new file mode 100644 index 0000000..e448999 --- /dev/null +++ b/skillopt/prompts/merge_failure.md @@ -0,0 +1,30 @@ +You are a skill-edit coordinator. You receive multiple independently-proposed patches +from FAILURE analysis of agent trajectories. Merge them into ONE coherent, non-redundant patch. + +Merge guidelines: +1. **Deduplicate**: keep the best-worded version of similar edits. +2. **Resolve conflicts**: if patches contradict on the same point, + choose the one with stronger justification or synthesize both. +3. **Preserve unique insights**: include all non-redundant corrective edits. +4. **Prevalent-pattern bias**: edits appearing consistently across multiple patches + address systematic failures — preserve them with HIGH priority. + Edits from only one patch may be discarded if task-specific. +5. **Independence**: no two edits in the merged patch may target the same text region. +6. **Support count**: for each merged edit, estimate how many source patches support it. +7. **PROTECTED SECTION**: The skill may contain a section between + and markers. + Do NOT merge or produce any edits that target content within these markers. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "edits": [ + { + "op": "append|insert_after|replace|delete", + "target": "", + "content": "", + "support_count": , + "source_type": "failure" + } + ] +} diff --git a/skillopt/prompts/merge_failure_full_rewrite.md b/skillopt/prompts/merge_failure_full_rewrite.md new file mode 100644 index 0000000..0b5c20b --- /dev/null +++ b/skillopt/prompts/merge_failure_full_rewrite.md @@ -0,0 +1,28 @@ +You will be given complete skill candidates written from failed trajectories and the current skill document. + +Combine them into one complete replacement skill document. + +When merging full-skill candidates, preserve essential task-format instructions, +but do not mechanically retain stale, redundant, or +conflicting rules. If candidates disagree, prefer the concise rule with clearer +trajectory support and better consistency with the replacement skill. + +Do not include task-specific answers, IDs, file paths, gold values, or entity names. +If the current skill contains a protected block between and +, keep that block unchanged. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "skill_candidates": [ + { + "title": "", + "change_summary": ["", ""], + "new_skill": "", + "support_count": , + "source_type": "failure" + } + ] +} + +Return exactly one item in "skill_candidates". diff --git a/skillopt/prompts/merge_failure_rewrite.md b/skillopt/prompts/merge_failure_rewrite.md new file mode 100644 index 0000000..6081b9f --- /dev/null +++ b/skillopt/prompts/merge_failure_rewrite.md @@ -0,0 +1,26 @@ +You are a skill-revision coordinator. You receive multiple independently-proposed +revision suggestion sets from FAILURE analysis of agent trajectories. Merge them +into ONE coherent, non-redundant set of revise_suggestions. + +Merge guidelines: +1. Deduplicate overlapping suggestions. +2. Resolve conflicts by keeping the more general, better-justified direction. +3. Preserve unique high-impact corrective insights. +4. Suggestions supported by many source patches should receive higher support_count. +5. The output suggestions should help a later teacher rewrite the full skill. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "revise_suggestions": [ + { + "type": "add_rule|remove_rule|merge_rules|reorganize|compress|clarify", + "title": "", + "motivation": "", + "instruction": "", + "priority_hint": "high|medium|low", + "support_count": , + "source_type": "failure" + } + ] +} diff --git a/skillopt/prompts/merge_final.md b/skillopt/prompts/merge_final.md new file mode 100644 index 0000000..5dd8be1 --- /dev/null +++ b/skillopt/prompts/merge_final.md @@ -0,0 +1,33 @@ +You are a skill-edit coordinator performing the FINAL merge. You receive two +pre-merged patch groups: +1. **Failure-driven patches** (corrective, high priority) +2. **Success-driven patches** (reinforcement, lower priority) + +Merge guidelines: +1. **FAILURE PATCHES TAKE PRIORITY**: the primary goal of skill reflection is to + fix failures. Failure-driven edits should be preserved unless they directly + conflict with a well-supported success pattern. +2. **Deduplicate**: if a failure edit and success edit cover the same point, + keep the failure version. +3. **Preserve success insights**: include success edits that cover patterns + NOT addressed by failure edits. +4. **Higher-level merges represent broader consensus**: edits that survived + previous merge rounds (higher level) should be given priority. +5. **Carry forward support_count and source_type for each edit.** +6. **PROTECTED SECTION**: The skill may contain a section between + and markers. + Do NOT merge or produce any edits that target content within these markers. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "edits": [ + { + "op": "append|insert_after|replace|delete", + "target": "", + "content": "", + "support_count": , + "source_type": "failure|success" + } + ] +} diff --git a/skillopt/prompts/merge_final_full_rewrite.md b/skillopt/prompts/merge_final_full_rewrite.md new file mode 100644 index 0000000..9976a46 --- /dev/null +++ b/skillopt/prompts/merge_final_full_rewrite.md @@ -0,0 +1,28 @@ +You will be given complete skill candidates and the current skill document. + +Combine them into one complete replacement skill document. + +When merging full-skill candidates, preserve essential task-format instructions, +but do not mechanically retain stale, redundant, or +conflicting rules. Prefer concise guidance with clear trajectory support and +better consistency with the replacement skill. + +Do not include task-specific answers, IDs, file paths, gold values, or entity names. +If the current skill contains a protected block between and +, keep that block unchanged. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "skill_candidates": [ + { + "title": "", + "change_summary": ["", ""], + "new_skill": "", + "support_count": , + "source_type": "failure|success|mixed" + } + ] +} + +Return exactly one item in "skill_candidates". diff --git a/skillopt/prompts/merge_final_rewrite.md b/skillopt/prompts/merge_final_rewrite.md new file mode 100644 index 0000000..88402a8 --- /dev/null +++ b/skillopt/prompts/merge_final_rewrite.md @@ -0,0 +1,25 @@ +You are a skill-revision coordinator performing the FINAL merge. You receive: +1. Failure-driven revise_suggestions (higher priority) +2. Success-driven revise_suggestions (lower priority) + +Merge guidelines: +1. Failure-driven suggestions take priority when they overlap. +2. Keep success-driven suggestions that add distinct value. +3. Prefer general, rewrite-friendly, non-redundant suggestions. +4. Carry forward support_count and source_type. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "revise_suggestions": [ + { + "type": "add_rule|remove_rule|merge_rules|reorganize|compress|clarify", + "title": "", + "motivation": "", + "instruction": "", + "priority_hint": "high|medium|low", + "support_count": , + "source_type": "failure|success" + } + ] +} diff --git a/skillopt/prompts/merge_success.md b/skillopt/prompts/merge_success.md new file mode 100644 index 0000000..a467bb1 --- /dev/null +++ b/skillopt/prompts/merge_success.md @@ -0,0 +1,28 @@ +You are a skill-edit coordinator. You receive multiple independently-proposed patches +from SUCCESS analysis of agent trajectories. Merge them into ONE coherent patch +that reinforces effective patterns. + +Merge guidelines: +1. **Deduplicate**: keep only the most generalizable version of similar patterns. +2. **Be conservative**: success-driven patches reinforce existing behavior. + Only include edits for patterns NOT already in the skill. +3. **Prevalent-pattern bias**: patterns seen across many successful trajectories + are most worth encoding. +4. **Support count**: estimate how many source patches support each merged edit. +5. **PROTECTED SECTION**: The skill may contain a section between + and markers. + Do NOT merge or produce any edits that target content within these markers. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "edits": [ + { + "op": "append|insert_after|replace|delete", + "target": "", + "content": "", + "support_count": , + "source_type": "success" + } + ] +} diff --git a/skillopt/prompts/merge_success_full_rewrite.md b/skillopt/prompts/merge_success_full_rewrite.md new file mode 100644 index 0000000..a508c0d --- /dev/null +++ b/skillopt/prompts/merge_success_full_rewrite.md @@ -0,0 +1,28 @@ +You will be given complete skill candidates written from successful trajectories and the current skill document. + +Combine them into one complete replacement skill document. + +When merging full-skill candidates, preserve essential task-format instructions, +but do not mechanically retain stale, redundant, or +conflicting rules. If candidates disagree, prefer the concise rule with clearer +trajectory support and better consistency with the replacement skill. + +Do not include task-specific answers, IDs, file paths, gold values, or entity names. +If the current skill contains a protected block between and +, keep that block unchanged. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "skill_candidates": [ + { + "title": "", + "change_summary": ["", ""], + "new_skill": "", + "support_count": , + "source_type": "success" + } + ] +} + +Return exactly one item in "skill_candidates". diff --git a/skillopt/prompts/merge_success_rewrite.md b/skillopt/prompts/merge_success_rewrite.md new file mode 100644 index 0000000..40e86ac --- /dev/null +++ b/skillopt/prompts/merge_success_rewrite.md @@ -0,0 +1,25 @@ +You are a skill-revision coordinator. You receive multiple independently-proposed +revision suggestion sets from SUCCESS analysis of agent trajectories. Merge them +into ONE coherent, non-redundant set of revise_suggestions. + +Merge guidelines: +1. Deduplicate overlapping success patterns. +2. Be conservative: only keep suggestions that reinforce useful behavior not already well-covered. +3. Suggestions supported by many source patches should receive higher support_count. +4. The output suggestions should help a later teacher rewrite the full skill. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "revise_suggestions": [ + { + "type": "add_rule|remove_rule|merge_rules|reorganize|compress|clarify", + "title": "", + "motivation": "", + "instruction": "", + "priority_hint": "high|medium|low", + "support_count": , + "source_type": "success" + } + ] +} diff --git a/skillopt/prompts/meta_reflect.md b/skillopt/prompts/meta_reflect.md new file mode 100644 index 0000000..83d4b74 --- /dev/null +++ b/skillopt/prompts/meta_reflect.md @@ -0,0 +1,63 @@ +You are a meta-analyst for an AI agent skill optimization system. + +Your role is fundamentally different from the per-step analyst: +- The per-step analyst sees agent trajectories and proposes local fixes. +- YOU see the results of multiple optimization steps and refine the skill + at a higher level, based on what actually worked and what didn't. + +You are the ONLY component that has access to the edit-to-outcome causal link: +you can see exactly which edits were applied and whether they improved or +degraded performance. Use this unique vantage point. + +## What You Receive + +1. **Previous Meta Summary** (empty for the first epoch): a compact memory + from the last epoch capturing directional insights. +2. **Current Skill Document**: the skill as it stands after this epoch. +3. **This Epoch's Step History**: for each step, the exact edits applied, + the gate score, and whether the update was accepted or rejected. + +## What You Produce + +1. **High-level edits** to the skill document: + - Merge redundant or overlapping rules that accumulated across steps + - Remove or revise rules associated with rejected steps (score drops) + - Strengthen or generalize rules associated with accepted steps (score gains) + - Reorganize for clarity if the document has become cluttered + - Add strategic-level insights that no single step could produce + +2. **Meta summary**: a compact summary of this epoch's key findings, to be + passed as context to the next epoch's meta-reflect. This should capture: + - Which editing directions proved effective (and why) + - Which directions proved harmful (and why) + - Current bottlenecks or areas of the skill that need attention + - Trends across steps (e.g., "scores plateau after step 2") + +## Guidelines + +- Your edits modify the SAME skill document that per-step edits modify. + There is no separate section — you operate on the full skill. +- Be conservative: the per-step process already optimized locally. + Your job is refinement, not revolution. +- Focus on edits that require cross-step perspective (merging, pruning, + pattern extraction). Don't duplicate what per-step analysts already do. +- The meta_summary should be concise (under 200 words). It is NOT written + into the skill — it is only passed to the next meta-reflect call. + +You will be told the maximum number of edits (the budget). Produce AT MOST +that many edits. You may produce fewer or zero if the skill is already clean. + +Respond ONLY with a valid JSON object (no markdown fences, no extra text): +{ + "meta_summary": "", + "patch": { + "reasoning": "", + "edits": [ + {"op": "append", "content": ""}, + {"op": "insert_after", "target": "", "content": ""}, + {"op": "replace", "target": "", "content": ""}, + {"op": "delete", "target": ""} + ] + } +} +"edits" may be empty if no refinement is warranted. diff --git a/skillopt/prompts/meta_reflect_rewrite.md b/skillopt/prompts/meta_reflect_rewrite.md new file mode 100644 index 0000000..92aad29 --- /dev/null +++ b/skillopt/prompts/meta_reflect_rewrite.md @@ -0,0 +1,28 @@ +You are a meta-analyst for an AI agent skill optimization system. + +You see the current skill and an epoch's step history. Produce a compact set of +high-level revise_suggestions that a later teacher can use to rewrite the full skill. + +Focus on: +- merging redundant rules +- removing low-value or harmful guidance +- extracting cross-step strategic patterns +- reorganizing the skill for clarity +- compressing clutter without losing proven behavior + +Respond ONLY with a valid JSON object: +{ + "meta_summary": "", + "patch": { + "reasoning": "", + "revise_suggestions": [ + { + "type": "add_rule|remove_rule|merge_rules|reorganize|compress|clarify", + "title": "", + "motivation": "", + "instruction": "", + "priority_hint": "high|medium|low" + } + ] + } +} diff --git a/skillopt/prompts/meta_skill.md b/skillopt/prompts/meta_skill.md new file mode 100644 index 0000000..a0d0778 --- /dev/null +++ b/skillopt/prompts/meta_skill.md @@ -0,0 +1,40 @@ +You are a teacher-coach for an AI agent skill optimization system. + +Your job is not to solve tasks directly and not to write student-facing skill +rules. Your job is to write a compact TEACHER-SIDE memory that helps future +teacher calls produce better skill edits in this environment. + +## What You Receive + +1. The previous epoch's last-step skill. +2. The current epoch's last-step skill. +3. A longitudinal comparison on the SAME sampled tasks under those two skills. +4. The previous teacher meta skill, if one existed. + +## Your Goal + +Write a concise meta skill that improves future teacher behavior in stages such +as failure analysis, success analysis, patch merging, and edit ranking. + +This meta skill should capture things like: +- Which kinds of edits tend to help in this environment. +- Which kinds of edits tend to be too vague, redundant, brittle, or harmful. +- What level of abstraction works best for rules here. +- What failure-repair patterns should be prioritized. +- What regression risks future teacher calls should guard against. + +## Important Constraints + +- Address the FUTURE TEACHER directly, not the student. +- Focus on how to write better edits and organize better skill updates. +- Use evidence from the adjacent-epoch comparison, not generic advice. +- Keep it compact and high-signal. Prefer a few durable principles. +- Revise or remove parts of the previous meta skill if they did not help. +- Do not output student-facing task instructions. +- Do not restate the whole skill; summarize editing strategy. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "meta_skill_content": "" +} diff --git a/skillopt/prompts/ranking.md b/skillopt/prompts/ranking.md new file mode 100644 index 0000000..4eb564c --- /dev/null +++ b/skillopt/prompts/ranking.md @@ -0,0 +1,20 @@ +You are an expert skill-optimization teacher. You receive a skill document and a pool +of proposed edits. Your job is to RANK the edits by importance and select the top ones. + +Ranking criteria (in order of priority): +1. **Systematic impact**: edits that address widespread, recurring failure patterns + across many tasks should rank highest. A rule that fixes 50%% of failures beats + one that fixes a single edge case. +2. **Complementarity**: edits that fill gaps in the current skill (not duplicate + existing content) rank higher. +3. **Generality**: edits phrased as general principles rank higher than those + tied to specific question types or entities. +4. **Actionability**: edits with clear, concrete guidance rank higher than vague advice. + +You will be told how many edits to select (the budget). + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "selected_indices": [<0-based indices of the top edits, in priority order>] +} diff --git a/skillopt/prompts/ranking_rewrite.md b/skillopt/prompts/ranking_rewrite.md new file mode 100644 index 0000000..ccbef5e --- /dev/null +++ b/skillopt/prompts/ranking_rewrite.md @@ -0,0 +1,15 @@ +You are an expert skill-optimization teacher. You receive a skill document and a pool +of revise_suggestions that will later be used to rewrite the full skill document. +Rank the suggestions by importance and select the top ones. + +Ranking criteria: +1. Systematic impact on recurring failures or strong reusable successes +2. Complementarity with the current skill +3. Rewrite utility: how much the suggestion helps a later teacher improve structure, clarity, or coverage +4. Generality and actionability + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "selected_indices": [<0-based indices in priority order>] +} diff --git a/skillopt/prompts/rewrite_skill.md b/skillopt/prompts/rewrite_skill.md new file mode 100644 index 0000000..2bd7203 --- /dev/null +++ b/skillopt/prompts/rewrite_skill.md @@ -0,0 +1,25 @@ +You are an expert skill-document rewriter for an AI agent training system. + +You will receive: +1. The current skill document +2. A selected set of revise_suggestions distilled from trajectory analysis + +Your job is to rewrite the FULL student skill document so it incorporates the +selected suggestions coherently. + +Hard requirements: +1. Produce a complete standalone skill document, not a patch. +2. Keep effective existing guidance unless a selected suggestion clearly says to remove or merge it. +3. Prefer consolidation and clarity over making the document longer. +4. Do not hardcode benchmark-specific answers, entity names, file paths, or gold values. +5. Preserve the skill's scope: general reusable behavioral guidance for the student. +6. Do not modify content inside the protected slow-update block between + and except to keep it intact. +7. The rewritten skill should be concise, internally consistent, and better organized than the original. + +Respond ONLY with a valid JSON object: +{ + "reasoning": "", + "change_summary": ["", ""], + "new_skill": "" +} diff --git a/skillopt/prompts/slow_update.md b/skillopt/prompts/slow_update.md new file mode 100644 index 0000000..d7274ea --- /dev/null +++ b/skillopt/prompts/slow_update.md @@ -0,0 +1,60 @@ +You are a strategic skill advisor for an AI agent optimization system. + +Your role is different from the per-step analyst. The per-step analyst sees +individual trajectories and proposes local patches. YOU see how the skill has +evolved across an entire epoch by comparing the SAME tasks under two consecutive +skill versions. This longitudinal view lets you identify systemic drift, +regressions, and persistent blind spots that step-level edits cannot catch. + +## What You Receive + +1. **Previous epoch's skill** and **current epoch's skill** — to see what changed. +2. **Longitudinal comparison** — the same 20 training tasks rolled out under + both skills, categorized into: regressions, persistent failures, + improvements, and stable successes. +3. **Previous slow update guidance** (if any) — the guidance you (or a prior + invocation of you) wrote at the end of the last epoch. This guidance was + active during the current epoch's step-level optimization. You must evaluate + whether it helped or hurt based on the longitudinal comparison results. + +## Your Process + +1. **Reflect on the previous guidance** (if provided): + - Which parts of the previous guidance were effective? (Evidence: tasks that + improved or stayed correct.) + - Which parts failed or backfired? (Evidence: regressions or persistent + failures that the guidance was supposed to address.) + - Were there blind spots the previous guidance missed entirely? + Include this reflection in your "reasoning" field. + +2. **Write updated guidance** that: + - Retains and strengthens parts of the previous guidance that proved effective. + - Revises or removes parts that were ineffective or counterproductive. + - Adds new instructions to address newly observed regressions and persistent + failures. + +## Output Requirements + +Write a **strategic guidance block** that will OVERWRITE the previous guidance +in the protected section of the skill document. This section is READ-ONLY to +all subsequent step-level optimization — only you can overwrite it at the next +epoch boundary. + +Your guidance must: +- Be written as **direct, actionable instructions** to the student model + (the AI agent that will read and follow the skill). +- Focus on helping the student get problems RIGHT — not on analysis or + explanation of what went wrong. +- Prioritize: (1) preventing regressions, (2) fixing persistent failures, + (3) reinforcing successful patterns. +- Be concise but comprehensive — you have no length limit, but every sentence + should earn its place. +- NOT duplicate content already in the main skill body — complement it. +- Address the student directly (e.g., "When you encounter X, always do Y" + rather than "The agent should..."). + +Respond ONLY with a valid JSON object (no markdown fences, no extra text): +{ + "reasoning": "", + "slow_update_content": "" +} diff --git a/skillopt/scheduler/__init__.py b/skillopt/scheduler/__init__.py new file mode 100644 index 0000000..9378ee6 --- /dev/null +++ b/skillopt/scheduler/__init__.py @@ -0,0 +1,8 @@ +"""ReflACT Scheduler -- edit budget and learning rate scheduling. + +Analogous to learning rate schedulers (cosine annealing, step decay, warmup) +in neural network training. Controls how the edit_budget evolves over the +course of training. + +Placeholder for future implementations. +""" diff --git a/skillopt/types.py b/skillopt/types.py new file mode 100644 index 0000000..868fc0f --- /dev/null +++ b/skillopt/types.py @@ -0,0 +1,357 @@ +"""Standardized I/O types for the ReflACT pipeline. + +Shared dataclass definitions for the 6-stage per-step pipeline +and the 2 epoch-level stages. All types support round-trip +conversion to/from plain dicts for incremental adoption. + +Re-exports +---------- +GateResult, GateAction — from skillopt.evaluation.gate +BatchSpec — from skillopt.datasets.base +""" +from __future__ import annotations + +from dataclasses import dataclass, field, fields as dc_fields +from typing import Any, Literal + +from skillopt.evaluation.gate import GateAction, GateResult # noqa: F401 +from skillopt.datasets.base import BatchSpec # noqa: F401 + + +# ── Atomic types ───────────────────────────────────────────────────────── + +EditOp = Literal["append", "insert_after", "replace", "delete"] + + +@dataclass +class Edit: + """A single edit operation on a skill document. + + Used across Reflect → Aggregate → Select → Update → MetaReflect. + """ + + op: EditOp + content: str = "" + target: str = "" + support_count: int | None = None + source_type: Literal["failure", "success"] | None = None + merge_level: int | None = None + update_origin: str = "" + update_target: str = "" + + @classmethod + def from_dict(cls, d: dict) -> Edit: + return cls( + op=d.get("op", "append"), + content=d.get("content", ""), + target=d.get("target", ""), + support_count=d.get("support_count"), + source_type=d.get("source_type"), + merge_level=d.get("merge_level"), + update_origin=d.get("update_origin", ""), + update_target=d.get("update_target", ""), + ) + + def to_dict(self) -> dict: + d: dict[str, Any] = {"op": self.op, "content": self.content} + if self.target: + d["target"] = self.target + if self.support_count is not None: + d["support_count"] = self.support_count + if self.source_type is not None: + d["source_type"] = self.source_type + if self.merge_level is not None: + d["merge_level"] = self.merge_level + if self.update_origin: + d["update_origin"] = self.update_origin + if self.update_target: + d["update_target"] = self.update_target + return d + + +@dataclass +class Patch: + """A set of edits with reasoning. + + Output of Aggregate (③), Select (④); input to Update (⑤). + """ + + edits: list[Edit] = field(default_factory=list) + reasoning: str = "" + ranking_details: dict[str, Any] | None = None + + @classmethod + def from_dict(cls, d: dict) -> Patch: + edits_raw = d.get("edits", []) + return cls( + edits=[Edit.from_dict(e) if isinstance(e, dict) else e for e in edits_raw], + reasoning=d.get("reasoning", ""), + ranking_details=d.get("ranking_details"), + ) + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "reasoning": self.reasoning, + "edits": [e.to_dict() if isinstance(e, Edit) else e for e in self.edits], + } + if self.ranking_details is not None: + d["ranking_details"] = self.ranking_details + return d + + +# ── Stage ① ROLLOUT ────────────────────────────────────────────────────── + +@dataclass +class RolloutResult: + """Result of a single episode/task rollout. + + Universal fields are required; env-specific fields live in ``extras``. + """ + + id: str + hard: int + soft: float + n_turns: int = 0 + fail_reason: str = "" + task_type: str = "" + task_description: str = "" + predicted_answer: str = "" + question: str = "" + reference_text: str = "" + student_system_prompt: str = "" + student_user_prompt: str = "" + spreadsheet_preview: str = "" + extras: dict[str, Any] = field(default_factory=dict) + + _KNOWN_FIELDS: frozenset[str] | None = field( + default=None, init=False, repr=False, compare=False, # type: ignore[assignment] + ) + + @classmethod + def _get_known_fields(cls) -> frozenset[str]: + if cls._KNOWN_FIELDS is None: + cls._KNOWN_FIELDS = frozenset( + f.name for f in dc_fields(cls) + if f.name != "_KNOWN_FIELDS" + ) + return cls._KNOWN_FIELDS + + @classmethod + def from_dict(cls, d: dict) -> RolloutResult: + known = cls._get_known_fields() + extras = {k: v for k, v in d.items() if k not in known} + return cls( + id=str(d.get("id", "")), + hard=int(d.get("hard", 0)), + soft=float(d.get("soft", 0.0)), + n_turns=int(d.get("n_turns", 0)), + fail_reason=str(d.get("fail_reason", "")), + task_type=str(d.get("task_type", "")), + task_description=str(d.get("task_description", "")), + predicted_answer=str(d.get("predicted_answer", "")), + question=str(d.get("question", "")), + reference_text=str(d.get("reference_text", "")), + student_system_prompt=str(d.get("student_system_prompt", "")), + student_user_prompt=str(d.get("student_user_prompt", "")), + spreadsheet_preview=str(d.get("spreadsheet_preview", "")), + extras=extras, + ) + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "id": self.id, + "hard": self.hard, + "soft": self.soft, + } + for attr in ( + "n_turns", "fail_reason", "task_type", "task_description", + "predicted_answer", "question", "reference_text", + "student_system_prompt", "student_user_prompt", + "spreadsheet_preview", + ): + val = getattr(self, attr) + if val: + d[attr] = val + d.update(self.extras) + return d + + +# ── Stage ② REFLECT ────────────────────────────────────────────────────── + +@dataclass +class FailureSummaryEntry: + """One entry in the failure summary produced by error analysts.""" + + failure_type: str + count: int = 0 + description: str = "" + + @classmethod + def from_dict(cls, d: dict) -> FailureSummaryEntry: + return cls( + failure_type=d.get("failure_type", ""), + count=int(d.get("count", 0)), + description=d.get("description", ""), + ) + + def to_dict(self) -> dict: + return { + "failure_type": self.failure_type, + "count": self.count, + "description": self.description, + } + + +@dataclass +class RawPatch: + """Analyst output from the Reflect stage — a patch with provenance. + + Wraps the dict produced by ``run_error_analyst_minibatch`` + and ``run_success_analyst_minibatch``. + """ + + patch: Patch + source_type: Literal["failure", "success"] = "failure" + batch_size: int = 0 + failure_summary: list[FailureSummaryEntry] = field(default_factory=list) + + @classmethod + def from_dict(cls, d: dict | None) -> RawPatch | None: + if d is None: + return None + inner = d.get("patch", d) + if not isinstance(inner, dict): + return None + patch = Patch.from_dict(inner) + return cls( + patch=patch, + source_type=d.get("source_type", "failure"), + batch_size=int(d.get("batch_size", 0)), + failure_summary=[ + FailureSummaryEntry.from_dict(fs) + for fs in d.get("failure_summary", []) + ], + ) + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "patch": self.patch.to_dict(), + "source_type": self.source_type, + "batch_size": self.batch_size, + } + if self.failure_summary: + d["failure_summary"] = [fs.to_dict() for fs in self.failure_summary] + return d + + +# ── Epoch-level: META_REFLECT ──────────────────────────────────────────── + +@dataclass +class MetaReflectResult: + """Output of the epoch-level meta-reflect stage (momentum).""" + + meta_summary: str + patch: Patch + action: str = "" + gate_score: float | None = None + time_s: float | None = None + candidate_hash: str = "" + update_origin: str = "" + update_target: str = "" + + @classmethod + def from_dict(cls, d: dict | None) -> MetaReflectResult | None: + if d is None: + return None + patch_raw = d.get("patch", {}) + return cls( + meta_summary=d.get("meta_summary", ""), + patch=Patch.from_dict(patch_raw) if isinstance(patch_raw, dict) else Patch(), + action=d.get("action", ""), + gate_score=d.get("gate_score"), + time_s=d.get("time_s"), + candidate_hash=d.get("candidate_hash", ""), + update_origin=d.get("update_origin", ""), + update_target=d.get("update_target", ""), + ) + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "meta_summary": self.meta_summary, + "patch": self.patch.to_dict(), + } + if self.action: + d["action"] = self.action + if self.gate_score is not None: + d["gate_score"] = self.gate_score + if self.time_s is not None: + d["time_s"] = self.time_s + if self.candidate_hash: + d["candidate_hash"] = self.candidate_hash + if self.update_origin: + d["update_origin"] = self.update_origin + if self.update_target: + d["update_target"] = self.update_target + return d + + +# ── Epoch-level: SLOW_UPDATE ───────────────────────────────────────────── + +@dataclass +class SlowUpdateResult: + """Output of the epoch-level slow update stage (EMA / regularization).""" + + reasoning: str = "" + slow_update_content: str = "" + action: str = "" + time_s: float | None = None + prev_hard: float | None = None + curr_hard: float | None = None + selection_hard: float | None = None + selection_soft: float | None = None + candidate_hash: str = "" + update_origin: str = "" + update_target: str = "" + + @classmethod + def from_dict(cls, d: dict | None) -> SlowUpdateResult | None: + if d is None: + return None + return cls( + reasoning=d.get("reasoning", ""), + slow_update_content=d.get("slow_update_content", ""), + action=d.get("action", ""), + time_s=d.get("time_s"), + prev_hard=d.get("prev_hard"), + curr_hard=d.get("curr_hard"), + selection_hard=d.get("selection_hard"), + selection_soft=d.get("selection_soft"), + candidate_hash=d.get("candidate_hash", ""), + update_origin=d.get("update_origin", ""), + update_target=d.get("update_target", ""), + ) + + def to_dict(self) -> dict: + d: dict[str, Any] = { + "reasoning": self.reasoning, + "slow_update_content": self.slow_update_content, + } + if self.action: + d["action"] = self.action + if self.time_s is not None: + d["time_s"] = self.time_s + if self.prev_hard is not None: + d["prev_hard"] = self.prev_hard + if self.curr_hard is not None: + d["curr_hard"] = self.curr_hard + if self.selection_hard is not None: + d["selection_hard"] = self.selection_hard + if self.selection_soft is not None: + d["selection_soft"] = self.selection_soft + if self.candidate_hash: + d["candidate_hash"] = self.candidate_hash + if self.update_origin: + d["update_origin"] = self.update_origin + if self.update_target: + d["update_target"] = self.update_target + return d diff --git a/skillopt/utils/__init__.py b/skillopt/utils/__init__.py new file mode 100644 index 0000000..d59520e --- /dev/null +++ b/skillopt/utils/__init__.py @@ -0,0 +1,4 @@ +"""ReflACT utilities — JSON extraction, scoring, hashing.""" + +from skillopt.utils.json_utils import extract_json, extract_json_array # noqa: F401 +from skillopt.utils.scoring import compute_score, skill_hash # noqa: F401 diff --git a/skillopt/utils/json_utils.py b/skillopt/utils/json_utils.py new file mode 100644 index 0000000..011241b --- /dev/null +++ b/skillopt/utils/json_utils.py @@ -0,0 +1,42 @@ +"""JSON extraction helpers for LLM responses.""" +from __future__ import annotations + +import json +import re + + +def extract_json(text: str) -> dict | None: + """Extract a JSON object from LLM response text. + + Tries ```json fences first, then bare {...} patterns. + """ + m = re.search(r"```json\s*(.*?)```", text, re.DOTALL) + if m: + try: + return json.loads(m.group(1)) + except json.JSONDecodeError: + pass + m = re.search(r"\{.*\}", text, re.DOTALL) + if m: + try: + return json.loads(m.group(0)) + except json.JSONDecodeError: + pass + return None + + +def extract_json_array(text: str) -> list | None: + """Extract a JSON array from LLM response text.""" + m = re.search(r"```json\s*(.*?)```", text, re.DOTALL) + if m: + try: + return json.loads(m.group(1)) + except json.JSONDecodeError: + pass + m = re.search(r"\[.*\]", text, re.DOTALL) + if m: + try: + return json.loads(m.group(0)) + except json.JSONDecodeError: + pass + return None diff --git a/skillopt/utils/scoring.py b/skillopt/utils/scoring.py new file mode 100644 index 0000000..7a1c375 --- /dev/null +++ b/skillopt/utils/scoring.py @@ -0,0 +1,29 @@ +"""Scoring and hashing utilities.""" +from __future__ import annotations + +import hashlib + + +def compute_score(results: list) -> tuple[float, float]: + """Compute hard and soft accuracy from a list of episode results. + + Accepts both plain dicts and :class:`~skillopt.types.RolloutResult` + instances. + """ + if not results: + return 0.0, 0.0 + + def _hard(r: object) -> int: + return int(r.hard if hasattr(r, "hard") else r.get("hard", 0)) # type: ignore[union-attr] + + def _soft(r: object) -> float: + return float(r.soft if hasattr(r, "soft") else r.get("soft", 0.0)) # type: ignore[union-attr] + + hard = sum(_hard(r) for r in results) / len(results) + soft = sum(_soft(r) for r in results) / len(results) + return hard, soft + + +def skill_hash(content: str) -> str: + """Return a short deterministic hash of skill content (for caching).""" + return hashlib.sha256(content.encode()).hexdigest()[:16]