mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
Initial commit
This commit is contained in:
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
data/
|
||||
outputs/
|
||||
logs/
|
||||
external/
|
||||
/BabyVision/
|
||||
/MMRB/
|
||||
/SpreadsheetBench/
|
||||
/dl4ir-searchQA/
|
||||
configs/local/
|
||||
configs/**/*.local.yaml
|
||||
*.local.md
|
||||
.secrets/
|
||||
.codex_azure*/
|
||||
557
README.md
Normal file
557
README.md
Normal file
@@ -0,0 +1,557 @@
|
||||
# ReflACT: Reflective Agent Tuning
|
||||
|
||||
ReflACT is a framework for optimizing an external skill document through iterative rollout, reflection, editing, and gated validation.
|
||||
|
||||
It does **not** fine-tune model weights. Instead, it treats the skill document as the optimization target:
|
||||
|
||||
- the **student** model executes tasks with the current skill
|
||||
- the **teacher** model analyzes trajectories and proposes edits
|
||||
- the framework merges, ranks, applies, and validates those edits
|
||||
- only validated skill updates are kept
|
||||
|
||||
This branch implements a full training loop with step-level skill optimization and optional epoch-level memory mechanisms (`slow_update`, `meta_skill`, `meta_reflect`).
|
||||
|
||||
## Method Overview
|
||||
|
||||
### Optimization Target
|
||||
|
||||
Each run maintains a mutable markdown skill document. The framework repeatedly improves that document instead of changing model parameters.
|
||||
|
||||
This gives a training-style loop for prompt / policy optimization:
|
||||
|
||||
1. Roll out the current skill on a batch of tasks.
|
||||
2. Reflect on failures and successes.
|
||||
3. Merge patch proposals into a coherent candidate update.
|
||||
4. Rank and select a bounded number of edits.
|
||||
5. Apply those edits to produce a candidate skill.
|
||||
6. Validate the candidate skill on a held-out selection split.
|
||||
7. Keep the update only if the gate accepts it.
|
||||
|
||||
### Per-Step Pipeline
|
||||
|
||||
Every training step executes the following pipeline in `skillopt/engine/trainer.py`:
|
||||
|
||||
1. **Rollout**
|
||||
The student model runs a batch of tasks using the current skill.
|
||||
|
||||
2. **Reflect**
|
||||
The teacher analyzes minibatches of trajectories and emits raw patches.
|
||||
Failure-driven and success-driven patches are tracked separately.
|
||||
|
||||
3. **Aggregate**
|
||||
Raw patches are merged hierarchically. Metadata such as `support_count` and `source_type` is carried into the merged patch so later ranking can use it.
|
||||
|
||||
4. **Select**
|
||||
The teacher ranks the merged edit pool and keeps up to `edit_budget` edits.
|
||||
|
||||
5. **Update**
|
||||
The selected edits are applied to the skill document. The framework records an `edit_apply_report.json` so you can see which edits actually landed, which were skipped, and why.
|
||||
|
||||
6. **Evaluate / Gate**
|
||||
The candidate skill is evaluated on the selection split. Gate validation is mandatory in this branch. A candidate update is accepted only if it improves over the current selection score; a new global best is tracked separately.
|
||||
|
||||
### Within-Epoch Memory
|
||||
|
||||
Inside an epoch, the trainer maintains a step buffer containing:
|
||||
|
||||
- compact failure-pattern summaries from previous steps
|
||||
- rejected edits and their score deltas
|
||||
|
||||
That context is fed back into later reflection calls so the teacher can avoid repeating ineffective edits and can focus on unsolved error patterns.
|
||||
|
||||
### Epoch-Level Mechanisms
|
||||
|
||||
This branch supports three optional epoch-level mechanisms.
|
||||
|
||||
#### Slow Update
|
||||
|
||||
At the end of each epoch, `slow_update` compares the previous epoch’s terminal skill and current epoch’s terminal skill on a sampled train subset. It then writes longitudinal guidance into a protected slow-update region inside the skill document.
|
||||
|
||||
Importantly, this guidance is **not** blindly written through. It is converted into a candidate skill and sent through the same selection gate as step-level updates.
|
||||
|
||||
#### Meta Skill
|
||||
|
||||
`meta_skill` is teacher-side cross-epoch memory. It does not directly edit the current skill. Instead, it writes a compact memory artifact describing longer-term patterns across adjacent epochs. That memory is loaded into later reflection / merge / ranking calls as extra context.
|
||||
|
||||
#### Meta Reflect
|
||||
|
||||
`meta_reflect` runs at epoch end over the step history of the current epoch. It looks at accepted and rejected directions from the whole epoch, proposes higher-level patch edits, applies them to a meta candidate, and then sends that candidate through the same selection gate.
|
||||
|
||||
## What This Branch Guarantees
|
||||
|
||||
The current implementation assumes the following as the mainline method contract:
|
||||
|
||||
- gate validation is always on
|
||||
- the current skill, current score, best skill, and best score stay aligned
|
||||
- `slow_update` is gated before being committed
|
||||
- patch provenance (`source_type`, `support_count`) reaches selection
|
||||
- patch application is observable through per-edit reports
|
||||
- resume state is restored from `runtime_state.json` rather than inferred only from history
|
||||
- all benchmark model calls go through the unified backend router
|
||||
|
||||
## Model Backends
|
||||
|
||||
All model access now goes through the split teacher/student model layer in `skillopt.model`.
|
||||
|
||||
Supported teacher backends:
|
||||
|
||||
- `openai_chat`
|
||||
- `claude_chat`
|
||||
|
||||
Supported student backends:
|
||||
|
||||
- `openai_chat`
|
||||
- `claude_chat`
|
||||
- `codex_exec`
|
||||
- `claude_code_exec`
|
||||
|
||||
Recommended config shape:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
teacher_backend: openai_chat
|
||||
student_backend: codex_exec
|
||||
teacher: gpt-5.4
|
||||
student: gpt-5.4-codex
|
||||
reasoning_effort: medium
|
||||
```
|
||||
|
||||
Legacy `model.backend` and CLI flags like `--backend codex` still work. They are mapped onto the split backend model for backward compatibility.
|
||||
|
||||
The same routing is used by:
|
||||
|
||||
- training (`scripts/train.py`)
|
||||
- eval-only runs (`scripts/eval_only.py`)
|
||||
- SpreadsheetBench standalone prompt eval scripts
|
||||
- LiveMathematicianBench baseline eval script
|
||||
- benchmark rollout code inside the main framework
|
||||
|
||||
### Azure OpenAI
|
||||
|
||||
If you use `openai_chat`, configure either environment variables or config values:
|
||||
|
||||
```bash
|
||||
export AZURE_OPENAI_ENDPOINT="https://your-endpoint.openai.azure.com/"
|
||||
export AZURE_OPENAI_API_KEY="your-api-key"
|
||||
export AZURE_OPENAI_API_VERSION="2025-04-01-preview"
|
||||
```
|
||||
|
||||
The config supports both the old keys and the new explicit names:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
azure_openai_endpoint: "..."
|
||||
azure_openai_api_version: "..."
|
||||
azure_openai_api_key: ""
|
||||
azure_openai_auth_mode: api_key
|
||||
azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
|
||||
azure_openai_managed_identity_client_id: ""
|
||||
```
|
||||
|
||||
`azure_openai_auth_mode` can be used for API-key auth or Azure AD / managed identity flows.
|
||||
|
||||
### Exec Harness
|
||||
|
||||
`codex_exec` and `claude_code_exec` run the student inside a workspace harness instead of a plain chat call. The harness writes task files, renders a dynamic `SKILL.md`, runs the student CLI, and saves raw execution artifacts such as:
|
||||
|
||||
- `codex_raw.txt`
|
||||
- `codex_trace_summary.txt`
|
||||
- workspace-local task / skill files
|
||||
|
||||
This branch keeps `meta_skill` and `apply_patch_with_report`, while upgrading the student path to the more realistic workspace-exec setup.
|
||||
|
||||
### Trace-Aware Deep Reflect
|
||||
|
||||
When `student_backend=codex_exec` and `gradient.use_deep_reflect=true`, deep reflection can probe a specific earlier Codex attempt:
|
||||
|
||||
- the teacher sees a compact Codex trace summary
|
||||
- deep probe can target `probe_target_id`
|
||||
- the follow-up rollout can resume from `probe_after_step`
|
||||
|
||||
This is wired for the dataset-backed environments in this branch.
|
||||
|
||||
### Rewrite Mode
|
||||
|
||||
Skill updates support two modes:
|
||||
|
||||
- `optimizer.skill_update_mode=patch`
|
||||
- `optimizer.skill_update_mode=rewrite_from_suggestions`
|
||||
|
||||
`patch` keeps the existing fine-grained edit application path and still records `edit_apply_report.json`.
|
||||
|
||||
`rewrite_from_suggestions` asks the teacher to emit higher-level rewrite suggestions, then rewrites the whole skill in one pass. This is useful when patch edits become too fragmented.
|
||||
|
||||
## Repository Layout
|
||||
|
||||
```text
|
||||
skillopt/
|
||||
engine/
|
||||
trainer.py main training loop
|
||||
gradient/
|
||||
reflect.py minibatch reflection
|
||||
aggregate.py hierarchical patch merge
|
||||
deep_probe.py diagnostic probing for deep reflect
|
||||
optimizer/
|
||||
clip.py edit ranking / selection
|
||||
skill.py patch application + apply report
|
||||
slow_update.py epoch-level longitudinal guidance
|
||||
meta_skill.py teacher-side cross-epoch memory
|
||||
meta_reflect.py epoch-level macro editing
|
||||
evaluation/
|
||||
gate.py pure gate decision logic
|
||||
model/
|
||||
backend_config.py teacher/student backend routing
|
||||
azure_openai.py Azure backend
|
||||
codex_harness.py workspace exec harness + Codex trace parsing
|
||||
claude_backend.py Claude backend
|
||||
envs/
|
||||
... environment adapters and rollout logic
|
||||
scripts/
|
||||
train.py unified training entry
|
||||
eval_only.py evaluate one skill without training
|
||||
configs/
|
||||
_base_/default.yaml shared defaults
|
||||
<env>/default.yaml environment-specific configs
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Configs use structured YAML with `_base_` inheritance.
|
||||
|
||||
The base config is `configs/_base_/default.yaml`. Key defaults in this branch are:
|
||||
|
||||
- `model.teacher_backend = openai_chat`
|
||||
- `model.student_backend = openai_chat`
|
||||
- `model.reasoning_effort = medium`
|
||||
- `optimizer.use_slow_update = true`
|
||||
- `optimizer.use_meta_skill = true`
|
||||
- `optimizer.use_meta_reflect = false`
|
||||
- `gradient.use_deep_reflect = false`
|
||||
- `optimizer.skill_update_mode = patch`
|
||||
|
||||
Default setting snapshot:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
backend: azure_openai
|
||||
teacher: gpt-5.4
|
||||
student: gpt-5.4
|
||||
teacher_backend: openai_chat
|
||||
student_backend: openai_chat
|
||||
reasoning_effort: medium
|
||||
rewrite_reasoning_effort: ""
|
||||
rewrite_max_completion_tokens: 64000
|
||||
codex_exec_path: codex
|
||||
codex_exec_sandbox: workspace-write
|
||||
codex_exec_profile: ""
|
||||
codex_exec_full_auto: false
|
||||
codex_exec_reasoning_effort: none
|
||||
claude_code_exec_path: claude
|
||||
claude_code_exec_profile: ""
|
||||
codex_trace_to_teacher: true
|
||||
|
||||
train:
|
||||
num_epochs: 4
|
||||
train_size: 0
|
||||
batch_size: 80
|
||||
accumulation: 1
|
||||
seed: 42
|
||||
|
||||
gradient:
|
||||
minibatch_size: 16
|
||||
merge_batch_size: 16
|
||||
analyst_workers: 16
|
||||
max_analyst_rounds: 3
|
||||
failure_only: false
|
||||
use_deep_reflect: false
|
||||
deep_reflect_failures: 4
|
||||
deep_reflect_successes: 2
|
||||
|
||||
optimizer:
|
||||
learning_rate: 8
|
||||
min_learning_rate: 2
|
||||
lr_scheduler: cosine
|
||||
skill_update_mode: patch
|
||||
use_meta_reflect: false
|
||||
meta_learning_rate: 8
|
||||
use_slow_update: true
|
||||
slow_update_samples: 20
|
||||
use_meta_skill: true
|
||||
|
||||
evaluation:
|
||||
use_gate: true
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
eval_test: true
|
||||
|
||||
env:
|
||||
split_mode: ratio
|
||||
split_ratio: "2:1:7"
|
||||
split_seed: 42
|
||||
```
|
||||
|
||||
For the full source of truth, see [configs/_base_/default.yaml](/home/azureuser/workspace-yqh/skillopt_final/configs/_base_/default.yaml).
|
||||
|
||||
Selected fields:
|
||||
|
||||
| Section | Key | Meaning |
|
||||
|---|---|---|
|
||||
| `model` | `teacher_backend` | teacher backend: `openai_chat` or `claude_chat` |
|
||||
| `model` | `student_backend` | student backend: chat backend or exec backend |
|
||||
| `model` | `teacher` | teacher model / deployment |
|
||||
| `model` | `student` | student model / deployment |
|
||||
| `model` | `reasoning_effort` | reasoning budget passed to the backend when supported |
|
||||
| `model` | `codex_trace_to_teacher` | include Codex trace summaries in teacher reflection context |
|
||||
| `train` | `num_epochs` | number of epochs |
|
||||
| `train` | `train_size` | expected train split size, or `0` to infer |
|
||||
| `train` | `batch_size` | tasks per rollout batch |
|
||||
| `train` | `accumulation` | number of rollout/reflect minibatches per step |
|
||||
| `gradient` | `minibatch_size` | trajectories per analyst minibatch |
|
||||
| `gradient` | `merge_batch_size` | patches per aggregate batch |
|
||||
| `gradient` | `use_deep_reflect` | enable diagnostic probe rollouts |
|
||||
| `gradient` | `max_analyst_rounds` | teacher reflection retries / refinement budget |
|
||||
| `optimizer` | `learning_rate` | max edits kept after selection |
|
||||
| `optimizer` | `lr_scheduler` | edit-budget scheduler |
|
||||
| `optimizer` | `use_slow_update` | epoch-level longitudinal guidance |
|
||||
| `optimizer` | `use_meta_skill` | teacher-side epoch memory |
|
||||
| `optimizer` | `use_meta_reflect` | epoch-level macro editing |
|
||||
| `optimizer` | `skill_update_mode` | `patch` or `rewrite_from_suggestions` |
|
||||
| `evaluation` | `sel_env_num` | selection set size (`0` means full split) |
|
||||
| `evaluation` | `test_env_num` | test set size (`0` means full split) |
|
||||
|
||||
### Important Branch Rule
|
||||
|
||||
`use_gate=false` is intentionally not supported in this branch. Gate validation is part of the method contract here.
|
||||
|
||||
If an old config still contains `evaluation.use_gate: false`, the loader / trainer will raise instead of silently continuing.
|
||||
|
||||
## Supported Environments
|
||||
|
||||
The main training entry and eval-only entry now register 11 environments:
|
||||
|
||||
| Env | Default rollout shape | Current default split / data setting | Branch alignment |
|
||||
|---|---|---|---|
|
||||
| `alfworld` | environment-backed episodic rollout | native ALFWorld train/eval splits | in `skillopt_new_zzw` |
|
||||
| `babyvision` | single-round multimodal QA | `split_mode=ratio` from raw metadata/images, or prepared `split_dir` | in `skillopt_new_zzw` |
|
||||
| `docvqa` | single-round multimodal QA | `split_dir: data/docvqa_split` | in `skillopt_new_zzw` |
|
||||
| `livemathematicianbench` | single-round QA | `split_mode=ratio` or prepared `split_dir` | in `skillopt_new_zzw` |
|
||||
| `mathverse` | single-round multimodal math QA | `data_root: data/MathVerse`, split files loaded from `split_dir` when provided | in `skillopt_new_zzw` |
|
||||
| `mmrb` | single-round multimodal reasoning QA | `split_mode=ratio` or prepared `split_dir` | in `skillopt_new_zzw` |
|
||||
| `officeqa` | multi-turn tool loop | `split_dir: data/officeqa_split` plus `data_dirs: [data/officeqa_docs_official]` | in `skillopt_new_zzw` |
|
||||
| `sealqa` | multi-turn tool loop | `split_dir: data/sealqa_split` | in `skillopt_new_zzw` |
|
||||
| `searchqa` | single-round QA (`max_turns=1`) | `split_dir: data/searchqa_split` | in `skillopt_new_zzw` |
|
||||
| `spreadsheetbench` | codegen loop, default `mode=multi`, `max_turns=30` | `split_dir: data/spreadsheetbench_split`, `data_root: data/spreadsheetbench_verified_400` | in `skillopt_new_zzw`, default adjusted here to multi-round |
|
||||
| `swebench` | mini-swe-agent multi-step bug-fixing rollout | `split_mode=ratio`, `dataset_name=lite`, repo-stratified `2:1:7` split materialized under `out_root/_generated_splits/...` unless `split_dir` is provided | added here, aligned to `swe-bench-old` |
|
||||
|
||||
## Data Expectations
|
||||
|
||||
The standard two-mode dataset entry path is:
|
||||
|
||||
- `split_mode: ratio`
|
||||
- load raw data from `env.data_path`
|
||||
- build a deterministic `train/`, `val/`, `test/` split under `env.split_output_dir` (or under `out_root/_generated_splits/` if unset)
|
||||
- default ratio is explicitly `2:1:7`
|
||||
- `split_mode: split_dir`
|
||||
- load an existing `env.split_dir` with `train/`, `val/`, `test/` subdirectories
|
||||
|
||||
This currently applies to:
|
||||
|
||||
- `searchqa`
|
||||
- `spreadsheetbench`
|
||||
- `babyvision`
|
||||
- `livemathematicianbench`
|
||||
- `mmrb`
|
||||
- `swebench`
|
||||
|
||||
`ALFWorld` is the exception: it is environment-backed rather than JSON split-backed.
|
||||
|
||||
The following environments currently expect prepared split directories or extra rooted assets rather than the generic ratio-split path:
|
||||
|
||||
- `docvqa`
|
||||
- `mathverse`
|
||||
- `officeqa`
|
||||
- `sealqa`
|
||||
|
||||
At a high level:
|
||||
|
||||
- `SearchQA`: raw QA json / jsonl or pre-split QA json files
|
||||
- `SpreadsheetBench`: raw task manifest json plus spreadsheet task directory, or a pre-split task manifest
|
||||
- `ALFWorld`: installed game environment and configured eval/train splits
|
||||
- `BabyVision`: raw `meta_data.jsonl` plus images, or a pre-split directory
|
||||
- `DocVQA`: pre-split CSV / JSON data under `split_dir`
|
||||
- `LiveMathematicianBench`: raw monthly QA json files, or a pre-split directory
|
||||
- `MathVerse`: split files plus `data_root` image assets
|
||||
- `MMRB`: raw extracted dataset json files, or a pre-split directory
|
||||
- `OfficeQA`: pre-split metadata plus resolved office document directories
|
||||
- `SealQA`: pre-split metadata for tool-augmented QA tasks
|
||||
- `SWEBench`: HuggingFace SWE-bench dataset alias (`lite` / `verified` / `full`) or a prepared split directory
|
||||
|
||||
### Split References Across Branches
|
||||
|
||||
The split-related defaults are not identical across `skillopt-final`, `skillopt_new_zzw`, `gepa`, and `swe-bench-old`. The practical reference points are:
|
||||
|
||||
| Source branch | Explicit split settings / dirs |
|
||||
|---|---|
|
||||
| `skillopt-final` | `searchqa -> data/searchqa_split`; `spreadsheetbench -> data/spreadsheetbench_split`; `docvqa -> data/docvqa_split`; `officeqa -> data/officeqa_split`; `sealqa -> data/sealqa_split`; `swebench -> ratio split 2:1:7 over the default lite dataset, materialized under out_root/_generated_splits/...` |
|
||||
| `skillopt_new_zzw` | Same 10-benchmark env set as above except no `swebench`; explicit split dirs are `data/searchqa_split`, `data/spreadsheetbench_split`, `data/docvqa_split`, `data/officeqa_split`, `data/sealqa_split`; `spreadsheetbench` there defaults to `mode=single`; `officeqa` uses `max_tool_turns=24`; `sealqa` uses `max_tool_turns=12` |
|
||||
| `gepa` | `configs/spreadsheetbench.yaml` uses `data.splits_dir = data/spreadsheetbench/splits`, `eval.mode = react`, `eval.max_turns = 20`; `configs/swebench.yaml` uses `dataset = SWE-bench/SWE-bench_Verified` with `train_size = 100`, `val_size = 50`, `test_size = 350` |
|
||||
| `swe-bench-old` | Repo-stratified `2:1:7` split over `SWE-Bench_Lite`, persisted as `outputs/.../split/train.json`, `selection.json`, `test.json`; the example split in that branch is `train=60`, `selection=33`, `test=207` |
|
||||
|
||||
For the 10 benches shared with `skillopt_new_zzw`, the current branch is now aligned on env coverage. The main intentional delta is `spreadsheetbench`: this branch defaults to multi-round codegen, while `skillopt_new_zzw` kept `mode=single` by default.
|
||||
|
||||
## Running Training
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python scripts/train.py --config configs/searchqa/default.yaml
|
||||
```
|
||||
|
||||
Explicit 2:1:7 split from raw data:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/searchqa/default.yaml \
|
||||
--split_mode ratio \
|
||||
--data_path /path/to/searchqa_train_2000.json
|
||||
```
|
||||
|
||||
Directly consume a prepared split directory:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/searchqa/default.yaml \
|
||||
--split_mode split_dir \
|
||||
--split_dir /path/to/searchqa_split
|
||||
```
|
||||
|
||||
You can override structured config keys from the CLI:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/spreadsheetbench/default.yaml \
|
||||
--cfg-options model.teacher_backend=openai_chat model.student_backend=codex_exec train.batch_size=40 optimizer.learning_rate=4
|
||||
```
|
||||
|
||||
Legacy flat overrides still work for common keys:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/searchqa/default.yaml \
|
||||
--backend azure_openai \
|
||||
--teacher_model gpt-5.4 \
|
||||
--student_model gpt-5.4 \
|
||||
--reasoning_effort medium
|
||||
```
|
||||
|
||||
Exec harness example:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/searchqa/default.yaml \
|
||||
--teacher_backend openai_chat \
|
||||
--student_backend codex_exec \
|
||||
--teacher_model gpt-5.4 \
|
||||
--student_model gpt-5.4-codex \
|
||||
--use_deep_reflect true \
|
||||
--skill_update_mode rewrite_from_suggestions
|
||||
```
|
||||
|
||||
SWEBench example:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/swebench/default.yaml \
|
||||
--cfg-options env.dataset_name=lite env.split_ratio=2:1:7
|
||||
```
|
||||
|
||||
## Eval-Only and Standalone Evaluation
|
||||
|
||||
Evaluate a specific skill without training:
|
||||
|
||||
```bash
|
||||
python scripts/eval_only.py \
|
||||
--config configs/searchqa/default.yaml \
|
||||
--skill skillopt/envs/searchqa/skills/initial.md
|
||||
```
|
||||
|
||||
The same dataset entry modes apply in eval-only runs:
|
||||
|
||||
- `--split_mode ratio --data_path ...`
|
||||
- `--split_mode split_dir --split_dir ...`
|
||||
|
||||
Standalone scripts also exist for benchmark-specific comparisons, including:
|
||||
|
||||
- `scripts/eval_prompt_custom.py`
|
||||
- `scripts/eval_prompt_official.py`
|
||||
- `scripts/eval_livemathematicianbench_baseline.py`
|
||||
|
||||
These scripts now also support backend selection through the unified model layer.
|
||||
|
||||
## Output Structure
|
||||
|
||||
Each run writes a structured output directory under `out_root`.
|
||||
|
||||
Important top-level artifacts:
|
||||
|
||||
- `config.json` — flattened runtime config
|
||||
- `history.json` — per-step history records
|
||||
- `runtime_state.json` — resume state for current/best skill tracking
|
||||
- `best_skill.md` — current best validated skill
|
||||
- `skills/skill_vXXXX.md` — persisted skill snapshot per step
|
||||
|
||||
Per-step artifacts live under `steps/step_XXXX/`, including:
|
||||
|
||||
- `merged_patch.json`
|
||||
- `ranked_edits.json`
|
||||
- `candidate_skill.md`
|
||||
- `edit_apply_report.json`
|
||||
- `rewrite_result.json` when rewrite mode is enabled
|
||||
- `selection_eval/`
|
||||
- `trajectory_digest.json`
|
||||
- rollout and patch subdirectories
|
||||
|
||||
Epoch-level artifacts live under:
|
||||
|
||||
- `slow_update/epoch_XX/`
|
||||
- `meta_skill/epoch_XX/`
|
||||
- `meta_reflect/epoch_XX/`
|
||||
|
||||
## Resume Behavior
|
||||
|
||||
The trainer resumes from `runtime_state.json` when present. That state tracks:
|
||||
|
||||
- last completed step
|
||||
- current skill path
|
||||
- current score
|
||||
- best skill path
|
||||
- best score
|
||||
- origin tags for current and best skill
|
||||
|
||||
This is important because skill state can change at both step level and epoch level; resuming only from `history.json` is not sufficient for this branch’s method logic.
|
||||
|
||||
## Notes
|
||||
|
||||
- This repository focuses on skill optimization logic; datasets are not included.
|
||||
- Patch application is intentionally observable. Inspect `edit_apply_report.json` when candidate skills do not behave as expected.
|
||||
- `SpreadsheetBench` now defaults to `mode=multi`. If you run an exec student backend there, override back to `env.mode=single` because exec backends are still only wired for SpreadsheetBench single-mode rollout.
|
||||
- `SWEBench` follows the older mini-swe-agent + `swebench.harness.run_evaluation` path, so it requires the SWE-bench / Docker toolchain rather than the generic chat-only stack.
|
||||
- `slow_update` writes into a protected skill region and normal edits are prevented from overwriting that region directly.
|
||||
- `meta_skill` is context memory, not a direct skill edit.
|
||||
- `meta_reflect` is a gated skill edit stage, not just logging.
|
||||
|
||||
## Minimal Setup
|
||||
|
||||
```bash
|
||||
conda create -n skillopt python=3.11
|
||||
conda activate skillopt
|
||||
pip install openai pyyaml openpyxl
|
||||
```
|
||||
|
||||
Depending on the environment, you may also need:
|
||||
|
||||
```bash
|
||||
pip install datasets gymnasium numpy ray regex
|
||||
```
|
||||
|
||||
For `SWEBench`, you also need a working Docker environment plus the SWE-bench / mini-swe-agent dependencies used in `swe-bench-old`.
|
||||
93
configs/_base_/default.yaml
Normal file
93
configs/_base_/default.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
# ReflACT default configuration — base for all environments.
|
||||
# Environment configs should inherit via: _base_: default.yaml
|
||||
|
||||
model:
|
||||
backend: azure_openai
|
||||
teacher: gpt-5.5
|
||||
student: gpt-5.5
|
||||
teacher_backend: openai_chat
|
||||
student_backend: openai_chat
|
||||
reasoning_effort: medium
|
||||
rewrite_reasoning_effort: ""
|
||||
rewrite_max_completion_tokens: 64000
|
||||
codex_exec_path: codex
|
||||
codex_exec_sandbox: workspace-write
|
||||
codex_exec_profile: ""
|
||||
codex_exec_full_auto: false
|
||||
codex_exec_reasoning_effort: none
|
||||
codex_exec_use_sdk: auto
|
||||
codex_exec_network_access: false
|
||||
codex_exec_web_search: false
|
||||
codex_exec_approval_policy: never
|
||||
claude_code_exec_path: claude
|
||||
claude_code_exec_profile: ""
|
||||
claude_code_exec_use_sdk: auto
|
||||
claude_code_exec_effort: medium
|
||||
claude_code_exec_max_thinking_tokens: 16384
|
||||
codex_trace_to_teacher: true
|
||||
azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/"
|
||||
azure_openai_api_version: "2024-12-01-preview"
|
||||
azure_openai_api_key: "" # Fill locally if you do not export AZURE_OPENAI_API_KEY
|
||||
azure_openai_auth_mode: azure_cli
|
||||
azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
|
||||
azure_openai_managed_identity_client_id: ""
|
||||
teacher_azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/"
|
||||
teacher_azure_openai_api_version: "2024-12-01-preview"
|
||||
teacher_azure_openai_api_key: ""
|
||||
teacher_azure_openai_auth_mode: azure_cli
|
||||
teacher_azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
|
||||
teacher_azure_openai_managed_identity_client_id: ""
|
||||
student_azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/"
|
||||
student_azure_openai_api_version: "2024-12-01-preview"
|
||||
student_azure_openai_api_key: ""
|
||||
student_azure_openai_auth_mode: azure_cli
|
||||
student_azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
|
||||
student_azure_openai_managed_identity_client_id: ""
|
||||
|
||||
train:
|
||||
num_epochs: 4
|
||||
train_size: 0 # 0 = derive from dataset split when available
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
seed: 42
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
analyst_workers: 16
|
||||
max_analyst_rounds: 3
|
||||
failure_only: false
|
||||
use_deep_reflect: false
|
||||
deep_reflect_failures: 4
|
||||
deep_reflect_successes: 2
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4 # max edits per step (edit_budget)
|
||||
min_learning_rate: 2 # min edits for decay schedulers
|
||||
lr_scheduler: cosine # constant / linear / cosine / autonomous
|
||||
lr_control_mode: fixed # fixed / autonomous / none
|
||||
skill_update_mode: patch # patch / rewrite_from_suggestions / full_rewrite_minibatch
|
||||
use_meta_reflect: false
|
||||
meta_learning_rate: 4 # max edits per epoch-level meta-reflect
|
||||
use_slow_update: true
|
||||
slow_update_samples: 20
|
||||
longitudinal_pair_policy: mixed # mixed / changed / unchanged
|
||||
use_meta_skill: true
|
||||
|
||||
evaluation:
|
||||
use_gate: true
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
eval_test: true
|
||||
|
||||
env:
|
||||
name: ""
|
||||
skill_init: ""
|
||||
split_mode: ratio # ratio = build deterministic split from data_path; split_dir = use pre-split train/val/test
|
||||
split_ratio: "2:1:7" # explicit default for dataset-backed benchmarks: train:val:test
|
||||
split_seed: 42
|
||||
split_dir: ""
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
exec_timeout: 120 # per student model/code-agent call timeout in seconds
|
||||
out_root: ""
|
||||
305
configs/ablation_study/README.md
Normal file
305
configs/ablation_study/README.md
Normal file
@@ -0,0 +1,305 @@
|
||||
# Ablation Study Configuration Manifest
|
||||
|
||||
This folder records the final, reproducible settings for the ablation runs used
|
||||
in `docs/ablation_paper_tables.md`.
|
||||
|
||||
It is intentionally separate from the benchmark default configs. The benchmark
|
||||
configs under `configs/<benchmark>/default.yaml` remain the source task configs;
|
||||
this folder records the exact matrix-level overrides, run roots, launch commands,
|
||||
and validation rules used for the paper ablations.
|
||||
|
||||
## Files
|
||||
|
||||
- `matrix.yaml`: canonical ablation matrix, common overrides, benchmark splits,
|
||||
token/output caps, and invalid-run rules.
|
||||
- `launch_commands.sh`: exact launcher commands for the valid run roots.
|
||||
- `validation.md`: monitoring, result extraction, and invalidation checklist.
|
||||
|
||||
## Source Of Truth
|
||||
|
||||
Use the matrix launcher:
|
||||
|
||||
```bash
|
||||
/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python scripts/run_ablation_matrix.py
|
||||
```
|
||||
|
||||
The launcher builds runs from the same defaults and values recorded in
|
||||
`matrix.yaml`. It skips completed runs by checking `summary.json` and skips
|
||||
active runs by checking `env.out_root` in active `scripts/train.py` processes.
|
||||
|
||||
Do not manually rerun a completed run into the same `env.out_root`. If a run is
|
||||
invalid, archive or remove its output directory first, then let the launcher
|
||||
start it cleanly.
|
||||
|
||||
## Current Correct Run Roots
|
||||
|
||||
- SearchQA / SpreadsheetBench original ablations:
|
||||
`outputs/ablation_20260502_040604_unique48`
|
||||
- SearchQA / SpreadsheetBench batch-size ablations:
|
||||
`outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run`
|
||||
- LiveMathBench / ALFWorld clean ablations:
|
||||
`outputs/ablation_livemath_alfworld_clean_20260503_155155_run`
|
||||
- DocVQA ablations:
|
||||
`outputs/ablation_docvqa_20260503_160225_run`
|
||||
|
||||
Archived, superseded, misaligned, dry-run, or pre-fix directories must not be
|
||||
used for paper tables.
|
||||
|
||||
## End-To-End Runbook
|
||||
|
||||
### Environment
|
||||
|
||||
Run from the repository root:
|
||||
|
||||
```bash
|
||||
cd /home/azureuser/workspace-gzy/SkillReflection
|
||||
```
|
||||
|
||||
Always use:
|
||||
|
||||
```bash
|
||||
PY=/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python
|
||||
export ALFWORLD_DATA=/home/azureuser/.cache/alfworld
|
||||
```
|
||||
|
||||
Default model/auth settings are generated by `scripts/run_ablation_matrix.py`:
|
||||
|
||||
```text
|
||||
teacher=gpt-5.5
|
||||
student=gpt-5.5
|
||||
teacher_backend=openai_chat
|
||||
student_backend=openai_chat
|
||||
reasoning_effort=medium
|
||||
teacher/student endpoint=https://t2vgoaigpt4o3.openai.azure.com/
|
||||
teacher/student api_version=2024-12-01-preview
|
||||
teacher/student auth_mode=azure_cli
|
||||
```
|
||||
|
||||
Core training settings:
|
||||
|
||||
```text
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.learning_rate=4
|
||||
optimizer.min_learning_rate=2
|
||||
optimizer.lr_scheduler=cosine
|
||||
optimizer.lr_control_mode=fixed
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
optimizer.longitudinal_pair_policy=mixed
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
```
|
||||
|
||||
`train.train_size=0` is intentional. The dataloader derives the train size from
|
||||
the fixed split. Batch-size ablations rely on the default `ceil(train_size /
|
||||
batch_size)` behavior; the last batch can be smaller than `train.batch_size`.
|
||||
|
||||
### Fixed Splits
|
||||
|
||||
Default split directories:
|
||||
|
||||
```text
|
||||
searchqa: data/ablation_splits/searchqa/2-1-7_seed42
|
||||
spreadsheetbench: data/ablation_splits/spreadsheetbench/2-1-7_seed42
|
||||
livemathematicianbench: data/ablation_splits/livemathematicianbench/2-1-7_seed42
|
||||
alfworld: data/ablation_splits/alfworld/2-1-7_seed42
|
||||
docvqa: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
|
||||
```
|
||||
|
||||
Default train/val/test sizes:
|
||||
|
||||
| Benchmark | Train | Val | Test |
|
||||
| --- | ---: | ---: | ---: |
|
||||
| SearchQA | 400 | 200 | 1400 |
|
||||
| SpreadsheetBench | 80 | 40 | 280 |
|
||||
| LiveMathBench | 35 | 18 | 124 |
|
||||
| ALFWorld | 39 | 18 | 134 |
|
||||
| DocVQA | 1070 | 535 | 3744 |
|
||||
|
||||
DocVQA images are not copied. The valid setup uses:
|
||||
|
||||
```text
|
||||
data/docvqa_images -> /home/azureuser/zisu/SkillReflection/data/docvqa_images
|
||||
```
|
||||
|
||||
2026-05-05 DocVQA data correction: all DocVQA final reruns should use the zisu
|
||||
10% split above and a fresh output root such as
|
||||
`outputs/ablation_docvqa_zisu10pct_20260505_run`. The older local
|
||||
`data/ablation_splits/docvqa/2-1-7_seed42` contains the same 5349 questionId
|
||||
pool but a different train/val/test assignment, so its completed summaries are
|
||||
historical only.
|
||||
|
||||
### Matrix Groups
|
||||
|
||||
Use these group names with `scripts/run_ablation_matrix.py`:
|
||||
|
||||
```text
|
||||
default split batch mbs lr sched slown mod smodel longpair lrctrl
|
||||
```
|
||||
|
||||
`longpair` is the slow-update/meta-skill comparison-example ablation. It keeps
|
||||
all prompts and training settings unchanged and only overrides:
|
||||
|
||||
```text
|
||||
optimizer.longitudinal_pair_policy=changed
|
||||
optimizer.longitudinal_pair_policy=unchanged
|
||||
```
|
||||
|
||||
The default paper setting remains `mixed`.
|
||||
|
||||
`lrctrl` contains the two learning-rate-control baselines:
|
||||
|
||||
```text
|
||||
optimizer.lr_control_mode=autonomous
|
||||
optimizer.lr_control_mode=none + optimizer.skill_update_mode=full_rewrite_minibatch
|
||||
```
|
||||
|
||||
The autonomous run logs the chosen integer per step in `lr_decision.json` and
|
||||
`lr_history.jsonl`. The full-rewrite run removes the LR/edit-selection concept:
|
||||
each minibatch analyst produces a complete skill candidate, and aggregate/merge
|
||||
produces the candidate skill directly.
|
||||
|
||||
Batch-size values are:
|
||||
|
||||
```text
|
||||
8 / 24 / 40 / 56 / full
|
||||
```
|
||||
|
||||
`40` is the default point. `full` expands to the benchmark train size.
|
||||
|
||||
### Launch Commands Used In This Session
|
||||
|
||||
The exact commands are recorded in `launch_commands.sh` and in
|
||||
`docs/ablation_plan.md`. The important current policy is:
|
||||
|
||||
- SearchQA / SpreadsheetBench batch-only matrix can run at `--max-parallel 8`.
|
||||
- DocVQA matrix can run with its launcher at `--max-parallel 8`; later top-up used `--max-parallel 16` only because completed runs were skipped and active roots were checked.
|
||||
- LiveMathBench is safe as API-only benchmark after the token cap fix.
|
||||
- ALFWorld must not be mixed into a 24-way run on this shared machine. Use `--bench alfworld --max-parallel 1` only after memory is available.
|
||||
|
||||
### Token And Timeout Fixes
|
||||
|
||||
LiveMathBench must use a large student completion cap:
|
||||
|
||||
```text
|
||||
max_completion_tokens=16384
|
||||
timeout=300
|
||||
```
|
||||
|
||||
The old 768/512 cap produced many empty visible responses because hidden
|
||||
reasoning consumed the budget.
|
||||
|
||||
ALFWorld must use:
|
||||
|
||||
```text
|
||||
max_completion_tokens=2048
|
||||
empty response fallback -> <action>look</action>
|
||||
missing action fallback -> <action>look</action>
|
||||
```
|
||||
|
||||
### Invalid Runs
|
||||
|
||||
Never fill paper tables from these archive directories:
|
||||
|
||||
```text
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_livemath_token768_20260504_022258/
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_token512_20260504_021417/
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_empty_action_20260504_025311/
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_prefallback_20260504_025402/
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_oom_partial_20260504_050517/
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_serial_lowmem_20260504_1300/
|
||||
```
|
||||
|
||||
### Current Resource Notes
|
||||
|
||||
ALFWorld model calls are API calls. The ablation branch now creates local
|
||||
ALFWorld/TextWorld environments through multiprocessing workers ported from
|
||||
`skillopt_final_zzw`, not through Ray actors. Old Ray-based archived runs are
|
||||
not valid for table fill. The observed historical failure mode in this session
|
||||
was system RAM pressure and Ray OOM prevention, not model GPU memory.
|
||||
|
||||
GPU memory currently shown by `nvidia-smi` came from unrelated Ray Serve visual
|
||||
models under:
|
||||
|
||||
```text
|
||||
/home/azureuser/workspace-gzy/zyf/gca-skill
|
||||
```
|
||||
|
||||
Those processes are `GroundingDINOModel` / `DA3Model`, not the
|
||||
SkillReflection ablation ALFWorld run.
|
||||
|
||||
There are also unrelated ALFWorld jobs under:
|
||||
|
||||
```text
|
||||
/home/azureuser/zisu/skill_distill
|
||||
```
|
||||
|
||||
Do not confuse those with this repository's ablation outputs.
|
||||
|
||||
### Monitoring
|
||||
|
||||
Active run and duplicate output-root check:
|
||||
|
||||
```bash
|
||||
$PY - <<'PY'
|
||||
import subprocess, re, collections, time
|
||||
try:
|
||||
raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True)
|
||||
except subprocess.CalledProcessError:
|
||||
raw = ""
|
||||
roots = []
|
||||
for line in raw.splitlines():
|
||||
m = re.search(r"env\.out_root=([^\s]+)", line)
|
||||
if m:
|
||||
roots.append(m.group(1))
|
||||
ctr = collections.Counter(roots)
|
||||
print("time", time.strftime("%F %T"))
|
||||
print("active_count", len(roots))
|
||||
print("duplicates", [r.rsplit("/", 1)[-1] for r, c in ctr.items() if c > 1])
|
||||
for root in sorted(roots):
|
||||
print(root.rsplit("/", 1)[-1])
|
||||
PY
|
||||
```
|
||||
|
||||
Error scan:
|
||||
|
||||
```bash
|
||||
rg -n "Traceback|ERROR|Error code|AuthenticationError|BadRequest|RateLimit|content_filter|Killed|OutOfMemory|CUDA out of memory|\\[FAIL\\]|LLM call failed" \
|
||||
outputs/ablation_docvqa_20260503_160225_run/logs \
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/logs \
|
||||
outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/logs \
|
||||
-g '*.log' | tail -100 || true
|
||||
```
|
||||
|
||||
Resource checks:
|
||||
|
||||
```bash
|
||||
free -h | sed -n '1,3p'
|
||||
df -h /tmp
|
||||
du -sh /tmp/ray 2>/dev/null || true
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
### Filling Tables
|
||||
|
||||
Only use top-level `summary.json` from valid run roots. Fill
|
||||
`docs/ablation_paper_tables.md` from:
|
||||
|
||||
```text
|
||||
best_selection_hard
|
||||
baseline_test_hard
|
||||
test_hard
|
||||
test_delta_hard
|
||||
token_summary._total.total_tokens
|
||||
```
|
||||
81
configs/ablation_study/launch_commands.sh
Executable file
81
configs/ablation_study/launch_commands.sh
Executable file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd /home/azureuser/workspace-gzy/SkillReflection
|
||||
|
||||
PY=/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python
|
||||
export ALFWORLD_DATA=/home/azureuser/.cache/alfworld
|
||||
|
||||
# Original SearchQA / SpreadsheetBench full matrix reproduction command.
|
||||
# Do not run this into the existing root unless intentionally reproducing from
|
||||
# scratch; the current valid root is already populated:
|
||||
# outputs/ablation_20260502_040604_unique48
|
||||
#
|
||||
# setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
# --groups default split mbs lr sched slown mod smodel \
|
||||
# --bench searchqa spreadsheetbench \
|
||||
# --run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_20260502_040604_unique48 \
|
||||
# --max-parallel 24 \
|
||||
# --execute \
|
||||
# > /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_20260502_040604_unique48/launcher_reproduce_full_matrix.log 2>&1 < /dev/null &
|
||||
#
|
||||
# SearchQA / SpreadsheetBench batch-size ablations only.
|
||||
# Original non-batch SearchQA/SpreadsheetBench ablations live in:
|
||||
# outputs/ablation_20260502_040604_unique48
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups batch \
|
||||
--bench searchqa spreadsheetbench \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run \
|
||||
--max-parallel 8 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/launcher_parallel8.log 2>&1 < /dev/null &
|
||||
|
||||
# DocVQA full matrix.
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups default split batch mbs lr sched slown mod smodel \
|
||||
--bench docvqa \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_docvqa_20260503_160225_run \
|
||||
--max-parallel 8 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_docvqa_20260503_160225_run/launcher_parallel8.log 2>&1 < /dev/null &
|
||||
|
||||
# LiveMathBench clean matrix. ALFWorld should be launched separately at lower
|
||||
# concurrency because Ray OOM occurred when many ALFWorld runs were mixed into a
|
||||
# 24-way run.
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups default split batch mbs lr sched slown mod smodel \
|
||||
--bench livemathematicianbench \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run \
|
||||
--max-parallel 8 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_livemath_parallel8.log 2>&1 < /dev/null &
|
||||
|
||||
# ALFWorld clean matrix. Increase to 2 only after checking memory, /tmp/ray,
|
||||
# and that no other ALFWorld run is active. Do not use 8/16/24 for ALFWorld on
|
||||
# the current shared machine unless resources are explicitly reserved.
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups default split batch mbs lr sched slown mod smodel \
|
||||
--bench alfworld \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run \
|
||||
--max-parallel 1 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_alfworld_parallel1.log 2>&1 < /dev/null &
|
||||
|
||||
# Longitudinal comparison-example policy ablations. This intentionally excludes
|
||||
# ALFWorld. The only varied setting is optimizer.longitudinal_pair_policy.
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups longpair \
|
||||
--bench searchqa spreadsheetbench livemathematicianbench docvqa \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_longpair_20260504_run \
|
||||
--max-parallel 8 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_longpair_20260504_run/launcher_longpair_parallel8.log 2>&1 < /dev/null &
|
||||
|
||||
# Learning-rate-control baselines. This intentionally excludes ALFWorld.
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups lrctrl \
|
||||
--bench searchqa spreadsheetbench livemathematicianbench docvqa \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_lrctrl_20260504_run \
|
||||
--max-parallel 8 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_lrctrl_20260504_run/launcher_lrctrl_parallel8.log 2>&1 < /dev/null &
|
||||
257
configs/ablation_study/matrix.yaml
Normal file
257
configs/ablation_study/matrix.yaml
Normal file
@@ -0,0 +1,257 @@
|
||||
version: 2026-05-04
|
||||
purpose: "Canonical paper ablation settings matching the valid current runs."
|
||||
|
||||
launcher:
|
||||
script: scripts/run_ablation_matrix.py
|
||||
python: /home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python
|
||||
skip_completed_by: summary.json
|
||||
skip_active_by: "active scripts/train.py env.out_root"
|
||||
|
||||
environment:
|
||||
working_directory: /home/azureuser/workspace-gzy/SkillReflection
|
||||
required_env:
|
||||
ALFWORLD_DATA: /home/azureuser/.cache/alfworld
|
||||
docvqa_images_symlink: "data/docvqa_images -> /home/azureuser/zisu/SkillReflection/data/docvqa_images"
|
||||
|
||||
common_overrides:
|
||||
model.teacher_backend: openai_chat
|
||||
model.student_backend: openai_chat
|
||||
model.teacher: gpt-5.5
|
||||
model.student: gpt-5.5
|
||||
model.teacher_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/
|
||||
model.teacher_azure_openai_api_version: 2024-12-01-preview
|
||||
model.teacher_azure_openai_auth_mode: azure_cli
|
||||
model.student_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/
|
||||
model.student_azure_openai_api_version: 2024-12-01-preview
|
||||
model.student_azure_openai_auth_mode: azure_cli
|
||||
model.reasoning_effort: medium
|
||||
train.num_epochs: 4
|
||||
train.train_size: 0
|
||||
train.batch_size: 40
|
||||
train.accumulation: 1
|
||||
train.seed: 42
|
||||
gradient.minibatch_size: 8
|
||||
gradient.merge_batch_size: 8
|
||||
gradient.analyst_workers: 16
|
||||
gradient.use_deep_reflect: false
|
||||
optimizer.learning_rate: 4
|
||||
optimizer.min_learning_rate: 2
|
||||
optimizer.lr_scheduler: cosine
|
||||
optimizer.skill_update_mode: patch
|
||||
optimizer.use_slow_update: true
|
||||
optimizer.slow_update_samples: 20
|
||||
optimizer.use_meta_skill: true
|
||||
optimizer.use_meta_reflect: false
|
||||
evaluation.use_gate: true
|
||||
evaluation.eval_test: true
|
||||
env.split_mode: split_dir
|
||||
|
||||
benchmarks:
|
||||
searchqa:
|
||||
config: configs/searchqa/default.yaml
|
||||
run_roots:
|
||||
original_matrix: outputs/ablation_20260502_040604_unique48
|
||||
batch_matrix: outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run
|
||||
default_split: data/ablation_splits/searchqa/2-1-7_seed42
|
||||
train: 400
|
||||
val: 200
|
||||
test: 1400
|
||||
student_rollout:
|
||||
function: skillopt/envs/searchqa/rollout.py::chat_student
|
||||
max_completion_tokens:
|
||||
first_turn: 512
|
||||
refinement: 512
|
||||
rationale: "Short-answer QA; sampled empties are low and not LiveMath-like."
|
||||
|
||||
spreadsheetbench:
|
||||
config: configs/spreadsheetbench/default.yaml
|
||||
run_roots:
|
||||
original_matrix: outputs/ablation_20260502_040604_unique48
|
||||
batch_matrix: outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run
|
||||
default_split: data/ablation_splits/spreadsheetbench/2-1-7_seed42
|
||||
train: 80
|
||||
val: 40
|
||||
test: 280
|
||||
student_rollout:
|
||||
function: skillopt/envs/spreadsheetbench/codegen_agent.py::run_multi
|
||||
max_output_tokens: 16384
|
||||
result_note: "results.jsonl stores execution fields, not a response field."
|
||||
|
||||
livemathematicianbench:
|
||||
config: configs/livemathematicianbench/default.yaml
|
||||
run_roots:
|
||||
clean_matrix: outputs/ablation_livemath_alfworld_clean_20260503_155155_run
|
||||
default_split: data/ablation_splits/livemathematicianbench/2-1-7_seed42
|
||||
train: 35
|
||||
val: 18
|
||||
test: 124
|
||||
student_rollout:
|
||||
function: skillopt/envs/livemathematicianbench/rollout.py::chat_student
|
||||
max_completion_tokens:
|
||||
first_turn: 16384
|
||||
refinement: 16384
|
||||
timeout_seconds: 300
|
||||
invalid_old_caps:
|
||||
first_turn: 768
|
||||
refinement: 512
|
||||
invalid_archive: outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_livemath_token768_20260504_022258
|
||||
rationale: "GPT-5 reasoning consumed small budgets and produced many empty visible responses."
|
||||
|
||||
alfworld:
|
||||
config: configs/alfworld/default.yaml
|
||||
run_roots:
|
||||
clean_matrix: outputs/ablation_livemath_alfworld_clean_20260503_155155_run
|
||||
default_split: data/ablation_splits/alfworld/2-1-7_seed42
|
||||
train: 39
|
||||
val: 18
|
||||
test: 134
|
||||
student_rollout:
|
||||
function: skillopt/envs/alfworld/rollout.py::chat_student
|
||||
max_completion_tokens: 2048
|
||||
timeout_seconds: 120
|
||||
max_steps: 50
|
||||
fallback_action: look
|
||||
invalid_old_cap: 512
|
||||
invalid_archives:
|
||||
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_token512_20260504_021417
|
||||
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_empty_action_20260504_025311
|
||||
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_prefallback_20260504_025402
|
||||
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_oom_partial_20260504_050517
|
||||
concurrency_note: "Do not mix many ALFWorld runs into 24-way total concurrency; Ray OOM occurred. Prefer 1-2 ALFWorld runs at a time unless resources are clearly free."
|
||||
|
||||
docvqa:
|
||||
config: configs/docvqa/default.yaml
|
||||
run_roots:
|
||||
matrix: outputs/ablation_docvqa_zisu10pct_20260505_run
|
||||
default_split: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
|
||||
train: 1070
|
||||
val: 535
|
||||
test: 3744
|
||||
data_note: "2026-05-05: use zisu-provided 5349-item DocVQA split directly; previous local 2-1-7_seed42 used the same item pool but a different train/val/test assignment and must not be used for final DocVQA reruns."
|
||||
student_rollout:
|
||||
function: skillopt/envs/docvqa/rollout.py::chat_student_messages
|
||||
max_completion_tokens:
|
||||
first_turn: 768
|
||||
refinement: 512
|
||||
rationale: "Short-answer VQA output; preserve current setting for alignment unless explicitly rerunning all affected DocVQA."
|
||||
|
||||
splits:
|
||||
tags:
|
||||
1shot:
|
||||
extra_overrides:
|
||||
optimizer.slow_update_samples: 1
|
||||
1-1-8: {}
|
||||
2-1-7:
|
||||
default: true
|
||||
4-1-5: {}
|
||||
paths:
|
||||
searchqa:
|
||||
1shot: data/ablation_splits/searchqa/1shot_seed42
|
||||
1-1-8: data/ablation_splits/searchqa/1-1-8_seed42
|
||||
2-1-7: data/ablation_splits/searchqa/2-1-7_seed42
|
||||
4-1-5: data/ablation_splits/searchqa/4-1-5_seed42
|
||||
spreadsheetbench:
|
||||
1shot: data/ablation_splits/spreadsheetbench/1shot_seed42
|
||||
1-1-8: data/ablation_splits/spreadsheetbench/1-1-8_seed42
|
||||
2-1-7: data/ablation_splits/spreadsheetbench/2-1-7_seed42
|
||||
4-1-5: data/ablation_splits/spreadsheetbench/4-1-5_seed42
|
||||
livemathematicianbench:
|
||||
1shot: data/ablation_splits/livemathematicianbench/1shot_seed42
|
||||
1-1-8: data/ablation_splits/livemathematicianbench/1-1-8_seed42
|
||||
2-1-7: data/ablation_splits/livemathematicianbench/2-1-7_seed42
|
||||
4-1-5: data/ablation_splits/livemathematicianbench/4-1-5_seed42
|
||||
alfworld:
|
||||
1shot: data/ablation_splits/alfworld/1shot_seed42
|
||||
1-1-8: data/ablation_splits/alfworld/1-1-8_seed42
|
||||
2-1-7: data/ablation_splits/alfworld/2-1-7_seed42
|
||||
4-1-5: data/ablation_splits/alfworld/4-1-5_seed42
|
||||
docvqa:
|
||||
1shot: data/ablation_splits/docvqa/1shot_seed42
|
||||
1-1-8: data/ablation_splits/docvqa/1-1-8_seed42
|
||||
2-1-7: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
|
||||
4-1-5: data/ablation_splits/docvqa/4-1-5_seed42
|
||||
|
||||
groups:
|
||||
default:
|
||||
run_id: "DEFAULT-{benchmark}-5.5"
|
||||
overrides: {}
|
||||
split:
|
||||
values: [1shot, 1-1-8, 4-1-5]
|
||||
skip_default_2_1_7: true
|
||||
override_template: "env.split_dir={split_path}"
|
||||
batch:
|
||||
values: [8, 24, 56, full]
|
||||
default_value_reused: 40
|
||||
full_values:
|
||||
searchqa: 400
|
||||
spreadsheetbench: 80
|
||||
livemathematicianbench: 35
|
||||
alfworld: 39
|
||||
docvqa: 1070
|
||||
fixed_overrides:
|
||||
gradient.minibatch_size: 8
|
||||
mbs:
|
||||
values: [1, 2, 4, 16, 32]
|
||||
default_value_reused: 8
|
||||
override_template: "gradient.minibatch_size={value}"
|
||||
lr:
|
||||
values: [1, 2, 4, 8, 16]
|
||||
fixed_overrides:
|
||||
optimizer.lr_scheduler: constant
|
||||
optimizer.min_learning_rate: 1
|
||||
override_template: "optimizer.learning_rate={value}"
|
||||
sched:
|
||||
values: [constant, linear]
|
||||
default_value_reused: cosine
|
||||
override_template: "optimizer.lr_scheduler={value}"
|
||||
slown:
|
||||
values: [5, 10, 40]
|
||||
default_value_reused: 20
|
||||
override_template: "optimizer.slow_update_samples={value}"
|
||||
mod:
|
||||
values:
|
||||
slow-only:
|
||||
optimizer.use_slow_update: true
|
||||
optimizer.use_meta_skill: false
|
||||
meta-only:
|
||||
optimizer.use_slow_update: false
|
||||
optimizer.use_meta_skill: true
|
||||
none:
|
||||
optimizer.use_slow_update: false
|
||||
optimizer.use_meta_skill: false
|
||||
default_value_reused: slow-meta
|
||||
longpair:
|
||||
values: [changed, unchanged]
|
||||
default_value_reused: mixed
|
||||
override_template: "optimizer.longitudinal_pair_policy={value}"
|
||||
note: "Only changes slow-update/meta-skill comparison examples; prompts and other settings remain unchanged."
|
||||
lrctrl:
|
||||
values:
|
||||
autonomous:
|
||||
optimizer.lr_control_mode: autonomous
|
||||
full-rewrite:
|
||||
optimizer.lr_control_mode: none
|
||||
optimizer.skill_update_mode: full_rewrite_minibatch
|
||||
default_value_reused: "fixed patch learning_rate=4"
|
||||
note: "autonomous records lr_decision.json/lr_history.jsonl; full-rewrite removes LR/select/apply-edit and uses full skill candidates."
|
||||
smodel:
|
||||
values:
|
||||
"5.4":
|
||||
model.student: gpt-5.4-pro
|
||||
model.student_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/
|
||||
model.student_azure_openai_api_version: 2025-03-01-preview
|
||||
model.student_azure_openai_auth_mode: azure_cli
|
||||
"5.4-mini":
|
||||
model.student: gpt-5.4-mini
|
||||
model.student_azure_openai_endpoint: https://searchagent5.cognitiveservices.azure.com/
|
||||
model.student_azure_openai_api_version: 2024-12-01-preview
|
||||
model.student_azure_openai_auth_mode: azure_cli
|
||||
default_value_reused: "5.5"
|
||||
|
||||
validity_rules:
|
||||
use_for_tables:
|
||||
- "Only runs with summary.json in valid run roots."
|
||||
- "Do not use archive, archived, MISALIGNED, SUPERSEDED, dryrun, smoke, or debug directories."
|
||||
- "Do not use ALFWorld runs started before empty/missing-action fallback."
|
||||
- "Do not use old LiveMath runs with 768/512 token caps."
|
||||
rerun_rule: "Archive or remove invalid out_root before relaunch; never write a rerun into a polluted output directory."
|
||||
141
configs/ablation_study/validation.md
Normal file
141
configs/ablation_study/validation.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Ablation Validation Checklist
|
||||
|
||||
Use this checklist before launch, during monitoring, and before filling
|
||||
`docs/ablation_paper_tables.md`.
|
||||
|
||||
## Before Launch
|
||||
|
||||
Run from repo root:
|
||||
|
||||
```bash
|
||||
cd /home/azureuser/workspace-gzy/SkillReflection
|
||||
export ALFWORLD_DATA=/home/azureuser/.cache/alfworld
|
||||
```
|
||||
|
||||
Verify syntax for edited files:
|
||||
|
||||
```bash
|
||||
/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python -m py_compile \
|
||||
scripts/run_ablation_matrix.py \
|
||||
scripts/train.py \
|
||||
skillopt/model/azure_openai.py \
|
||||
skillopt/envs/searchqa/rollout.py \
|
||||
skillopt/envs/spreadsheetbench/rollout.py \
|
||||
skillopt/envs/livemathematicianbench/rollout.py \
|
||||
skillopt/envs/alfworld/rollout.py \
|
||||
skillopt/envs/docvqa/rollout.py
|
||||
```
|
||||
|
||||
Check active runs and duplicate `env.out_root` before starting more:
|
||||
|
||||
```bash
|
||||
/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python - <<'PY'
|
||||
import subprocess, re, collections
|
||||
try:
|
||||
raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True)
|
||||
except subprocess.CalledProcessError:
|
||||
raw = ""
|
||||
roots = []
|
||||
for line in raw.splitlines():
|
||||
m = re.search(r"env\.out_root=([^\s]+)", line)
|
||||
if m:
|
||||
roots.append(m.group(1))
|
||||
ctr = collections.Counter(roots)
|
||||
print("train_count", len(roots))
|
||||
print("duplicate_roots", [r.rsplit("/", 1)[-1] for r, c in ctr.items() if c > 1])
|
||||
for root in sorted(roots):
|
||||
print(root.rsplit("/", 1)[-1])
|
||||
PY
|
||||
```
|
||||
|
||||
## During Monitoring
|
||||
|
||||
Check launchers:
|
||||
|
||||
```bash
|
||||
pgrep -af 'scripts/run_ablation_matrix.py' || true
|
||||
tail -80 outputs/ablation_docvqa_20260503_160225_run/launcher_parallel8.log 2>/dev/null || true
|
||||
tail -80 outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_livemath_parallel8.log 2>/dev/null || true
|
||||
tail -80 outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_alfworld_parallel1.log 2>/dev/null || true
|
||||
```
|
||||
|
||||
Scan current logs for new hard failures:
|
||||
|
||||
```bash
|
||||
rg -n "Traceback|ERROR|Error code|AuthenticationError|BadRequest|RateLimit|content_filter|Killed|OutOfMemory|\\[FAIL\\]|\\[RETRY\\]" \
|
||||
outputs/ablation_docvqa_20260503_160225_run/logs \
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/logs \
|
||||
outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/logs \
|
||||
-g '*.log' | tail -160 || true
|
||||
```
|
||||
|
||||
Check resource pressure:
|
||||
|
||||
```bash
|
||||
df -h /tmp
|
||||
du -sh /tmp/ray 2>/dev/null || true
|
||||
free -h | sed -n '1,3p'
|
||||
```
|
||||
|
||||
## Quality Checks
|
||||
|
||||
LiveMathBench current valid runs should not look like old 768/512 runs:
|
||||
|
||||
```bash
|
||||
/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python - <<'PY'
|
||||
import json, pathlib
|
||||
root = pathlib.Path("outputs/ablation_livemath_alfworld_clean_20260503_155155_run")
|
||||
for run in sorted(root.glob("*livemathematicianbench*")):
|
||||
if not run.is_dir() or "archive" in str(run):
|
||||
continue
|
||||
for rel in ["test_eval_baseline/results.jsonl", "test_eval/results.jsonl"]:
|
||||
p = run / rel
|
||||
if not p.exists():
|
||||
continue
|
||||
rows = [json.loads(l) for l in p.open(errors="ignore") if l.strip()]
|
||||
empty = sum(1 for r in rows if not str(r.get("response", "")).strip())
|
||||
answer = sum(1 for r in rows if "<answer>" in str(r.get("response", "")).lower())
|
||||
if empty:
|
||||
print(run.name, rel, "empty", empty, "answer", answer, "n", len(rows))
|
||||
PY
|
||||
```
|
||||
|
||||
ALFWorld valid runs must not contain empty action or missing action:
|
||||
|
||||
```bash
|
||||
/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python - <<'PY'
|
||||
import json, pathlib
|
||||
root = pathlib.Path("outputs/ablation_livemath_alfworld_clean_20260503_155155_run")
|
||||
for run in sorted(root.glob("*alfworld*")):
|
||||
if not run.is_dir() or "archive" in str(run):
|
||||
continue
|
||||
bad = []
|
||||
fallback = 0
|
||||
for c in run.glob("**/conversation.json"):
|
||||
data = json.load(c.open(errors="ignore"))
|
||||
for step in data:
|
||||
if step.get("step") is None:
|
||||
continue
|
||||
if not step.get("action"):
|
||||
bad.append(str(c.relative_to(run)))
|
||||
break
|
||||
mr = str(step.get("model_response", ""))
|
||||
if "empty model response" in mr or "missing action tag" in mr:
|
||||
fallback += 1
|
||||
print(run.name, "bad_action_files", len(bad), "fallback", fallback)
|
||||
PY
|
||||
```
|
||||
|
||||
## Filling Tables
|
||||
|
||||
Use only `summary.json` fields:
|
||||
|
||||
- `best_selection_hard` -> Best Sel
|
||||
- `baseline_test_hard` -> Base Test
|
||||
- `test_hard` -> Best Test
|
||||
- `test_delta_hard` -> Delta
|
||||
- `total_accepts` -> Accept
|
||||
- `total_rejects` -> Reject
|
||||
- `token_summary._total.total_tokens` -> Tokens
|
||||
|
||||
Do not fill table rows from logs alone.
|
||||
30
configs/alfworld/default.yaml
Normal file
30
configs/alfworld/default.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
train:
|
||||
train_size: 0
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
use_meta_reflect: false
|
||||
|
||||
evaluation:
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
|
||||
env:
|
||||
name: alfworld
|
||||
skill_init: skillopt/envs/alfworld/skills/initial.md
|
||||
split_mode: split_dir
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: data/ablation_splits/alfworld/2-1-7_seed42
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_steps: 50
|
||||
workers: 8
|
||||
max_api_workers: 8
|
||||
limit: 0
|
||||
4
configs/alfworld/meta_reflect.yaml
Normal file
4
configs/alfworld/meta_reflect.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
_base_: default.yaml
|
||||
|
||||
optimizer:
|
||||
use_meta_reflect: true
|
||||
21
configs/babyvision/default.yaml
Normal file
21
configs/babyvision/default.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
accumulation: 1
|
||||
|
||||
env:
|
||||
name: babyvision
|
||||
skill_init: skillopt/envs/babyvision/skills/initial.md
|
||||
split_mode: ratio
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: ""
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_turns: 1
|
||||
workers: 16
|
||||
limit: 0
|
||||
image_detail: auto
|
||||
judge_model: gpt-5.4
|
||||
judge_max_completion_tokens: 256
|
||||
judge_retries: 5
|
||||
28
configs/docvqa/default.yaml
Normal file
28
configs/docvqa/default.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
env:
|
||||
name: docvqa
|
||||
skill_init: skillopt/envs/docvqa/skills/initial.md
|
||||
split_mode: split_dir
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_turns: 1
|
||||
workers: 16
|
||||
image_detail: auto
|
||||
limit: 0
|
||||
22
configs/livemathematicianbench/default.yaml
Normal file
22
configs/livemathematicianbench/default.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
train:
|
||||
train_size: 0
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
|
||||
env:
|
||||
name: livemathematicianbench
|
||||
skill_init: skillopt/envs/livemathematicianbench/skills/initial.md
|
||||
split_mode: split_dir
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: data/ablation_splits/livemathematicianbench/2-1-7_seed42
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_turns: 1
|
||||
exec_timeout: 300
|
||||
workers: 64
|
||||
limit: 0
|
||||
shuffle_choices: true
|
||||
use_theorem: false
|
||||
use_sketch: false
|
||||
23
configs/mathverse/default.yaml
Normal file
23
configs/mathverse/default.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
codex_exec_sandbox: danger-full-access
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
accumulation: 1
|
||||
|
||||
env:
|
||||
name: mathverse
|
||||
skill_init: skillopt/envs/mathverse/skills/initial.md
|
||||
split_dir: ""
|
||||
data_root: data/MathVerse
|
||||
problem_version: Text Lite
|
||||
use_text_dominant_reference: false
|
||||
max_turns: 1
|
||||
workers: 16
|
||||
limit: 0
|
||||
image_detail: auto
|
||||
judge_model: gpt-5.4
|
||||
judge_max_completion_tokens: 256
|
||||
judge_retries: 5
|
||||
18
configs/mmrb/default.yaml
Normal file
18
configs/mmrb/default.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
train:
|
||||
batch_size: 128
|
||||
accumulation: 1
|
||||
|
||||
env:
|
||||
name: mmrb
|
||||
skill_init: skillopt/envs/mmrb/skills/initial.md
|
||||
split_mode: ratio
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: ""
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_turns: 1
|
||||
workers: 16
|
||||
limit: 0
|
||||
image_detail: auto
|
||||
25
configs/officeqa/default.yaml
Normal file
25
configs/officeqa/default.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
env:
|
||||
name: officeqa
|
||||
skill_init: skillopt/envs/officeqa/skills/initial.md
|
||||
split_dir: data/officeqa_split
|
||||
data_dirs:
|
||||
- data/officeqa_docs_official
|
||||
workers: 4
|
||||
max_tool_turns: 24
|
||||
limit: 0
|
||||
23
configs/sealqa/default.yaml
Normal file
23
configs/sealqa/default.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
batch_size: 10
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
env:
|
||||
name: sealqa
|
||||
skill_init: skillopt/envs/sealqa/skills/initial.md
|
||||
split_dir: data/sealqa_split
|
||||
workers: 4
|
||||
max_tool_turns: 12
|
||||
limit: 0
|
||||
32
configs/searchqa/default.yaml
Normal file
32
configs/searchqa/default.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
train_size: 400
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
evaluation:
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
|
||||
env:
|
||||
name: searchqa
|
||||
skill_init: skillopt/envs/searchqa/skills/initial.md
|
||||
split_mode: split_dir
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: data/searchqa_split
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_turns: 1
|
||||
workers: 24
|
||||
limit: 0
|
||||
34
configs/spreadsheetbench/default.yaml
Normal file
34
configs/spreadsheetbench/default.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
train_size: 80
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
evaluation:
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
|
||||
env:
|
||||
name: spreadsheetbench
|
||||
skill_init: skillopt/envs/spreadsheetbench/skills/initial.md
|
||||
split_mode: split_dir
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: data/spreadsheetbench_split
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
data_root: data/spreadsheetbench_verified_400
|
||||
mode: multi
|
||||
max_turns: 30
|
||||
exec_timeout: 600
|
||||
workers: 24
|
||||
36
configs/swebench/default.yaml
Normal file
36
configs/swebench/default.yaml
Normal file
@@ -0,0 +1,36 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
batch_size: 20
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 4
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
evaluation:
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
|
||||
env:
|
||||
name: swebench
|
||||
skill_init: skillopt/envs/swebench/skills/initial.md
|
||||
split_mode: ratio
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: ""
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
dataset_name: lite
|
||||
hf_split: test
|
||||
workers: 8
|
||||
eval_workers: 8
|
||||
step_limit: 50
|
||||
cost_limit: 3.0
|
||||
timeout_per_instance: 600
|
||||
limit: 0
|
||||
20
scripts/codex_azure_mi.sh
Executable file
20
scripts/codex_azure_mi.sh
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT="/home/azureuser/workspace-gzy/SkillReflection"
|
||||
PY="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
REAL_CODEX="/home/azureuser/.nvm/versions/node/v18.20.8/bin/codex"
|
||||
CLIENT_ID="8cafa2b1-a2a7-4ad9-814a-ffe4aed7e800"
|
||||
SCOPE="https://cognitiveservices.azure.com/.default"
|
||||
|
||||
token="$("$PY" - <<PY
|
||||
from azure.identity import ManagedIdentityCredential, get_bearer_token_provider
|
||||
cred = ManagedIdentityCredential(client_id="$CLIENT_ID")
|
||||
print(get_bearer_token_provider(cred, "$SCOPE")())
|
||||
PY
|
||||
)"
|
||||
|
||||
export CODEX_HOME="${CODEX_HOME:-$ROOT/.codex_azure}"
|
||||
export AZURE_OPENAI_AUTH_HEADER="Bearer $token"
|
||||
|
||||
exec "$REAL_CODEX" "$@"
|
||||
53
scripts/download_babyvision.py
Normal file
53
scripts/download_babyvision.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Download BabyVision from Hugging Face and convert it to local meta_data.jsonl + images/ format."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument("--out_dir", type=str, required=True)
|
||||
p.add_argument("--dataset", type=str, default="UnipatAI/BabyVision")
|
||||
p.add_argument("--split", type=str, default="train")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise SystemExit("Please install `datasets` first: pip install datasets pillow") from exc
|
||||
|
||||
out_dir = Path(args.out_dir).resolve()
|
||||
images_dir = out_dir / "images"
|
||||
meta_path = out_dir / "meta_data.jsonl"
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
dataset = load_dataset(args.dataset, split=args.split)
|
||||
with open(meta_path, "w", encoding="utf-8") as outf:
|
||||
for idx, row in enumerate(dataset):
|
||||
image = row.get("image")
|
||||
if image is None:
|
||||
continue
|
||||
task_id = str(row.get("taskId") or row.get("id") or idx + 1)
|
||||
image_name = f"{task_id}.png"
|
||||
image_path = images_dir / image_name
|
||||
image.save(image_path)
|
||||
|
||||
record = dict(row)
|
||||
record["image"] = image_name
|
||||
outf.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"Saved BabyVision to {out_dir}")
|
||||
print(f"Metadata: {meta_path}")
|
||||
print(f"Images: {images_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
333
scripts/eval_livemathematicianbench_baseline.py
Normal file
333
scripts/eval_livemathematicianbench_baseline.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Evaluate LiveMathematicianBench under current or official-style prompts."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
from skillopt.envs.livemathematicianbench.dataloader import load_items
|
||||
from skillopt.envs.livemathematicianbench.evaluator import evaluate as current_evaluate
|
||||
from skillopt.envs.livemathematicianbench.rollout import _build_system, _build_user
|
||||
from skillopt.model import (
|
||||
chat_with_deployment,
|
||||
configure_azure_openai,
|
||||
set_backend,
|
||||
set_reasoning_effort,
|
||||
)
|
||||
|
||||
_LABELS = ["A", "B", "C", "D", "E"]
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument("--data_path", type=str, required=True)
|
||||
p.add_argument("--model", type=str, default="gpt-5.4")
|
||||
p.add_argument("--backend", type=str, choices=["azure_openai", "codex", "claude"], default="azure_openai")
|
||||
p.add_argument("--mode", type=str, choices=["current", "official"], required=True)
|
||||
p.add_argument("--reasoning_effort", type=str, default=None)
|
||||
p.add_argument("--azure_endpoint", type=str, default="")
|
||||
p.add_argument("--azure_api_version", type=str, default="")
|
||||
p.add_argument("--azure_api_key", type=str, default="")
|
||||
p.add_argument("--max_completion_tokens", type=int, default=0)
|
||||
p.add_argument("--workers", type=int, default=8)
|
||||
p.add_argument("--seed", type=int, default=20260227)
|
||||
p.add_argument("--skill_path", type=str, default="skillopt/envs/livemathematicianbench/skills/initial.md")
|
||||
p.add_argument("--limit", type=int, default=0)
|
||||
p.add_argument("--resume", action="store_true")
|
||||
p.add_argument("--output_json", type=str, required=True)
|
||||
return p.parse_args()
|
||||
|
||||
def read_skill(skill_path: str) -> str:
|
||||
with open(skill_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def official_extract_answer(response_text: str) -> str | None:
|
||||
if not response_text:
|
||||
return None
|
||||
boxed_match = re.search(r"\\boxed\{([A-Ea-e])\}", response_text)
|
||||
if boxed_match:
|
||||
return boxed_match.group(1).upper()
|
||||
boxed_match = re.search(r"boxed\{([A-Ea-e])\}", response_text)
|
||||
if boxed_match:
|
||||
return boxed_match.group(1).upper()
|
||||
answer_match = re.search(r"answer is[:\s]*([A-Ea-e])", response_text, re.IGNORECASE)
|
||||
if answer_match:
|
||||
return answer_match.group(1).upper()
|
||||
answer_match = re.search(r"Answer[:\s]*\(?([A-Ea-e])\)?", response_text)
|
||||
if answer_match:
|
||||
return answer_match.group(1).upper()
|
||||
final_match = re.search(r"\b([A-Ea-e])\b\s*[.)]?\s*$", response_text.strip())
|
||||
if final_match:
|
||||
return final_match.group(1).upper()
|
||||
return None
|
||||
|
||||
|
||||
def official_format_mcq_prompt(question: str, choices: list[dict]) -> str:
|
||||
lines = [
|
||||
"Answer the following multiple-choice question.",
|
||||
"Think carefully, then provide your final answer in the format: \\boxed{X} where X is A, B, C, D, or E.",
|
||||
"",
|
||||
"Question:",
|
||||
question,
|
||||
"",
|
||||
"Choices:",
|
||||
]
|
||||
for choice in choices:
|
||||
lines.append(f"{choice['label']}. {choice['text']}")
|
||||
lines.append("")
|
||||
lines.append("Your answer:")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def shuffle_choices(item: dict, seed: int) -> tuple[list[dict], dict]:
|
||||
correct_choice = dict(item["correct_choice"])
|
||||
all_choices = [dict(choice) for choice in item["choices"]]
|
||||
rng = random.Random(f"{seed}:{item['id']}")
|
||||
rng.shuffle(all_choices)
|
||||
|
||||
shuffled: list[dict] = []
|
||||
new_correct = dict(correct_choice)
|
||||
correct_text = correct_choice["text"]
|
||||
|
||||
for idx, choice in enumerate(all_choices[: len(_LABELS)]):
|
||||
relabeled = {"label": _LABELS[idx], "text": choice["text"]}
|
||||
shuffled.append(relabeled)
|
||||
if choice["text"] == correct_text:
|
||||
new_correct = dict(relabeled)
|
||||
|
||||
return shuffled, new_correct
|
||||
|
||||
|
||||
def load_existing(output_path: Path) -> dict[str, dict]:
|
||||
if not output_path.exists():
|
||||
return {}
|
||||
with open(output_path, encoding="utf-8") as f:
|
||||
payload = json.load(f)
|
||||
existing = {}
|
||||
for row in payload.get("results", []):
|
||||
existing[str(row["id"])] = row
|
||||
return existing
|
||||
|
||||
|
||||
def save_results(output_path: Path, meta: dict, results: list[dict]) -> None:
|
||||
correct = sum(1 for row in results if row.get("is_correct"))
|
||||
total = len(results)
|
||||
payload = {
|
||||
**meta,
|
||||
"summary": {
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
"accuracy": (correct / total) if total else 0.0,
|
||||
},
|
||||
"results": sorted(results, key=lambda row: str(row["id"])),
|
||||
}
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def call_model(
|
||||
*,
|
||||
model: str,
|
||||
system: str,
|
||||
user: str,
|
||||
max_completion_tokens: int | None,
|
||||
reasoning_effort: str | None,
|
||||
) -> str:
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(5):
|
||||
try:
|
||||
set_reasoning_effort(reasoning_effort)
|
||||
raw, _ = chat_with_deployment(
|
||||
deployment=model,
|
||||
system=system,
|
||||
user=user,
|
||||
max_completion_tokens=max_completion_tokens if max_completion_tokens and max_completion_tokens > 0 else 4096,
|
||||
retries=1,
|
||||
stage="rollout",
|
||||
)
|
||||
return str(raw or "")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_error = exc
|
||||
if attempt == 4:
|
||||
break
|
||||
time.sleep(min(2 ** attempt, 10))
|
||||
raise RuntimeError(f"LLM call failed after retries: {last_error}")
|
||||
|
||||
|
||||
def evaluate_one(
|
||||
item: dict,
|
||||
*,
|
||||
mode: str,
|
||||
model: str,
|
||||
skill_content: str,
|
||||
max_completion_tokens: int,
|
||||
reasoning_effort: str | None,
|
||||
seed: int,
|
||||
) -> dict:
|
||||
shuffled_choices, correct_choice = shuffle_choices(item, seed)
|
||||
|
||||
if mode == "official":
|
||||
system = "You are an expert mathematician. Answer accurately."
|
||||
user = official_format_mcq_prompt(item["question"], shuffled_choices)
|
||||
effective_max_completion_tokens = max_completion_tokens if max_completion_tokens > 0 else None
|
||||
else:
|
||||
materialized = dict(item)
|
||||
materialized["choices"] = shuffled_choices
|
||||
materialized["correct_choice"] = correct_choice
|
||||
system = _build_system(skill_content)
|
||||
user = _build_user(materialized, use_theorem=False, use_sketch=False)
|
||||
effective_max_completion_tokens = max_completion_tokens if max_completion_tokens > 0 else 768
|
||||
|
||||
t0 = time.time()
|
||||
response = call_model(
|
||||
model=model,
|
||||
system=system,
|
||||
user=user,
|
||||
max_completion_tokens=effective_max_completion_tokens,
|
||||
reasoning_effort=reasoning_effort,
|
||||
)
|
||||
elapsed = time.time() - t0
|
||||
|
||||
if mode == "official":
|
||||
predicted = official_extract_answer(response)
|
||||
predicted_text = ""
|
||||
for choice in shuffled_choices:
|
||||
if choice["label"] == predicted:
|
||||
predicted_text = choice["text"]
|
||||
break
|
||||
is_correct = predicted == correct_choice["label"]
|
||||
return {
|
||||
"id": item["id"],
|
||||
"question": item["question"],
|
||||
"correct_label": correct_choice["label"],
|
||||
"correct_text": correct_choice["text"],
|
||||
"predicted_label": predicted,
|
||||
"predicted_text": predicted_text,
|
||||
"is_correct": is_correct,
|
||||
"elapsed_seconds": elapsed,
|
||||
"response": response,
|
||||
}
|
||||
|
||||
eval_result = current_evaluate(response, correct_choice, shuffled_choices)
|
||||
return {
|
||||
"id": item["id"],
|
||||
"question": item["question"],
|
||||
"correct_label": correct_choice["label"],
|
||||
"correct_text": correct_choice["text"],
|
||||
"predicted_label": eval_result["predicted_label"],
|
||||
"predicted_text": eval_result["predicted_text"],
|
||||
"is_correct": bool(eval_result["em"]),
|
||||
"elapsed_seconds": elapsed,
|
||||
"response": response,
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
set_backend(args.backend)
|
||||
configure_azure_openai(
|
||||
endpoint=args.azure_endpoint or None,
|
||||
api_version=args.azure_api_version or None,
|
||||
api_key=args.azure_api_key or None,
|
||||
)
|
||||
set_reasoning_effort(args.reasoning_effort)
|
||||
output_path = Path(args.output_json).resolve()
|
||||
skill_content = read_skill(args.skill_path) if args.mode == "current" else ""
|
||||
|
||||
items = load_items(args.data_path)
|
||||
if args.limit:
|
||||
items = items[:args.limit]
|
||||
|
||||
existing = load_existing(output_path) if args.resume else {}
|
||||
pending = [item for item in items if str(item["id"]) not in existing]
|
||||
results = list(existing.values())
|
||||
|
||||
print("=" * 72, flush=True)
|
||||
print("LiveMathematicianBench baseline eval", flush=True)
|
||||
print("=" * 72, flush=True)
|
||||
print(f"Mode: {args.mode}", flush=True)
|
||||
print(f"Model: {args.model}", flush=True)
|
||||
print(f"Reasoning effort: {args.reasoning_effort}", flush=True)
|
||||
print(f"Items: {len(items)} total, {len(pending)} pending, {len(existing)} resumed", flush=True)
|
||||
print(f"Output: {output_path}", flush=True)
|
||||
print("=" * 72, flush=True)
|
||||
|
||||
meta = {
|
||||
"mode": args.mode,
|
||||
"model": args.model,
|
||||
"reasoning_effort": args.reasoning_effort,
|
||||
"seed": args.seed,
|
||||
"max_completion_tokens": args.max_completion_tokens,
|
||||
}
|
||||
|
||||
if not pending:
|
||||
save_results(output_path, meta, results)
|
||||
summary = json.loads(output_path.read_text(encoding="utf-8"))["summary"]
|
||||
print(f"Accuracy: {summary['correct']}/{summary['total']} = {summary['accuracy']:.4f}", flush=True)
|
||||
return
|
||||
|
||||
with ThreadPoolExecutor(max_workers=args.workers) as ex:
|
||||
futs = {
|
||||
ex.submit(
|
||||
evaluate_one,
|
||||
item,
|
||||
mode=args.mode,
|
||||
model=args.model,
|
||||
skill_content=skill_content,
|
||||
max_completion_tokens=args.max_completion_tokens,
|
||||
reasoning_effort=args.reasoning_effort,
|
||||
seed=args.seed,
|
||||
): item
|
||||
for item in pending
|
||||
}
|
||||
completed = 0
|
||||
for fut in as_completed(futs):
|
||||
item = futs[fut]
|
||||
try:
|
||||
row = fut.result()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
row = {
|
||||
"id": item["id"],
|
||||
"question": item["question"],
|
||||
"correct_label": None,
|
||||
"correct_text": item["correct_choice"]["text"],
|
||||
"predicted_label": None,
|
||||
"predicted_text": "",
|
||||
"is_correct": False,
|
||||
"elapsed_seconds": 0.0,
|
||||
"response": "",
|
||||
"error": str(exc),
|
||||
}
|
||||
results.append(row)
|
||||
completed += 1
|
||||
correct = sum(1 for result in results if result.get("is_correct"))
|
||||
total = len(results)
|
||||
print(
|
||||
f"[{completed}/{len(pending)}] id={row['id']} "
|
||||
f"pred={row['predicted_label']} gold={row['correct_label']} "
|
||||
f"acc={correct}/{total}={correct/total:.4f}",
|
||||
flush=True,
|
||||
)
|
||||
save_results(output_path, meta, results)
|
||||
|
||||
summary = json.loads(output_path.read_text(encoding="utf-8"))["summary"]
|
||||
print("=" * 72, flush=True)
|
||||
print(f"Accuracy: {summary['correct']}/{summary['total']} = {summary['accuracy']:.4f}", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
451
scripts/eval_only.py
Normal file
451
scripts/eval_only.py
Normal file
@@ -0,0 +1,451 @@
|
||||
#!/usr/bin/env python3
|
||||
"""ReflACT eval-only: run a single skill on a dataset without training.
|
||||
|
||||
Usage
|
||||
-----
|
||||
python scripts/eval_only.py \
|
||||
--config configs/spreadsheetbench/default.yaml \
|
||||
--skill skillopt/envs/spreadsheetbench/skills/initial.md \
|
||||
--split_dir /path/to/split \
|
||||
--out_root outputs/eval_skill0
|
||||
|
||||
All YAML keys can be overridden from the CLI, same as train.py.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
from skillopt.model import (
|
||||
configure_azure_openai,
|
||||
configure_claude_code_exec,
|
||||
configure_codex_exec,
|
||||
set_reasoning_effort,
|
||||
set_student_backend,
|
||||
set_student_deployment,
|
||||
set_teacher_backend,
|
||||
set_teacher_deployment,
|
||||
)
|
||||
from skillopt.model.common import default_model_for_backend, normalize_backend_name
|
||||
|
||||
_OPENAI_DEFAULT_MODEL_SENTINELS = {"gpt-5.4", "gpt-5.5"}
|
||||
from skillopt.utils import compute_score
|
||||
|
||||
|
||||
# ── Reuse registry from train.py ───────────────────────────────────────────
|
||||
|
||||
_ENV_REGISTRY: dict[str, type] = {}
|
||||
|
||||
|
||||
def _register_builtins() -> None:
|
||||
try:
|
||||
from skillopt.envs.alfworld.adapter import ALFWorldAdapter
|
||||
_ENV_REGISTRY["alfworld"] = ALFWorldAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.searchqa.adapter import SearchQAAdapter
|
||||
_ENV_REGISTRY["searchqa"] = SearchQAAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.livemathematicianbench.adapter import LiveMathematicianBenchAdapter
|
||||
_ENV_REGISTRY["livemathematicianbench"] = LiveMathematicianBenchAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.babyvision.adapter import BabyVisionAdapter
|
||||
_ENV_REGISTRY["babyvision"] = BabyVisionAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.spreadsheetbench.adapter import SpreadsheetBenchAdapter
|
||||
_ENV_REGISTRY["spreadsheetbench"] = SpreadsheetBenchAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.mmrb.adapter import MMRBAdapter
|
||||
_ENV_REGISTRY["mmrb"] = MMRBAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.docvqa.adapter import DocVQAAdapter
|
||||
_ENV_REGISTRY["docvqa"] = DocVQAAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.mathverse.adapter import MathVerseAdapter
|
||||
_ENV_REGISTRY["mathverse"] = MathVerseAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.officeqa.adapter import OfficeQAAdapter
|
||||
_ENV_REGISTRY["officeqa"] = OfficeQAAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.sealqa.adapter import SealQAAdapter
|
||||
_ENV_REGISTRY["sealqa"] = SealQAAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.swebench.adapter import SWEBenchAdapter
|
||||
_ENV_REGISTRY["swebench"] = SWEBenchAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_adapter(cfg: dict):
|
||||
_register_builtins()
|
||||
env_name = cfg.get("env", "alfworld")
|
||||
if env_name not in _ENV_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Unknown environment '{env_name}'. "
|
||||
f"Available: {list(_ENV_REGISTRY.keys())}"
|
||||
)
|
||||
adapter_cls = _ENV_REGISTRY[env_name]
|
||||
|
||||
import inspect
|
||||
sig = inspect.signature(adapter_cls.__init__)
|
||||
accepted = set(sig.parameters.keys()) - {"self"}
|
||||
adapter_kwargs = {k: cfg[k] for k in accepted if k in cfg}
|
||||
return adapter_cls(**adapter_kwargs)
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────────────────────────
|
||||
|
||||
_BOOL = lambda x: str(x).lower() in ("true", "1", "yes") # noqa: E731
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description="ReflACT eval-only")
|
||||
p.add_argument("--config", type=str, required=True)
|
||||
p.add_argument("--skill", type=str, required=True,
|
||||
help="Path to skill .md file to evaluate")
|
||||
p.add_argument("--split", type=str, default="all",
|
||||
help="Which split to eval: train/valid_seen/valid_unseen/all (default: all)")
|
||||
p.add_argument("--cfg-options", nargs="+", default=[],
|
||||
help="Override config: section.key=value")
|
||||
# Legacy flat overrides
|
||||
p.add_argument("--env", type=str)
|
||||
p.add_argument("--backend", type=str,
|
||||
choices=["azure_openai", "codex", "codex_exec", "claude", "claude_chat", "claude_code_exec"])
|
||||
p.add_argument("--teacher_model", type=str)
|
||||
p.add_argument("--student_model", type=str)
|
||||
p.add_argument("--teacher_backend", type=str)
|
||||
p.add_argument("--student_backend", type=str)
|
||||
p.add_argument("--reasoning_effort", type=str,
|
||||
choices=["", "low", "medium", "high", "xhigh", "max"])
|
||||
p.add_argument("--azure_endpoint", type=str)
|
||||
p.add_argument("--azure_api_version", type=str)
|
||||
p.add_argument("--azure_api_key", type=str)
|
||||
p.add_argument("--azure_openai_endpoint", type=str)
|
||||
p.add_argument("--azure_openai_api_version", type=str)
|
||||
p.add_argument("--azure_openai_api_key", type=str)
|
||||
p.add_argument("--azure_openai_auth_mode", type=str)
|
||||
p.add_argument("--azure_openai_ad_scope", type=str)
|
||||
p.add_argument("--azure_openai_managed_identity_client_id", type=str)
|
||||
p.add_argument("--teacher_azure_openai_endpoint", type=str)
|
||||
p.add_argument("--teacher_azure_openai_api_version", type=str)
|
||||
p.add_argument("--teacher_azure_openai_api_key", type=str)
|
||||
p.add_argument("--teacher_azure_openai_auth_mode", type=str)
|
||||
p.add_argument("--teacher_azure_openai_ad_scope", type=str)
|
||||
p.add_argument("--teacher_azure_openai_managed_identity_client_id", type=str)
|
||||
p.add_argument("--student_azure_openai_endpoint", type=str)
|
||||
p.add_argument("--student_azure_openai_api_version", type=str)
|
||||
p.add_argument("--student_azure_openai_api_key", type=str)
|
||||
p.add_argument("--student_azure_openai_auth_mode", type=str)
|
||||
p.add_argument("--student_azure_openai_ad_scope", type=str)
|
||||
p.add_argument("--student_azure_openai_managed_identity_client_id", type=str)
|
||||
p.add_argument("--codex_exec_path", type=str)
|
||||
p.add_argument("--codex_exec_sandbox", type=str)
|
||||
p.add_argument("--codex_exec_profile", type=str)
|
||||
p.add_argument("--codex_exec_full_auto", type=_BOOL)
|
||||
p.add_argument("--codex_exec_reasoning_effort", type=str)
|
||||
p.add_argument("--codex_exec_use_sdk", type=str)
|
||||
p.add_argument("--codex_exec_network_access", type=_BOOL)
|
||||
p.add_argument("--codex_exec_web_search", type=_BOOL)
|
||||
p.add_argument("--codex_exec_approval_policy", type=str)
|
||||
p.add_argument("--claude_code_exec_path", type=str)
|
||||
p.add_argument("--claude_code_exec_profile", type=str)
|
||||
p.add_argument("--claude_code_exec_use_sdk", type=str)
|
||||
p.add_argument("--claude_code_exec_effort", type=str)
|
||||
p.add_argument("--claude_code_exec_max_thinking_tokens", type=int)
|
||||
p.add_argument("--out_root", type=str)
|
||||
p.add_argument("--data_path", type=str)
|
||||
p.add_argument("--split_mode", type=str,
|
||||
choices=["ratio", "split_dir"])
|
||||
p.add_argument("--split_ratio", type=str)
|
||||
p.add_argument("--split_seed", type=int)
|
||||
p.add_argument("--split_dir", type=str)
|
||||
p.add_argument("--split_output_dir", type=str)
|
||||
p.add_argument("--data_root", type=str)
|
||||
p.add_argument("--max_turns", type=int)
|
||||
p.add_argument("--workers", type=int)
|
||||
p.add_argument("--max_api_workers", type=int)
|
||||
p.add_argument("--seed", type=int)
|
||||
p.add_argument("--test_env_num", type=int)
|
||||
p.add_argument("--mode", type=str,
|
||||
help="SpreadsheetBench: single/multi/react (default comes from config)")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
from skillopt.config import load_config as _load, flatten_config, is_structured
|
||||
|
||||
cfg = _load(args.config, overrides=args.cfg_options)
|
||||
structured = is_structured(cfg)
|
||||
|
||||
# Apply legacy --key value overrides
|
||||
cli = {k: v for k, v in vars(args).items()
|
||||
if v is not None and k not in ("config", "skill", "split", "cfg_options")}
|
||||
if cli:
|
||||
if structured:
|
||||
from skillopt.config import apply_overrides
|
||||
_MAP = {
|
||||
"backend": "model.backend",
|
||||
"teacher_model": "model.teacher",
|
||||
"student_model": "model.student",
|
||||
"teacher_backend": "model.teacher_backend",
|
||||
"student_backend": "model.student_backend",
|
||||
"reasoning_effort": "model.reasoning_effort",
|
||||
"azure_endpoint": "model.azure_endpoint",
|
||||
"azure_api_version": "model.azure_api_version",
|
||||
"azure_api_key": "model.azure_api_key",
|
||||
"azure_openai_endpoint": "model.azure_openai_endpoint",
|
||||
"azure_openai_api_version": "model.azure_openai_api_version",
|
||||
"azure_openai_api_key": "model.azure_openai_api_key",
|
||||
"azure_openai_auth_mode": "model.azure_openai_auth_mode",
|
||||
"azure_openai_ad_scope": "model.azure_openai_ad_scope",
|
||||
"azure_openai_managed_identity_client_id": "model.azure_openai_managed_identity_client_id",
|
||||
"teacher_azure_openai_endpoint": "model.teacher_azure_openai_endpoint",
|
||||
"teacher_azure_openai_api_version": "model.teacher_azure_openai_api_version",
|
||||
"teacher_azure_openai_api_key": "model.teacher_azure_openai_api_key",
|
||||
"teacher_azure_openai_auth_mode": "model.teacher_azure_openai_auth_mode",
|
||||
"teacher_azure_openai_ad_scope": "model.teacher_azure_openai_ad_scope",
|
||||
"teacher_azure_openai_managed_identity_client_id": "model.teacher_azure_openai_managed_identity_client_id",
|
||||
"student_azure_openai_endpoint": "model.student_azure_openai_endpoint",
|
||||
"student_azure_openai_api_version": "model.student_azure_openai_api_version",
|
||||
"student_azure_openai_api_key": "model.student_azure_openai_api_key",
|
||||
"student_azure_openai_auth_mode": "model.student_azure_openai_auth_mode",
|
||||
"student_azure_openai_ad_scope": "model.student_azure_openai_ad_scope",
|
||||
"student_azure_openai_managed_identity_client_id": "model.student_azure_openai_managed_identity_client_id",
|
||||
"codex_exec_path": "model.codex_exec_path",
|
||||
"codex_exec_sandbox": "model.codex_exec_sandbox",
|
||||
"codex_exec_profile": "model.codex_exec_profile",
|
||||
"codex_exec_full_auto": "model.codex_exec_full_auto",
|
||||
"codex_exec_reasoning_effort": "model.codex_exec_reasoning_effort",
|
||||
"codex_exec_use_sdk": "model.codex_exec_use_sdk",
|
||||
"codex_exec_network_access": "model.codex_exec_network_access",
|
||||
"codex_exec_web_search": "model.codex_exec_web_search",
|
||||
"codex_exec_approval_policy": "model.codex_exec_approval_policy",
|
||||
"claude_code_exec_path": "model.claude_code_exec_path",
|
||||
"claude_code_exec_profile": "model.claude_code_exec_profile",
|
||||
"claude_code_exec_use_sdk": "model.claude_code_exec_use_sdk",
|
||||
"claude_code_exec_effort": "model.claude_code_exec_effort",
|
||||
"claude_code_exec_max_thinking_tokens": "model.claude_code_exec_max_thinking_tokens",
|
||||
"seed": "train.seed",
|
||||
"test_env_num": "evaluation.test_env_num",
|
||||
"env": "env.name",
|
||||
"out_root": "env.out_root",
|
||||
}
|
||||
mapped = []
|
||||
for k, v in cli.items():
|
||||
dotted = _MAP.get(k)
|
||||
if dotted:
|
||||
mapped.append(f"{dotted}={v}")
|
||||
else:
|
||||
mapped.append(f"env.{k}={v}")
|
||||
apply_overrides(cfg, mapped)
|
||||
else:
|
||||
cfg.update(cli)
|
||||
|
||||
cfg = flatten_config(cfg) if structured else cfg
|
||||
|
||||
for new_key, old_key in (
|
||||
("azure_openai_endpoint", "azure_endpoint"),
|
||||
("azure_openai_api_version", "azure_api_version"),
|
||||
("azure_openai_api_key", "azure_api_key"),
|
||||
):
|
||||
if cfg.get(new_key) in (None, "") and cfg.get(old_key) not in (None, ""):
|
||||
cfg[new_key] = cfg[old_key]
|
||||
|
||||
explicit_backend = getattr(args, "backend", None)
|
||||
if explicit_backend is None:
|
||||
for option in args.cfg_options or []:
|
||||
key = str(option).split("=", 1)[0].strip()
|
||||
if key == "model.backend":
|
||||
explicit_backend = str(option).split("=", 1)[1].strip()
|
||||
break
|
||||
|
||||
backend = normalize_backend_name(cfg.get("model_backend") or cfg.get("student_backend") or "azure_openai")
|
||||
|
||||
def _has_model_override(dotted_key: str, legacy_key: str) -> bool:
|
||||
if getattr(args, legacy_key, None) is not None:
|
||||
return True
|
||||
for option in args.cfg_options or []:
|
||||
key = str(option).split("=", 1)[0].strip()
|
||||
if key == dotted_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
if explicit_backend is not None:
|
||||
backend = normalize_backend_name(explicit_backend)
|
||||
cfg["model_backend"] = backend
|
||||
if backend in {"claude", "claude_chat"}:
|
||||
cfg.setdefault("teacher_backend", "claude_chat")
|
||||
cfg.setdefault("student_backend", "claude_chat")
|
||||
elif backend in {"codex", "codex_exec"}:
|
||||
cfg.setdefault("teacher_backend", "openai_chat")
|
||||
cfg.setdefault("student_backend", "codex_exec")
|
||||
elif backend == "claude_code_exec":
|
||||
cfg.setdefault("teacher_backend", "openai_chat")
|
||||
cfg.setdefault("student_backend", "claude_code_exec")
|
||||
else:
|
||||
cfg.setdefault("teacher_backend", "openai_chat")
|
||||
cfg.setdefault("student_backend", "openai_chat")
|
||||
else:
|
||||
cfg.setdefault("teacher_backend", "openai_chat")
|
||||
cfg.setdefault("student_backend", "openai_chat")
|
||||
|
||||
if cfg.get("teacher_backend") == "claude_chat":
|
||||
if (
|
||||
str(cfg.get("teacher_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS
|
||||
and not _has_model_override("model.teacher", "teacher_model")
|
||||
):
|
||||
cfg["teacher_model"] = default_model_for_backend("claude_chat")
|
||||
if cfg.get("student_backend") == "claude_chat":
|
||||
if (
|
||||
str(cfg.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS
|
||||
and not _has_model_override("model.student", "student_model")
|
||||
):
|
||||
cfg["student_model"] = default_model_for_backend("claude_chat")
|
||||
if cfg.get("student_backend") == "claude_code_exec":
|
||||
if (
|
||||
str(cfg.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS
|
||||
and not _has_model_override("model.student", "student_model")
|
||||
):
|
||||
cfg["student_model"] = default_model_for_backend("claude_chat")
|
||||
|
||||
if not cfg.get("out_root"):
|
||||
env = cfg.get("env", "unknown")
|
||||
model = cfg.get("student_model", "unknown").replace("/", "-")
|
||||
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
cfg["out_root"] = os.path.join("outputs", f"eval_{env}_{model}_{ts}")
|
||||
|
||||
cfg["out_root"] = os.path.abspath(cfg["out_root"])
|
||||
|
||||
out_root = cfg["out_root"]
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
|
||||
# Load skill
|
||||
skill_path = os.path.abspath(args.skill)
|
||||
with open(skill_path) as f:
|
||||
skill_content = f.read()
|
||||
print(f" [skill] {skill_path} ({len(skill_content)} chars)")
|
||||
|
||||
# Configure models
|
||||
configure_azure_openai(
|
||||
endpoint=(cfg.get("azure_openai_endpoint") or cfg.get("azure_endpoint") or None),
|
||||
api_version=(cfg.get("azure_openai_api_version") or cfg.get("azure_api_version") or None),
|
||||
api_key=(cfg.get("azure_openai_api_key") or cfg.get("azure_api_key") or None),
|
||||
auth_mode=cfg.get("azure_openai_auth_mode") or None,
|
||||
ad_scope=cfg.get("azure_openai_ad_scope") or None,
|
||||
managed_identity_client_id=cfg.get("azure_openai_managed_identity_client_id") or None,
|
||||
teacher_endpoint=cfg.get("teacher_azure_openai_endpoint") or None,
|
||||
teacher_api_version=cfg.get("teacher_azure_openai_api_version") or None,
|
||||
teacher_api_key=cfg.get("teacher_azure_openai_api_key") or None,
|
||||
teacher_auth_mode=cfg.get("teacher_azure_openai_auth_mode") or None,
|
||||
teacher_ad_scope=cfg.get("teacher_azure_openai_ad_scope") or None,
|
||||
teacher_managed_identity_client_id=(
|
||||
cfg.get("teacher_azure_openai_managed_identity_client_id") or None
|
||||
),
|
||||
student_endpoint=cfg.get("student_azure_openai_endpoint") or None,
|
||||
student_api_version=cfg.get("student_azure_openai_api_version") or None,
|
||||
student_api_key=cfg.get("student_azure_openai_api_key") or None,
|
||||
student_auth_mode=cfg.get("student_azure_openai_auth_mode") or None,
|
||||
student_ad_scope=cfg.get("student_azure_openai_ad_scope") or None,
|
||||
student_managed_identity_client_id=(
|
||||
cfg.get("student_azure_openai_managed_identity_client_id") or None
|
||||
),
|
||||
)
|
||||
set_teacher_backend(cfg.get("teacher_backend", "openai_chat"))
|
||||
set_student_backend(cfg.get("student_backend", "openai_chat"))
|
||||
set_teacher_deployment(cfg.get("teacher_model", default_model_for_backend(backend)))
|
||||
set_student_deployment(cfg.get("student_model", default_model_for_backend(backend)))
|
||||
configure_codex_exec(
|
||||
path=cfg.get("codex_exec_path", "codex"),
|
||||
sandbox=cfg.get("codex_exec_sandbox", "workspace-write"),
|
||||
profile=cfg.get("codex_exec_profile", ""),
|
||||
full_auto=cfg.get("codex_exec_full_auto", False),
|
||||
reasoning_effort=cfg.get("codex_exec_reasoning_effort", "none"),
|
||||
use_sdk=cfg.get("codex_exec_use_sdk", None),
|
||||
network_access=cfg.get("codex_exec_network_access", False),
|
||||
web_search=cfg.get("codex_exec_web_search", False),
|
||||
approval_policy=cfg.get("codex_exec_approval_policy", "never"),
|
||||
)
|
||||
configure_claude_code_exec(
|
||||
path=cfg.get("claude_code_exec_path", "claude"),
|
||||
profile=cfg.get("claude_code_exec_profile", ""),
|
||||
use_sdk=cfg.get("claude_code_exec_use_sdk", None),
|
||||
effort=cfg.get("claude_code_exec_effort", cfg.get("reasoning_effort", "medium")),
|
||||
max_thinking_tokens=cfg.get("claude_code_exec_max_thinking_tokens", 16384),
|
||||
)
|
||||
set_reasoning_effort(cfg.get("reasoning_effort", "") or None)
|
||||
|
||||
# Build adapter
|
||||
adapter = get_adapter(cfg)
|
||||
adapter.setup(cfg)
|
||||
|
||||
seed = cfg.get("seed", 42)
|
||||
split = args.split or "all"
|
||||
|
||||
if split == "all":
|
||||
items = (
|
||||
adapter.build_eval_env(0, "train", seed)
|
||||
+ adapter.build_eval_env(0, "valid_seen", seed)
|
||||
+ adapter.build_eval_env(0, "valid_unseen", seed)
|
||||
)
|
||||
else:
|
||||
env_num = cfg.get("test_env_num", 0)
|
||||
items = adapter.build_eval_env(env_num, split, seed)
|
||||
|
||||
print(f"\n [eval] split={split} items={len(items)}")
|
||||
print(f" [eval] out_root={out_root}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Run rollout
|
||||
results = adapter.rollout(items, skill_content, out_root)
|
||||
|
||||
# Score
|
||||
hard, soft = compute_score(results)
|
||||
print(f"\n{'='*60}")
|
||||
print(f" Results: hard={hard:.4f} soft={soft:.4f} (n={len(results)})")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Save summary
|
||||
summary = {
|
||||
"skill": skill_path,
|
||||
"split": split,
|
||||
"n_items": len(results),
|
||||
"hard": hard,
|
||||
"soft": soft,
|
||||
}
|
||||
with open(os.path.join(out_root, "eval_summary.json"), "w") as f:
|
||||
json.dump(summary, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f" Saved to: {out_root}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
361
scripts/eval_prompt_custom.py
Normal file
361
scripts/eval_prompt_custom.py
Normal file
@@ -0,0 +1,361 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Standalone eval: CUSTOM prompt (with Critical Rules) on verified-400.
|
||||
|
||||
Usage:
|
||||
python scripts/eval_prompt_custom.py --workers 8
|
||||
python scripts/eval_prompt_custom.py --workers 32 --limit 20
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeoutError
|
||||
|
||||
import openpyxl
|
||||
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
from skillopt.model import (
|
||||
chat_messages_with_deployment,
|
||||
configure_azure_openai,
|
||||
set_backend,
|
||||
set_student_deployment,
|
||||
)
|
||||
from skillopt.envs.spreadsheetbench.evaluator import evaluate
|
||||
|
||||
|
||||
# ── Config ──────────────────────────────────────────────────────────────────
|
||||
|
||||
DATA_ROOT = "/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400"
|
||||
JSONL_PATH = os.path.join(DATA_ROOT, "dataset.json")
|
||||
MODEL = "gpt-5-mini"
|
||||
|
||||
# ── Custom Prompt (with Critical Rules) ─────────────────────────────────────
|
||||
|
||||
_SYSTEM_TEMPLATE = """\
|
||||
You are an expert Python programmer specializing in spreadsheet manipulation.
|
||||
You will be given a user instruction together with a preview of an input .xlsx file.
|
||||
Your job is to write a single self-contained Python script that reads the input file
|
||||
at the path stored in the variable INPUT_PATH, performs the requested manipulation,
|
||||
and saves the result to OUTPUT_PATH.
|
||||
|
||||
## Critical Rules
|
||||
1. NEVER write Excel formulas to cells. openpyxl does NOT compute formulas —
|
||||
the evaluator will see None. Compute results in Python and write literal values.
|
||||
2. Use only: standard library, openpyxl, pandas.
|
||||
3. Do NOT hardcode cell values from the preview — iterate over actual rows.
|
||||
4. The script must define INPUT_PATH and OUTPUT_PATH at the top.
|
||||
|
||||
{skill_section}\
|
||||
Return ONLY the Python code inside a single ```python ... ``` fenced block.
|
||||
"""
|
||||
|
||||
|
||||
def build_system(skill_content: str = "") -> str:
|
||||
if skill_content.strip():
|
||||
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
|
||||
else:
|
||||
skill_section = ""
|
||||
return _SYSTEM_TEMPLATE.format(skill_section=skill_section)
|
||||
|
||||
|
||||
def build_user(instruction, input_xlsx, instruction_type="", answer_position=""):
|
||||
try:
|
||||
preview = _preview_workbook(input_xlsx)
|
||||
except Exception as e:
|
||||
preview = f"(failed to preview: {e})"
|
||||
extra = ""
|
||||
if instruction_type:
|
||||
extra += f"\nInstruction type: {instruction_type}"
|
||||
if answer_position:
|
||||
extra += f"\nExpected answer position: {answer_position}"
|
||||
return (
|
||||
f"# Instruction\n{instruction}\n{extra}\n\n"
|
||||
f"# Input spreadsheet preview\n{preview}\n\n"
|
||||
"# Task\n"
|
||||
"Write a Python script that reads the workbook from the variable `INPUT_PATH`, "
|
||||
"applies the instruction, and writes the modified workbook to `OUTPUT_PATH`. "
|
||||
"Preserve all other cells unchanged. "
|
||||
"The preview may be truncated — do not hardcode row counts; "
|
||||
"iterate over all actual rows in the workbook instead.\n"
|
||||
"Return only a ```python``` code block."
|
||||
)
|
||||
|
||||
|
||||
# ── Shared utilities ────────────────────────────────────────────────────────
|
||||
|
||||
def _preview_workbook(path, max_rows=5, max_cols=20):
|
||||
wb = openpyxl.load_workbook(path, data_only=False)
|
||||
chunks = []
|
||||
for sn in wb.sheetnames:
|
||||
ws = wb[sn]
|
||||
chunks.append(f"## Sheet: {sn} (dim={ws.dimensions}, max_row={ws.max_row}, max_col={ws.max_column})")
|
||||
for row in ws.iter_rows(min_row=1, max_row=min(ws.max_row, max_rows),
|
||||
max_col=min(ws.max_column, max_cols), values_only=False):
|
||||
cells = []
|
||||
for c in row:
|
||||
v = c.value
|
||||
s = "" if v is None else str(v)
|
||||
if len(s) > 40: s = s[:37] + "..."
|
||||
cells.append(f"{c.coordinate}={s}")
|
||||
chunks.append(" | ".join(cells))
|
||||
if ws.max_row > max_rows:
|
||||
chunks.append(f"... ({ws.max_row - max_rows} more rows)")
|
||||
chunks.append("")
|
||||
wb.close()
|
||||
return "\n".join(chunks)
|
||||
|
||||
|
||||
def extract_code(text):
|
||||
if "```" not in text:
|
||||
return text.strip()
|
||||
start = text.find("```")
|
||||
nl = text.find("\n", start)
|
||||
end = text.find("```", nl + 1)
|
||||
if nl == -1 or end == -1:
|
||||
return text.strip()
|
||||
return text[nl + 1:end].strip()
|
||||
|
||||
|
||||
_PATH_RE = re.compile(r'^\s*(INPUT_PATH|OUTPUT_PATH)\s*=\s*.+$', re.MULTILINE)
|
||||
|
||||
def strip_paths(code):
|
||||
return _PATH_RE.sub("", code)
|
||||
|
||||
|
||||
RUNNER_TEMPLATE = textwrap.dedent("""
|
||||
import os, sys, traceback
|
||||
INPUT_PATH = {input_path!r}
|
||||
OUTPUT_PATH = {output_path!r}
|
||||
try:
|
||||
{code_indented}
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
sys.exit(2)
|
||||
""")
|
||||
|
||||
|
||||
def run_code(code, input_path, output_path, timeout=120):
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
cleaned = strip_paths(code)
|
||||
indented = textwrap.indent(cleaned, " ")
|
||||
script = RUNNER_TEMPLATE.format(input_path=input_path, output_path=output_path, code_indented=indented)
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
|
||||
f.write(script)
|
||||
tmp = f.name
|
||||
try:
|
||||
proc = subprocess.run([sys.executable, tmp], capture_output=True, text=True, timeout=timeout)
|
||||
if proc.returncode != 0:
|
||||
return False, (proc.stdout + "\n" + proc.stderr).strip()
|
||||
if not os.path.exists(output_path):
|
||||
return False, "output file was not created"
|
||||
return True, ""
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, f"timeout after {timeout}s"
|
||||
finally:
|
||||
try: os.unlink(tmp)
|
||||
except OSError: pass
|
||||
|
||||
|
||||
def find_test_cases(task_dir):
|
||||
cases = []
|
||||
for ip in sorted(glob.glob(os.path.join(task_dir, "*_input.xlsx"))):
|
||||
no = os.path.basename(ip).split("_", 1)[0]
|
||||
ap = ip.replace("_input.xlsx", "_answer.xlsx")
|
||||
if os.path.exists(ap): cases.append((no, ip, ap))
|
||||
for ip in sorted(glob.glob(os.path.join(task_dir, "*_init.xlsx"))):
|
||||
no = os.path.basename(ip).split("_", 1)[0]
|
||||
ap = ip.replace("_init.xlsx", "_golden.xlsx")
|
||||
if os.path.exists(ap): cases.append((no, ip, ap))
|
||||
if not cases:
|
||||
bare_init = os.path.join(task_dir, "initial.xlsx")
|
||||
bare_gold = os.path.join(task_dir, "golden.xlsx")
|
||||
if os.path.exists(bare_init) and os.path.exists(bare_gold):
|
||||
cases.append(("1", bare_init, bare_gold))
|
||||
return cases
|
||||
|
||||
|
||||
def load_items(path):
|
||||
if path.endswith(".json"):
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
data = data.get("data") or list(data.values())
|
||||
return list(data)
|
||||
items = []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line: items.append(json.loads(line))
|
||||
return items
|
||||
|
||||
|
||||
# ── LLM call ────────────────────────────────────────────────────────────────
|
||||
|
||||
def llm_call(messages, deployment, max_tokens=16384, retries=5, llm_timeout=120):
|
||||
raw, _ = chat_messages_with_deployment(
|
||||
deployment=deployment,
|
||||
messages=messages,
|
||||
max_completion_tokens=max_tokens,
|
||||
retries=retries,
|
||||
stage="rollout",
|
||||
timeout=llm_timeout,
|
||||
)
|
||||
return str(raw or "")
|
||||
|
||||
|
||||
# ── Process one task ────────────────────────────────────────────────────────
|
||||
|
||||
def process_one(item, data_root, out_root, model):
|
||||
task_id = str(item["id"])
|
||||
instruction = item["instruction"]
|
||||
instruction_type = item.get("instruction_type", "")
|
||||
answer_position = item.get("answer_position", "")
|
||||
answer_sheet = item.get("answer_sheet", "")
|
||||
if answer_position and answer_sheet and "!" not in answer_position:
|
||||
answer_position = f"{answer_sheet}!{answer_position}"
|
||||
|
||||
sp = item.get("spreadsheet_path", f"spreadsheet/{task_id}")
|
||||
task_dir = sp if os.path.isabs(sp) else os.path.join(data_root, sp)
|
||||
|
||||
result = {"id": task_id, "ok": False, "hard": 0, "soft": 0.0,
|
||||
"n_cases": 0, "n_pass": 0, "fail_reason": "", "error": ""}
|
||||
try:
|
||||
cases = find_test_cases(task_dir)
|
||||
result["n_cases"] = len(cases)
|
||||
if not cases:
|
||||
result["fail_reason"] = "no-test-cases"
|
||||
return result
|
||||
|
||||
task_out = os.path.join(out_root, "predictions", task_id)
|
||||
os.makedirs(task_out, exist_ok=True)
|
||||
|
||||
# LLM call
|
||||
system = build_system("")
|
||||
user = build_user(instruction, cases[0][1], instruction_type, answer_position)
|
||||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
||||
|
||||
raw = llm_call(messages, model)
|
||||
time.sleep(3)
|
||||
code = extract_code(raw)
|
||||
|
||||
with open(os.path.join(task_out, "code.py"), "w") as f: f.write(code)
|
||||
with open(os.path.join(task_out, "raw.txt"), "w") as f: f.write(raw)
|
||||
|
||||
if not code.strip():
|
||||
result["fail_reason"] = "empty-code"
|
||||
return result
|
||||
|
||||
# Execute + evaluate each test case
|
||||
for no, ip, ap in cases:
|
||||
pred = os.path.join(task_out, f"{no}_pred.xlsx")
|
||||
ok_exec, err = run_code(code, ip, pred)
|
||||
if not ok_exec:
|
||||
if not result["fail_reason"]:
|
||||
result["fail_reason"] = f"exec: {err[:200]}"
|
||||
continue
|
||||
try:
|
||||
ev = evaluate(pred, ap, instruction_type, answer_position)
|
||||
except Exception as e:
|
||||
ev = {"ok": False, "reason": str(e)}
|
||||
if ev["ok"]:
|
||||
result["n_pass"] += 1
|
||||
|
||||
nc, np = result["n_cases"], result["n_pass"]
|
||||
result["soft"] = np / nc if nc else 0.0
|
||||
result["hard"] = 1 if nc > 0 and np == nc else 0
|
||||
result["ok"] = bool(result["hard"])
|
||||
if result["ok"]: result["fail_reason"] = ""
|
||||
return result
|
||||
except Exception as e:
|
||||
result["fail_reason"] = f"unexpected: {e}"
|
||||
result["error"] = traceback.format_exc()
|
||||
return result
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Eval CUSTOM prompt on verified-400")
|
||||
ap.add_argument("--model", default=MODEL)
|
||||
ap.add_argument("--backend", choices=["azure_openai", "codex", "claude"], default="azure_openai")
|
||||
ap.add_argument("--azure_endpoint", default="")
|
||||
ap.add_argument("--azure_api_version", default="")
|
||||
ap.add_argument("--azure_api_key", default="")
|
||||
ap.add_argument("--workers", type=int, default=8)
|
||||
ap.add_argument("--limit", type=int, default=0)
|
||||
ap.add_argument("--out_root", default="")
|
||||
args = ap.parse_args()
|
||||
|
||||
set_backend(args.backend)
|
||||
configure_azure_openai(
|
||||
endpoint=args.azure_endpoint or None,
|
||||
api_version=args.azure_api_version or None,
|
||||
api_key=args.azure_api_key or None,
|
||||
)
|
||||
set_student_deployment(args.model)
|
||||
ts = time.strftime("%Y%m%d_%H%M%S")
|
||||
out_root = args.out_root or os.path.join(_PROJECT_ROOT, "outputs", f"prompt_custom_{args.model}_{ts}")
|
||||
out_root = os.path.abspath(out_root)
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
|
||||
items = load_items(JSONL_PATH)
|
||||
if args.limit: items = items[:args.limit]
|
||||
|
||||
print(f"{'='*60}")
|
||||
print(f" Prompt: CUSTOM (Critical Rules)")
|
||||
print(f" Model: {args.model}")
|
||||
print(f" Items: {len(items)}")
|
||||
print(f" Output: {out_root}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
t0 = time.time()
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=args.workers) as ex:
|
||||
futs = {ex.submit(process_one, it, DATA_ROOT, out_root, args.model): it for it in items}
|
||||
for i, fut in enumerate(as_completed(futs), 1):
|
||||
item = futs[fut]
|
||||
try:
|
||||
res = fut.result(timeout=300)
|
||||
except FuturesTimeoutError:
|
||||
res = {"id": str(item["id"]), "ok": False, "hard": 0, "soft": 0.0,
|
||||
"n_cases": 0, "n_pass": 0, "fail_reason": "timeout"}
|
||||
except Exception as e:
|
||||
res = {"id": str(item["id"]), "ok": False, "hard": 0, "soft": 0.0,
|
||||
"n_cases": 0, "n_pass": 0, "fail_reason": str(e)}
|
||||
results.append(res)
|
||||
status = "PASS" if res.get("hard") else "FAIL"
|
||||
dt = time.time() - t0
|
||||
print(f" {i}/{len(items)} id={res['id']:<10} {status} cases={res.get('n_pass',0)}/{res.get('n_cases',0)} dt={dt:.0f}s")
|
||||
|
||||
# Summary
|
||||
hard_sum = sum(r.get("hard", 0) for r in results)
|
||||
soft_sum = sum(r.get("soft", 0.0) for r in results)
|
||||
n = len(results)
|
||||
print(f"\n{'='*60}")
|
||||
print(f" CUSTOM prompt: hard={hard_sum}/{n}={hard_sum/n:.4f} soft={soft_sum/n:.4f}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
with open(os.path.join(out_root, "results.jsonl"), "w") as f:
|
||||
for r in results:
|
||||
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
||||
with open(os.path.join(out_root, "summary.json"), "w") as f:
|
||||
json.dump({"prompt": "custom", "model": args.model, "n": n,
|
||||
"hard": hard_sum/n, "soft": soft_sum/n}, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
352
scripts/eval_prompt_official.py
Normal file
352
scripts/eval_prompt_official.py
Normal file
@@ -0,0 +1,352 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Standalone eval: OFFICIAL prompt (SpreadsheetBench original) on verified-400.
|
||||
|
||||
Usage:
|
||||
python scripts/eval_prompt_official.py --workers 8
|
||||
python scripts/eval_prompt_official.py --workers 32 --limit 20
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError as FuturesTimeoutError
|
||||
|
||||
import openpyxl
|
||||
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
from skillopt.model import (
|
||||
chat_messages_with_deployment,
|
||||
configure_azure_openai,
|
||||
set_backend,
|
||||
set_student_deployment,
|
||||
)
|
||||
from skillopt.envs.spreadsheetbench.evaluator import evaluate
|
||||
|
||||
|
||||
# ── Config ──────────────────────────────────────────────────────────────────
|
||||
|
||||
DATA_ROOT = "/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400"
|
||||
JSONL_PATH = os.path.join(DATA_ROOT, "dataset.json")
|
||||
MODEL = "gpt-5-mini"
|
||||
|
||||
# ── Official Prompt (from SpreadsheetBench src/prompt.py) ───────────────────
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are an expert Python programmer specializing in spreadsheet manipulation. "
|
||||
"You will be given a user instruction together with a preview of an input .xlsx file. "
|
||||
"Your job is to write a single self-contained Python script that reads the input file "
|
||||
"at the path stored in the variable INPUT_PATH, performs the requested manipulation, "
|
||||
"and saves the result to OUTPUT_PATH. Use only the standard library, openpyxl, and pandas. "
|
||||
"Do not print anything. Do not use input(). Do not hardcode file paths. "
|
||||
"Return ONLY the Python code inside a single ```python ... ``` fenced block."
|
||||
)
|
||||
|
||||
|
||||
def build_system(skill_content: str = "") -> str:
|
||||
base = _SYSTEM_PROMPT
|
||||
if skill_content.strip():
|
||||
base += f"\n\n## Skill\n{skill_content.strip()}"
|
||||
return base
|
||||
|
||||
|
||||
def build_user(instruction, input_xlsx, instruction_type="", answer_position=""):
|
||||
try:
|
||||
preview = _preview_workbook(input_xlsx)
|
||||
except Exception as e:
|
||||
preview = f"(failed to preview: {e})"
|
||||
extra = ""
|
||||
if instruction_type:
|
||||
extra += f"\nInstruction type: {instruction_type}"
|
||||
if answer_position:
|
||||
extra += f"\nExpected answer position: {answer_position}"
|
||||
return (
|
||||
f"# Instruction\n{instruction}\n{extra}\n\n"
|
||||
f"# Input spreadsheet preview\n{preview}\n\n"
|
||||
"# Task\n"
|
||||
"Write a Python script that reads the workbook from the variable `INPUT_PATH`, "
|
||||
"applies the instruction, and writes the modified workbook to `OUTPUT_PATH`. "
|
||||
"Preserve all other cells unchanged. "
|
||||
"The preview may be truncated — do not hardcode row counts or assume the data ends at the last previewed row; "
|
||||
"iterate over all actual rows in the workbook instead. "
|
||||
"Return only a ```python``` code block."
|
||||
)
|
||||
|
||||
|
||||
# ── Shared utilities (identical to custom version) ──────────────────────────
|
||||
|
||||
def _preview_workbook(path, max_rows=5, max_cols=20):
|
||||
wb = openpyxl.load_workbook(path, data_only=False)
|
||||
chunks = []
|
||||
for sn in wb.sheetnames:
|
||||
ws = wb[sn]
|
||||
chunks.append(f"## Sheet: {sn} (dim={ws.dimensions}, max_row={ws.max_row}, max_col={ws.max_column})")
|
||||
for row in ws.iter_rows(min_row=1, max_row=min(ws.max_row, max_rows),
|
||||
max_col=min(ws.max_column, max_cols), values_only=False):
|
||||
cells = []
|
||||
for c in row:
|
||||
v = c.value
|
||||
s = "" if v is None else str(v)
|
||||
if len(s) > 40: s = s[:37] + "..."
|
||||
cells.append(f"{c.coordinate}={s}")
|
||||
chunks.append(" | ".join(cells))
|
||||
if ws.max_row > max_rows:
|
||||
chunks.append(f"... ({ws.max_row - max_rows} more rows)")
|
||||
chunks.append("")
|
||||
wb.close()
|
||||
return "\n".join(chunks)
|
||||
|
||||
|
||||
def extract_code(text):
|
||||
if "```" not in text:
|
||||
return text.strip()
|
||||
start = text.find("```")
|
||||
nl = text.find("\n", start)
|
||||
end = text.find("```", nl + 1)
|
||||
if nl == -1 or end == -1:
|
||||
return text.strip()
|
||||
return text[nl + 1:end].strip()
|
||||
|
||||
|
||||
_PATH_RE = re.compile(r'^\s*(INPUT_PATH|OUTPUT_PATH)\s*=\s*.+$', re.MULTILINE)
|
||||
|
||||
def strip_paths(code):
|
||||
return _PATH_RE.sub("", code)
|
||||
|
||||
|
||||
RUNNER_TEMPLATE = textwrap.dedent("""
|
||||
import os, sys, traceback
|
||||
INPUT_PATH = {input_path!r}
|
||||
OUTPUT_PATH = {output_path!r}
|
||||
try:
|
||||
{code_indented}
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
sys.exit(2)
|
||||
""")
|
||||
|
||||
|
||||
def run_code(code, input_path, output_path, timeout=120):
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
cleaned = strip_paths(code)
|
||||
indented = textwrap.indent(cleaned, " ")
|
||||
script = RUNNER_TEMPLATE.format(input_path=input_path, output_path=output_path, code_indented=indented)
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
|
||||
f.write(script)
|
||||
tmp = f.name
|
||||
try:
|
||||
proc = subprocess.run([sys.executable, tmp], capture_output=True, text=True, timeout=timeout)
|
||||
if proc.returncode != 0:
|
||||
return False, (proc.stdout + "\n" + proc.stderr).strip()
|
||||
if not os.path.exists(output_path):
|
||||
return False, "output file was not created"
|
||||
return True, ""
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, f"timeout after {timeout}s"
|
||||
finally:
|
||||
try: os.unlink(tmp)
|
||||
except OSError: pass
|
||||
|
||||
|
||||
def find_test_cases(task_dir):
|
||||
cases = []
|
||||
for ip in sorted(glob.glob(os.path.join(task_dir, "*_input.xlsx"))):
|
||||
no = os.path.basename(ip).split("_", 1)[0]
|
||||
ap = ip.replace("_input.xlsx", "_answer.xlsx")
|
||||
if os.path.exists(ap): cases.append((no, ip, ap))
|
||||
for ip in sorted(glob.glob(os.path.join(task_dir, "*_init.xlsx"))):
|
||||
no = os.path.basename(ip).split("_", 1)[0]
|
||||
ap = ip.replace("_init.xlsx", "_golden.xlsx")
|
||||
if os.path.exists(ap): cases.append((no, ip, ap))
|
||||
if not cases:
|
||||
bare_init = os.path.join(task_dir, "initial.xlsx")
|
||||
bare_gold = os.path.join(task_dir, "golden.xlsx")
|
||||
if os.path.exists(bare_init) and os.path.exists(bare_gold):
|
||||
cases.append(("1", bare_init, bare_gold))
|
||||
return cases
|
||||
|
||||
|
||||
def load_items(path):
|
||||
if path.endswith(".json"):
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
data = data.get("data") or list(data.values())
|
||||
return list(data)
|
||||
items = []
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line: items.append(json.loads(line))
|
||||
return items
|
||||
|
||||
|
||||
# ── LLM call ────────────────────────────────────────────────────────────────
|
||||
|
||||
def llm_call(messages, deployment, max_tokens=16384, retries=5, llm_timeout=120):
|
||||
raw, _ = chat_messages_with_deployment(
|
||||
deployment=deployment,
|
||||
messages=messages,
|
||||
max_completion_tokens=max_tokens,
|
||||
retries=retries,
|
||||
stage="rollout",
|
||||
timeout=llm_timeout,
|
||||
)
|
||||
return str(raw or "")
|
||||
|
||||
|
||||
# ── Process one task ────────────────────────────────────────────────────────
|
||||
|
||||
def process_one(item, data_root, out_root, model):
|
||||
task_id = str(item["id"])
|
||||
instruction = item["instruction"]
|
||||
instruction_type = item.get("instruction_type", "")
|
||||
answer_position = item.get("answer_position", "")
|
||||
answer_sheet = item.get("answer_sheet", "")
|
||||
if answer_position and answer_sheet and "!" not in answer_position:
|
||||
answer_position = f"{answer_sheet}!{answer_position}"
|
||||
|
||||
sp = item.get("spreadsheet_path", f"spreadsheet/{task_id}")
|
||||
task_dir = sp if os.path.isabs(sp) else os.path.join(data_root, sp)
|
||||
|
||||
result = {"id": task_id, "ok": False, "hard": 0, "soft": 0.0,
|
||||
"n_cases": 0, "n_pass": 0, "fail_reason": "", "error": ""}
|
||||
try:
|
||||
cases = find_test_cases(task_dir)
|
||||
result["n_cases"] = len(cases)
|
||||
if not cases:
|
||||
result["fail_reason"] = "no-test-cases"
|
||||
return result
|
||||
|
||||
task_out = os.path.join(out_root, "predictions", task_id)
|
||||
os.makedirs(task_out, exist_ok=True)
|
||||
|
||||
# LLM call
|
||||
system = build_system("")
|
||||
user = build_user(instruction, cases[0][1], instruction_type, answer_position)
|
||||
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
|
||||
|
||||
raw = llm_call(messages, model)
|
||||
time.sleep(3)
|
||||
code = extract_code(raw)
|
||||
|
||||
with open(os.path.join(task_out, "code.py"), "w") as f: f.write(code)
|
||||
with open(os.path.join(task_out, "raw.txt"), "w") as f: f.write(raw)
|
||||
|
||||
if not code.strip():
|
||||
result["fail_reason"] = "empty-code"
|
||||
return result
|
||||
|
||||
# Execute + evaluate each test case
|
||||
for no, ip, ap in cases:
|
||||
pred = os.path.join(task_out, f"{no}_pred.xlsx")
|
||||
ok_exec, err = run_code(code, ip, pred)
|
||||
if not ok_exec:
|
||||
if not result["fail_reason"]:
|
||||
result["fail_reason"] = f"exec: {err[:200]}"
|
||||
continue
|
||||
try:
|
||||
ev = evaluate(pred, ap, instruction_type, answer_position)
|
||||
except Exception as e:
|
||||
ev = {"ok": False, "reason": str(e)}
|
||||
if ev["ok"]:
|
||||
result["n_pass"] += 1
|
||||
|
||||
nc, np = result["n_cases"], result["n_pass"]
|
||||
result["soft"] = np / nc if nc else 0.0
|
||||
result["hard"] = 1 if nc > 0 and np == nc else 0
|
||||
result["ok"] = bool(result["hard"])
|
||||
if result["ok"]: result["fail_reason"] = ""
|
||||
return result
|
||||
except Exception as e:
|
||||
result["fail_reason"] = f"unexpected: {e}"
|
||||
result["error"] = traceback.format_exc()
|
||||
return result
|
||||
|
||||
|
||||
# ── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Eval OFFICIAL prompt on verified-400")
|
||||
ap.add_argument("--model", default=MODEL)
|
||||
ap.add_argument("--backend", choices=["azure_openai", "codex", "claude"], default="azure_openai")
|
||||
ap.add_argument("--azure_endpoint", default="")
|
||||
ap.add_argument("--azure_api_version", default="")
|
||||
ap.add_argument("--azure_api_key", default="")
|
||||
ap.add_argument("--workers", type=int, default=8)
|
||||
ap.add_argument("--limit", type=int, default=0)
|
||||
ap.add_argument("--out_root", default="")
|
||||
args = ap.parse_args()
|
||||
|
||||
set_backend(args.backend)
|
||||
configure_azure_openai(
|
||||
endpoint=args.azure_endpoint or None,
|
||||
api_version=args.azure_api_version or None,
|
||||
api_key=args.azure_api_key or None,
|
||||
)
|
||||
set_student_deployment(args.model)
|
||||
ts = time.strftime("%Y%m%d_%H%M%S")
|
||||
out_root = args.out_root or os.path.join(_PROJECT_ROOT, "outputs", f"prompt_official_{args.model}_{ts}")
|
||||
out_root = os.path.abspath(out_root)
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
|
||||
items = load_items(JSONL_PATH)
|
||||
if args.limit: items = items[:args.limit]
|
||||
|
||||
print(f"{'='*60}")
|
||||
print(f" Prompt: OFFICIAL (SpreadsheetBench original)")
|
||||
print(f" Model: {args.model}")
|
||||
print(f" Items: {len(items)}")
|
||||
print(f" Output: {out_root}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
t0 = time.time()
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=args.workers) as ex:
|
||||
futs = {ex.submit(process_one, it, DATA_ROOT, out_root, args.model): it for it in items}
|
||||
for i, fut in enumerate(as_completed(futs), 1):
|
||||
item = futs[fut]
|
||||
try:
|
||||
res = fut.result(timeout=300)
|
||||
except FuturesTimeoutError:
|
||||
res = {"id": str(item["id"]), "ok": False, "hard": 0, "soft": 0.0,
|
||||
"n_cases": 0, "n_pass": 0, "fail_reason": "timeout"}
|
||||
except Exception as e:
|
||||
res = {"id": str(item["id"]), "ok": False, "hard": 0, "soft": 0.0,
|
||||
"n_cases": 0, "n_pass": 0, "fail_reason": str(e)}
|
||||
results.append(res)
|
||||
status = "PASS" if res.get("hard") else "FAIL"
|
||||
dt = time.time() - t0
|
||||
print(f" {i}/{len(items)} id={res['id']:<10} {status} cases={res.get('n_pass',0)}/{res.get('n_cases',0)} dt={dt:.0f}s")
|
||||
|
||||
# Summary
|
||||
hard_sum = sum(r.get("hard", 0) for r in results)
|
||||
soft_sum = sum(r.get("soft", 0.0) for r in results)
|
||||
n = len(results)
|
||||
print(f"\n{'='*60}")
|
||||
print(f" OFFICIAL prompt: hard={hard_sum}/{n}={hard_sum/n:.4f} soft={soft_sum/n:.4f}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
with open(os.path.join(out_root, "results.jsonl"), "w") as f:
|
||||
for r in results:
|
||||
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
||||
with open(os.path.join(out_root, "summary.json"), "w") as f:
|
||||
json.dump({"prompt": "official", "model": args.model, "n": n,
|
||||
"hard": hard_sum/n, "soft": soft_sum/n}, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
37
scripts/eval_searchqa_val500.sh
Executable file
37
scripts/eval_searchqa_val500.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# ReflACT3 — SearchQA Eval-Only (验证集 500)
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/eval_searchqa_val500.sh --skill_path outputs/xxx/best_skill.md
|
||||
# bash scripts/eval_searchqa_val500.sh --skill_path outputs/xxx/best_skill.md --workers 32
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
||||
|
||||
export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5-mini}"
|
||||
|
||||
VAL_PATH="/home/azureuser/workspace-yqh/refleAct/search-qa/data/searchqa_val_500.json"
|
||||
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
DEFAULT_OUT_ROOT="${PROJECT_ROOT}/outputs/searchqa_eval_val500_${STUDENT_DEPLOYMENT}_${TIMESTAMP}"
|
||||
|
||||
echo "============================================================"
|
||||
echo " ReflACT3 — SearchQA Eval-Only (val-500)"
|
||||
echo " Student: ${STUDENT_DEPLOYMENT}"
|
||||
echo " Data: ${VAL_PATH}"
|
||||
echo "============================================================"
|
||||
|
||||
cd "${PROJECT_ROOT}"
|
||||
|
||||
python scripts/eval_only.py \
|
||||
--config configs/searchqa_default.yaml \
|
||||
--data_path "${VAL_PATH}" \
|
||||
--out_root "${DEFAULT_OUT_ROOT}" \
|
||||
"$@"
|
||||
|
||||
echo ""
|
||||
echo "Done! Results saved to: ${DEFAULT_OUT_ROOT}"
|
||||
41
scripts/eval_verified400.sh
Executable file
41
scripts/eval_verified400.sh
Executable file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Eval skill0 on full SpreadsheetBench verified-400
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/eval_verified400.sh
|
||||
# bash scripts/eval_verified400.sh --workers 64
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
|
||||
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
||||
|
||||
# ── Paths ────────────────────────────────────────────────────────────────────
|
||||
DATA_ROOT="/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400"
|
||||
JSONL_PATH="${DATA_ROOT}/dataset.json"
|
||||
SKILL_PATH="${PROJECT_ROOT}/skillopt/envs/spreadsheetbench/skills/initial.md"
|
||||
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
OUT_ROOT="${PROJECT_ROOT}/outputs/eval_verified400_${TIMESTAMP}"
|
||||
|
||||
echo "============================================================"
|
||||
echo " Eval skill0 on verified-400 (full)"
|
||||
echo "============================================================"
|
||||
echo " data_root: ${DATA_ROOT}"
|
||||
echo " skill: ${SKILL_PATH}"
|
||||
echo " out_root: ${OUT_ROOT}"
|
||||
echo "============================================================"
|
||||
|
||||
cd "${PROJECT_ROOT}"
|
||||
|
||||
python scripts/eval_only.py \
|
||||
--config configs/spreadsheetbench_default.yaml \
|
||||
--skill "${SKILL_PATH}" \
|
||||
--split all \
|
||||
--data_root "${DATA_ROOT}" \
|
||||
--jsonl_path "${JSONL_PATH}" \
|
||||
--out_root "${OUT_ROOT}" \
|
||||
"$@"
|
||||
42
scripts/eval_verified400_multi.sh
Executable file
42
scripts/eval_verified400_multi.sh
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Eval skill0 on full SpreadsheetBench verified-400 (MULTI-ROUND codegen)
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/eval_verified400_multi.sh
|
||||
# bash scripts/eval_verified400_multi.sh --workers 64 --max_turns 5
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
|
||||
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
||||
|
||||
DATA_ROOT="/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400"
|
||||
JSONL_PATH="${DATA_ROOT}/dataset.json"
|
||||
SKILL_PATH="${PROJECT_ROOT}/skillopt/envs/spreadsheetbench/skills/initial.md"
|
||||
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
OUT_ROOT="${PROJECT_ROOT}/outputs/eval_multi_verified400_${TIMESTAMP}"
|
||||
|
||||
echo "============================================================"
|
||||
echo " Eval skill0 — MULTI-ROUND codegen — verified-400"
|
||||
echo "============================================================"
|
||||
echo " data_root: ${DATA_ROOT}"
|
||||
echo " skill: ${SKILL_PATH}"
|
||||
echo " mode: multi"
|
||||
echo " out_root: ${OUT_ROOT}"
|
||||
echo "============================================================"
|
||||
|
||||
cd "${PROJECT_ROOT}"
|
||||
|
||||
python scripts/eval_only.py \
|
||||
--config configs/spreadsheetbench_default.yaml \
|
||||
--skill "${SKILL_PATH}" \
|
||||
--split all \
|
||||
--mode multi \
|
||||
--data_root "${DATA_ROOT}" \
|
||||
--jsonl_path "${JSONL_PATH}" \
|
||||
--out_root "${OUT_ROOT}" \
|
||||
"$@"
|
||||
42
scripts/eval_verified400_single.sh
Executable file
42
scripts/eval_verified400_single.sh
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# Eval skill0 on full SpreadsheetBench verified-400 (SINGLE-ROUND codegen)
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/eval_verified400_single.sh
|
||||
# bash scripts/eval_verified400_single.sh --workers 64
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
|
||||
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
||||
|
||||
DATA_ROOT="/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400"
|
||||
JSONL_PATH="${DATA_ROOT}/dataset.json"
|
||||
SKILL_PATH="${PROJECT_ROOT}/skillopt/envs/spreadsheetbench/skills/initial.md"
|
||||
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
OUT_ROOT="${PROJECT_ROOT}/outputs/eval_single_verified400_${TIMESTAMP}"
|
||||
|
||||
echo "============================================================"
|
||||
echo " Eval skill0 — SINGLE-ROUND codegen — verified-400"
|
||||
echo "============================================================"
|
||||
echo " data_root: ${DATA_ROOT}"
|
||||
echo " skill: ${SKILL_PATH}"
|
||||
echo " mode: single"
|
||||
echo " out_root: ${OUT_ROOT}"
|
||||
echo "============================================================"
|
||||
|
||||
cd "${PROJECT_ROOT}"
|
||||
|
||||
python scripts/eval_only.py \
|
||||
--config configs/spreadsheetbench_default.yaml \
|
||||
--skill "${SKILL_PATH}" \
|
||||
--split all \
|
||||
--mode single \
|
||||
--data_root "${DATA_ROOT}" \
|
||||
--jsonl_path "${JSONL_PATH}" \
|
||||
--out_root "${OUT_ROOT}" \
|
||||
"$@"
|
||||
120
scripts/launch_harness_bestsetting_from_scratch.sh
Executable file
120
scripts/launch_harness_bestsetting_from_scratch.sh
Executable file
@@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
PY="${PY:-python}"
|
||||
RUN_ROOT="${RUN_ROOT:-$ROOT/outputs/harness_bestsetting_fromscratch_$(date -u +%Y%m%d_%H%M%S)_run}"
|
||||
MAX_PARALLEL="${MAX_PARALLEL:-2}"
|
||||
|
||||
mkdir -p "$RUN_ROOT/logs"
|
||||
cd "$ROOT"
|
||||
export PYTHONPATH="$ROOT:${PYTHONPATH:-}"
|
||||
|
||||
COMMON=(
|
||||
model.teacher_backend=openai_chat
|
||||
model.teacher=gpt-5.5
|
||||
model.teacher_azure_openai_endpoint=https://t2vgoaigpt4o3.openai.azure.com/
|
||||
model.teacher_azure_openai_api_version=2024-12-01-preview
|
||||
model.teacher_azure_openai_auth_mode=azure_cli
|
||||
model.reasoning_effort=medium
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.min_learning_rate=2
|
||||
optimizer.lr_control_mode=fixed
|
||||
optimizer.skill_update_mode=patch
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
)
|
||||
|
||||
CODEX=(
|
||||
model.student_backend=codex_exec
|
||||
model.student=gpt-5.5
|
||||
model.codex_exec_use_sdk=auto
|
||||
model.codex_exec_sandbox=workspace-write
|
||||
model.codex_exec_approval_policy=never
|
||||
model.codex_trace_to_teacher=true
|
||||
)
|
||||
|
||||
CLAUDE=(
|
||||
model.student_backend=claude_code_exec
|
||||
model.student=claude-sonnet-4-6
|
||||
model.claude_code_exec_use_sdk=auto
|
||||
model.codex_trace_to_teacher=false
|
||||
)
|
||||
|
||||
active=0
|
||||
launch() {
|
||||
local run_id="$1"; shift
|
||||
local config="$1"; shift
|
||||
local out="$RUN_ROOT/$run_id"
|
||||
local log="$RUN_ROOT/logs/$run_id.log"
|
||||
echo "START $run_id"
|
||||
setsid "$PY" -u scripts/train.py \
|
||||
--config "$config" \
|
||||
--cfg-options "${COMMON[@]}" "$@" "env.out_root=$out" \
|
||||
> "$log" 2>&1 < /dev/null &
|
||||
active=$((active + 1))
|
||||
if (( active >= MAX_PARALLEL )); then
|
||||
wait -n
|
||||
active=$((active - 1))
|
||||
fi
|
||||
}
|
||||
|
||||
# SearchQA best openai-chat setting: optimizer.lr_scheduler=constant.
|
||||
launch HARNESS-BESTSETTING-searchqa-codex configs/searchqa/default.yaml \
|
||||
"${CODEX[@]}" \
|
||||
train.batch_size=40 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant \
|
||||
env.split_dir=data/searchqa/splits
|
||||
|
||||
launch HARNESS-BESTSETTING-searchqa-claude configs/searchqa/default.yaml \
|
||||
"${CLAUDE[@]}" \
|
||||
train.batch_size=40 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant \
|
||||
env.split_dir=data/searchqa/splits
|
||||
|
||||
# SpreadsheetBench best openai-chat setting: optimizer.lr_scheduler=constant.
|
||||
# Must stay env.mode=multi; exec-backend multi support is fixed on this branch.
|
||||
launch HARNESS-BESTSETTING-spreadsheetbench-codex configs/spreadsheetbench/default.yaml \
|
||||
"${CODEX[@]}" \
|
||||
train.batch_size=40 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant \
|
||||
env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi env.workers=4
|
||||
|
||||
launch HARNESS-BESTSETTING-spreadsheetbench-claude configs/spreadsheetbench/default.yaml \
|
||||
"${CLAUDE[@]}" \
|
||||
train.batch_size=40 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant \
|
||||
env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi env.workers=4
|
||||
|
||||
# LiveMathBench best openai-chat setting: optimizer.learning_rate=8.
|
||||
launch HARNESS-BESTSETTING-livemathematicianbench-codex configs/livemathematicianbench/default.yaml \
|
||||
"${CODEX[@]}" \
|
||||
train.batch_size=40 optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant \
|
||||
env.split_dir=data/livemathbench/splits
|
||||
|
||||
launch HARNESS-BESTSETTING-livemathematicianbench-claude configs/livemathematicianbench/default.yaml \
|
||||
"${CLAUDE[@]}" \
|
||||
train.batch_size=40 optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant \
|
||||
env.split_dir=data/livemathbench/splits
|
||||
|
||||
# DocVQA best openai-chat setting was full batch. On 10% harness split, train=107.
|
||||
launch HARNESS-BESTSETTING-docvqa10pct-codex configs/docvqa/default.yaml \
|
||||
"${CODEX[@]}" \
|
||||
train.batch_size=107 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine \
|
||||
env.split_dir=data/harness_splits/docvqa_zisu_first10pct
|
||||
|
||||
launch HARNESS-BESTSETTING-docvqa10pct-claude configs/docvqa/default.yaml \
|
||||
"${CLAUDE[@]}" \
|
||||
train.batch_size=107 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine \
|
||||
env.split_dir=data/harness_splits/docvqa_zisu_first10pct
|
||||
|
||||
wait
|
||||
echo "All launched runs finished or exited. RUN_ROOT=$RUN_ROOT"
|
||||
178
scripts/launch_harness_canonical_claude18.sh
Executable file
178
scripts/launch_harness_canonical_claude18.sh
Executable file
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
REPO="/home/azureuser/workspace-gzy/SkillReflection"
|
||||
PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
|
||||
cd "$REPO"
|
||||
|
||||
# Claude Code on this machine must use the local copilot-api proxy.
|
||||
export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}"
|
||||
export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}"
|
||||
export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}"
|
||||
export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}"
|
||||
export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}"
|
||||
export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}"
|
||||
|
||||
if [[ -f ".secrets/teacher_oaidr9.env" ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source ".secrets/teacher_oaidr9.env"
|
||||
else
|
||||
echo "missing .secrets/teacher_oaidr9.env" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
stamp="$(date -u +%Y%m%d_%H%M%S)"
|
||||
RUN_ROOT="${1:-outputs/harness_canonical_claude18_workers2_timeout1020_${stamp}_run}"
|
||||
SESSION="${2:-harness_canon_claude18_${stamp}}"
|
||||
|
||||
mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands"
|
||||
|
||||
COMMON_CFG=(
|
||||
model.teacher_backend=openai_chat
|
||||
model.teacher=gpt-5.5
|
||||
model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.reasoning_effort=medium
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.lr_control_mode=fixed
|
||||
optimizer.skill_update_mode=patch
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
env.workers=2
|
||||
env.exec_timeout=1020
|
||||
model.student_backend=claude_code_exec
|
||||
model.student=claude-sonnet-4-6
|
||||
model.claude_code_exec_use_sdk=sdk
|
||||
model.claude_code_exec_effort=medium
|
||||
model.claude_code_exec_max_thinking_tokens=16384
|
||||
model.codex_trace_to_teacher=false
|
||||
)
|
||||
|
||||
tmux_started=0
|
||||
|
||||
launch_run() {
|
||||
local run_id="$1"
|
||||
local config="$2"
|
||||
local skill="$3"
|
||||
shift 3
|
||||
|
||||
local cmd_file="$RUN_ROOT/commands/${run_id}.sh"
|
||||
local log_file="$RUN_ROOT/logs/${run_id}.log"
|
||||
local out_root="$RUN_ROOT/$run_id"
|
||||
|
||||
local -a cmd=(
|
||||
"$PYTHON" -u scripts/train.py
|
||||
--config "$config"
|
||||
--cfg-options
|
||||
"${COMMON_CFG[@]}"
|
||||
env.skill_init="$skill"
|
||||
env.out_root="$out_root"
|
||||
"$@"
|
||||
)
|
||||
|
||||
{
|
||||
echo "#!/usr/bin/env bash"
|
||||
echo "set -euo pipefail"
|
||||
echo "cd '$REPO'"
|
||||
printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL"
|
||||
printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN"
|
||||
printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL"
|
||||
printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL"
|
||||
printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS"
|
||||
printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC"
|
||||
printf '%q ' "${cmd[@]}"
|
||||
printf ' >%q 2>&1 < /dev/null\n' "$log_file"
|
||||
} > "$cmd_file"
|
||||
chmod +x "$cmd_file"
|
||||
|
||||
if [[ "$tmux_started" -eq 0 ]]; then
|
||||
tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
tmux_started=1
|
||||
else
|
||||
tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
fi
|
||||
echo "$run_id"
|
||||
}
|
||||
|
||||
SEARCHQA_SKILL="docs/harness_source_skills/searchqa_best_skill.md"
|
||||
LIVEMATH_SKILL="docs/harness_source_skills/livemathematicianbench_best_skill.md"
|
||||
DOCVQA_SKILL="docs/harness_source_skills/docvqa_best_skill.md"
|
||||
SPREADSHEET_SKILL="docs/harness_source_skills/spreadsheetbench_best_skill.md"
|
||||
|
||||
launch_run "HARNESS-Claude-SearchQA-sched-constant" "configs/searchqa/default.yaml" "$SEARCHQA_SKILL" \
|
||||
env.split_dir=data/searchqa/splits \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant
|
||||
launch_run "HARNESS-Claude-SearchQA-sched-linear" "configs/searchqa/default.yaml" "$SEARCHQA_SKILL" \
|
||||
env.split_dir=data/searchqa/splits \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=linear
|
||||
launch_run "HARNESS-Claude-SearchQA-batch-full" "configs/searchqa/default.yaml" "$SEARCHQA_SKILL" \
|
||||
env.split_dir=data/searchqa/splits \
|
||||
train.batch_size=400 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine
|
||||
launch_run "HARNESS-Claude-SearchQA-lr8" "configs/searchqa/default.yaml" "$SEARCHQA_SKILL" \
|
||||
env.split_dir=data/searchqa/splits \
|
||||
optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
|
||||
launch_run "HARNESS-Claude-LiveMath-lr8" "configs/livemathematicianbench/default.yaml" "$LIVEMATH_SKILL" \
|
||||
env.split_dir=data/livemathbench/splits \
|
||||
optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
launch_run "HARNESS-Claude-LiveMath-lr16" "configs/livemathematicianbench/default.yaml" "$LIVEMATH_SKILL" \
|
||||
env.split_dir=data/livemathbench/splits \
|
||||
optimizer.learning_rate=16 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
launch_run "HARNESS-Claude-LiveMath-slow10" "configs/livemathematicianbench/default.yaml" "$LIVEMATH_SKILL" \
|
||||
env.split_dir=data/livemathbench/splits \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine optimizer.slow_update_samples=10
|
||||
launch_run "HARNESS-Claude-LiveMath-sched-linear" "configs/livemathematicianbench/default.yaml" "$LIVEMATH_SKILL" \
|
||||
env.split_dir=data/livemathbench/splits \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=linear
|
||||
launch_run "HARNESS-Claude-LiveMath-minibatch4" "configs/livemathematicianbench/default.yaml" "$LIVEMATH_SKILL" \
|
||||
env.split_dir=data/livemathbench/splits \
|
||||
gradient.minibatch_size=4 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine
|
||||
|
||||
launch_run "HARNESS-Claude-DocVQA10-batch-full" "configs/docvqa/default.yaml" "$DOCVQA_SKILL" \
|
||||
env.split_dir=data/harness_splits/docvqa_zisu_first10pct \
|
||||
train.batch_size=107 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine
|
||||
launch_run "HARNESS-Claude-DocVQA10-lr16" "configs/docvqa/default.yaml" "$DOCVQA_SKILL" \
|
||||
env.split_dir=data/harness_splits/docvqa_zisu_first10pct \
|
||||
optimizer.learning_rate=16 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
launch_run "HARNESS-Claude-DocVQA10-lr8" "configs/docvqa/default.yaml" "$DOCVQA_SKILL" \
|
||||
env.split_dir=data/harness_splits/docvqa_zisu_first10pct \
|
||||
optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
launch_run "HARNESS-Claude-DocVQA10-minibatch32" "configs/docvqa/default.yaml" "$DOCVQA_SKILL" \
|
||||
env.split_dir=data/harness_splits/docvqa_zisu_first10pct \
|
||||
gradient.minibatch_size=32 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine
|
||||
launch_run "HARNESS-Claude-DocVQA10-batch24" "configs/docvqa/default.yaml" "$DOCVQA_SKILL" \
|
||||
env.split_dir=data/harness_splits/docvqa_zisu_first10pct \
|
||||
train.batch_size=24 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine
|
||||
|
||||
launch_run "HARNESS-Claude-Spreadsheet-sched-constant-multi" "configs/spreadsheetbench/default.yaml" "$SPREADSHEET_SKILL" \
|
||||
env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant
|
||||
launch_run "HARNESS-Claude-Spreadsheet-lr4-multi" "configs/spreadsheetbench/default.yaml" "$SPREADSHEET_SKILL" \
|
||||
env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
launch_run "HARNESS-Claude-Spreadsheet-lr16-multi" "configs/spreadsheetbench/default.yaml" "$SPREADSHEET_SKILL" \
|
||||
env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \
|
||||
optimizer.learning_rate=16 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
launch_run "HARNESS-Claude-Spreadsheet-minibatch16-multi" "configs/spreadsheetbench/default.yaml" "$SPREADSHEET_SKILL" \
|
||||
env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \
|
||||
gradient.minibatch_size=16 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine
|
||||
|
||||
echo "RUN_ROOT=$RUN_ROOT"
|
||||
echo "SESSION=$SESSION"
|
||||
130
scripts/launch_harness_canonical_claude4_smoke.sh
Executable file
130
scripts/launch_harness_canonical_claude4_smoke.sh
Executable file
@@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
REPO="/home/azureuser/workspace-gzy/SkillReflection"
|
||||
PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
|
||||
cd "$REPO"
|
||||
|
||||
export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}"
|
||||
export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}"
|
||||
export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}"
|
||||
export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}"
|
||||
export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}"
|
||||
export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}"
|
||||
|
||||
if [[ -f ".secrets/teacher_oaidr9.env" ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source ".secrets/teacher_oaidr9.env"
|
||||
else
|
||||
echo "missing .secrets/teacher_oaidr9.env" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
stamp="$(date -u +%Y%m%d_%H%M%S)"
|
||||
RUN_ROOT="${1:-outputs/harness_canonical_claude4_smoke_workers2_timeout1020_${stamp}_run}"
|
||||
SESSION="${2:-harness_canon_claude4_${stamp}}"
|
||||
|
||||
mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands"
|
||||
|
||||
COMMON_CFG=(
|
||||
model.teacher_backend=openai_chat
|
||||
model.teacher=gpt-5.5
|
||||
model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.reasoning_effort=medium
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.lr_control_mode=fixed
|
||||
optimizer.skill_update_mode=patch
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
env.workers=2
|
||||
env.exec_timeout=1020
|
||||
model.student_backend=claude_code_exec
|
||||
model.student=claude-sonnet-4-6
|
||||
model.claude_code_exec_use_sdk=sdk
|
||||
model.claude_code_exec_effort=medium
|
||||
model.claude_code_exec_max_thinking_tokens=16384
|
||||
model.codex_trace_to_teacher=false
|
||||
)
|
||||
|
||||
tmux_started=0
|
||||
|
||||
launch_run() {
|
||||
local run_id="$1"
|
||||
local config="$2"
|
||||
local skill="$3"
|
||||
shift 3
|
||||
|
||||
local cmd_file="$RUN_ROOT/commands/${run_id}.sh"
|
||||
local log_file="$RUN_ROOT/logs/${run_id}.log"
|
||||
local out_root="$RUN_ROOT/$run_id"
|
||||
|
||||
local -a cmd=(
|
||||
"$PYTHON" -u scripts/train.py
|
||||
--config "$config"
|
||||
--cfg-options
|
||||
"${COMMON_CFG[@]}"
|
||||
env.skill_init="$skill"
|
||||
env.out_root="$out_root"
|
||||
"$@"
|
||||
)
|
||||
|
||||
{
|
||||
echo "#!/usr/bin/env bash"
|
||||
echo "set -euo pipefail"
|
||||
echo "cd '$REPO'"
|
||||
printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL"
|
||||
printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN"
|
||||
printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL"
|
||||
printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL"
|
||||
printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS"
|
||||
printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC"
|
||||
printf '%q ' "${cmd[@]}"
|
||||
printf ' >%q 2>&1 < /dev/null\n' "$log_file"
|
||||
} > "$cmd_file"
|
||||
chmod +x "$cmd_file"
|
||||
|
||||
if [[ "$tmux_started" -eq 0 ]]; then
|
||||
tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
tmux_started=1
|
||||
else
|
||||
tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
fi
|
||||
echo "$run_id"
|
||||
}
|
||||
|
||||
launch_run "HARNESS-Claude-SearchQA-sched-constant" "configs/searchqa/default.yaml" "docs/harness_source_skills/searchqa_best_skill.md" \
|
||||
env.split_dir=data/searchqa/splits \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant
|
||||
|
||||
launch_run "HARNESS-Claude-LiveMath-lr8" "configs/livemathematicianbench/default.yaml" "docs/harness_source_skills/livemathematicianbench_best_skill.md" \
|
||||
env.split_dir=data/livemathbench/splits \
|
||||
optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
|
||||
launch_run "HARNESS-Claude-DocVQA10-lr8" "configs/docvqa/default.yaml" "docs/harness_source_skills/docvqa_best_skill.md" \
|
||||
env.split_dir=data/harness_splits/docvqa_zisu_first10pct \
|
||||
optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
|
||||
launch_run "HARNESS-Claude-Spreadsheet-lr4-multi" "configs/spreadsheetbench/default.yaml" "docs/harness_source_skills/spreadsheetbench_best_skill.md" \
|
||||
env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
|
||||
echo "RUN_ROOT=$RUN_ROOT"
|
||||
echo "SESSION=$SESSION"
|
||||
168
scripts/launch_harness_canonical_wave1.sh
Executable file
168
scripts/launch_harness_canonical_wave1.sh
Executable file
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
REPO="/home/azureuser/workspace-gzy/SkillReflection"
|
||||
PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
CODEX_WRAPPER="$REPO/scripts/codex_azure_mi.sh"
|
||||
|
||||
cd "$REPO"
|
||||
|
||||
# Claude Code is routed through the local copilot-api proxy on this machine.
|
||||
# Do not rely on interactive Claude login state inside tmux/train workers.
|
||||
export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}"
|
||||
export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}"
|
||||
export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}"
|
||||
export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}"
|
||||
export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}"
|
||||
export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}"
|
||||
|
||||
if [[ -f ".secrets/teacher_oaidr9.env" ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source ".secrets/teacher_oaidr9.env"
|
||||
else
|
||||
echo "missing .secrets/teacher_oaidr9.env" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
stamp="$(date -u +%Y%m%d_%H%M%S)"
|
||||
RUN_ROOT="${1:-outputs/harness_canonical_step12_wave1_workers2_timeout1020_${stamp}_run}"
|
||||
SESSION="${2:-harness_canon_wave1_${stamp}}"
|
||||
|
||||
mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands"
|
||||
|
||||
COMMON_CFG=(
|
||||
model.teacher_backend=openai_chat
|
||||
model.teacher=gpt-5.5
|
||||
model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.reasoning_effort=medium
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.lr_control_mode=fixed
|
||||
optimizer.skill_update_mode=patch
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
env.workers=2
|
||||
env.exec_timeout=1020
|
||||
)
|
||||
|
||||
tmux_started=0
|
||||
|
||||
launch_run() {
|
||||
local run_id="$1"
|
||||
local backend="$2"
|
||||
local config="$3"
|
||||
local skill="$4"
|
||||
local split_dir="$5"
|
||||
shift 5
|
||||
|
||||
local -a backend_cfg=()
|
||||
if [[ "$backend" == "codex" ]]; then
|
||||
backend_cfg=(
|
||||
model.student_backend=codex_exec
|
||||
model.student=gpt-5.5
|
||||
model.codex_exec_path="$CODEX_WRAPPER"
|
||||
model.codex_exec_use_sdk=auto
|
||||
model.codex_exec_sandbox=workspace-write
|
||||
model.codex_exec_approval_policy=never
|
||||
model.codex_exec_reasoning_effort=medium
|
||||
model.codex_trace_to_teacher=true
|
||||
)
|
||||
elif [[ "$backend" == "claude" ]]; then
|
||||
backend_cfg=(
|
||||
model.student_backend=claude_code_exec
|
||||
model.student=claude-sonnet-4-6
|
||||
model.claude_code_exec_use_sdk=sdk
|
||||
model.claude_code_exec_effort=medium
|
||||
model.claude_code_exec_max_thinking_tokens=16384
|
||||
model.codex_trace_to_teacher=false
|
||||
)
|
||||
else
|
||||
echo "unknown backend: $backend" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
local cmd_file="$RUN_ROOT/commands/${run_id}.sh"
|
||||
local log_file="$RUN_ROOT/logs/${run_id}.log"
|
||||
local out_root="$RUN_ROOT/$run_id"
|
||||
|
||||
local -a cmd=(
|
||||
"$PYTHON" -u scripts/train.py
|
||||
--config "$config"
|
||||
--cfg-options
|
||||
"${COMMON_CFG[@]}"
|
||||
"${backend_cfg[@]}"
|
||||
env.split_dir="$split_dir"
|
||||
env.skill_init="$skill"
|
||||
env.out_root="$out_root"
|
||||
"$@"
|
||||
)
|
||||
|
||||
{
|
||||
echo "#!/usr/bin/env bash"
|
||||
echo "set -euo pipefail"
|
||||
echo "cd '$REPO'"
|
||||
printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL"
|
||||
printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN"
|
||||
printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL"
|
||||
printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL"
|
||||
printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS"
|
||||
printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC"
|
||||
printf '%q ' "${cmd[@]}"
|
||||
printf ' >%q 2>&1 < /dev/null\n' "$log_file"
|
||||
} > "$cmd_file"
|
||||
chmod +x "$cmd_file"
|
||||
|
||||
if [[ "$tmux_started" -eq 0 ]]; then
|
||||
tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
tmux_started=1
|
||||
else
|
||||
tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
fi
|
||||
echo "$run_id"
|
||||
}
|
||||
|
||||
SEARCHQA_SKILL="docs/harness_source_skills/searchqa_best_skill.md"
|
||||
LIVEMATH_SKILL="docs/harness_source_skills/livemathematicianbench_best_skill.md"
|
||||
|
||||
SEARCHQA_CFG="configs/searchqa/default.yaml"
|
||||
LIVEMATH_CFG="configs/livemathematicianbench/default.yaml"
|
||||
|
||||
SEARCHQA_SPLIT="data/searchqa/splits"
|
||||
LIVEMATH_SPLIT="data/livemathbench/splits"
|
||||
|
||||
for backend in codex claude; do
|
||||
prefix="HARNESS-Codex"
|
||||
[[ "$backend" == "claude" ]] && prefix="HARNESS-Claude"
|
||||
|
||||
launch_run "${prefix}-SearchQA-sched-constant" "$backend" "$SEARCHQA_CFG" "$SEARCHQA_SKILL" "$SEARCHQA_SPLIT" \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant
|
||||
launch_run "${prefix}-SearchQA-sched-linear" "$backend" "$SEARCHQA_CFG" "$SEARCHQA_SKILL" "$SEARCHQA_SPLIT" \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=linear
|
||||
launch_run "${prefix}-SearchQA-batch-full" "$backend" "$SEARCHQA_CFG" "$SEARCHQA_SKILL" "$SEARCHQA_SPLIT" \
|
||||
train.batch_size=400 optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=cosine
|
||||
launch_run "${prefix}-SearchQA-lr8" "$backend" "$SEARCHQA_CFG" "$SEARCHQA_SKILL" "$SEARCHQA_SPLIT" \
|
||||
optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
launch_run "${prefix}-LiveMath-lr8" "$backend" "$LIVEMATH_CFG" "$LIVEMATH_SKILL" "$LIVEMATH_SPLIT" \
|
||||
optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
launch_run "${prefix}-LiveMath-lr16" "$backend" "$LIVEMATH_CFG" "$LIVEMATH_SKILL" "$LIVEMATH_SPLIT" \
|
||||
optimizer.learning_rate=16 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
done
|
||||
|
||||
echo "RUN_ROOT=$RUN_ROOT"
|
||||
echo "SESSION=$SESSION"
|
||||
128
scripts/launch_harness_initial_claude4.sh
Executable file
128
scripts/launch_harness_initial_claude4.sh
Executable file
@@ -0,0 +1,128 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
REPO="/home/azureuser/workspace-gzy/SkillReflection"
|
||||
PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
|
||||
cd "$REPO"
|
||||
|
||||
export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}"
|
||||
export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}"
|
||||
export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}"
|
||||
export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}"
|
||||
export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}"
|
||||
export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}"
|
||||
|
||||
if [[ -f ".secrets/teacher_oaidr9.env" ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source ".secrets/teacher_oaidr9.env"
|
||||
else
|
||||
echo "missing .secrets/teacher_oaidr9.env" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
stamp="$(date -u +%Y%m%d_%H%M%S)"
|
||||
RUN_ROOT="${1:-outputs/harness_initial_claude4_workers2_timeout1020_${stamp}_run}"
|
||||
SESSION="${2:-harness_initial_claude4_${stamp}}"
|
||||
|
||||
mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands"
|
||||
|
||||
COMMON_CFG=(
|
||||
model.teacher_backend=openai_chat
|
||||
model.teacher=gpt-5.5
|
||||
model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.reasoning_effort=medium
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.lr_control_mode=fixed
|
||||
optimizer.skill_update_mode=patch
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
env.workers=2
|
||||
env.exec_timeout=1020
|
||||
model.student_backend=claude_code_exec
|
||||
model.student=claude-sonnet-4-6
|
||||
model.claude_code_exec_use_sdk=sdk
|
||||
model.claude_code_exec_effort=medium
|
||||
model.claude_code_exec_max_thinking_tokens=16384
|
||||
model.codex_trace_to_teacher=false
|
||||
)
|
||||
|
||||
tmux_started=0
|
||||
|
||||
launch_run() {
|
||||
local run_id="$1"
|
||||
local config="$2"
|
||||
shift 2
|
||||
|
||||
local cmd_file="$RUN_ROOT/commands/${run_id}.sh"
|
||||
local log_file="$RUN_ROOT/logs/${run_id}.log"
|
||||
local out_root="$RUN_ROOT/$run_id"
|
||||
|
||||
local -a cmd=(
|
||||
"$PYTHON" -u scripts/train.py
|
||||
--config "$config"
|
||||
--cfg-options
|
||||
"${COMMON_CFG[@]}"
|
||||
env.out_root="$out_root"
|
||||
"$@"
|
||||
)
|
||||
|
||||
{
|
||||
echo "#!/usr/bin/env bash"
|
||||
echo "set -euo pipefail"
|
||||
echo "cd '$REPO'"
|
||||
printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL"
|
||||
printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN"
|
||||
printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL"
|
||||
printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL"
|
||||
printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS"
|
||||
printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC"
|
||||
printf '%q ' "${cmd[@]}"
|
||||
printf ' >%q 2>&1 < /dev/null\n' "$log_file"
|
||||
} > "$cmd_file"
|
||||
chmod +x "$cmd_file"
|
||||
|
||||
if [[ "$tmux_started" -eq 0 ]]; then
|
||||
tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
tmux_started=1
|
||||
else
|
||||
tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
fi
|
||||
echo "$run_id"
|
||||
}
|
||||
|
||||
launch_run "HARNESS-ClaudeInit-SearchQA-sched-constant" "configs/searchqa/default.yaml" \
|
||||
env.split_dir=data/searchqa/splits \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=2 optimizer.lr_scheduler=constant
|
||||
|
||||
launch_run "HARNESS-ClaudeInit-LiveMath-lr8" "configs/livemathematicianbench/default.yaml" \
|
||||
env.split_dir=data/livemathbench/splits \
|
||||
optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
|
||||
launch_run "HARNESS-ClaudeInit-DocVQA10-lr8" "configs/docvqa/default.yaml" \
|
||||
env.split_dir=data/harness_splits/docvqa_zisu_first10pct \
|
||||
optimizer.learning_rate=8 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
|
||||
launch_run "HARNESS-ClaudeInit-Spreadsheet-lr4-multi" "configs/spreadsheetbench/default.yaml" \
|
||||
env.split_dir=data/spreadsheetbench env.data_root=data/spreadsheetbench/files env.mode=multi \
|
||||
optimizer.learning_rate=4 optimizer.min_learning_rate=1 optimizer.lr_scheduler=constant
|
||||
|
||||
echo "RUN_ROOT=$RUN_ROOT"
|
||||
echo "SESSION=$SESSION"
|
||||
103
scripts/launch_harness_initial_spreadsheet_clean.sh
Executable file
103
scripts/launch_harness_initial_spreadsheet_clean.sh
Executable file
@@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
REPO="/home/azureuser/workspace-gzy/SkillReflection"
|
||||
PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
|
||||
cd "$REPO"
|
||||
|
||||
export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}"
|
||||
export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}"
|
||||
export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}"
|
||||
export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}"
|
||||
export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}"
|
||||
export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}"
|
||||
|
||||
if [[ -f ".secrets/teacher_oaidr9.env" ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source ".secrets/teacher_oaidr9.env"
|
||||
else
|
||||
echo "missing .secrets/teacher_oaidr9.env" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
stamp="$(date -u +%Y%m%d_%H%M%S)"
|
||||
RUN_ROOT="${1:-outputs/harness_initial_spreadsheet_clean_workers2_timeout1020_${stamp}_run}"
|
||||
SESSION="${2:-harness_initial_spreadsheet_clean_${stamp}}"
|
||||
RUN_ID="HARNESS-ClaudeInit-Spreadsheet-lr4-multi-clean"
|
||||
SPLIT_DIR="${SPREADSHEET_SPLIT_DIR:-data/harness_splits/spreadsheetbench_full}"
|
||||
DATA_ROOT="${SPREADSHEET_DATA_ROOT:-data/spreadsheetbench/files}"
|
||||
|
||||
mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands"
|
||||
|
||||
cmd_file="$RUN_ROOT/commands/${RUN_ID}.sh"
|
||||
log_file="$RUN_ROOT/logs/${RUN_ID}.log"
|
||||
out_root="$RUN_ROOT/$RUN_ID"
|
||||
|
||||
cmd=(
|
||||
"$PYTHON" -u scripts/train.py
|
||||
--config configs/spreadsheetbench/default.yaml
|
||||
--cfg-options
|
||||
model.teacher_backend=openai_chat
|
||||
model.teacher=gpt-5.5
|
||||
model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.reasoning_effort=medium
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.lr_control_mode=fixed
|
||||
optimizer.skill_update_mode=patch
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
env.workers=2
|
||||
env.exec_timeout=1020
|
||||
model.student_backend=claude_code_exec
|
||||
model.student=claude-sonnet-4-6
|
||||
model.claude_code_exec_use_sdk=sdk
|
||||
model.claude_code_exec_effort=medium
|
||||
model.claude_code_exec_max_thinking_tokens=16384
|
||||
model.codex_trace_to_teacher=false
|
||||
env.out_root="$out_root"
|
||||
env.split_dir="$SPLIT_DIR"
|
||||
env.data_root="$DATA_ROOT"
|
||||
env.mode=multi
|
||||
optimizer.learning_rate=4
|
||||
optimizer.min_learning_rate=1
|
||||
optimizer.lr_scheduler=constant
|
||||
)
|
||||
|
||||
{
|
||||
echo "#!/usr/bin/env bash"
|
||||
echo "set -euo pipefail"
|
||||
echo "cd '$REPO'"
|
||||
printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL"
|
||||
printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN"
|
||||
printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL"
|
||||
printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL"
|
||||
printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS"
|
||||
printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC"
|
||||
printf '%q ' "${cmd[@]}"
|
||||
printf ' >%q 2>&1 < /dev/null\n' "$log_file"
|
||||
} > "$cmd_file"
|
||||
chmod +x "$cmd_file"
|
||||
|
||||
tmux new-session -d -s "$SESSION" -n "$RUN_ID" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
|
||||
echo "RUN_ROOT=$RUN_ROOT"
|
||||
echo "SESSION=$SESSION"
|
||||
echo "RUN_ID=$RUN_ID"
|
||||
116
scripts/launch_lrctrl_fullrewrite_neutral3.sh
Executable file
116
scripts/launch_lrctrl_fullrewrite_neutral3.sh
Executable file
@@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
REPO="/home/azureuser/workspace-gzy/SkillReflection"
|
||||
PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
|
||||
cd "$REPO"
|
||||
|
||||
if [[ -f ".secrets/teacher_oaidr9.env" ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source ".secrets/teacher_oaidr9.env"
|
||||
else
|
||||
echo "missing .secrets/teacher_oaidr9.env" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
stamp="$(date -u +%Y%m%d_%H%M%S)"
|
||||
RUN_ROOT="${1:-outputs/lrctrl_fullrewrite_neutral3_workers2_timeout1020_${stamp}_run}"
|
||||
SESSION="${2:-lrctrl_fullrewrite_neutral3_${stamp}}"
|
||||
SEED="${3:-42}"
|
||||
|
||||
mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands"
|
||||
|
||||
COMMON_CFG=(
|
||||
model.teacher_backend=openai_chat
|
||||
model.student_backend=openai_chat
|
||||
model.teacher=gpt-5.5
|
||||
model.student=gpt-5.5
|
||||
model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.student_azure_openai_endpoint="${STUDENT_AZURE_OPENAI_ENDPOINT:-$TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.student_azure_openai_api_version="${STUDENT_AZURE_OPENAI_API_VERSION:-$TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.student_azure_openai_auth_mode="${STUDENT_AZURE_OPENAI_AUTH_MODE:-$TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.student_azure_openai_managed_identity_client_id="${STUDENT_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID:-$TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.student_azure_openai_ad_scope="${STUDENT_AZURE_OPENAI_AD_SCOPE:-$TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.reasoning_effort=medium
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed="${SEED}"
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.learning_rate=4
|
||||
optimizer.min_learning_rate=2
|
||||
optimizer.lr_scheduler=cosine
|
||||
optimizer.lr_control_mode=none
|
||||
optimizer.skill_update_mode=full_rewrite_minibatch
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
env.workers=2
|
||||
env.exec_timeout=1020
|
||||
)
|
||||
|
||||
tmux_started=0
|
||||
|
||||
launch_run() {
|
||||
local run_id="$1"
|
||||
local config="$2"
|
||||
shift 2
|
||||
|
||||
local cmd_file="$RUN_ROOT/commands/${run_id}.sh"
|
||||
local log_file="$RUN_ROOT/logs/${run_id}.log"
|
||||
local out_root="$RUN_ROOT/$run_id"
|
||||
|
||||
local -a cmd=(
|
||||
"$PYTHON" -u scripts/train.py
|
||||
--config "$config"
|
||||
--cfg-options
|
||||
"${COMMON_CFG[@]}"
|
||||
env.out_root="$out_root"
|
||||
"$@"
|
||||
)
|
||||
|
||||
{
|
||||
echo "#!/usr/bin/env bash"
|
||||
echo "set -euo pipefail"
|
||||
echo "cd '$REPO'"
|
||||
printf '%q ' "${cmd[@]}"
|
||||
printf ' >%q 2>&1 < /dev/null\n' "$log_file"
|
||||
} > "$cmd_file"
|
||||
chmod +x "$cmd_file"
|
||||
|
||||
if [[ "$tmux_started" -eq 0 ]]; then
|
||||
tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
tmux_started=1
|
||||
else
|
||||
tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
fi
|
||||
echo "$run_id"
|
||||
}
|
||||
|
||||
launch_run "LRCTRL-searchqa-full-rewrite-neutral3-seed${SEED}" "configs/searchqa/default.yaml" \
|
||||
env.split_dir=data/ablation_splits/searchqa/2-1-7_seed42
|
||||
|
||||
launch_run "LRCTRL-spreadsheetbench-full-rewrite-neutral3-seed${SEED}" "configs/spreadsheetbench/default.yaml" \
|
||||
env.split_dir=data/ablation_splits/spreadsheetbench/2-1-7_seed42 \
|
||||
env.data_root=data/spreadsheetbench_verified_400 \
|
||||
env.mode=multi
|
||||
|
||||
launch_run "LRCTRL-livemathematicianbench-full-rewrite-neutral3-seed${SEED}" "configs/livemathematicianbench/default.yaml" \
|
||||
env.split_dir=data/ablation_splits/livemathematicianbench/2-1-7_seed42
|
||||
|
||||
echo "RUN_ROOT=$RUN_ROOT"
|
||||
echo "SESSION=$SESSION"
|
||||
echo "SEED=$SEED"
|
||||
111
scripts/launch_lrctrl_fullrewrite_neutral3_spreadsheet_repeats.sh
Executable file
111
scripts/launch_lrctrl_fullrewrite_neutral3_spreadsheet_repeats.sh
Executable file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
REPO="/home/azureuser/workspace-gzy/SkillReflection"
|
||||
PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
|
||||
cd "$REPO"
|
||||
|
||||
if [[ -f ".secrets/teacher_oaidr9.env" ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source ".secrets/teacher_oaidr9.env"
|
||||
else
|
||||
echo "missing .secrets/teacher_oaidr9.env" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
stamp="$(date -u +%Y%m%d_%H%M%S)"
|
||||
RUN_ROOT="${1:-outputs/lrctrl_fullrewrite_neutral3_spreadsheet_promptweak_workers2_timeout1020_${stamp}_run}"
|
||||
SESSION="${2:-lrctrl_fr_spreadsheet_promptweak_${stamp}}"
|
||||
START_INDEX="${3:-4}"
|
||||
N_REPEATS="${4:-3}"
|
||||
|
||||
mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands"
|
||||
|
||||
COMMON_CFG=(
|
||||
model.teacher_backend=openai_chat
|
||||
model.student_backend=openai_chat
|
||||
model.teacher=gpt-5.5
|
||||
model.student=gpt-5.5
|
||||
model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.student_azure_openai_endpoint="${STUDENT_AZURE_OPENAI_ENDPOINT:-$TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.student_azure_openai_api_version="${STUDENT_AZURE_OPENAI_API_VERSION:-$TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.student_azure_openai_auth_mode="${STUDENT_AZURE_OPENAI_AUTH_MODE:-$TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.student_azure_openai_managed_identity_client_id="${STUDENT_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID:-$TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.student_azure_openai_ad_scope="${STUDENT_AZURE_OPENAI_AD_SCOPE:-$TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.reasoning_effort=medium
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.learning_rate=4
|
||||
optimizer.min_learning_rate=2
|
||||
optimizer.lr_scheduler=cosine
|
||||
optimizer.lr_control_mode=none
|
||||
optimizer.skill_update_mode=full_rewrite_minibatch
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
env.workers=2
|
||||
env.exec_timeout=1020
|
||||
env.split_dir=data/ablation_splits/spreadsheetbench/2-1-7_seed42
|
||||
env.data_root=data/spreadsheetbench_verified_400
|
||||
env.mode=multi
|
||||
)
|
||||
|
||||
tmux_started=0
|
||||
|
||||
launch_run() {
|
||||
local run_id="$1"
|
||||
|
||||
local cmd_file="$RUN_ROOT/commands/${run_id}.sh"
|
||||
local log_file="$RUN_ROOT/logs/${run_id}.log"
|
||||
local out_root="$RUN_ROOT/$run_id"
|
||||
|
||||
local -a cmd=(
|
||||
"$PYTHON" -u scripts/train.py
|
||||
--config configs/spreadsheetbench/default.yaml
|
||||
--cfg-options
|
||||
"${COMMON_CFG[@]}"
|
||||
env.out_root="$out_root"
|
||||
)
|
||||
|
||||
{
|
||||
echo "#!/usr/bin/env bash"
|
||||
echo "set -euo pipefail"
|
||||
echo "cd '$REPO'"
|
||||
printf '%q ' "${cmd[@]}"
|
||||
printf ' >%q 2>&1 < /dev/null\n' "$log_file"
|
||||
} > "$cmd_file"
|
||||
chmod +x "$cmd_file"
|
||||
|
||||
if [[ "$tmux_started" -eq 0 ]]; then
|
||||
tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
tmux_started=1
|
||||
else
|
||||
tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
fi
|
||||
echo "$run_id"
|
||||
}
|
||||
|
||||
for ((i=START_INDEX; i<START_INDEX+N_REPEATS; i++)); do
|
||||
launch_run "LRCTRL-spreadsheetbench-full-rewrite-neutral3-promptweak-r${i}"
|
||||
done
|
||||
|
||||
echo "RUN_ROOT=$RUN_ROOT"
|
||||
echo "SESSION=$SESSION"
|
||||
echo "START_INDEX=$START_INDEX"
|
||||
echo "N_REPEATS=$N_REPEATS"
|
||||
115
scripts/launch_lrctrl_fullrewrite_neutral3_sq_lm_repeats.sh
Executable file
115
scripts/launch_lrctrl_fullrewrite_neutral3_sq_lm_repeats.sh
Executable file
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
REPO="/home/azureuser/workspace-gzy/SkillReflection"
|
||||
PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
|
||||
cd "$REPO"
|
||||
|
||||
if [[ -f ".secrets/teacher_oaidr9.env" ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source ".secrets/teacher_oaidr9.env"
|
||||
else
|
||||
echo "missing .secrets/teacher_oaidr9.env" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
stamp="$(date -u +%Y%m%d_%H%M%S)"
|
||||
RUN_ROOT="${1:-outputs/lrctrl_fullrewrite_neutral3_sq_lm_extra_workers2_timeout1020_${stamp}_run}"
|
||||
SESSION="${2:-lrctrl_fr_sq_lm_extra_${stamp}}"
|
||||
START_INDEX="${3:-4}"
|
||||
N_REPEATS="${4:-3}"
|
||||
|
||||
mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands"
|
||||
|
||||
COMMON_CFG=(
|
||||
model.teacher_backend=openai_chat
|
||||
model.student_backend=openai_chat
|
||||
model.teacher=gpt-5.5
|
||||
model.student=gpt-5.5
|
||||
model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.student_azure_openai_endpoint="${STUDENT_AZURE_OPENAI_ENDPOINT:-$TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.student_azure_openai_api_version="${STUDENT_AZURE_OPENAI_API_VERSION:-$TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.student_azure_openai_auth_mode="${STUDENT_AZURE_OPENAI_AUTH_MODE:-$TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.student_azure_openai_managed_identity_client_id="${STUDENT_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID:-$TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.student_azure_openai_ad_scope="${STUDENT_AZURE_OPENAI_AD_SCOPE:-$TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.reasoning_effort=medium
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.learning_rate=4
|
||||
optimizer.min_learning_rate=2
|
||||
optimizer.lr_scheduler=cosine
|
||||
optimizer.lr_control_mode=none
|
||||
optimizer.skill_update_mode=full_rewrite_minibatch
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
env.workers=2
|
||||
env.exec_timeout=1020
|
||||
)
|
||||
|
||||
tmux_started=0
|
||||
|
||||
launch_run() {
|
||||
local run_id="$1"
|
||||
local config="$2"
|
||||
shift 2
|
||||
|
||||
local cmd_file="$RUN_ROOT/commands/${run_id}.sh"
|
||||
local log_file="$RUN_ROOT/logs/${run_id}.log"
|
||||
local out_root="$RUN_ROOT/$run_id"
|
||||
|
||||
local -a cmd=(
|
||||
"$PYTHON" -u scripts/train.py
|
||||
--config "$config"
|
||||
--cfg-options
|
||||
"${COMMON_CFG[@]}"
|
||||
env.out_root="$out_root"
|
||||
"$@"
|
||||
)
|
||||
|
||||
{
|
||||
echo "#!/usr/bin/env bash"
|
||||
echo "set -euo pipefail"
|
||||
echo "cd '$REPO'"
|
||||
printf '%q ' "${cmd[@]}"
|
||||
printf ' >%q 2>&1 < /dev/null\n' "$log_file"
|
||||
} > "$cmd_file"
|
||||
chmod +x "$cmd_file"
|
||||
|
||||
if [[ "$tmux_started" -eq 0 ]]; then
|
||||
tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
tmux_started=1
|
||||
else
|
||||
tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
fi
|
||||
echo "$run_id"
|
||||
}
|
||||
|
||||
for ((i=START_INDEX; i<START_INDEX+N_REPEATS; i++)); do
|
||||
launch_run "LRCTRL-searchqa-full-rewrite-neutral3-extra-r${i}" "configs/searchqa/default.yaml" \
|
||||
env.split_dir=data/ablation_splits/searchqa/2-1-7_seed42
|
||||
|
||||
launch_run "LRCTRL-livemathematicianbench-full-rewrite-neutral3-extra-r${i}" "configs/livemathematicianbench/default.yaml" \
|
||||
env.split_dir=data/ablation_splits/livemathematicianbench/2-1-7_seed42
|
||||
done
|
||||
|
||||
echo "RUN_ROOT=$RUN_ROOT"
|
||||
echo "SESSION=$SESSION"
|
||||
echo "START_INDEX=$START_INDEX"
|
||||
echo "N_REPEATS=$N_REPEATS"
|
||||
175
scripts/launch_spreadsheet_full_replacements.sh
Executable file
175
scripts/launch_spreadsheet_full_replacements.sh
Executable file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
REPO="/home/azureuser/workspace-gzy/SkillReflection"
|
||||
PYTHON="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
|
||||
cd "$REPO"
|
||||
|
||||
export ANTHROPIC_BASE_URL="${ANTHROPIC_BASE_URL:-http://127.0.0.1:4343}"
|
||||
export ANTHROPIC_AUTH_TOKEN="${ANTHROPIC_AUTH_TOKEN:-dummy}"
|
||||
export ANTHROPIC_MODEL="${ANTHROPIC_MODEL:-claude-sonnet-4-6}"
|
||||
export ANTHROPIC_SMALL_FAST_MODEL="${ANTHROPIC_SMALL_FAST_MODEL:-claude-sonnet-4-6}"
|
||||
export DISABLE_NON_ESSENTIAL_MODEL_CALLS="${DISABLE_NON_ESSENTIAL_MODEL_CALLS:-1}"
|
||||
export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC="${CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC:-1}"
|
||||
|
||||
if [[ -f ".secrets/teacher_oaidr9.env" ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source ".secrets/teacher_oaidr9.env"
|
||||
else
|
||||
echo "missing .secrets/teacher_oaidr9.env" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
stamp="$(date -u +%Y%m%d_%H%M%S)"
|
||||
RUN_ROOT="${1:-outputs/spreadsheet_full_replacements_workers2_timeout1020_${stamp}_run}"
|
||||
SESSION="${2:-spreadsheet_full_replacements_${stamp}}"
|
||||
|
||||
mkdir -p "$RUN_ROOT/logs" "$RUN_ROOT/commands"
|
||||
|
||||
tmux_started=0
|
||||
|
||||
launch_run() {
|
||||
local run_id="$1"
|
||||
shift
|
||||
|
||||
local cmd_file="$RUN_ROOT/commands/${run_id}.sh"
|
||||
local log_file="$RUN_ROOT/logs/${run_id}.log"
|
||||
|
||||
{
|
||||
echo "#!/usr/bin/env bash"
|
||||
echo "set -euo pipefail"
|
||||
echo "cd '$REPO'"
|
||||
printf 'export ANTHROPIC_BASE_URL=%q\n' "$ANTHROPIC_BASE_URL"
|
||||
printf 'export ANTHROPIC_AUTH_TOKEN=%q\n' "$ANTHROPIC_AUTH_TOKEN"
|
||||
printf 'export ANTHROPIC_MODEL=%q\n' "$ANTHROPIC_MODEL"
|
||||
printf 'export ANTHROPIC_SMALL_FAST_MODEL=%q\n' "$ANTHROPIC_SMALL_FAST_MODEL"
|
||||
printf 'export DISABLE_NON_ESSENTIAL_MODEL_CALLS=%q\n' "$DISABLE_NON_ESSENTIAL_MODEL_CALLS"
|
||||
printf 'export CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=%q\n' "$CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC"
|
||||
printf '%q ' "$@"
|
||||
printf ' >%q 2>&1 < /dev/null\n' "$log_file"
|
||||
} > "$cmd_file"
|
||||
chmod +x "$cmd_file"
|
||||
|
||||
if [[ "$tmux_started" -eq 0 ]]; then
|
||||
tmux new-session -d -s "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
tmux_started=1
|
||||
else
|
||||
tmux new-window -t "$SESSION" -n "$run_id" "bash '$cmd_file'; code=\$?; echo EXIT:\$code; sleep 3600"
|
||||
fi
|
||||
echo "$run_id"
|
||||
}
|
||||
|
||||
OPENAI_COMMON=(
|
||||
"$PYTHON" -u scripts/train.py
|
||||
--config configs/spreadsheetbench/default.yaml
|
||||
--cfg-options
|
||||
model.teacher_backend=openai_chat
|
||||
model.student_backend=openai_chat
|
||||
model.teacher=gpt-5.5
|
||||
model.student=gpt-5.5
|
||||
model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.student_azure_openai_endpoint="${STUDENT_AZURE_OPENAI_ENDPOINT:-$TEACHER_AZURE_OPENAI_ENDPOINT}"
|
||||
model.student_azure_openai_api_version="${STUDENT_AZURE_OPENAI_API_VERSION:-$TEACHER_AZURE_OPENAI_API_VERSION}"
|
||||
model.student_azure_openai_auth_mode="${STUDENT_AZURE_OPENAI_AUTH_MODE:-$TEACHER_AZURE_OPENAI_AUTH_MODE}"
|
||||
model.student_azure_openai_managed_identity_client_id="${STUDENT_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID:-$TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}"
|
||||
model.student_azure_openai_ad_scope="${STUDENT_AZURE_OPENAI_AD_SCOPE:-$TEACHER_AZURE_OPENAI_AD_SCOPE}"
|
||||
model.reasoning_effort=medium
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.learning_rate=4
|
||||
optimizer.min_learning_rate=2
|
||||
optimizer.lr_scheduler=cosine
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_reflect=false
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
env.workers=2
|
||||
env.exec_timeout=1020
|
||||
env.split_dir=data/ablation_splits/spreadsheetbench/2-1-7_seed42
|
||||
env.data_root=data/spreadsheetbench_verified_400
|
||||
env.mode=multi
|
||||
)
|
||||
|
||||
HARNESS_RUN="HARNESS-ClaudeInit-Spreadsheet-lr4-multi-full"
|
||||
launch_run "$HARNESS_RUN" \
|
||||
"$PYTHON" -u scripts/train.py \
|
||||
--config configs/spreadsheetbench/default.yaml \
|
||||
--cfg-options \
|
||||
model.teacher_backend=openai_chat \
|
||||
model.teacher=gpt-5.5 \
|
||||
model.teacher_azure_openai_endpoint="${TEACHER_AZURE_OPENAI_ENDPOINT}" \
|
||||
model.teacher_azure_openai_api_version="${TEACHER_AZURE_OPENAI_API_VERSION}" \
|
||||
model.teacher_azure_openai_auth_mode="${TEACHER_AZURE_OPENAI_AUTH_MODE}" \
|
||||
model.teacher_azure_openai_managed_identity_client_id="${TEACHER_AZURE_OPENAI_MANAGED_IDENTITY_CLIENT_ID}" \
|
||||
model.teacher_azure_openai_ad_scope="${TEACHER_AZURE_OPENAI_AD_SCOPE}" \
|
||||
model.reasoning_effort=medium \
|
||||
train.num_epochs=4 \
|
||||
train.train_size=0 \
|
||||
train.batch_size=40 \
|
||||
train.accumulation=1 \
|
||||
train.seed=42 \
|
||||
gradient.minibatch_size=8 \
|
||||
gradient.merge_batch_size=8 \
|
||||
gradient.analyst_workers=16 \
|
||||
gradient.use_deep_reflect=false \
|
||||
optimizer.lr_control_mode=fixed \
|
||||
optimizer.skill_update_mode=patch \
|
||||
optimizer.use_slow_update=true \
|
||||
optimizer.slow_update_samples=20 \
|
||||
optimizer.use_meta_skill=true \
|
||||
optimizer.use_meta_reflect=false \
|
||||
evaluation.use_gate=true \
|
||||
evaluation.eval_test=true \
|
||||
env.split_mode=split_dir \
|
||||
env.workers=2 \
|
||||
env.exec_timeout=1020 \
|
||||
model.student_backend=claude_code_exec \
|
||||
model.student=claude-sonnet-4-6 \
|
||||
model.claude_code_exec_use_sdk=sdk \
|
||||
model.claude_code_exec_effort=medium \
|
||||
model.claude_code_exec_max_thinking_tokens=16384 \
|
||||
model.codex_trace_to_teacher=false \
|
||||
env.out_root="$RUN_ROOT/$HARNESS_RUN" \
|
||||
env.split_dir=data/spreadsheetbench/splits \
|
||||
env.data_root=data/spreadsheetbench/files \
|
||||
env.mode=multi \
|
||||
optimizer.learning_rate=4 \
|
||||
optimizer.min_learning_rate=1 \
|
||||
optimizer.lr_scheduler=constant
|
||||
|
||||
for repeat in r1 r2 r3; do
|
||||
run_id="LRCTRL-spreadsheetbench-full-rewrite-neutral3-full-${repeat}"
|
||||
launch_run "$run_id" \
|
||||
"${OPENAI_COMMON[@]}" \
|
||||
optimizer.lr_control_mode=none \
|
||||
optimizer.skill_update_mode=full_rewrite_minibatch \
|
||||
optimizer.use_meta_skill=true \
|
||||
env.out_root="$RUN_ROOT/$run_id"
|
||||
done
|
||||
|
||||
for repeat in r1 r2 r3; do
|
||||
run_id="SLOWMETA-spreadsheetbench-true-false-full-${repeat}"
|
||||
launch_run "$run_id" \
|
||||
"${OPENAI_COMMON[@]}" \
|
||||
optimizer.lr_control_mode=fixed \
|
||||
optimizer.skill_update_mode=patch \
|
||||
optimizer.use_meta_skill=false \
|
||||
env.out_root="$RUN_ROOT/$run_id"
|
||||
done
|
||||
|
||||
echo "RUN_ROOT=$RUN_ROOT"
|
||||
echo "SESSION=$SESSION"
|
||||
61
scripts/monitor_harness_claude18.sh
Executable file
61
scripts/monitor_harness_claude18.sh
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ROOT="${1:?usage: monitor_harness_claude18.sh RUN_ROOT}"
|
||||
|
||||
while true; do
|
||||
ts="$(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
echo "===== $ts CLAUDE18 ====="
|
||||
uptime | sed 's/^/uptime=/'
|
||||
active="$(pgrep -af "scripts/train.py.*${ROOT}" | grep -v pgrep | wc -l || true)"
|
||||
claude_child="$(pgrep -af 'claude.*--output-format stream-json' | grep -v pgrep | wc -l || true)"
|
||||
echo "active_train_total=$active"
|
||||
echo "active_codex_train=0"
|
||||
echo "active_claude_train=$active"
|
||||
echo "claude_child=$claude_child"
|
||||
|
||||
for d in "$ROOT"/HARNESS-Claude-*; do
|
||||
[[ -d "$d" ]] || continue
|
||||
rid="$(basename "$d")"
|
||||
read -r base best < <(python3 - "$d" <<'PY'
|
||||
import json, sys
|
||||
from pathlib import Path
|
||||
d = Path(sys.argv[1])
|
||||
s = d / "summary.json"
|
||||
if not s.exists():
|
||||
print("pending pending")
|
||||
raise SystemExit
|
||||
try:
|
||||
obj = json.loads(s.read_text())
|
||||
except Exception:
|
||||
print("pending pending")
|
||||
raise SystemExit
|
||||
base = obj.get("baseline_test_hard", obj.get("base_test", "pending"))
|
||||
best = obj.get("test_hard", obj.get("best_test", "pending"))
|
||||
def fmt(x):
|
||||
if isinstance(x, (int, float)):
|
||||
return f"{x:.4f}"
|
||||
return str(x)
|
||||
print(fmt(base), fmt(best))
|
||||
PY
|
||||
)
|
||||
scan_files="$(mktemp)"
|
||||
find "$d" \
|
||||
\( -path '*/codex_exec' -o -path '*/codex_multi' \) -prune -o \
|
||||
-maxdepth 6 -type f \
|
||||
\( -name 'claude_trace_summary.txt' -o -name 'codex_trace_summary.txt' -o -name '*.log' -o -name 'summary.json' \) \
|
||||
-print > "$scan_files" 2>/dev/null || true
|
||||
auth="$({ xargs -r rg -l 'Not logged in|authentication_failed' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')"
|
||||
e429="$({ xargs -r rg -l 'Too Many Requests|RateLimitError|Error code: 429|api_error_status.: 429|rate_limit|too_many_requests' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')"
|
||||
e401="$({ xargs -r rg -l '401 Unauthorized|Error code: 401|HTTP 401|AuthenticationTypeDisabled|PermissionDeniedError' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')"
|
||||
timeout="$({ xargs -r rg -l 'TimeoutError|Task timed out|timed out after|subprocess.TimeoutExpired|timeout_exceeded' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')"
|
||||
teacher="$({ xargs -r rg -l 'APITimeoutError|APIConnectionError|AuthenticationError|Azure OpenAI Responses API is enabled only|teacher.*error' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')"
|
||||
results="$(find "$d" -maxdepth 5 -path '*/results.jsonl' -type f -print0 2>/dev/null | xargs -0 -r wc -l | awk 'END{print $1+0}')"
|
||||
empty="$({ xargs -r rg -l 'final response chars: 0|\"final_response\"\\s*:\\s*\"\"|\"result\"\\s*:\\s*\"\"' < "$scan_files" 2>/dev/null || true; } | wc -l | tr -d ' ')"
|
||||
rm -f "$scan_files"
|
||||
errors=$((auth + e429 + e401 + timeout + teacher))
|
||||
echo "$rid Base=$base Best=$best Errors=$errors auth=$auth 429=$e429 401=$e401 timeout=$timeout teacher=$teacher Results=$results Empty=$empty"
|
||||
done | sort
|
||||
echo
|
||||
sleep 60
|
||||
done
|
||||
130
scripts/prepare_ablation_splits.py
Normal file
130
scripts/prepare_ablation_splits.py
Normal file
@@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Prepare fixed data splits for ablation experiments."""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
DATASETS = {
|
||||
"searchqa": {
|
||||
"raw": PROJECT_ROOT / "data/searchqa_train_2000.json",
|
||||
"out": PROJECT_ROOT / "data/ablation_splits/searchqa",
|
||||
"filenames": {"train": "train.json", "val": "selection.json", "test": "test.json"},
|
||||
},
|
||||
"spreadsheetbench": {
|
||||
"raw": PROJECT_ROOT / "data/spreadsheetbench_verified_400/dataset.json",
|
||||
"out": PROJECT_ROOT / "data/ablation_splits/spreadsheetbench",
|
||||
"filenames": {"train": "train.json", "val": "sel.json", "test": "test.json"},
|
||||
},
|
||||
}
|
||||
|
||||
SPLITS = ("1shot", "1:1:8", "2:1:7", "4:1:5")
|
||||
|
||||
|
||||
def load_items(path: Path) -> list[dict]:
|
||||
with path.open(encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, list):
|
||||
raise TypeError(f"Expected JSON array in {path}, got {type(data).__name__}")
|
||||
return data
|
||||
|
||||
|
||||
def split_counts(total: int, split: str) -> tuple[int, int, int]:
|
||||
if split == "1shot":
|
||||
if total < 3:
|
||||
raise ValueError(f"Need at least 3 items for 1shot split, got {total}")
|
||||
return 1, 1, total - 2
|
||||
|
||||
ratio = split
|
||||
weights = [int(part) for part in ratio.split(":")]
|
||||
if len(weights) != 3 or min(weights) <= 0:
|
||||
raise ValueError(f"Invalid ratio: {ratio}")
|
||||
denom = sum(weights)
|
||||
raw = [total * weight / denom for weight in weights]
|
||||
counts = [int(value) for value in raw]
|
||||
remaining = total - sum(counts)
|
||||
order = sorted(
|
||||
range(3),
|
||||
key=lambda idx: (raw[idx] - counts[idx], weights[idx]),
|
||||
reverse=True,
|
||||
)
|
||||
for idx in order[:remaining]:
|
||||
counts[idx] += 1
|
||||
return counts[0], counts[1], counts[2]
|
||||
|
||||
|
||||
def split_tag(split: str) -> str:
|
||||
return "1shot" if split == "1shot" else split.replace(":", "-")
|
||||
|
||||
|
||||
def write_json(path: Path, items: list[dict]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as f:
|
||||
json.dump(items, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def prepare_dataset(name: str, *, seed: int, force: bool) -> None:
|
||||
spec = DATASETS[name]
|
||||
raw_path = spec["raw"]
|
||||
out_root = spec["out"]
|
||||
filenames = spec["filenames"]
|
||||
|
||||
items = load_items(raw_path)
|
||||
for split in SPLITS:
|
||||
ratio_tag = split_tag(split)
|
||||
split_dir = out_root / f"{ratio_tag}_seed{seed}"
|
||||
manifest_path = split_dir / "split_manifest.json"
|
||||
if manifest_path.exists() and not force:
|
||||
print(f"skip {name} {split}: {split_dir} exists")
|
||||
continue
|
||||
|
||||
shuffled = list(items)
|
||||
random.Random(seed).shuffle(shuffled)
|
||||
train_n, val_n, test_n = split_counts(len(shuffled), split)
|
||||
train_items = shuffled[:train_n]
|
||||
val_items = shuffled[train_n: train_n + val_n]
|
||||
test_items = shuffled[train_n + val_n: train_n + val_n + test_n]
|
||||
|
||||
write_json(split_dir / "train" / filenames["train"], train_items)
|
||||
write_json(split_dir / "val" / filenames["val"], val_items)
|
||||
write_json(split_dir / "test" / filenames["test"], test_items)
|
||||
write_json(
|
||||
manifest_path,
|
||||
{
|
||||
"dataset": name,
|
||||
"source": str(raw_path),
|
||||
"split_mode": "precomputed_ratio",
|
||||
"split_name": split,
|
||||
"split_ratio": split if split != "1shot" else "1 train / 1 val / rest test",
|
||||
"split_seed": seed,
|
||||
"counts": {
|
||||
"train": len(train_items),
|
||||
"val": len(val_items),
|
||||
"test": len(test_items),
|
||||
},
|
||||
},
|
||||
)
|
||||
print(
|
||||
f"wrote {name} {split} -> {split_dir} "
|
||||
f"(train={len(train_items)}, val={len(val_items)}, test={len(test_items)})"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--force", action="store_true")
|
||||
parser.add_argument("--dataset", choices=sorted(DATASETS), action="append")
|
||||
args = parser.parse_args()
|
||||
|
||||
for name in args.dataset or sorted(DATASETS):
|
||||
prepare_dataset(name, seed=args.seed, force=args.force)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
680
scripts/run_ablation_matrix.py
Executable file
680
scripts/run_ablation_matrix.py
Executable file
@@ -0,0 +1,680 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Launch the SearchQA / SpreadsheetBench ablation matrix.
|
||||
|
||||
By default this script prints commands only. Pass --execute to actually start
|
||||
runs. Every run writes to a unique out_root under the run root and logs stdout
|
||||
/ stderr to logs/<run_id>.log.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
PYTHON_BIN = Path("/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python")
|
||||
|
||||
T2_ENDPOINT = "https://t2vgoaigpt4o3.openai.azure.com/"
|
||||
SEARCHAGENT5_ENDPOINT = "https://searchagent5.cognitiveservices.azure.com/"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Experiment:
|
||||
run_id: str
|
||||
benchmark: str
|
||||
config: str
|
||||
overrides: tuple[str, ...]
|
||||
|
||||
|
||||
BENCH_CONFIG = {
|
||||
"searchqa": "configs/searchqa/default.yaml",
|
||||
"spreadsheetbench": "configs/spreadsheetbench/default.yaml",
|
||||
"livemathematicianbench": "configs/livemathematicianbench/default.yaml",
|
||||
"alfworld": "configs/alfworld/default.yaml",
|
||||
"docvqa": "configs/docvqa/default.yaml",
|
||||
}
|
||||
|
||||
DEFAULT_SPLIT = {
|
||||
"searchqa": "data/ablation_splits/searchqa/2-1-7_seed42",
|
||||
"spreadsheetbench": "data/ablation_splits/spreadsheetbench/2-1-7_seed42",
|
||||
"livemathematicianbench": "data/ablation_splits/livemathematicianbench/2-1-7_seed42",
|
||||
"alfworld": "data/ablation_splits/alfworld/2-1-7_seed42",
|
||||
"docvqa": "/home/azureuser/zisu/SkillReflection/data/docvqa/splits",
|
||||
}
|
||||
|
||||
DEFAULT_TRAIN_SIZE = {
|
||||
"searchqa": 400,
|
||||
"spreadsheetbench": 80,
|
||||
"livemathematicianbench": 35,
|
||||
"alfworld": 39,
|
||||
"docvqa": 1070,
|
||||
}
|
||||
|
||||
BATCH_SIZE_VALUES: tuple[int | str, ...] = (8, 24, 40, 56, "full")
|
||||
|
||||
SPLITS = {
|
||||
"searchqa": {
|
||||
"1shot": ("data/ablation_splits/searchqa/1shot_seed42", ("optimizer.slow_update_samples=1",)),
|
||||
"1-1-8": ("data/ablation_splits/searchqa/1-1-8_seed42", ()),
|
||||
"2-1-7": ("data/ablation_splits/searchqa/2-1-7_seed42", ()),
|
||||
"4-1-5": ("data/ablation_splits/searchqa/4-1-5_seed42", ()),
|
||||
},
|
||||
"spreadsheetbench": {
|
||||
"1shot": ("data/ablation_splits/spreadsheetbench/1shot_seed42", ("optimizer.slow_update_samples=1",)),
|
||||
"1-1-8": ("data/ablation_splits/spreadsheetbench/1-1-8_seed42", ()),
|
||||
"2-1-7": ("data/ablation_splits/spreadsheetbench/2-1-7_seed42", ()),
|
||||
"4-1-5": ("data/ablation_splits/spreadsheetbench/4-1-5_seed42", ()),
|
||||
},
|
||||
"livemathematicianbench": {
|
||||
"1shot": ("data/ablation_splits/livemathematicianbench/1shot_seed42", ("optimizer.slow_update_samples=1",)),
|
||||
"1-1-8": ("data/ablation_splits/livemathematicianbench/1-1-8_seed42", ()),
|
||||
"2-1-7": ("data/ablation_splits/livemathematicianbench/2-1-7_seed42", ()),
|
||||
"4-1-5": ("data/ablation_splits/livemathematicianbench/4-1-5_seed42", ()),
|
||||
},
|
||||
"alfworld": {
|
||||
"1shot": ("data/ablation_splits/alfworld/1shot_seed42", ("optimizer.slow_update_samples=1",)),
|
||||
"1-1-8": ("data/ablation_splits/alfworld/1-1-8_seed42", ()),
|
||||
"2-1-7": ("data/ablation_splits/alfworld/2-1-7_seed42", ()),
|
||||
"4-1-5": ("data/ablation_splits/alfworld/4-1-5_seed42", ()),
|
||||
},
|
||||
"docvqa": {
|
||||
"1shot": ("data/ablation_splits/docvqa/1shot_seed42", ("optimizer.slow_update_samples=1",)),
|
||||
"1-1-8": ("data/ablation_splits/docvqa/1-1-8_seed42", ()),
|
||||
"2-1-7": ("/home/azureuser/zisu/SkillReflection/data/docvqa/splits", ()),
|
||||
"4-1-5": ("data/ablation_splits/docvqa/4-1-5_seed42", ()),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def common_overrides(benchmark: str, out_root: Path) -> list[str]:
|
||||
return [
|
||||
"model.teacher_backend=openai_chat",
|
||||
"model.student_backend=openai_chat",
|
||||
"model.teacher=gpt-5.5",
|
||||
"model.student=gpt-5.5",
|
||||
f"model.teacher_azure_openai_endpoint={T2_ENDPOINT}",
|
||||
"model.teacher_azure_openai_api_version=2024-12-01-preview",
|
||||
"model.teacher_azure_openai_auth_mode=azure_cli",
|
||||
f"model.student_azure_openai_endpoint={T2_ENDPOINT}",
|
||||
"model.student_azure_openai_api_version=2024-12-01-preview",
|
||||
"model.student_azure_openai_auth_mode=azure_cli",
|
||||
"model.reasoning_effort=medium",
|
||||
"train.num_epochs=4",
|
||||
"train.train_size=0",
|
||||
"train.batch_size=40",
|
||||
"train.accumulation=1",
|
||||
"train.seed=42",
|
||||
"gradient.minibatch_size=8",
|
||||
"gradient.merge_batch_size=8",
|
||||
"gradient.analyst_workers=16",
|
||||
"gradient.use_deep_reflect=false",
|
||||
"optimizer.learning_rate=4",
|
||||
"optimizer.min_learning_rate=2",
|
||||
"optimizer.lr_scheduler=cosine",
|
||||
"optimizer.skill_update_mode=patch",
|
||||
"optimizer.use_slow_update=true",
|
||||
"optimizer.slow_update_samples=20",
|
||||
"optimizer.use_meta_skill=true",
|
||||
"optimizer.use_meta_reflect=false",
|
||||
"evaluation.use_gate=true",
|
||||
"evaluation.eval_test=true",
|
||||
"env.split_mode=split_dir",
|
||||
f"env.split_dir={DEFAULT_SPLIT[benchmark]}",
|
||||
f"env.out_root={out_root}",
|
||||
]
|
||||
|
||||
|
||||
def make_experiment(
|
||||
group: str,
|
||||
benchmark: str,
|
||||
suffix: str,
|
||||
run_root: Path,
|
||||
overrides: list[str],
|
||||
) -> Experiment:
|
||||
run_id = f"{group}-{benchmark}-{suffix}"
|
||||
out_root = run_root / run_id
|
||||
all_overrides = common_overrides(benchmark, out_root)
|
||||
all_overrides.extend(overrides)
|
||||
return Experiment(
|
||||
run_id=run_id,
|
||||
benchmark=benchmark,
|
||||
config=BENCH_CONFIG[benchmark],
|
||||
overrides=tuple(all_overrides),
|
||||
)
|
||||
|
||||
|
||||
def build_matrix(
|
||||
groups: set[str],
|
||||
benchmarks: list[str],
|
||||
run_root: Path,
|
||||
*,
|
||||
include_duplicate_defaults: bool = False,
|
||||
) -> list[Experiment]:
|
||||
exps: list[Experiment] = []
|
||||
group_order = [
|
||||
"default",
|
||||
"split",
|
||||
"batch",
|
||||
"mbs",
|
||||
"lr",
|
||||
"sched",
|
||||
"slown",
|
||||
"mod",
|
||||
"smodel",
|
||||
"longpair",
|
||||
"lrctrl",
|
||||
]
|
||||
|
||||
for group in group_order:
|
||||
if group not in groups:
|
||||
continue
|
||||
for benchmark in benchmarks:
|
||||
if group == "default":
|
||||
exps.append(make_experiment("DEFAULT", benchmark, "5.5", run_root, []))
|
||||
continue
|
||||
|
||||
if group == "split":
|
||||
for tag, (split_dir, extra) in SPLITS[benchmark].items():
|
||||
if not include_duplicate_defaults and tag == "2-1-7":
|
||||
continue
|
||||
exps.append(make_experiment(
|
||||
"SPLIT",
|
||||
benchmark,
|
||||
tag,
|
||||
run_root,
|
||||
[f"env.split_dir={split_dir}", *extra],
|
||||
))
|
||||
continue
|
||||
|
||||
if group == "mbs":
|
||||
for value in (1, 2, 4, 8, 16, 32):
|
||||
if not include_duplicate_defaults and value == 8:
|
||||
continue
|
||||
exps.append(make_experiment(
|
||||
"MBS",
|
||||
benchmark,
|
||||
str(value),
|
||||
run_root,
|
||||
[f"gradient.minibatch_size={value}"],
|
||||
))
|
||||
continue
|
||||
|
||||
if group == "batch":
|
||||
for value in BATCH_SIZE_VALUES:
|
||||
if not include_duplicate_defaults and value == 40:
|
||||
continue
|
||||
batch_size = DEFAULT_TRAIN_SIZE[benchmark] if value == "full" else int(value)
|
||||
exps.append(make_experiment(
|
||||
"BATCH",
|
||||
benchmark,
|
||||
str(value),
|
||||
run_root,
|
||||
[
|
||||
f"train.batch_size={batch_size}",
|
||||
"gradient.minibatch_size=8",
|
||||
],
|
||||
))
|
||||
continue
|
||||
|
||||
if group == "lr":
|
||||
for value in (1, 2, 4, 8, 16):
|
||||
exps.append(make_experiment(
|
||||
"LR",
|
||||
benchmark,
|
||||
str(value),
|
||||
run_root,
|
||||
[
|
||||
"optimizer.lr_scheduler=constant",
|
||||
"optimizer.min_learning_rate=1",
|
||||
f"optimizer.learning_rate={value}",
|
||||
],
|
||||
))
|
||||
continue
|
||||
|
||||
if group == "sched":
|
||||
for value in ("constant", "cosine", "linear"):
|
||||
if not include_duplicate_defaults and value == "cosine":
|
||||
continue
|
||||
exps.append(make_experiment(
|
||||
"SCHED",
|
||||
benchmark,
|
||||
value,
|
||||
run_root,
|
||||
[f"optimizer.lr_scheduler={value}"],
|
||||
))
|
||||
continue
|
||||
|
||||
if group == "slown":
|
||||
for value in (5, 10, 20, 40):
|
||||
if not include_duplicate_defaults and value == 20:
|
||||
continue
|
||||
exps.append(make_experiment(
|
||||
"SLOWN",
|
||||
benchmark,
|
||||
str(value),
|
||||
run_root,
|
||||
[f"optimizer.slow_update_samples={value}"],
|
||||
))
|
||||
continue
|
||||
|
||||
if group == "mod":
|
||||
settings = {
|
||||
"slow-meta": ("true", "true"),
|
||||
"slow-only": ("true", "false"),
|
||||
"meta-only": ("false", "true"),
|
||||
"none": ("false", "false"),
|
||||
}
|
||||
for tag, (slow, meta) in settings.items():
|
||||
if not include_duplicate_defaults and tag == "slow-meta":
|
||||
continue
|
||||
exps.append(make_experiment(
|
||||
"MOD",
|
||||
benchmark,
|
||||
tag,
|
||||
run_root,
|
||||
[
|
||||
f"optimizer.use_slow_update={slow}",
|
||||
f"optimizer.use_meta_skill={meta}",
|
||||
],
|
||||
))
|
||||
continue
|
||||
|
||||
if group == "smodel":
|
||||
student_settings = {
|
||||
"5.4": [
|
||||
"model.student=gpt-5.4-pro",
|
||||
f"model.student_azure_openai_endpoint={T2_ENDPOINT}",
|
||||
"model.student_azure_openai_api_version=2025-03-01-preview",
|
||||
"model.student_azure_openai_auth_mode=azure_cli",
|
||||
],
|
||||
"5.4-mini": [
|
||||
"model.student=gpt-5.4-mini",
|
||||
f"model.student_azure_openai_endpoint={SEARCHAGENT5_ENDPOINT}",
|
||||
"model.student_azure_openai_api_version=2024-12-01-preview",
|
||||
"model.student_azure_openai_auth_mode=azure_cli",
|
||||
],
|
||||
"5.5": [],
|
||||
}
|
||||
for tag, overrides in student_settings.items():
|
||||
if not include_duplicate_defaults and tag == "5.5":
|
||||
continue
|
||||
exps.append(make_experiment("SMODEL", benchmark, tag, run_root, overrides))
|
||||
continue
|
||||
|
||||
if group == "longpair":
|
||||
for value in ("changed", "unchanged"):
|
||||
exps.append(make_experiment(
|
||||
"LONGPAIR",
|
||||
benchmark,
|
||||
value,
|
||||
run_root,
|
||||
[f"optimizer.longitudinal_pair_policy={value}"],
|
||||
))
|
||||
continue
|
||||
|
||||
if group == "lrctrl":
|
||||
settings = {
|
||||
"autonomous": ["optimizer.lr_control_mode=autonomous"],
|
||||
"full-rewrite": [
|
||||
"optimizer.lr_control_mode=none",
|
||||
"optimizer.skill_update_mode=full_rewrite_minibatch",
|
||||
],
|
||||
}
|
||||
for tag, overrides in settings.items():
|
||||
exps.append(make_experiment("LRCTRL", benchmark, tag, run_root, overrides))
|
||||
continue
|
||||
|
||||
return exps
|
||||
|
||||
|
||||
def _build_matrix_legacy(
|
||||
groups: set[str],
|
||||
benchmarks: list[str],
|
||||
run_root: Path,
|
||||
*,
|
||||
include_duplicate_defaults: bool = False,
|
||||
) -> list[Experiment]:
|
||||
exps: list[Experiment] = []
|
||||
for benchmark in benchmarks:
|
||||
if "default" in groups:
|
||||
exps.append(make_experiment("DEFAULT", benchmark, "5.5", run_root, []))
|
||||
|
||||
if "split" in groups:
|
||||
for tag, (split_dir, extra) in SPLITS[benchmark].items():
|
||||
if not include_duplicate_defaults and tag == "2-1-7":
|
||||
continue
|
||||
exps.append(make_experiment(
|
||||
"SPLIT",
|
||||
benchmark,
|
||||
tag,
|
||||
run_root,
|
||||
[f"env.split_dir={split_dir}", *extra],
|
||||
))
|
||||
|
||||
if "mbs" in groups:
|
||||
for value in (1, 2, 4, 8, 16, 32):
|
||||
if not include_duplicate_defaults and value == 8:
|
||||
continue
|
||||
exps.append(make_experiment(
|
||||
"MBS",
|
||||
benchmark,
|
||||
str(value),
|
||||
run_root,
|
||||
[f"gradient.minibatch_size={value}"],
|
||||
))
|
||||
|
||||
if "batch" in groups:
|
||||
for value in BATCH_SIZE_VALUES:
|
||||
if not include_duplicate_defaults and value == 40:
|
||||
continue
|
||||
batch_size = DEFAULT_TRAIN_SIZE[benchmark] if value == "full" else int(value)
|
||||
exps.append(make_experiment(
|
||||
"BATCH",
|
||||
benchmark,
|
||||
str(value),
|
||||
run_root,
|
||||
[
|
||||
f"train.batch_size={batch_size}",
|
||||
"gradient.minibatch_size=8",
|
||||
],
|
||||
))
|
||||
|
||||
if "lr" in groups:
|
||||
for value in (1, 2, 4, 8, 16):
|
||||
exps.append(make_experiment(
|
||||
"LR",
|
||||
benchmark,
|
||||
str(value),
|
||||
run_root,
|
||||
[
|
||||
"optimizer.lr_scheduler=constant",
|
||||
"optimizer.min_learning_rate=1",
|
||||
f"optimizer.learning_rate={value}",
|
||||
],
|
||||
))
|
||||
|
||||
if "sched" in groups:
|
||||
for value in ("constant", "cosine", "linear"):
|
||||
if not include_duplicate_defaults and value == "cosine":
|
||||
continue
|
||||
exps.append(make_experiment(
|
||||
"SCHED",
|
||||
benchmark,
|
||||
value,
|
||||
run_root,
|
||||
[f"optimizer.lr_scheduler={value}"],
|
||||
))
|
||||
|
||||
if "slown" in groups:
|
||||
for value in (5, 10, 20, 40):
|
||||
if not include_duplicate_defaults and value == 20:
|
||||
continue
|
||||
exps.append(make_experiment(
|
||||
"SLOWN",
|
||||
benchmark,
|
||||
str(value),
|
||||
run_root,
|
||||
[f"optimizer.slow_update_samples={value}"],
|
||||
))
|
||||
|
||||
if "mod" in groups:
|
||||
settings = {
|
||||
"slow-meta": ("true", "true"),
|
||||
"slow-only": ("true", "false"),
|
||||
"meta-only": ("false", "true"),
|
||||
"none": ("false", "false"),
|
||||
}
|
||||
for tag, (slow, meta) in settings.items():
|
||||
if not include_duplicate_defaults and tag == "slow-meta":
|
||||
continue
|
||||
exps.append(make_experiment(
|
||||
"MOD",
|
||||
benchmark,
|
||||
tag,
|
||||
run_root,
|
||||
[
|
||||
f"optimizer.use_slow_update={slow}",
|
||||
f"optimizer.use_meta_skill={meta}",
|
||||
],
|
||||
))
|
||||
|
||||
if "smodel" in groups:
|
||||
student_settings = {
|
||||
"5.4": [
|
||||
"model.student=gpt-5.4-pro",
|
||||
f"model.student_azure_openai_endpoint={T2_ENDPOINT}",
|
||||
"model.student_azure_openai_api_version=2025-03-01-preview",
|
||||
"model.student_azure_openai_auth_mode=azure_cli",
|
||||
],
|
||||
"5.4-mini": [
|
||||
"model.student=gpt-5.4-mini",
|
||||
f"model.student_azure_openai_endpoint={SEARCHAGENT5_ENDPOINT}",
|
||||
"model.student_azure_openai_api_version=2024-12-01-preview",
|
||||
"model.student_azure_openai_auth_mode=azure_cli",
|
||||
],
|
||||
"5.5": [],
|
||||
}
|
||||
for tag, overrides in student_settings.items():
|
||||
if not include_duplicate_defaults and tag == "5.5":
|
||||
continue
|
||||
exps.append(make_experiment("SMODEL", benchmark, tag, run_root, overrides))
|
||||
|
||||
if "longpair" in groups:
|
||||
for value in ("changed", "unchanged"):
|
||||
exps.append(make_experiment(
|
||||
"LONGPAIR",
|
||||
benchmark,
|
||||
value,
|
||||
run_root,
|
||||
[f"optimizer.longitudinal_pair_policy={value}"],
|
||||
))
|
||||
|
||||
if "lrctrl" in groups:
|
||||
settings = {
|
||||
"autonomous": ["optimizer.lr_control_mode=autonomous"],
|
||||
"full-rewrite": [
|
||||
"optimizer.lr_control_mode=none",
|
||||
"optimizer.skill_update_mode=full_rewrite_minibatch",
|
||||
],
|
||||
}
|
||||
for tag, overrides in settings.items():
|
||||
exps.append(make_experiment("LRCTRL", benchmark, tag, run_root, overrides))
|
||||
|
||||
return exps
|
||||
|
||||
|
||||
def command_for(exp: Experiment) -> list[str]:
|
||||
return [
|
||||
str(PYTHON_BIN),
|
||||
"scripts/train.py",
|
||||
"--config",
|
||||
exp.config,
|
||||
"--cfg-options",
|
||||
*exp.overrides,
|
||||
]
|
||||
|
||||
|
||||
def active_run_ids(run_root: Path, valid_run_ids: set[str] | None = None) -> set[str]:
|
||||
try:
|
||||
raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True)
|
||||
except subprocess.CalledProcessError:
|
||||
return set()
|
||||
pattern = re.compile(re.escape(str(run_root)) + r"/([^\s]+)")
|
||||
active: set[str] = set()
|
||||
for line in raw.splitlines():
|
||||
for match in pattern.finditer(line):
|
||||
run_id = match.group(1).strip("'\"")
|
||||
if run_id.endswith(".log") or "/" in run_id:
|
||||
continue
|
||||
if valid_run_ids is not None and run_id not in valid_run_ids:
|
||||
continue
|
||||
active.add(run_id)
|
||||
return active
|
||||
|
||||
|
||||
def completed_run_ids(run_root: Path) -> set[str]:
|
||||
return {
|
||||
path.parent.name
|
||||
for path in run_root.glob("*/summary.json")
|
||||
if path.is_file()
|
||||
}
|
||||
|
||||
|
||||
def print_commands(exps: list[Experiment]) -> None:
|
||||
for exp in exps:
|
||||
cmd = command_for(exp)
|
||||
print(f"\n# {exp.run_id}")
|
||||
print(" ".join(subprocess.list2cmdline([part]) for part in cmd))
|
||||
|
||||
|
||||
def run_commands(
|
||||
exps: list[Experiment],
|
||||
run_root: Path,
|
||||
max_parallel: int,
|
||||
run_retries: int,
|
||||
) -> int:
|
||||
logs_dir = run_root / "logs"
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
active: list[tuple[Experiment, subprocess.Popen, object]] = []
|
||||
valid_run_ids = {exp.run_id for exp in exps}
|
||||
skipped_completed = completed_run_ids(run_root)
|
||||
skipped_active = active_run_ids(run_root, valid_run_ids)
|
||||
pending: list[tuple[Experiment, int]] = [
|
||||
(exp, 0)
|
||||
for exp in exps
|
||||
if exp.run_id not in skipped_completed and exp.run_id not in skipped_active
|
||||
]
|
||||
for run_id in sorted(skipped_completed):
|
||||
print(f"[SKIP_COMPLETED] {run_id}", flush=True)
|
||||
for run_id in sorted(skipped_active):
|
||||
print(f"[SKIP_ACTIVE] {run_id}", flush=True)
|
||||
failures = 0
|
||||
|
||||
while pending or active:
|
||||
external_active = active_run_ids(run_root, valid_run_ids) - {exp.run_id for exp, _, _ in active}
|
||||
while pending and len(active) + len(external_active) < max_parallel:
|
||||
exp, attempt = pending.pop(0)
|
||||
log_path = logs_dir / f"{exp.run_id}.log"
|
||||
if attempt:
|
||||
log_path = logs_dir / f"{exp.run_id}.retry{attempt}.log"
|
||||
log_f = open(log_path, "w", encoding="utf-8")
|
||||
print(f"[START] {exp.run_id} attempt={attempt + 1} log={log_path}", flush=True)
|
||||
proc = subprocess.Popen(
|
||||
command_for(exp),
|
||||
cwd=PROJECT_ROOT,
|
||||
stdout=log_f,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
setattr(proc, "_attempt", attempt)
|
||||
active.append((exp, proc, log_f))
|
||||
|
||||
time.sleep(5)
|
||||
still_active: list[tuple[Experiment, subprocess.Popen, object]] = []
|
||||
for exp, proc, log_f in active:
|
||||
rc = proc.poll()
|
||||
if rc is None:
|
||||
still_active.append((exp, proc, log_f))
|
||||
continue
|
||||
log_f.close()
|
||||
if rc == 0:
|
||||
print(f"[DONE] {exp.run_id}", flush=True)
|
||||
else:
|
||||
if getattr(proc, "_attempt", 0) < run_retries:
|
||||
next_attempt = getattr(proc, "_attempt", 0) + 1
|
||||
pending.append((exp, next_attempt))
|
||||
print(f"[RETRY] {exp.run_id} rc={rc} next_attempt={next_attempt + 1}", flush=True)
|
||||
else:
|
||||
failures += 1
|
||||
print(f"[FAIL] {exp.run_id} rc={rc}", flush=True)
|
||||
active = still_active
|
||||
|
||||
return failures
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--groups",
|
||||
nargs="+",
|
||||
default=["default"],
|
||||
choices=[
|
||||
"default",
|
||||
"split",
|
||||
"batch",
|
||||
"mbs",
|
||||
"lr",
|
||||
"sched",
|
||||
"slown",
|
||||
"mod",
|
||||
"smodel",
|
||||
"longpair",
|
||||
"lrctrl",
|
||||
"all",
|
||||
],
|
||||
help="Experiment groups to include. Default: default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench",
|
||||
nargs="+",
|
||||
default=["searchqa", "spreadsheetbench"],
|
||||
choices=["searchqa", "spreadsheetbench", "livemathematicianbench", "alfworld", "docvqa"],
|
||||
)
|
||||
parser.add_argument("--run-root", default="", help="Output root. Default: outputs/ablation_<UTC timestamp>.")
|
||||
parser.add_argument("--max-parallel", type=int, default=1)
|
||||
parser.add_argument("--run-retries", type=int, default=1, help="Retry failed runs this many times. Default: 1.")
|
||||
parser.add_argument(
|
||||
"--include-duplicate-defaults",
|
||||
action="store_true",
|
||||
help="Also run ablation points that are exactly the default setting.",
|
||||
)
|
||||
parser.add_argument("--execute", action="store_true", help="Actually start runs. Without this, print commands only.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
groups = set(args.groups)
|
||||
if "all" in groups:
|
||||
groups = {"default", "split", "batch", "mbs", "lr", "sched", "slown", "mod", "smodel"}
|
||||
|
||||
ts = time.strftime("%Y%m%d_%H%M%S", time.gmtime())
|
||||
run_root = Path(args.run_root) if args.run_root else PROJECT_ROOT / "outputs" / f"ablation_{ts}"
|
||||
if not run_root.is_absolute():
|
||||
run_root = PROJECT_ROOT / run_root
|
||||
run_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
exps = build_matrix(
|
||||
groups,
|
||||
args.bench,
|
||||
run_root,
|
||||
include_duplicate_defaults=args.include_duplicate_defaults,
|
||||
)
|
||||
print(f"run_root={run_root}")
|
||||
print(f"num_experiments={len(exps)}")
|
||||
print(f"groups={','.join(sorted(groups))}")
|
||||
print(f"bench={','.join(args.bench)}")
|
||||
|
||||
if not args.execute:
|
||||
print_commands(exps)
|
||||
return
|
||||
|
||||
max_parallel = max(1, int(args.max_parallel))
|
||||
failures = run_commands(
|
||||
exps,
|
||||
run_root,
|
||||
max_parallel=max_parallel,
|
||||
run_retries=max(0, int(args.run_retries)),
|
||||
)
|
||||
if failures:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
68
scripts/run_alfworld.sh
Executable file
68
scripts/run_alfworld.sh
Executable file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# ReflACT — ALFWorld training launch script
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/run_alfworld.sh
|
||||
# bash scripts/run_alfworld.sh --num_epochs 2 --edit_budget 6
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
# ── Paths ────────────────────────────────────────────────────────────────────
|
||||
WORKSPACE="/home/azureuser/workspace-gzy"
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
|
||||
# Activate conda environment
|
||||
export PATH="${WORKSPACE}/miniconda3/envs/skillopt/bin:${WORKSPACE}/miniconda3/bin:${PATH}"
|
||||
|
||||
# ALFWorld data — uses ~/.cache/alfworld by default (standard alfworld location)
|
||||
export ALFWORLD_DATA="${ALFWORLD_DATA:-${HOME}/.cache/alfworld}"
|
||||
|
||||
# Ensure ReflACT is importable
|
||||
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
||||
|
||||
# ── Verify ALFWorld data exists ──────────────────────────────────────────────
|
||||
if [ ! -d "${ALFWORLD_DATA}/json_2.1.1" ]; then
|
||||
echo "ERROR: ALFWorld data not found at ${ALFWORLD_DATA}/json_2.1.1"
|
||||
echo ""
|
||||
echo "To download ALFWorld data, run:"
|
||||
echo " pip install alfworld[full]"
|
||||
echo " alfworld-download"
|
||||
echo ""
|
||||
echo "Or set ALFWORLD_DATA to the directory containing json_2.1.1/"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ── Azure OpenAI credentials ────────────────────────────────────────────────
|
||||
export AZURE_OPENAI_ENDPOINT="${AZURE_OPENAI_ENDPOINT:-https://agl-dev.cognitiveservices.azure.com/}"
|
||||
export AZURE_OPENAI_API_KEY="${AZURE_OPENAI_API_KEY:-<your-azure-openai-api-key>}"
|
||||
export AZURE_OPENAI_API_VERSION="${AZURE_OPENAI_API_VERSION:-2025-04-01-preview}"
|
||||
|
||||
# ── Model configuration ─────────────────────────────────────────────────────
|
||||
export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}"
|
||||
export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}"
|
||||
|
||||
# ── Output directory ─────────────────────────────────────────────────────────
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
DEFAULT_OUT_ROOT="${PROJECT_ROOT}/outputs/skillopt_alfworld_${STUDENT_DEPLOYMENT}_${TIMESTAMP}"
|
||||
|
||||
# ── Run ──────────────────────────────────────────────────────────────────────
|
||||
echo "============================================================"
|
||||
echo " ReflACT — Reflective Agent Tuning (ALFWorld)"
|
||||
echo "============================================================"
|
||||
echo " Teacher: ${TEACHER_DEPLOYMENT}"
|
||||
echo " Student: ${STUDENT_DEPLOYMENT}"
|
||||
echo " ALFWORLD_DATA: ${ALFWORLD_DATA}"
|
||||
echo " Output: ${DEFAULT_OUT_ROOT}"
|
||||
echo "============================================================"
|
||||
|
||||
cd "${PROJECT_ROOT}"
|
||||
|
||||
python scripts/train.py \
|
||||
--config configs/alfworld_default.yaml \
|
||||
--out_root "${DEFAULT_OUT_ROOT}" \
|
||||
"$@"
|
||||
|
||||
echo ""
|
||||
echo "Done! Results saved to: ${DEFAULT_OUT_ROOT}"
|
||||
94
scripts/run_meta_skill_ablation.sh
Executable file
94
scripts/run_meta_skill_ablation.sh
Executable file
@@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
PROJECT_ROOT="/home/azureuser/workspace-gzy/SkillReflection_dev"
|
||||
PYTHON_BIN="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
TS="$(date -u +%Y%m%d_%H%M%S)"
|
||||
RUN_ROOT="${PROJECT_ROOT}/outputs/meta_skill_ablation_${TS}"
|
||||
|
||||
mkdir -p "${RUN_ROOT}"
|
||||
|
||||
run_train() {
|
||||
local benchmark="$1"
|
||||
local reasoning="$2"
|
||||
local condition="$3"
|
||||
local config_path=""
|
||||
local reasoning_override=""
|
||||
local meta_skill_flag=""
|
||||
|
||||
case "${benchmark}" in
|
||||
searchqa)
|
||||
config_path="configs/searchqa/default.yaml"
|
||||
;;
|
||||
spreadsheetbench)
|
||||
config_path="configs/spreadsheetbench/default.yaml"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown benchmark: ${benchmark}" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
case "${reasoning}" in
|
||||
medium)
|
||||
reasoning_override="model.reasoning_effort=medium"
|
||||
;;
|
||||
none)
|
||||
reasoning_override="model.reasoning_effort="
|
||||
;;
|
||||
*)
|
||||
echo "Unknown reasoning setting: ${reasoning}" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
case "${condition}" in
|
||||
slow)
|
||||
meta_skill_flag="optimizer.use_meta_skill=false"
|
||||
;;
|
||||
slow_meta)
|
||||
meta_skill_flag="optimizer.use_meta_skill=true"
|
||||
;;
|
||||
*)
|
||||
echo "Unknown condition: ${condition}" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
local out_root="${RUN_ROOT}/${benchmark}_${reasoning}_${condition}"
|
||||
|
||||
echo
|
||||
echo "============================================================"
|
||||
echo "START ${benchmark} ${reasoning} ${condition}"
|
||||
echo "out_root=${out_root}"
|
||||
echo "============================================================"
|
||||
|
||||
(
|
||||
cd "${PROJECT_ROOT}"
|
||||
"${PYTHON_BIN}" scripts/train.py \
|
||||
--config "${config_path}" \
|
||||
--cfg-options \
|
||||
"${reasoning_override}" \
|
||||
"optimizer.use_slow_update=true" \
|
||||
"${meta_skill_flag}" \
|
||||
"optimizer.use_meta_reflect=false" \
|
||||
"gradient.use_deep_reflect=false" \
|
||||
"env.out_root=${out_root}"
|
||||
)
|
||||
|
||||
echo
|
||||
echo "============================================================"
|
||||
echo "DONE ${benchmark} ${reasoning} ${condition}"
|
||||
echo "============================================================"
|
||||
}
|
||||
|
||||
for benchmark in searchqa spreadsheetbench; do
|
||||
for reasoning in medium none; do
|
||||
run_train "${benchmark}" "${reasoning}" "slow"
|
||||
run_train "${benchmark}" "${reasoning}" "slow_meta"
|
||||
done
|
||||
done
|
||||
|
||||
echo
|
||||
echo "All runs completed."
|
||||
echo "Run root: ${RUN_ROOT}"
|
||||
43
scripts/run_missing_meta_parallel.sh
Normal file
43
scripts/run_missing_meta_parallel.sh
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
PROJECT_ROOT="/home/azureuser/workspace-gzy/SkillReflection_dev"
|
||||
PYTHON_BIN="/home/azureuser/workspace-gzy/miniconda3/envs/skillopt/bin/python"
|
||||
RUN_ROOT="${PROJECT_ROOT}/outputs/meta_skill_parallel_20260430_072356"
|
||||
LOG_DIR="${PROJECT_ROOT}/logs/meta_skill_parallel_20260430_072356"
|
||||
|
||||
mkdir -p "${RUN_ROOT}" "${LOG_DIR}"
|
||||
|
||||
start_run() {
|
||||
local name="$1"
|
||||
local config_path="$2"
|
||||
local meta_skill="$3"
|
||||
local out_root="${RUN_ROOT}/${name}"
|
||||
local log_path="${LOG_DIR}/${name}.log"
|
||||
|
||||
echo "[START] ${name}"
|
||||
echo " out_root=${out_root}"
|
||||
echo " log=${log_path}"
|
||||
|
||||
(
|
||||
cd "${PROJECT_ROOT}"
|
||||
PYTHONUNBUFFERED=1 "${PYTHON_BIN}" scripts/train.py \
|
||||
--config "${config_path}" \
|
||||
--cfg-options \
|
||||
"model.reasoning_effort=medium" \
|
||||
"optimizer.use_slow_update=true" \
|
||||
"optimizer.use_meta_skill=${meta_skill}" \
|
||||
"optimizer.use_meta_reflect=false" \
|
||||
"gradient.use_deep_reflect=false" \
|
||||
"env.out_root=${out_root}"
|
||||
) > "${log_path}" 2>&1 &
|
||||
|
||||
echo "$!" > "${LOG_DIR}/${name}.pid"
|
||||
}
|
||||
|
||||
start_run "searchqa_medium_slow_meta" "configs/searchqa/default.yaml" "true"
|
||||
start_run "spreadsheetbench_medium_slow" "configs/spreadsheetbench/default.yaml" "false"
|
||||
start_run "spreadsheetbench_medium_slow_meta" "configs/spreadsheetbench/default.yaml" "true"
|
||||
|
||||
echo "[WAIT] missing comparison runs are active"
|
||||
wait
|
||||
43
scripts/run_searchqa.sh
Executable file
43
scripts/run_searchqa.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# ReflACT — SearchQA training launch script
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/run_searchqa.sh
|
||||
# bash scripts/run_searchqa.sh --data_path data/searchqa_train_2000.json
|
||||
# bash scripts/run_searchqa.sh --num_epochs 2 --edit_budget 6
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
# ── Paths ────────────────────────────────────────────────────────────────────
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
|
||||
# Ensure ReflACT is importable
|
||||
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
||||
|
||||
# ── Model configuration ─────────────────────────────────────────────────────
|
||||
export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}"
|
||||
export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}"
|
||||
|
||||
# ── Output directory ─────────────────────────────────────────────────────────
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
DEFAULT_OUT_ROOT="${PROJECT_ROOT}/outputs/skillopt_searchqa_${STUDENT_DEPLOYMENT}_${TIMESTAMP}"
|
||||
|
||||
# ── Run ──────────────────────────────────────────────────────────────────────
|
||||
echo "============================================================"
|
||||
echo " ReflACT — Reflective Agent Tuning (SearchQA)"
|
||||
echo "============================================================"
|
||||
echo " Teacher: ${TEACHER_DEPLOYMENT}"
|
||||
echo " Student: ${STUDENT_DEPLOYMENT}"
|
||||
echo "============================================================"
|
||||
|
||||
cd "${PROJECT_ROOT}"
|
||||
|
||||
python scripts/train.py \
|
||||
--config configs/searchqa_default.yaml \
|
||||
--out_root "${DEFAULT_OUT_ROOT}" \
|
||||
"$@"
|
||||
|
||||
echo ""
|
||||
echo "Done! Results saved to: ${DEFAULT_OUT_ROOT}"
|
||||
48
scripts/run_spreadsheetbench.sh
Executable file
48
scripts/run_spreadsheetbench.sh
Executable file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# ReflACT — SpreadsheetBench training launch script
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/run_spreadsheetbench.sh \
|
||||
# --data_root /path/to/data \
|
||||
# --jsonl_path /path/to/benchmark.jsonl
|
||||
#
|
||||
# bash scripts/run_spreadsheetbench.sh \
|
||||
# --data_root /path/to/data \
|
||||
# --jsonl_path /path/to/benchmark.jsonl \
|
||||
# --num_epochs 2 --edit_budget 6
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
# ── Paths ────────────────────────────────────────────────────────────────────
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
|
||||
# Ensure ReflACT is importable
|
||||
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
||||
|
||||
# ── Model configuration ─────────────────────────────────────────────────────
|
||||
export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}"
|
||||
export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}"
|
||||
|
||||
# ── Output directory ─────────────────────────────────────────────────────────
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
DEFAULT_OUT_ROOT="${PROJECT_ROOT}/outputs/skillopt_spreadsheetbench_${STUDENT_DEPLOYMENT}_${TIMESTAMP}"
|
||||
|
||||
# ── Run ──────────────────────────────────────────────────────────────────────
|
||||
echo "============================================================"
|
||||
echo " ReflACT — Reflective Agent Tuning (SpreadsheetBench)"
|
||||
echo "============================================================"
|
||||
echo " Teacher: ${TEACHER_DEPLOYMENT}"
|
||||
echo " Student: ${STUDENT_DEPLOYMENT}"
|
||||
echo "============================================================"
|
||||
|
||||
cd "${PROJECT_ROOT}"
|
||||
|
||||
python scripts/train.py \
|
||||
--config configs/spreadsheetbench_default.yaml \
|
||||
--out_root "${DEFAULT_OUT_ROOT}" \
|
||||
"$@"
|
||||
|
||||
echo ""
|
||||
echo "Done! Results saved to: ${DEFAULT_OUT_ROOT}"
|
||||
483
scripts/train.py
Normal file
483
scripts/train.py
Normal file
@@ -0,0 +1,483 @@
|
||||
#!/usr/bin/env python3
|
||||
"""ReflACT unified training entry point.
|
||||
|
||||
Usage
|
||||
-----
|
||||
python scripts/train.py --config configs/alfworld/default.yaml
|
||||
|
||||
Any YAML key can be overridden from the command line::
|
||||
|
||||
python scripts/train.py --config configs/alfworld/default.yaml \\
|
||||
--batch_size 40 --num_epochs 2 --seed 123
|
||||
|
||||
Run ``python scripts/train.py --help`` for a full list of options.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Ensure the project root is on sys.path so ``import skillopt`` works
|
||||
# regardless of where the script is invoked from.
|
||||
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(_SCRIPT_DIR)
|
||||
if _PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, _PROJECT_ROOT)
|
||||
|
||||
from skillopt.model.common import default_model_for_backend, normalize_backend_name
|
||||
|
||||
_OPENAI_DEFAULT_MODEL_SENTINELS = {"gpt-5.4", "gpt-5.5"}
|
||||
|
||||
|
||||
# ── Environment registry ────────────────────────────────────────────────────
|
||||
|
||||
_ENV_REGISTRY: dict[str, type] = {}
|
||||
|
||||
|
||||
def _register_builtins() -> None:
|
||||
"""Lazy-import built-in adapters so we don't pull heavy deps at CLI parse time."""
|
||||
try:
|
||||
from skillopt.envs.alfworld.adapter import ALFWorldAdapter
|
||||
_ENV_REGISTRY["alfworld"] = ALFWorldAdapter
|
||||
except ImportError:
|
||||
pass # ALFWorld deps not installed — skip
|
||||
try:
|
||||
from skillopt.envs.searchqa.adapter import SearchQAAdapter
|
||||
_ENV_REGISTRY["searchqa"] = SearchQAAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.livemathematicianbench.adapter import LiveMathematicianBenchAdapter
|
||||
_ENV_REGISTRY["livemathematicianbench"] = LiveMathematicianBenchAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.babyvision.adapter import BabyVisionAdapter
|
||||
_ENV_REGISTRY["babyvision"] = BabyVisionAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.spreadsheetbench.adapter import SpreadsheetBenchAdapter
|
||||
_ENV_REGISTRY["spreadsheetbench"] = SpreadsheetBenchAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.mmrb.adapter import MMRBAdapter
|
||||
_ENV_REGISTRY["mmrb"] = MMRBAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.docvqa.adapter import DocVQAAdapter
|
||||
_ENV_REGISTRY["docvqa"] = DocVQAAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.mathverse.adapter import MathVerseAdapter
|
||||
_ENV_REGISTRY["mathverse"] = MathVerseAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.officeqa.adapter import OfficeQAAdapter
|
||||
_ENV_REGISTRY["officeqa"] = OfficeQAAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.sealqa.adapter import SealQAAdapter
|
||||
_ENV_REGISTRY["sealqa"] = SealQAAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from skillopt.envs.swebench.adapter import SWEBenchAdapter
|
||||
_ENV_REGISTRY["swebench"] = SWEBenchAdapter
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_adapter(cfg: dict):
|
||||
"""Instantiate the environment adapter specified in ``cfg["env"]``."""
|
||||
_register_builtins()
|
||||
env_name = cfg.get("env", "alfworld")
|
||||
if env_name not in _ENV_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Unknown environment '{env_name}'. "
|
||||
f"Available: {list(_ENV_REGISTRY.keys())}"
|
||||
)
|
||||
adapter_cls = _ENV_REGISTRY[env_name]
|
||||
|
||||
# Inspect adapter __init__ signature and only pass accepted kwargs
|
||||
import inspect
|
||||
sig = inspect.signature(adapter_cls.__init__)
|
||||
accepted = set(sig.parameters.keys()) - {"self"}
|
||||
adapter_kwargs: dict = {}
|
||||
for key in accepted:
|
||||
if key in cfg:
|
||||
adapter_kwargs[key] = cfg[key]
|
||||
|
||||
return adapter_cls(**adapter_kwargs)
|
||||
|
||||
|
||||
# ── CLI ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
_BOOL = lambda x: x.lower() in ("true", "1", "yes") # noqa: E731
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
description="ReflACT: Reflective Agent Tuning",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
p.add_argument("--config", type=str, required=True,
|
||||
help="Path to YAML config file")
|
||||
p.add_argument("--cfg-options", nargs="+", default=[],
|
||||
help="Override config: section.key=value (e.g. train.batch_size=40)")
|
||||
|
||||
# Legacy flat CLI overrides (still work, prefer --cfg-options for new usage)
|
||||
p.add_argument("--env", type=str)
|
||||
p.add_argument("--backend", type=str,
|
||||
choices=["azure_openai", "codex", "codex_exec", "claude", "claude_chat", "claude_code_exec"])
|
||||
p.add_argument("--teacher_model", type=str)
|
||||
p.add_argument("--student_model", type=str)
|
||||
p.add_argument("--teacher_backend", type=str)
|
||||
p.add_argument("--student_backend", type=str)
|
||||
p.add_argument("--reasoning_effort", type=str,
|
||||
choices=["", "low", "medium", "high", "xhigh", "max"])
|
||||
p.add_argument("--rewrite_reasoning_effort", type=str)
|
||||
p.add_argument("--rewrite_max_completion_tokens", type=int)
|
||||
p.add_argument("--azure_endpoint", type=str)
|
||||
p.add_argument("--azure_api_version", type=str)
|
||||
p.add_argument("--azure_api_key", type=str)
|
||||
p.add_argument("--azure_openai_endpoint", type=str)
|
||||
p.add_argument("--azure_openai_api_version", type=str)
|
||||
p.add_argument("--azure_openai_api_key", type=str)
|
||||
p.add_argument("--azure_openai_auth_mode", type=str)
|
||||
p.add_argument("--azure_openai_ad_scope", type=str)
|
||||
p.add_argument("--azure_openai_managed_identity_client_id", type=str)
|
||||
p.add_argument("--teacher_azure_openai_endpoint", type=str)
|
||||
p.add_argument("--teacher_azure_openai_api_version", type=str)
|
||||
p.add_argument("--teacher_azure_openai_api_key", type=str)
|
||||
p.add_argument("--teacher_azure_openai_auth_mode", type=str)
|
||||
p.add_argument("--teacher_azure_openai_ad_scope", type=str)
|
||||
p.add_argument("--teacher_azure_openai_managed_identity_client_id", type=str)
|
||||
p.add_argument("--student_azure_openai_endpoint", type=str)
|
||||
p.add_argument("--student_azure_openai_api_version", type=str)
|
||||
p.add_argument("--student_azure_openai_api_key", type=str)
|
||||
p.add_argument("--student_azure_openai_auth_mode", type=str)
|
||||
p.add_argument("--student_azure_openai_ad_scope", type=str)
|
||||
p.add_argument("--student_azure_openai_managed_identity_client_id", type=str)
|
||||
p.add_argument("--codex_exec_path", type=str)
|
||||
p.add_argument("--codex_exec_sandbox", type=str)
|
||||
p.add_argument("--codex_exec_profile", type=str)
|
||||
p.add_argument("--codex_exec_full_auto", type=_BOOL)
|
||||
p.add_argument("--codex_exec_reasoning_effort", type=str)
|
||||
p.add_argument("--codex_exec_use_sdk", type=str)
|
||||
p.add_argument("--codex_exec_network_access", type=_BOOL)
|
||||
p.add_argument("--codex_exec_web_search", type=_BOOL)
|
||||
p.add_argument("--codex_exec_approval_policy", type=str)
|
||||
p.add_argument("--claude_code_exec_path", type=str)
|
||||
p.add_argument("--claude_code_exec_profile", type=str)
|
||||
p.add_argument("--claude_code_exec_use_sdk", type=str)
|
||||
p.add_argument("--claude_code_exec_effort", type=str)
|
||||
p.add_argument("--claude_code_exec_max_thinking_tokens", type=int)
|
||||
p.add_argument("--codex_trace_to_teacher", type=_BOOL)
|
||||
p.add_argument("--skill_init", type=str)
|
||||
p.add_argument("--num_epochs", type=int)
|
||||
p.add_argument("--train_size", type=int)
|
||||
p.add_argument("--steps_per_epoch", type=int)
|
||||
p.add_argument("--batch_size", type=int)
|
||||
p.add_argument("--accumulation", type=int)
|
||||
p.add_argument("--seed", type=int)
|
||||
p.add_argument("--edit_budget", type=int)
|
||||
p.add_argument("--min_edit_budget", type=int)
|
||||
p.add_argument("--lr_scheduler", type=str,
|
||||
choices=["constant", "linear", "cosine", "autonomous"])
|
||||
p.add_argument("--lr_control_mode", type=str,
|
||||
choices=["fixed", "autonomous", "none"])
|
||||
p.add_argument("--merge_batch_size", type=int)
|
||||
p.add_argument("--max_analyst_rounds", type=int)
|
||||
p.add_argument("--sel_env_num", type=int)
|
||||
p.add_argument("--test_env_num", type=int)
|
||||
p.add_argument("--eval_test", type=_BOOL)
|
||||
p.add_argument("--use_gate", type=_BOOL)
|
||||
p.add_argument("--max_steps", type=int)
|
||||
p.add_argument("--max_api_workers", type=int)
|
||||
p.add_argument("--analyst_workers", type=int)
|
||||
p.add_argument("--failure_only", type=_BOOL)
|
||||
p.add_argument("--minibatch_size", type=int)
|
||||
p.add_argument("--use_meta_reflect", type=_BOOL)
|
||||
p.add_argument("--meta_edit_budget", type=int)
|
||||
p.add_argument("--skill_update_mode", type=str,
|
||||
choices=[
|
||||
"patch",
|
||||
"rewrite_from_suggestions",
|
||||
"rewrite",
|
||||
"suggestions",
|
||||
"full_rewrite",
|
||||
"full_rewrite_minibatch",
|
||||
"minibatch_full_rewrite",
|
||||
])
|
||||
p.add_argument("--use_deep_reflect", type=_BOOL)
|
||||
p.add_argument("--deep_reflect_failures", type=int)
|
||||
p.add_argument("--deep_reflect_successes", type=int)
|
||||
p.add_argument("--use_slow_update", type=_BOOL)
|
||||
p.add_argument("--slow_update_samples", type=int)
|
||||
p.add_argument("--longitudinal_pair_policy", type=str,
|
||||
choices=["mixed", "changed", "unchanged"])
|
||||
p.add_argument("--use_meta_skill", type=_BOOL)
|
||||
p.add_argument("--data_path", type=str)
|
||||
p.add_argument("--split_mode", type=str,
|
||||
choices=["ratio", "split_dir"])
|
||||
p.add_argument("--split_ratio", type=str)
|
||||
p.add_argument("--split_seed", type=int)
|
||||
p.add_argument("--split_dir", type=str)
|
||||
p.add_argument("--split_output_dir", type=str)
|
||||
p.add_argument("--data_root", type=str)
|
||||
p.add_argument("--max_turns", type=int)
|
||||
p.add_argument("--workers", type=int)
|
||||
p.add_argument("--limit", type=int)
|
||||
p.add_argument("--shuffle_choices", type=_BOOL)
|
||||
p.add_argument("--use_theorem", type=_BOOL)
|
||||
p.add_argument("--use_sketch", type=_BOOL)
|
||||
p.add_argument("--image_detail", type=str)
|
||||
p.add_argument("--judge_model", type=str)
|
||||
p.add_argument("--judge_max_completion_tokens", type=int)
|
||||
p.add_argument("--judge_retries", type=int)
|
||||
p.add_argument("--out_root", type=str)
|
||||
p.add_argument("--mode", type=str)
|
||||
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
# ── Flat key → structured path mapping (for legacy CLI → structured config) ──
|
||||
|
||||
_LEGACY_TO_STRUCTURED: dict[str, str] = {
|
||||
"backend": "model.backend",
|
||||
"teacher_model": "model.teacher",
|
||||
"student_model": "model.student",
|
||||
"teacher_backend": "model.teacher_backend",
|
||||
"student_backend": "model.student_backend",
|
||||
"reasoning_effort": "model.reasoning_effort",
|
||||
"rewrite_reasoning_effort": "model.rewrite_reasoning_effort",
|
||||
"rewrite_max_completion_tokens": "model.rewrite_max_completion_tokens",
|
||||
"azure_endpoint": "model.azure_endpoint",
|
||||
"azure_api_version": "model.azure_api_version",
|
||||
"azure_api_key": "model.azure_api_key",
|
||||
"azure_openai_endpoint": "model.azure_openai_endpoint",
|
||||
"azure_openai_api_version": "model.azure_openai_api_version",
|
||||
"azure_openai_api_key": "model.azure_openai_api_key",
|
||||
"azure_openai_auth_mode": "model.azure_openai_auth_mode",
|
||||
"azure_openai_ad_scope": "model.azure_openai_ad_scope",
|
||||
"azure_openai_managed_identity_client_id": "model.azure_openai_managed_identity_client_id",
|
||||
"teacher_azure_openai_endpoint": "model.teacher_azure_openai_endpoint",
|
||||
"teacher_azure_openai_api_version": "model.teacher_azure_openai_api_version",
|
||||
"teacher_azure_openai_api_key": "model.teacher_azure_openai_api_key",
|
||||
"teacher_azure_openai_auth_mode": "model.teacher_azure_openai_auth_mode",
|
||||
"teacher_azure_openai_ad_scope": "model.teacher_azure_openai_ad_scope",
|
||||
"teacher_azure_openai_managed_identity_client_id": "model.teacher_azure_openai_managed_identity_client_id",
|
||||
"student_azure_openai_endpoint": "model.student_azure_openai_endpoint",
|
||||
"student_azure_openai_api_version": "model.student_azure_openai_api_version",
|
||||
"student_azure_openai_api_key": "model.student_azure_openai_api_key",
|
||||
"student_azure_openai_auth_mode": "model.student_azure_openai_auth_mode",
|
||||
"student_azure_openai_ad_scope": "model.student_azure_openai_ad_scope",
|
||||
"student_azure_openai_managed_identity_client_id": "model.student_azure_openai_managed_identity_client_id",
|
||||
"codex_exec_path": "model.codex_exec_path",
|
||||
"codex_exec_sandbox": "model.codex_exec_sandbox",
|
||||
"codex_exec_profile": "model.codex_exec_profile",
|
||||
"codex_exec_full_auto": "model.codex_exec_full_auto",
|
||||
"codex_exec_reasoning_effort": "model.codex_exec_reasoning_effort",
|
||||
"codex_exec_use_sdk": "model.codex_exec_use_sdk",
|
||||
"codex_exec_network_access": "model.codex_exec_network_access",
|
||||
"codex_exec_web_search": "model.codex_exec_web_search",
|
||||
"codex_exec_approval_policy": "model.codex_exec_approval_policy",
|
||||
"claude_code_exec_path": "model.claude_code_exec_path",
|
||||
"claude_code_exec_profile": "model.claude_code_exec_profile",
|
||||
"claude_code_exec_use_sdk": "model.claude_code_exec_use_sdk",
|
||||
"claude_code_exec_effort": "model.claude_code_exec_effort",
|
||||
"claude_code_exec_max_thinking_tokens": "model.claude_code_exec_max_thinking_tokens",
|
||||
"codex_trace_to_teacher": "model.codex_trace_to_teacher",
|
||||
"num_epochs": "train.num_epochs",
|
||||
"train_size": "train.train_size",
|
||||
"steps_per_epoch": "train.steps_per_epoch",
|
||||
"batch_size": "train.batch_size",
|
||||
"accumulation": "train.accumulation",
|
||||
"seed": "train.seed",
|
||||
"minibatch_size": "gradient.minibatch_size",
|
||||
"merge_batch_size": "gradient.merge_batch_size",
|
||||
"analyst_workers": "gradient.analyst_workers",
|
||||
"max_analyst_rounds": "gradient.max_analyst_rounds",
|
||||
"failure_only": "gradient.failure_only",
|
||||
"use_deep_reflect": "gradient.use_deep_reflect",
|
||||
"deep_reflect_failures": "gradient.deep_reflect_failures",
|
||||
"deep_reflect_successes": "gradient.deep_reflect_successes",
|
||||
"edit_budget": "optimizer.learning_rate",
|
||||
"min_edit_budget": "optimizer.min_learning_rate",
|
||||
"lr_scheduler": "optimizer.lr_scheduler",
|
||||
"lr_control_mode": "optimizer.lr_control_mode",
|
||||
"skill_update_mode": "optimizer.skill_update_mode",
|
||||
"use_meta_reflect": "optimizer.use_meta_reflect",
|
||||
"meta_edit_budget": "optimizer.meta_learning_rate",
|
||||
"use_slow_update": "optimizer.use_slow_update",
|
||||
"slow_update_samples": "optimizer.slow_update_samples",
|
||||
"longitudinal_pair_policy": "optimizer.longitudinal_pair_policy",
|
||||
"use_meta_skill": "optimizer.use_meta_skill",
|
||||
"use_gate": "evaluation.use_gate",
|
||||
"sel_env_num": "evaluation.sel_env_num",
|
||||
"test_env_num": "evaluation.test_env_num",
|
||||
"eval_test": "evaluation.eval_test",
|
||||
"env": "env.name",
|
||||
"skill_init": "env.skill_init",
|
||||
"out_root": "env.out_root",
|
||||
}
|
||||
|
||||
|
||||
def load_config(args: argparse.Namespace) -> dict:
|
||||
"""Load config with _base_ inheritance, then apply CLI overrides."""
|
||||
from skillopt.config import load_config as _load, flatten_config, is_structured
|
||||
|
||||
cfg = _load(args.config, overrides=args.cfg_options)
|
||||
structured = is_structured(cfg)
|
||||
|
||||
# Apply legacy --key value overrides
|
||||
cli = {k: v for k, v in vars(args).items()
|
||||
if v is not None and k not in ("config", "cfg_options")}
|
||||
if cli:
|
||||
if structured:
|
||||
from skillopt.config import apply_overrides
|
||||
mapped = []
|
||||
for k, v in cli.items():
|
||||
dotted = _LEGACY_TO_STRUCTURED.get(k)
|
||||
if dotted:
|
||||
mapped.append(f"{dotted}={v}")
|
||||
else:
|
||||
mapped.append(f"env.{k}={v}")
|
||||
apply_overrides(cfg, mapped)
|
||||
else:
|
||||
cfg.update(cli)
|
||||
|
||||
# Flatten structured config → flat dict for trainer/adapter
|
||||
flat = flatten_config(cfg) if structured else cfg
|
||||
|
||||
for new_key, old_key in (
|
||||
("azure_openai_endpoint", "azure_endpoint"),
|
||||
("azure_openai_api_version", "azure_api_version"),
|
||||
("azure_openai_api_key", "azure_api_key"),
|
||||
):
|
||||
if flat.get(new_key) in (None, "") and flat.get(old_key) not in (None, ""):
|
||||
flat[new_key] = flat[old_key]
|
||||
|
||||
explicit_backend = getattr(args, "backend", None)
|
||||
if explicit_backend is None:
|
||||
for option in args.cfg_options or []:
|
||||
key = str(option).split("=", 1)[0].strip()
|
||||
if key == "model.backend":
|
||||
explicit_backend = str(option).split("=", 1)[1].strip()
|
||||
break
|
||||
|
||||
backend = normalize_backend_name(flat.get("model_backend") or flat.get("student_backend") or "azure_openai")
|
||||
|
||||
def _has_model_override(dotted_key: str, legacy_key: str) -> bool:
|
||||
if getattr(args, legacy_key, None) is not None:
|
||||
return True
|
||||
for option in args.cfg_options or []:
|
||||
key = str(option).split("=", 1)[0].strip()
|
||||
if key == dotted_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
if explicit_backend is not None:
|
||||
backend = normalize_backend_name(explicit_backend)
|
||||
flat["model_backend"] = backend
|
||||
if backend in {"claude", "claude_chat"}:
|
||||
flat.setdefault("teacher_backend", "claude_chat")
|
||||
flat.setdefault("student_backend", "claude_chat")
|
||||
elif backend in {"codex", "codex_exec"}:
|
||||
flat.setdefault("teacher_backend", "openai_chat")
|
||||
flat.setdefault("student_backend", "codex_exec")
|
||||
elif backend == "claude_code_exec":
|
||||
flat.setdefault("teacher_backend", "openai_chat")
|
||||
flat.setdefault("student_backend", "claude_code_exec")
|
||||
else:
|
||||
flat.setdefault("teacher_backend", "openai_chat")
|
||||
flat.setdefault("student_backend", "openai_chat")
|
||||
else:
|
||||
flat.setdefault("teacher_backend", "openai_chat")
|
||||
flat.setdefault("student_backend", "openai_chat")
|
||||
|
||||
if flat.get("teacher_backend") == "claude_chat":
|
||||
if (
|
||||
str(flat.get("teacher_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS
|
||||
and not _has_model_override("model.teacher", "teacher_model")
|
||||
):
|
||||
flat["teacher_model"] = default_model_for_backend("claude_chat")
|
||||
if flat.get("student_backend") == "claude_chat":
|
||||
if (
|
||||
str(flat.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS
|
||||
and not _has_model_override("model.student", "student_model")
|
||||
):
|
||||
flat["student_model"] = default_model_for_backend("claude_chat")
|
||||
if flat.get("student_backend") == "claude_code_exec":
|
||||
if (
|
||||
str(flat.get("student_model", "") or "").strip() in _OPENAI_DEFAULT_MODEL_SENTINELS
|
||||
and not _has_model_override("model.student", "student_model")
|
||||
):
|
||||
flat["student_model"] = default_model_for_backend("claude_chat")
|
||||
|
||||
# Auto-generate output root
|
||||
if not flat.get("out_root"):
|
||||
env = flat.get("env", "unknown")
|
||||
model = flat.get("teacher_model", "unknown").replace("/", "-")
|
||||
ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
flat["out_root"] = os.path.join("outputs", f"skillopt_{env}_{model}_{ts}")
|
||||
|
||||
flat["out_root"] = os.path.abspath(flat["out_root"])
|
||||
return flat
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
cfg = load_config(args)
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f" ReflACT — Reflective Agent Tuning")
|
||||
print(f"{'='*60}")
|
||||
print(f" env: {cfg.get('env')}")
|
||||
print(f" teacher_model: {cfg.get('teacher_model')}")
|
||||
print(f" student_model: {cfg.get('student_model')}")
|
||||
print(f" teacher_backend:{cfg.get('teacher_backend', 'openai_chat')}")
|
||||
print(f" student_backend:{cfg.get('student_backend', 'openai_chat')}")
|
||||
print(f" reasoning: {cfg.get('reasoning_effort') or 'off'}")
|
||||
print(f" rewrite_effort: {cfg.get('rewrite_reasoning_effort') or 'off'}")
|
||||
print(f" epochs: {cfg.get('num_epochs')}")
|
||||
print(f" train_size: {cfg.get('train_size') or 'from dataset'}")
|
||||
print(f" steps/epoch: auto")
|
||||
print(f" batch_size: {cfg.get('batch_size')}")
|
||||
print(f" edit_budget: {cfg.get('edit_budget')}")
|
||||
print(f" lr_scheduler: {cfg.get('lr_scheduler', 'constant')}")
|
||||
print(f" update_mode: {cfg.get('skill_update_mode', 'patch')}")
|
||||
print(f" min_edit_budget:{cfg.get('min_edit_budget', 2)}")
|
||||
print(f" minibatch_size: {cfg.get('minibatch_size')}")
|
||||
print(f" seed: {cfg.get('seed')}")
|
||||
print(f" meta_reflect: {cfg.get('use_meta_reflect', False)}")
|
||||
print(f" meta_skill: {cfg.get('use_meta_skill', False)}")
|
||||
print(f" out_root: {cfg.get('out_root')}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Build adapter
|
||||
adapter = get_adapter(cfg)
|
||||
|
||||
# Build trainer and run
|
||||
from skillopt.engine.trainer import ReflACTTrainer
|
||||
trainer = ReflACTTrainer(cfg, adapter)
|
||||
summary = trainer.train()
|
||||
|
||||
print(f"\n Output saved to: {cfg['out_root']}")
|
||||
if summary.get("test_hard") is not None:
|
||||
print(f" Final test: {summary['test_hard']:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
43
scripts/train_searchqa.sh
Executable file
43
scripts/train_searchqa.sh
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# ReflACT3 — SearchQA Training
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/train_searchqa.sh
|
||||
# bash scripts/train_searchqa.sh --limit 50
|
||||
# bash scripts/train_searchqa.sh --num_epochs 2 --workers 32
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
||||
|
||||
# ── Models ───────────────────────────────────────────────────────────────────
|
||||
export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}"
|
||||
export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}"
|
||||
|
||||
# ── Data ─────────────────────────────────────────────────────────────────────
|
||||
DATA_PATH="/home/azureuser/workspace-yqh/refleAct/search-qa/data/searchqa_train_2000.json"
|
||||
|
||||
# ── Output ───────────────────────────────────────────────────────────────────
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
DEFAULT_OUT_ROOT="${PROJECT_ROOT}/outputs/searchqa-metaskill/searchqa_${STUDENT_DEPLOYMENT}"
|
||||
|
||||
echo "============================================================"
|
||||
echo " ReflACT3 — SearchQA Training"
|
||||
echo " Teacher: ${TEACHER_DEPLOYMENT}"
|
||||
echo " Student: ${STUDENT_DEPLOYMENT}"
|
||||
echo " Data: ${DATA_PATH}"
|
||||
echo "============================================================"
|
||||
|
||||
cd "${PROJECT_ROOT}"
|
||||
|
||||
python scripts/train.py \
|
||||
--config configs/searchqa_default.yaml \
|
||||
--data_path "${DATA_PATH}" \
|
||||
--out_root "${DEFAULT_OUT_ROOT}" \
|
||||
"$@"
|
||||
|
||||
echo ""
|
||||
echo "Done! Results saved to: ${DEFAULT_OUT_ROOT}"
|
||||
48
scripts/train_spreadsheet_multi.sh
Executable file
48
scripts/train_spreadsheet_multi.sh
Executable file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# ReflACT — SpreadsheetBench training (MULTI-ROUND codegen, no tool-call)
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/train_spreadsheet_multi.sh
|
||||
# bash scripts/train_spreadsheet_multi.sh --num_epochs 2 --max_turns 5
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
|
||||
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
||||
export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}"
|
||||
export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}"
|
||||
|
||||
DATA_ROOT="/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400"
|
||||
JSONL_PATH="${DATA_ROOT}/dataset.json"
|
||||
SPLIT_DIR="/home/azureuser/workspace-yqh/refleACT3/data/spreadsheetbench_split_2_1_7"
|
||||
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
OUT_ROOT="${PROJECT_ROOT}/outputs/spreadsheet-metaskill-new/train_multi_${STUDENT_DEPLOYMENT}"
|
||||
|
||||
echo "============================================================"
|
||||
echo " ReflACT — SpreadsheetBench Training (MULTI-ROUND)"
|
||||
echo "============================================================"
|
||||
echo " Teacher: ${TEACHER_DEPLOYMENT}"
|
||||
echo " Student: ${STUDENT_DEPLOYMENT}"
|
||||
echo " Mode: multi"
|
||||
echo " Data: ${DATA_ROOT}"
|
||||
echo " Split: ${SPLIT_DIR}"
|
||||
echo " Output: ${OUT_ROOT}"
|
||||
echo "============================================================"
|
||||
|
||||
cd "${PROJECT_ROOT}"
|
||||
|
||||
python scripts/train.py \
|
||||
--config configs/spreadsheetbench_default.yaml \
|
||||
--mode multi \
|
||||
--data_root "${DATA_ROOT}" \
|
||||
--jsonl_path "${JSONL_PATH}" \
|
||||
--split_dir "${SPLIT_DIR}" \
|
||||
--out_root "${OUT_ROOT}" \
|
||||
"$@"
|
||||
|
||||
echo ""
|
||||
echo "Done! Results saved to: ${OUT_ROOT}"
|
||||
48
scripts/train_spreadsheet_single.sh
Executable file
48
scripts/train_spreadsheet_single.sh
Executable file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env bash
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
# ReflACT — SpreadsheetBench training (SINGLE-ROUND codegen, no tool-call)
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/train_spreadsheet_single.sh
|
||||
# bash scripts/train_spreadsheet_single.sh --num_epochs 2 --edit_budget 6
|
||||
# ──────────────────────────────────────────────────────────────────────────────
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
|
||||
export PYTHONPATH="${PROJECT_ROOT}:${PYTHONPATH:-}"
|
||||
export TEACHER_DEPLOYMENT="${TEACHER_DEPLOYMENT:-gpt-5.5}"
|
||||
export STUDENT_DEPLOYMENT="${STUDENT_DEPLOYMENT:-gpt-5.5}"
|
||||
|
||||
DATA_ROOT="/home/azureuser/workspace-yqh/sr/spreadsheetbench/data/spreadsheetbench_verified_400"
|
||||
JSONL_PATH="${DATA_ROOT}/dataset.json"
|
||||
SPLIT_DIR="/home/azureuser/workspace-yqh/refleACT3/data/spreadsheetbench_split_2_1_7"
|
||||
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
OUT_ROOT="${PROJECT_ROOT}/outputs/spreadsheet-metaskill-new/train_single_${STUDENT_DEPLOYMENT}"
|
||||
|
||||
echo "============================================================"
|
||||
echo " ReflACT — SpreadsheetBench Training (SINGLE-ROUND)"
|
||||
echo "============================================================"
|
||||
echo " Teacher: ${TEACHER_DEPLOYMENT}"
|
||||
echo " Student: ${STUDENT_DEPLOYMENT}"
|
||||
echo " Mode: single"
|
||||
echo " Data: ${DATA_ROOT}"
|
||||
echo " Split: ${SPLIT_DIR}"
|
||||
echo " Output: ${OUT_ROOT}"
|
||||
echo "============================================================"
|
||||
|
||||
cd "${PROJECT_ROOT}"
|
||||
|
||||
python scripts/train.py \
|
||||
--config configs/spreadsheetbench_default.yaml \
|
||||
--mode single \
|
||||
--data_root "${DATA_ROOT}" \
|
||||
--jsonl_path "${JSONL_PATH}" \
|
||||
--split_dir "${SPLIT_DIR}" \
|
||||
--out_root "${OUT_ROOT}" \
|
||||
"$@"
|
||||
|
||||
echo ""
|
||||
echo "Done! Results saved to: ${OUT_ROOT}"
|
||||
218
scripts/watch_ablation.py
Normal file
218
scripts/watch_ablation.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Watch an ablation run root and rerun final failures.
|
||||
|
||||
This watcher is intended to run in tmux next to scripts/run_ablation_matrix.py.
|
||||
It writes STATUS.md on every poll and starts a direct rerun for any run that
|
||||
the launcher marks as final [FAIL].
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from run_ablation_matrix import PROJECT_ROOT, build_matrix, command_for
|
||||
|
||||
|
||||
RUN_RE = re.compile(r"\[(START|DONE|FAIL|RETRY)\]\s+([^\s]+)")
|
||||
ERROR_RE = re.compile(
|
||||
r"Traceback|RuntimeError|AuthenticationError|PermissionDenied|"
|
||||
r"DeploymentNotFound|LLM call failed|LLM message call failed|"
|
||||
r"BadRequestError|RateLimitError",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def read_text(path: Path) -> str:
|
||||
try:
|
||||
return path.read_text(encoding="utf-8", errors="replace")
|
||||
except FileNotFoundError:
|
||||
return ""
|
||||
|
||||
|
||||
def parse_launcher(path: Path) -> dict[str, list[str]]:
|
||||
events = {"START": [], "DONE": [], "FAIL": [], "RETRY": []}
|
||||
for line in read_text(path).splitlines():
|
||||
match = RUN_RE.search(line)
|
||||
if match:
|
||||
events[match.group(1)].append(match.group(2))
|
||||
return events
|
||||
|
||||
|
||||
def active_run_ids(run_root: Path) -> list[str]:
|
||||
try:
|
||||
raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True)
|
||||
except subprocess.CalledProcessError:
|
||||
return []
|
||||
active: list[str] = []
|
||||
pattern = re.compile(re.escape(str(run_root)) + r"/([^\s]+)")
|
||||
for line in raw.splitlines():
|
||||
for match in pattern.finditer(line):
|
||||
active.append(match.group(1))
|
||||
return sorted(set(active))
|
||||
|
||||
|
||||
def scan_errors(logs_dir: Path) -> dict[str, str]:
|
||||
errors: dict[str, str] = {}
|
||||
for log_path in sorted(logs_dir.glob("*.log")):
|
||||
text = read_text(log_path)
|
||||
match = ERROR_RE.search(text)
|
||||
if match:
|
||||
run_id = log_path.name.split(".watchrerun", 1)[0].removesuffix(".log")
|
||||
errors[run_id] = match.group(0)
|
||||
return errors
|
||||
|
||||
|
||||
def load_state(path: Path) -> dict:
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return {"reruns": {}}
|
||||
|
||||
|
||||
def save_state(path: Path, state: dict) -> None:
|
||||
tmp = path.with_suffix(".tmp")
|
||||
tmp.write_text(json.dumps(state, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||
tmp.replace(path)
|
||||
|
||||
|
||||
def write_status(
|
||||
run_root: Path,
|
||||
total: int,
|
||||
events: dict[str, list[str]],
|
||||
active: list[str],
|
||||
completed: list[str],
|
||||
pending: list[str],
|
||||
errors: dict[str, str],
|
||||
reruns: dict[str, int],
|
||||
) -> None:
|
||||
now = time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime())
|
||||
failed = sorted(set(events["FAIL"]))
|
||||
retrying = sorted(set(events["RETRY"]))
|
||||
lines = [
|
||||
"# Ablation Status",
|
||||
"",
|
||||
f"Updated: {now}",
|
||||
f"Run root: `{run_root}`",
|
||||
"",
|
||||
"| Metric | Count |",
|
||||
"| --- | ---: |",
|
||||
f"| Total planned | {total} |",
|
||||
f"| Completed summaries | {len(completed)} |",
|
||||
f"| Active train processes | {len(active)} |",
|
||||
f"| Pending/not summarized | {len(pending)} |",
|
||||
f"| Launcher final fails | {len(failed)} |",
|
||||
f"| Launcher retries | {len(retrying)} |",
|
||||
f"| Logs with error patterns | {len(errors)} |",
|
||||
"",
|
||||
"## Active",
|
||||
"",
|
||||
*(f"- `{run_id}`" for run_id in active),
|
||||
"",
|
||||
"## Final Fails",
|
||||
"",
|
||||
*(f"- `{run_id}` watcher_reruns={reruns.get(run_id, 0)}" for run_id in failed),
|
||||
"",
|
||||
"## Error Patterns",
|
||||
"",
|
||||
*(f"- `{run_id}`: `{err}`" for run_id, err in sorted(errors.items())),
|
||||
"",
|
||||
"## Recent Launcher",
|
||||
"",
|
||||
"```text",
|
||||
"\n".join(read_text(run_root / "launcher.log").splitlines()[-30:]),
|
||||
"```",
|
||||
]
|
||||
(run_root / "STATUS.md").write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--run-root", required=True)
|
||||
parser.add_argument("--interval", type=int, default=60)
|
||||
parser.add_argument("--watcher-retries", type=int, default=1)
|
||||
parser.add_argument("--groups", nargs="+", default=["all"])
|
||||
parser.add_argument("--bench", nargs="+", default=["searchqa", "spreadsheetbench"])
|
||||
args = parser.parse_args()
|
||||
|
||||
run_root = Path(args.run_root).resolve()
|
||||
logs_dir = run_root / "logs"
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
state_path = run_root / "watcher_state.json"
|
||||
|
||||
groups = set(args.groups)
|
||||
if "all" in groups:
|
||||
groups = {"default", "split", "mbs", "lr", "sched", "slown", "mod", "smodel"}
|
||||
experiments = {
|
||||
exp.run_id: exp
|
||||
for exp in build_matrix(groups, args.bench, run_root, include_duplicate_defaults=False)
|
||||
}
|
||||
|
||||
active_reruns: dict[str, subprocess.Popen] = {}
|
||||
while True:
|
||||
state = load_state(state_path)
|
||||
reruns = state.setdefault("reruns", {})
|
||||
events = parse_launcher(run_root / "launcher.log")
|
||||
active = active_run_ids(run_root)
|
||||
completed = sorted(
|
||||
run_id for run_id in experiments
|
||||
if (run_root / run_id / "summary.json").exists()
|
||||
)
|
||||
pending = sorted(set(experiments) - set(completed))
|
||||
errors = scan_errors(logs_dir)
|
||||
|
||||
# Reap watcher-started reruns.
|
||||
for run_id, proc in list(active_reruns.items()):
|
||||
rc = proc.poll()
|
||||
if rc is None:
|
||||
continue
|
||||
active_reruns.pop(run_id, None)
|
||||
with open(logs_dir / f"{run_id}.watcher.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"\n[WATCHER_DONE] rc={rc} time={time.time()}\n")
|
||||
|
||||
for run_id in sorted(set(events["FAIL"])):
|
||||
if run_id not in experiments:
|
||||
continue
|
||||
if (run_root / run_id / "summary.json").exists():
|
||||
continue
|
||||
if run_id in active or run_id in active_reruns:
|
||||
continue
|
||||
count = int(reruns.get(run_id, 0))
|
||||
if count >= args.watcher_retries:
|
||||
continue
|
||||
reruns[run_id] = count + 1
|
||||
save_state(state_path, state)
|
||||
log_path = logs_dir / f"{run_id}.watchrerun{count + 1}.log"
|
||||
with open(log_path, "w", encoding="utf-8") as log_f:
|
||||
log_f.write(f"[WATCHER_START] run_id={run_id} attempt={count + 1}\n")
|
||||
log_f.flush()
|
||||
proc = subprocess.Popen(
|
||||
command_for(experiments[run_id]),
|
||||
cwd=PROJECT_ROOT,
|
||||
stdout=log_f,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
close_fds=True,
|
||||
)
|
||||
active_reruns[run_id] = proc
|
||||
|
||||
save_state(state_path, state)
|
||||
write_status(
|
||||
run_root=run_root,
|
||||
total=len(experiments),
|
||||
events=events,
|
||||
active=active,
|
||||
completed=completed,
|
||||
pending=pending,
|
||||
errors=errors,
|
||||
reruns={k: int(v) for k, v in reruns.items()},
|
||||
)
|
||||
time.sleep(max(5, int(args.interval)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
29
skillopt/__init__.py
Normal file
29
skillopt/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""ReflACT: Reflective Agent Tuning.
|
||||
|
||||
A general-purpose framework for iteratively optimizing LLM agent skills
|
||||
through structured reflection and self-improvement.
|
||||
|
||||
Pipeline stages:
|
||||
1. Rollout — execute episodes with current skill
|
||||
2. Reflect — analyze trajectories, generate patches
|
||||
3. Aggregate — hierarchical merge of patches
|
||||
4. Select — rank and select top edits
|
||||
5. Update — apply edits to skill document
|
||||
6. Evaluate — validate candidate skill, accept/reject
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from skillopt.types import ( # noqa: F401
|
||||
BatchSpec,
|
||||
Edit,
|
||||
EditOp,
|
||||
FailureSummaryEntry,
|
||||
GateAction,
|
||||
GateResult,
|
||||
MetaReflectResult,
|
||||
Patch,
|
||||
RawPatch,
|
||||
RolloutResult,
|
||||
SlowUpdateResult,
|
||||
)
|
||||
263
skillopt/config.py
Normal file
263
skillopt/config.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""ReflACT config loading engine — structured YAML with inheritance.
|
||||
|
||||
Supports two config formats:
|
||||
1. **Structured** (new): sections like ``model``, ``train``, ``gradient``,
|
||||
``optimizer``, ``evaluation``, ``env`` — with ``_base_`` inheritance.
|
||||
2. **Flat** (legacy): all keys at top level — fully backward compatible.
|
||||
|
||||
Usage::
|
||||
|
||||
from skillopt.config import load_config, flatten_config
|
||||
|
||||
cfg = load_config("configs/searchqa_default.yaml")
|
||||
flat = flatten_config(cfg) # always returns flat dict for trainer
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
# ── Section names that indicate a structured config ──────────────────────
|
||||
|
||||
_STRUCTURED_SECTIONS = frozenset({
|
||||
"model", "train", "gradient", "optimizer", "evaluation", "env",
|
||||
})
|
||||
|
||||
# ── Structured → flat key mapping ────────────────────────────────────────
|
||||
|
||||
_FLATTEN_MAP: dict[str, str] = {
|
||||
"model.backend": "model_backend",
|
||||
"model.teacher": "teacher_model",
|
||||
"model.student": "student_model",
|
||||
"model.teacher_backend": "teacher_backend",
|
||||
"model.student_backend": "student_backend",
|
||||
"model.reasoning_effort": "reasoning_effort",
|
||||
"model.rewrite_reasoning_effort": "rewrite_reasoning_effort",
|
||||
"model.rewrite_max_completion_tokens": "rewrite_max_completion_tokens",
|
||||
"model.codex_exec_path": "codex_exec_path",
|
||||
"model.codex_exec_sandbox": "codex_exec_sandbox",
|
||||
"model.codex_exec_profile": "codex_exec_profile",
|
||||
"model.codex_exec_full_auto": "codex_exec_full_auto",
|
||||
"model.codex_exec_reasoning_effort": "codex_exec_reasoning_effort",
|
||||
"model.codex_exec_use_sdk": "codex_exec_use_sdk",
|
||||
"model.codex_exec_network_access": "codex_exec_network_access",
|
||||
"model.codex_exec_web_search": "codex_exec_web_search",
|
||||
"model.codex_exec_approval_policy": "codex_exec_approval_policy",
|
||||
"model.claude_code_exec_path": "claude_code_exec_path",
|
||||
"model.claude_code_exec_profile": "claude_code_exec_profile",
|
||||
"model.claude_code_exec_use_sdk": "claude_code_exec_use_sdk",
|
||||
"model.claude_code_exec_effort": "claude_code_exec_effort",
|
||||
"model.claude_code_exec_max_thinking_tokens": "claude_code_exec_max_thinking_tokens",
|
||||
"model.codex_trace_to_teacher": "codex_trace_to_teacher",
|
||||
"model.azure_endpoint": "azure_endpoint",
|
||||
"model.azure_api_version": "azure_api_version",
|
||||
"model.azure_api_key": "azure_api_key",
|
||||
"model.azure_openai_endpoint": "azure_openai_endpoint",
|
||||
"model.azure_openai_api_version": "azure_openai_api_version",
|
||||
"model.azure_openai_api_key": "azure_openai_api_key",
|
||||
"model.azure_openai_auth_mode": "azure_openai_auth_mode",
|
||||
"model.azure_openai_ad_scope": "azure_openai_ad_scope",
|
||||
"model.azure_openai_managed_identity_client_id": "azure_openai_managed_identity_client_id",
|
||||
"model.teacher_azure_openai_endpoint": "teacher_azure_openai_endpoint",
|
||||
"model.teacher_azure_openai_api_version": "teacher_azure_openai_api_version",
|
||||
"model.teacher_azure_openai_api_key": "teacher_azure_openai_api_key",
|
||||
"model.teacher_azure_openai_auth_mode": "teacher_azure_openai_auth_mode",
|
||||
"model.teacher_azure_openai_ad_scope": "teacher_azure_openai_ad_scope",
|
||||
"model.teacher_azure_openai_managed_identity_client_id": "teacher_azure_openai_managed_identity_client_id",
|
||||
"model.student_azure_openai_endpoint": "student_azure_openai_endpoint",
|
||||
"model.student_azure_openai_api_version": "student_azure_openai_api_version",
|
||||
"model.student_azure_openai_api_key": "student_azure_openai_api_key",
|
||||
"model.student_azure_openai_auth_mode": "student_azure_openai_auth_mode",
|
||||
"model.student_azure_openai_ad_scope": "student_azure_openai_ad_scope",
|
||||
"model.student_azure_openai_managed_identity_client_id": "student_azure_openai_managed_identity_client_id",
|
||||
"train.num_epochs": "num_epochs",
|
||||
"train.train_size": "train_size",
|
||||
"train.steps_per_epoch": "steps_per_epoch",
|
||||
"train.batch_size": "batch_size",
|
||||
"train.accumulation": "accumulation",
|
||||
"train.seed": "seed",
|
||||
"gradient.minibatch_size": "minibatch_size",
|
||||
"gradient.merge_batch_size": "merge_batch_size",
|
||||
"gradient.analyst_workers": "analyst_workers",
|
||||
"gradient.failure_only": "failure_only",
|
||||
"gradient.use_deep_reflect": "use_deep_reflect",
|
||||
"gradient.deep_reflect_failures": "deep_reflect_failures",
|
||||
"gradient.deep_reflect_successes": "deep_reflect_successes",
|
||||
"gradient.max_analyst_rounds": "max_analyst_rounds",
|
||||
"optimizer.learning_rate": "edit_budget",
|
||||
"optimizer.min_learning_rate": "min_edit_budget",
|
||||
"optimizer.lr_scheduler": "lr_scheduler",
|
||||
"optimizer.lr_control_mode": "lr_control_mode",
|
||||
"optimizer.skill_update_mode": "skill_update_mode",
|
||||
"optimizer.use_meta_reflect": "use_meta_reflect",
|
||||
"optimizer.meta_learning_rate": "meta_edit_budget",
|
||||
"optimizer.use_slow_update": "use_slow_update",
|
||||
"optimizer.slow_update_samples": "slow_update_samples",
|
||||
"optimizer.longitudinal_pair_policy": "longitudinal_pair_policy",
|
||||
"optimizer.use_meta_skill": "use_meta_skill",
|
||||
"evaluation.use_gate": "use_gate",
|
||||
"evaluation.sel_env_num": "sel_env_num",
|
||||
"evaluation.test_env_num": "test_env_num",
|
||||
"evaluation.eval_test": "eval_test",
|
||||
"env.name": "env",
|
||||
"env.skill_init": "skill_init",
|
||||
"env.out_root": "out_root",
|
||||
}
|
||||
|
||||
|
||||
# ── Deep merge ───────────────────────────────────────────────────────────
|
||||
|
||||
def _deep_merge(base: dict, override: dict) -> dict:
|
||||
"""Recursively merge *override* into *base* (returns new dict)."""
|
||||
result = copy.deepcopy(base)
|
||||
for key, val in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(val, dict):
|
||||
result[key] = _deep_merge(result[key], val)
|
||||
else:
|
||||
result[key] = copy.deepcopy(val)
|
||||
return result
|
||||
|
||||
|
||||
# ── YAML loading with _base_ inheritance ─────────────────────────────────
|
||||
|
||||
def _load_yaml(path: str, _visited: set[str] | None = None) -> dict:
|
||||
"""Load a YAML file, resolving ``_base_`` inheritance recursively."""
|
||||
abs_path = os.path.abspath(path)
|
||||
if _visited is None:
|
||||
_visited = set()
|
||||
if abs_path in _visited:
|
||||
raise ValueError(f"Circular _base_ inheritance: {abs_path}")
|
||||
_visited.add(abs_path)
|
||||
|
||||
with open(abs_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
base_ref = cfg.pop("_base_", None)
|
||||
if base_ref:
|
||||
base_path = os.path.join(os.path.dirname(abs_path), base_ref)
|
||||
base_cfg = _load_yaml(base_path, _visited)
|
||||
cfg = _deep_merge(base_cfg, cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
# ── Format detection ─────────────────────────────────────────────────────
|
||||
|
||||
def is_structured(cfg: dict) -> bool:
|
||||
"""Return True if *cfg* uses the new structured section format."""
|
||||
return any(
|
||||
key in _STRUCTURED_SECTIONS and isinstance(cfg.get(key), dict)
|
||||
for key in cfg
|
||||
)
|
||||
|
||||
|
||||
# ── Flatten ──────────────────────────────────────────────────────────────
|
||||
|
||||
def flatten_config(cfg: dict) -> dict:
|
||||
"""Convert a structured config to the flat dict expected by the trainer.
|
||||
|
||||
If *cfg* is already flat, returns a shallow copy unchanged.
|
||||
"""
|
||||
if not is_structured(cfg):
|
||||
return dict(cfg)
|
||||
|
||||
flat: dict[str, Any] = {}
|
||||
|
||||
evaluation_section = cfg.get("evaluation", {})
|
||||
if isinstance(evaluation_section, dict) and evaluation_section.get("use_gate") is False:
|
||||
raise ValueError(
|
||||
"Gate validation is mandatory in this branch. Remove "
|
||||
"`evaluation.use_gate: false` from the config."
|
||||
)
|
||||
|
||||
# Apply the explicit mapping
|
||||
for dotted, flat_key in _FLATTEN_MAP.items():
|
||||
section, key = dotted.split(".", 1)
|
||||
section_dict = cfg.get(section, {})
|
||||
if isinstance(section_dict, dict) and key in section_dict:
|
||||
flat[flat_key] = section_dict[key]
|
||||
|
||||
# Pass through env-specific keys not in the explicit mapping
|
||||
env_section = cfg.get("env", {})
|
||||
if isinstance(env_section, dict):
|
||||
mapped_env_keys = {
|
||||
k.split(".", 1)[1]
|
||||
for k in _FLATTEN_MAP
|
||||
if k.startswith("env.")
|
||||
}
|
||||
for key, val in env_section.items():
|
||||
if key not in mapped_env_keys:
|
||||
flat[key] = val
|
||||
|
||||
return flat
|
||||
|
||||
|
||||
# ── Override application ─────────────────────────────────────────────────
|
||||
|
||||
def _cast_value(val_str: str) -> Any:
|
||||
"""Auto-cast a CLI string value to int / float / bool / str."""
|
||||
if val_str.lower() in ("true", "yes"):
|
||||
return True
|
||||
if val_str.lower() in ("false", "no"):
|
||||
return False
|
||||
try:
|
||||
return int(val_str)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(val_str)
|
||||
except ValueError:
|
||||
pass
|
||||
return val_str
|
||||
|
||||
|
||||
def apply_overrides(cfg: dict, overrides: list[str]) -> None:
|
||||
"""Apply ``key=value`` overrides to a structured config (in place).
|
||||
|
||||
Supports both ``section.key=value`` (for structured configs) and
|
||||
``key=value`` (for flat configs or flat keys in env section).
|
||||
"""
|
||||
for item in overrides:
|
||||
if "=" not in item:
|
||||
raise ValueError(f"Invalid override (expected key=value): {item!r}")
|
||||
key, val_str = item.split("=", 1)
|
||||
val = _cast_value(val_str)
|
||||
|
||||
if "." in key:
|
||||
section, subkey = key.split(".", 1)
|
||||
if section in cfg and isinstance(cfg[section], dict):
|
||||
cfg[section][subkey] = val
|
||||
else:
|
||||
cfg.setdefault(section, {})[subkey] = val
|
||||
else:
|
||||
# Flat key — apply to top level (for legacy compat)
|
||||
cfg[key] = val
|
||||
|
||||
|
||||
# ── Public API ───────────────────────────────────────────────────────────
|
||||
|
||||
def load_config(
|
||||
path: str,
|
||||
overrides: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Load a config file with ``_base_`` inheritance and optional overrides.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Path to the YAML config file.
|
||||
overrides : list[str] | None
|
||||
``key=value`` strings from ``--cfg-options``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The merged config (structured or flat depending on the YAML).
|
||||
"""
|
||||
cfg = _load_yaml(path)
|
||||
if overrides:
|
||||
apply_overrides(cfg, overrides)
|
||||
return cfg
|
||||
7
skillopt/datasets/__init__.py
Normal file
7
skillopt/datasets/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""ReflACT Datasets -- task batch planning and data loading.
|
||||
|
||||
Analogous to the datasets and dataloaders in neural network training:
|
||||
provides batch sampling, epoch planning, and data management for the
|
||||
ReflACT training pipeline.
|
||||
"""
|
||||
from skillopt.datasets.base import BaseDataLoader, BatchSpec, SplitDataLoader # noqa: F401
|
||||
512
skillopt/datasets/base.py
Normal file
512
skillopt/datasets/base.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Generic task dataloader abstractions for ReflACT.
|
||||
|
||||
ReflACT does not train model parameters directly. Instead, it iterates over
|
||||
task batches, rolls out the current skill, reflects on failures/successes,
|
||||
and updates the skill document. Because of that, the "dataloader" abstraction
|
||||
here is closer to a batch sampler / episode planner than a tensor loader.
|
||||
|
||||
Class hierarchy::
|
||||
|
||||
BaseDataLoader # abstract — simulator-backed envs (e.g. ALFWorld)
|
||||
└── SplitDataLoader # abstract — dataset-backed envs with split_dir
|
||||
|
||||
SplitDataLoader supports two dataset entry modes:
|
||||
|
||||
1. ``split_mode="split_dir"``: consume an existing split directory.
|
||||
2. ``split_mode="ratio"``: build a deterministic split directory from a raw
|
||||
dataset path using an explicit train:val:test ratio.
|
||||
|
||||
In either case, the standardised split layout is:
|
||||
|
||||
split_dir/
|
||||
├── train/ # training items
|
||||
├── val/ # validation / selection items (gate)
|
||||
└── test/ # held-out test items
|
||||
|
||||
Each subdirectory's contents are benchmark-specific. Subclasses only need
|
||||
to implement ``load_split_items(split_path)`` to teach the loader how to
|
||||
read items from one of those directories.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BatchSpec:
|
||||
"""A concrete batch request consumed by the training loop.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
phase : str
|
||||
``"train"`` or ``"eval"``.
|
||||
split : str
|
||||
Dataset split name, typically ``"train"`` or an eval split.
|
||||
seed : int
|
||||
Random seed used to construct the batch deterministically.
|
||||
batch_size : int
|
||||
Requested number of items / episodes in this batch.
|
||||
payload : object | None
|
||||
Environment-specific batch payload. For dataset-backed environments
|
||||
this is often a list of sampled items; for simulator-backed
|
||||
environments this may be ``None`` and the seed alone can define the
|
||||
batch.
|
||||
metadata : dict[str, Any]
|
||||
Optional structured metadata for logging, resume, or curriculum logic.
|
||||
"""
|
||||
|
||||
phase: str
|
||||
split: str
|
||||
seed: int
|
||||
batch_size: int
|
||||
payload: object | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class BaseDataLoader(ABC):
|
||||
"""Abstract base class for task batch planning in ReflACT.
|
||||
|
||||
Subclasses are responsible for defining how a train or eval batch is
|
||||
sampled. The default implementation here provides deterministic epoch seed
|
||||
planning so all loaders share the same reproducibility behavior.
|
||||
"""
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
"""Optional one-time initialization with the full trainer config."""
|
||||
|
||||
def set_out_root(self, out_root: str) -> None:
|
||||
"""Optional hook for loaders that persist split files or state."""
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
"""Return serializable loader state for resume support."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, Any]) -> None:
|
||||
"""Restore loader state from :meth:`state_dict` output."""
|
||||
|
||||
def get_train_size(self) -> int | None:
|
||||
"""Return the size of the training pool when known."""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def make_base_seeds(steps_per_epoch: int, accumulation: int, seed: int) -> list[int]:
|
||||
"""Return the deterministic seed pool used to define train batches."""
|
||||
batches_per_epoch = steps_per_epoch * accumulation
|
||||
return [seed + i + 1 for i in range(batches_per_epoch)]
|
||||
|
||||
@staticmethod
|
||||
def shuffle_epoch_seeds(base_seeds: list[int], epoch: int, seed: int) -> list[int]:
|
||||
"""Return the per-epoch deterministic shuffle of *base_seeds*."""
|
||||
epoch_rng = random.Random(seed + epoch * 1000)
|
||||
shuffled = list(base_seeds)
|
||||
epoch_rng.shuffle(shuffled)
|
||||
return shuffled
|
||||
|
||||
def plan_train_epoch(
|
||||
self,
|
||||
*,
|
||||
epoch: int,
|
||||
steps_per_epoch: int,
|
||||
accumulation: int,
|
||||
batch_size: int,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> list[BatchSpec]:
|
||||
"""Build the full list of training batches for one epoch."""
|
||||
base_seeds = self.make_base_seeds(
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
accumulation=accumulation,
|
||||
seed=seed,
|
||||
)
|
||||
shuffled_seeds = self.shuffle_epoch_seeds(base_seeds, epoch=epoch, seed=seed)
|
||||
return [
|
||||
self.build_train_batch(batch_size=batch_size, seed=batch_seed, **kwargs)
|
||||
for batch_seed in shuffled_seeds
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
|
||||
"""Construct one training batch specification."""
|
||||
|
||||
@abstractmethod
|
||||
def build_eval_batch(
|
||||
self,
|
||||
env_num: int,
|
||||
split: str,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> BatchSpec:
|
||||
"""Construct one evaluation batch specification."""
|
||||
|
||||
|
||||
# ── Split-based dataloader for dataset-backed environments ──────────────
|
||||
|
||||
# Canonical split names expected under split_dir/
|
||||
SPLIT_NAMES = ("train", "val", "test")
|
||||
|
||||
# Maps legacy / trainer split names → canonical directory names
|
||||
_SPLIT_ALIAS: dict[str, str] = {
|
||||
"train": "train",
|
||||
"valid_seen": "val",
|
||||
"selection": "val",
|
||||
"val": "val",
|
||||
"valid_unseen": "test",
|
||||
"test": "test",
|
||||
}
|
||||
|
||||
|
||||
def _load_json_or_jsonl(path: str) -> list[dict]:
|
||||
"""Load a list of items from a JSON or JSONL file."""
|
||||
with open(path, encoding="utf-8") as f:
|
||||
content = f.read().strip()
|
||||
if not content:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
data = None
|
||||
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
nested = data.get("data")
|
||||
if isinstance(nested, list):
|
||||
return nested
|
||||
return list(data.values())
|
||||
|
||||
items: list[dict] = []
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
items.append(json.loads(line))
|
||||
return items
|
||||
|
||||
|
||||
def _parse_split_ratio(text: str) -> tuple[int, int, int]:
|
||||
parts = [part.strip() for part in str(text or "").split(":") if part.strip()]
|
||||
if len(parts) != 3:
|
||||
raise ValueError(
|
||||
f"split_ratio must be in train:val:test form, got {text!r}"
|
||||
)
|
||||
try:
|
||||
train, val, test = (int(part) for part in parts)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"split_ratio must contain integers, got {text!r}"
|
||||
) from exc
|
||||
if min(train, val, test) <= 0:
|
||||
raise ValueError(f"split_ratio parts must be positive, got {text!r}")
|
||||
return train, val, test
|
||||
|
||||
|
||||
def _compute_split_counts(total: int, ratio: tuple[int, int, int]) -> tuple[int, int, int]:
|
||||
weights = list(ratio)
|
||||
denom = sum(weights)
|
||||
raw = [total * weight / denom for weight in weights]
|
||||
counts = [int(value) for value in raw]
|
||||
remaining = total - sum(counts)
|
||||
order = sorted(
|
||||
range(len(raw)),
|
||||
key=lambda idx: (raw[idx] - counts[idx], weights[idx]),
|
||||
reverse=True,
|
||||
)
|
||||
for idx in order[:remaining]:
|
||||
counts[idx] += 1
|
||||
return counts[0], counts[1], counts[2]
|
||||
|
||||
|
||||
class SplitDataLoader(BaseDataLoader):
|
||||
"""Base class for dataset-backed environments.
|
||||
|
||||
Supported modes:
|
||||
|
||||
- ``split_mode="split_dir"``: load an existing ``train/``, ``val/``,
|
||||
``test/`` directory tree.
|
||||
- ``split_mode="ratio"``: load raw items from ``data_path`` and materialize
|
||||
a deterministic split directory with the requested ratio.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "ratio",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.split_dir = split_dir
|
||||
self.data_path = data_path
|
||||
self.split_mode = split_mode
|
||||
self.split_ratio = split_ratio
|
||||
self.split_seed = int(split_seed)
|
||||
self.split_output_dir = split_output_dir
|
||||
self.seed = seed
|
||||
self.limit = limit
|
||||
self._splits: dict[str, list[dict]] = {}
|
||||
|
||||
# ── Setup ────────────────────────────────────────────────────────────
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
if not self.split_mode:
|
||||
self.split_mode = str(cfg.get("split_mode", "ratio") or "ratio")
|
||||
if not self.split_dir:
|
||||
self.split_dir = cfg.get("split_dir", "")
|
||||
if not self.data_path:
|
||||
self.data_path = cfg.get("data_path", "")
|
||||
if not self.split_output_dir:
|
||||
self.split_output_dir = cfg.get("split_output_dir", "")
|
||||
if "split_seed" in cfg and not self.split_seed:
|
||||
self.split_seed = int(cfg.get("split_seed", 0) or 0)
|
||||
if not self.split_seed:
|
||||
self.split_seed = self.seed
|
||||
if not self.split_ratio:
|
||||
self.split_ratio = str(cfg.get("split_ratio", "2:1:7") or "2:1:7")
|
||||
|
||||
mode = str(self.split_mode or "ratio").strip().lower()
|
||||
if mode not in {"ratio", "split_dir"}:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} split_mode must be 'ratio' or 'split_dir', "
|
||||
f"got {self.split_mode!r}"
|
||||
)
|
||||
self.split_mode = mode
|
||||
|
||||
if self.split_mode == "ratio":
|
||||
self.split_dir = self._materialize_ratio_split(cfg)
|
||||
if not self.split_dir:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} requires either "
|
||||
"`split_mode=ratio` with `data_path`, or `split_mode=split_dir` "
|
||||
f"with `split_dir` pointing to {'/'.join(SPLIT_NAMES)}/."
|
||||
)
|
||||
self._load_all_splits()
|
||||
|
||||
def _resolve_split_output_dir(self, cfg: dict) -> str:
|
||||
if self.split_output_dir:
|
||||
return os.path.abspath(self.split_output_dir)
|
||||
out_root = os.path.abspath(str(cfg.get("out_root") or os.getcwd()))
|
||||
env_name = str(cfg.get("env") or type(self).__name__.replace("DataLoader", "").lower())
|
||||
ratio_tag = str(self.split_ratio or "2:1:7").replace(":", "-")
|
||||
return os.path.join(out_root, "_generated_splits", f"{env_name}_{ratio_tag}_seed{self.split_seed}")
|
||||
|
||||
def load_raw_items(self, data_path: str) -> list[dict]:
|
||||
"""Load raw items from a dataset path before ratio splitting.
|
||||
|
||||
Subclasses can override when the raw dataset is not a single JSON/JSONL
|
||||
file or when directory layouts require custom normalization.
|
||||
"""
|
||||
if os.path.isdir(data_path):
|
||||
if any(os.path.isdir(os.path.join(data_path, name)) for name in SPLIT_NAMES):
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} got a split directory as data_path. "
|
||||
"Use split_mode=split_dir and pass it as split_dir instead."
|
||||
)
|
||||
candidates = sorted(glob.glob(os.path.join(data_path, "*.json")))
|
||||
candidates += sorted(glob.glob(os.path.join(data_path, "*.jsonl")))
|
||||
if len(candidates) != 1:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} expected data_path to be one JSON/JSONL file "
|
||||
f"or a directory containing exactly one such file, got: {data_path}"
|
||||
)
|
||||
return _load_json_or_jsonl(candidates[0])
|
||||
return _load_json_or_jsonl(data_path)
|
||||
|
||||
def write_split_items(self, split_path: str, items: list[dict]) -> None:
|
||||
os.makedirs(split_path, exist_ok=True)
|
||||
out_path = os.path.join(split_path, "items.json")
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(items, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def _materialize_ratio_split(self, cfg: dict) -> str:
|
||||
data_path = os.path.abspath(str(self.data_path or "").strip())
|
||||
if not data_path:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} requires data_path when split_mode=ratio."
|
||||
)
|
||||
|
||||
ratio = _parse_split_ratio(self.split_ratio)
|
||||
items = self.load_raw_items(data_path)
|
||||
if not isinstance(items, list) or not items:
|
||||
raise ValueError(f"No raw items available for ratio split from {data_path}")
|
||||
|
||||
shuffled = list(items)
|
||||
rng = random.Random(self.split_seed)
|
||||
rng.shuffle(shuffled)
|
||||
|
||||
train_n, val_n, test_n = _compute_split_counts(len(shuffled), ratio)
|
||||
train_items = shuffled[:train_n]
|
||||
val_items = shuffled[train_n: train_n + val_n]
|
||||
test_items = shuffled[train_n + val_n: train_n + val_n + test_n]
|
||||
|
||||
split_dir = self._resolve_split_output_dir(cfg)
|
||||
manifest = {
|
||||
"source_data_path": data_path,
|
||||
"split_mode": "ratio",
|
||||
"split_ratio": self.split_ratio,
|
||||
"split_seed": self.split_seed,
|
||||
"counts": {
|
||||
"train": len(train_items),
|
||||
"val": len(val_items),
|
||||
"test": len(test_items),
|
||||
},
|
||||
}
|
||||
os.makedirs(split_dir, exist_ok=True)
|
||||
self.write_split_items(os.path.join(split_dir, "train"), train_items)
|
||||
self.write_split_items(os.path.join(split_dir, "val"), val_items)
|
||||
self.write_split_items(os.path.join(split_dir, "test"), test_items)
|
||||
with open(os.path.join(split_dir, "split_manifest.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(manifest, f, ensure_ascii=False, indent=2)
|
||||
print(
|
||||
f" [{type(self).__name__}] generated ratio split {self.split_ratio} "
|
||||
f"at {split_dir} from {data_path}"
|
||||
)
|
||||
return split_dir
|
||||
|
||||
def _load_all_splits(self) -> None:
|
||||
for name in SPLIT_NAMES:
|
||||
split_path = os.path.join(self.split_dir, name)
|
||||
if not os.path.isdir(split_path):
|
||||
raise ValueError(
|
||||
f"Missing '{name}/' subdirectory in split_dir: {self.split_dir}"
|
||||
)
|
||||
items = self.load_split_items(split_path)
|
||||
if self.limit:
|
||||
items = items[: self.limit]
|
||||
self._splits[name] = items
|
||||
|
||||
counts = " ".join(f"{k}={len(v)}" for k, v in self._splits.items())
|
||||
print(f" [{type(self).__name__}] {counts} (from {self.split_dir})")
|
||||
|
||||
def load_split_items(self, split_path: str) -> list[dict]:
|
||||
"""Load items from one split directory (e.g. ``split_dir/train/``).
|
||||
|
||||
Default: finds the first ``.json`` file in the directory and loads it
|
||||
as a JSON array. Subclasses can override for custom formats.
|
||||
"""
|
||||
json_files = sorted(glob.glob(os.path.join(split_path, "*.json")))
|
||||
if not json_files:
|
||||
raise FileNotFoundError(
|
||||
f"No .json file found in {split_path}"
|
||||
)
|
||||
with open(json_files[0], encoding="utf-8") as f:
|
||||
items = json.load(f)
|
||||
if not isinstance(items, list):
|
||||
raise ValueError(
|
||||
f"Expected JSON array in {json_files[0]}, got {type(items).__name__}"
|
||||
)
|
||||
return items
|
||||
|
||||
# ── Accessors ────────────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def train_items(self) -> list[dict]:
|
||||
return self._splits.get("train", [])
|
||||
|
||||
@property
|
||||
def val_items(self) -> list[dict]:
|
||||
return self._splits.get("val", [])
|
||||
|
||||
@property
|
||||
def test_items(self) -> list[dict]:
|
||||
return self._splits.get("test", [])
|
||||
|
||||
def get_split_items(self, split: str) -> list[dict]:
|
||||
"""Resolve a split name (including legacy aliases) to its item list."""
|
||||
canonical = _SPLIT_ALIAS.get(split, split)
|
||||
return list(self._splits.get(canonical, self.val_items))
|
||||
|
||||
def get_train_size(self) -> int:
|
||||
return len(self.train_items)
|
||||
|
||||
def plan_train_epoch(
|
||||
self,
|
||||
*,
|
||||
epoch: int,
|
||||
steps_per_epoch: int,
|
||||
accumulation: int,
|
||||
batch_size: int,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> list[BatchSpec]:
|
||||
"""Build one full epoch that covers the train split in shuffled order.
|
||||
|
||||
For split-backed datasets, an epoch should correspond to one pass over
|
||||
the available training items rather than repeated independent sampling.
|
||||
"""
|
||||
epoch_rng = random.Random(seed + epoch * 1000)
|
||||
items = list(self.train_items)
|
||||
epoch_rng.shuffle(items)
|
||||
|
||||
total_batches = steps_per_epoch * accumulation
|
||||
if total_batches <= 0:
|
||||
return []
|
||||
|
||||
batches: list[BatchSpec] = []
|
||||
cursor = 0
|
||||
for batch_idx in range(total_batches):
|
||||
batch_items = items[cursor: cursor + batch_size]
|
||||
cursor += len(batch_items)
|
||||
|
||||
# Extremely small datasets can leave trailing empty microbatches
|
||||
# when accumulation > 1. Reuse the shuffled prefix in that case so
|
||||
# the trainer still receives the expected batch count.
|
||||
if not batch_items and items:
|
||||
refill_rng = random.Random(seed + epoch * 1000 + batch_idx + 1)
|
||||
batch_items = list(items)
|
||||
refill_rng.shuffle(batch_items)
|
||||
batch_items = batch_items[:batch_size]
|
||||
|
||||
batches.append(
|
||||
BatchSpec(
|
||||
phase="train",
|
||||
split="train",
|
||||
seed=seed + epoch * 1000 + batch_idx + 1,
|
||||
batch_size=len(batch_items),
|
||||
payload=batch_items,
|
||||
)
|
||||
)
|
||||
|
||||
return batches
|
||||
|
||||
# ── Batch construction ───────────────────────────────────────────────
|
||||
|
||||
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
|
||||
rng = random.Random(seed)
|
||||
items = list(self.train_items)
|
||||
rng.shuffle(items)
|
||||
items = items[:batch_size]
|
||||
return BatchSpec(
|
||||
phase="train",
|
||||
split="train",
|
||||
seed=seed,
|
||||
batch_size=len(items),
|
||||
payload=items,
|
||||
)
|
||||
|
||||
def build_eval_batch(
|
||||
self,
|
||||
env_num: int,
|
||||
split: str,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> BatchSpec:
|
||||
items = self.get_split_items(split)
|
||||
if env_num and env_num < len(items):
|
||||
items = items[:env_num]
|
||||
return BatchSpec(
|
||||
phase="eval",
|
||||
split=split,
|
||||
seed=seed,
|
||||
batch_size=len(items),
|
||||
payload=items,
|
||||
)
|
||||
9
skillopt/engine/__init__.py
Normal file
9
skillopt/engine/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""ReflACT Engine -- the training runner.
|
||||
|
||||
Analogous to the Runner in mmengine: orchestrates the full training pipeline
|
||||
including rollout, gradient computation, aggregation, optimization, and
|
||||
evaluation.
|
||||
"""
|
||||
from skillopt.engine.trainer import ReflACTTrainer # noqa: F401
|
||||
|
||||
__all__ = ["ReflACTTrainer"]
|
||||
2195
skillopt/engine/trainer.py
Normal file
2195
skillopt/engine/trainer.py
Normal file
File diff suppressed because it is too large
Load Diff
1
skillopt/envs/__init__.py
Normal file
1
skillopt/envs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""ReflACT environment adapters."""
|
||||
5
skillopt/envs/alfworld/__init__.py
Normal file
5
skillopt/envs/alfworld/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""ALFWorld environment adapter for ReflACT."""
|
||||
|
||||
from skillopt.envs.alfworld.adapter import ALFWorldAdapter
|
||||
|
||||
__all__ = ["ALFWorldAdapter"]
|
||||
585
skillopt/envs/alfworld/adapter.py
Normal file
585
skillopt/envs/alfworld/adapter.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""ALFWorld environment adapter for ReflACT.
|
||||
|
||||
Connects the ReflACT training loop to ALFWorld by implementing
|
||||
:class:`~skillopt.envs.base.EnvAdapter`.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
|
||||
from skillopt.gradient.deep_probe import generate_deep_probe_instruction
|
||||
from skillopt.datasets.base import BatchSpec
|
||||
from skillopt.envs.base import EnvAdapter
|
||||
from skillopt.envs.alfworld.dataloader import ALFWorldDataLoader
|
||||
from skillopt.envs.alfworld.rollout import (
|
||||
build_alfworld_env,
|
||||
run_alfworld_batch,
|
||||
TASKS,
|
||||
)
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
from skillopt.utils import compute_score
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ALFWorldBatchRun:
|
||||
"""Lazy ALFWorld batch description.
|
||||
|
||||
The adapter materializes this in rollout chunks so a large evaluation set
|
||||
does not keep every ALFWorld simulator open at once.
|
||||
"""
|
||||
|
||||
env_num: int
|
||||
eval_dataset: str
|
||||
seed: int
|
||||
is_train: bool
|
||||
workers: int
|
||||
specific_gamefiles: list[str] | None = None
|
||||
result_ids: list[str] | None = None
|
||||
items: list[dict] | None = None
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.items or [])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return int(self.env_num or 0)
|
||||
|
||||
|
||||
class ALFWorldAdapter(EnvAdapter):
|
||||
"""ALFWorld environment adapter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_steps : int
|
||||
Maximum steps per ALFWorld episode (default 50).
|
||||
max_api_workers : int
|
||||
Maximum concurrent API calls during rollout (default 8).
|
||||
analyst_workers : int
|
||||
Parallel workers for analyst stage (default 16).
|
||||
failure_only : bool
|
||||
If True, only run error analyst (skip success analyst).
|
||||
minibatch_size : int
|
||||
Trajectories per analyst group, M (default 8).
|
||||
edit_budget : int
|
||||
Maximum edits per minibatch, L (default 4).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "split_dir",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
train_size: int = 0,
|
||||
max_steps: int = 50,
|
||||
workers: int = 8,
|
||||
max_api_workers: int = 8,
|
||||
analyst_workers: int = 16,
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
use_deep_reflect: bool = False,
|
||||
deep_reflect_failures: int = 4,
|
||||
deep_reflect_successes: int = 2,
|
||||
) -> None:
|
||||
self.max_steps = max_steps
|
||||
self.workers = max(int(workers or 1), 1)
|
||||
self.max_api_workers = max_api_workers
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
self.edit_budget = edit_budget
|
||||
self.use_deep_reflect = use_deep_reflect
|
||||
self.deep_reflect_failures = deep_reflect_failures
|
||||
self.deep_reflect_successes = deep_reflect_successes
|
||||
self.dataloader = ALFWorldDataLoader(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
train_size=train_size,
|
||||
)
|
||||
self._traj_cache: dict[str, dict | None] = {}
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
self.dataloader.setup(cfg)
|
||||
|
||||
def _load_traj_data(self, item: dict) -> dict | None:
|
||||
gamefile = str(item.get("gamefile") or "").strip()
|
||||
if not gamefile:
|
||||
return None
|
||||
if gamefile in self._traj_cache:
|
||||
return self._traj_cache[gamefile]
|
||||
|
||||
traj_path = os.path.join(os.path.dirname(gamefile), "traj_data.json")
|
||||
try:
|
||||
with open(traj_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except Exception:
|
||||
data = None
|
||||
self._traj_cache[gamefile] = data
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _unique_lines(values: list[str], *, limit: int = 0) -> list[str]:
|
||||
lines: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw in values:
|
||||
line = str(raw or "").strip()
|
||||
if not line or line in seen:
|
||||
continue
|
||||
seen.add(line)
|
||||
lines.append(line)
|
||||
if limit > 0 and len(lines) >= limit:
|
||||
break
|
||||
return lines
|
||||
|
||||
@staticmethod
|
||||
def _format_high_pddl(high_pddl: list[dict]) -> list[str]:
|
||||
steps: list[str] = []
|
||||
for idx, step in enumerate(high_pddl or [], start=1):
|
||||
discrete = step.get("discrete_action") or {}
|
||||
action = str(discrete.get("action") or "").strip()
|
||||
args = [str(arg).strip() for arg in (discrete.get("args") or []) if str(arg).strip()]
|
||||
if action and args:
|
||||
text = f"{action}({', '.join(args)})"
|
||||
elif action:
|
||||
text = action
|
||||
else:
|
||||
planner_action = step.get("planner_action") or {}
|
||||
text = str(planner_action.get("action") or "").strip()
|
||||
if text:
|
||||
steps.append(f"{idx}. {text}")
|
||||
return steps
|
||||
|
||||
def _build_reference_bundle(self, item: dict) -> dict:
|
||||
data = self._load_traj_data(item)
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
anns = ((data.get("turk_annotations") or {}).get("anns") or [])
|
||||
task_descs = self._unique_lines(
|
||||
[ann.get("task_desc", "") for ann in anns],
|
||||
limit=3,
|
||||
)
|
||||
high_descs = self._unique_lines(
|
||||
[step for ann in anns for step in (ann.get("high_descs") or [])],
|
||||
limit=12,
|
||||
)
|
||||
pddl_params = {
|
||||
key: value
|
||||
for key, value in (data.get("pddl_params") or {}).items()
|
||||
if value not in ("", None, [], {})
|
||||
}
|
||||
scene = data.get("scene") or {}
|
||||
scene_summary = {
|
||||
key: scene.get(key)
|
||||
for key in ("floor_plan", "scene_num", "dirty_and_empty")
|
||||
if scene.get(key) not in ("", None, [], {})
|
||||
}
|
||||
high_pddl = self._format_high_pddl((data.get("plan") or {}).get("high_pddl") or [])
|
||||
task_type = str(data.get("task_type") or item.get("task_type") or "").strip()
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task_descs": task_descs,
|
||||
"high_descs": high_descs,
|
||||
"pddl_params": pddl_params,
|
||||
"high_pddl": high_pddl,
|
||||
"scene_summary": scene_summary,
|
||||
}
|
||||
|
||||
def build_reference_text(self, item: dict) -> str:
|
||||
bundle = self._build_reference_bundle(item)
|
||||
if not bundle:
|
||||
return ""
|
||||
|
||||
parts: list[str] = []
|
||||
if bundle["task_type"]:
|
||||
parts.append(f"## Reference Task Type\n{bundle['task_type']}")
|
||||
if bundle["task_descs"]:
|
||||
parts.append(
|
||||
"## Reference Human Task Descriptions\n"
|
||||
+ "\n".join(f"- {line}" for line in bundle["task_descs"])
|
||||
)
|
||||
if bundle["high_descs"]:
|
||||
parts.append(
|
||||
"## Reference Human High-Level Steps\n"
|
||||
+ "\n".join(f"{idx}. {line}" for idx, line in enumerate(bundle["high_descs"], start=1))
|
||||
)
|
||||
if bundle["pddl_params"]:
|
||||
parts.append(
|
||||
"## Reference PDDL Params\n"
|
||||
+ "\n".join(f"- {key}: {value}" for key, value in bundle["pddl_params"].items())
|
||||
)
|
||||
if bundle["high_pddl"]:
|
||||
parts.append(
|
||||
"## Reference Planner High-Level Plan\n" + "\n".join(bundle["high_pddl"])
|
||||
)
|
||||
if bundle["scene_summary"]:
|
||||
parts.append(
|
||||
"## Reference Scene Summary\n"
|
||||
+ "\n".join(f"- {key}: {value}" for key, value in bundle["scene_summary"].items())
|
||||
)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def get_reference_metadata(self, item: dict) -> dict:
|
||||
bundle = self._build_reference_bundle(item)
|
||||
if not bundle:
|
||||
return {"fields": [], "preview": ""}
|
||||
|
||||
fields: list[str] = []
|
||||
previews: list[str] = []
|
||||
if bundle["task_type"]:
|
||||
fields.append("task_type")
|
||||
previews.append(f"[task_type] {bundle['task_type']}")
|
||||
if bundle["task_descs"]:
|
||||
fields.append("task_desc")
|
||||
previews.append("[task_desc]\n" + "\n".join(bundle["task_descs"][:2]))
|
||||
if bundle["high_descs"]:
|
||||
fields.append("high_descs")
|
||||
previews.append("[high_descs]\n" + "\n".join(bundle["high_descs"][:3]))
|
||||
if bundle["pddl_params"]:
|
||||
fields.append("pddl_params")
|
||||
previews.append(
|
||||
"[pddl_params]\n"
|
||||
+ "\n".join(
|
||||
f"{key}: {value}" for key, value in list(bundle["pddl_params"].items())[:4]
|
||||
)
|
||||
)
|
||||
if bundle["high_pddl"]:
|
||||
fields.append("plan.high_pddl")
|
||||
previews.append("[plan.high_pddl]\n" + "\n".join(bundle["high_pddl"][:3]))
|
||||
if bundle["scene_summary"]:
|
||||
fields.append("scene")
|
||||
previews.append(
|
||||
"[scene]\n"
|
||||
+ "\n".join(
|
||||
f"{key}: {value}" for key, value in bundle["scene_summary"].items()
|
||||
)
|
||||
)
|
||||
return {
|
||||
"fields": fields,
|
||||
"preview": "\n\n".join(previews)[:600],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _infer_dataset_from_gamefile(gamefile: str) -> tuple[str, bool]:
|
||||
path = str(gamefile or "")
|
||||
if "/valid_seen/" in path:
|
||||
return "eval_in_distribution", False
|
||||
if "/valid_unseen/" in path:
|
||||
return "eval_out_of_distribution", False
|
||||
return "train", True
|
||||
|
||||
def get_dataloader(self):
|
||||
return self.dataloader
|
||||
|
||||
def _comparison_items(self, items: list[dict]) -> list[dict]:
|
||||
enriched: list[dict] = []
|
||||
for item in items:
|
||||
row = dict(item)
|
||||
bundle = self._build_reference_bundle(row)
|
||||
if bundle.get("task_descs"):
|
||||
row["task_description"] = bundle["task_descs"][0]
|
||||
elif bundle.get("task_type"):
|
||||
row["task_description"] = bundle["task_type"]
|
||||
enriched.append(row)
|
||||
return enriched
|
||||
|
||||
def requires_ray(self) -> bool:
|
||||
return False
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
gamefiles = list(batch.metadata.get("gamefiles") or [])
|
||||
result_ids = list(batch.metadata.get("result_ids") or [])
|
||||
items = self._comparison_items(list(batch.payload or []))
|
||||
return ALFWorldBatchRun(
|
||||
env_num=batch.batch_size,
|
||||
eval_dataset=batch.metadata.get("eval_dataset", batch.split),
|
||||
seed=batch.seed,
|
||||
is_train=batch.metadata.get("is_train", batch.phase == "train"),
|
||||
specific_gamefiles=gamefiles or None,
|
||||
result_ids=result_ids or None,
|
||||
items=items,
|
||||
workers=self.workers,
|
||||
)
|
||||
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def rollout(
|
||||
self,
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict]:
|
||||
results_path = os.path.join(out_dir, "results.jsonl")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Resume support
|
||||
if os.path.exists(results_path):
|
||||
existing: list[dict] = []
|
||||
with open(results_path) as f:
|
||||
for line in f:
|
||||
try:
|
||||
existing.append(json.loads(line))
|
||||
except Exception:
|
||||
pass
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if isinstance(env_manager, ALFWorldBatchRun):
|
||||
results = self._run_batch(
|
||||
env_manager,
|
||||
skill_content=skill_content,
|
||||
out_dir=out_dir,
|
||||
)
|
||||
else:
|
||||
results = run_alfworld_batch(
|
||||
env_manager=env_manager,
|
||||
skill_content=skill_content,
|
||||
max_steps=self.max_steps,
|
||||
out_root=out_dir,
|
||||
max_api_workers=self.max_api_workers,
|
||||
result_ids=getattr(env_manager, "_skillopt_result_ids", None),
|
||||
)
|
||||
|
||||
with open(results_path, "w") as f:
|
||||
for r in results:
|
||||
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _close_env(env_manager) -> None:
|
||||
close = getattr(env_manager, "close", None)
|
||||
if callable(close):
|
||||
close()
|
||||
|
||||
def _run_batch(
|
||||
self,
|
||||
batch: ALFWorldBatchRun,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
) -> list[dict]:
|
||||
total = int(batch.env_num or 0)
|
||||
if total <= 0:
|
||||
return []
|
||||
|
||||
workers = max(1, min(int(batch.workers or self.workers), total))
|
||||
if total > workers:
|
||||
print(
|
||||
f" [alfworld rollout] episodes={total} "
|
||||
f"env_workers={workers} chunks={(total + workers - 1) // workers}"
|
||||
)
|
||||
|
||||
all_results: list[dict] = []
|
||||
for start in range(0, total, workers):
|
||||
chunk_size = min(workers, total - start)
|
||||
chunk_gamefiles = (
|
||||
batch.specific_gamefiles[start:start + chunk_size]
|
||||
if batch.specific_gamefiles
|
||||
else None
|
||||
)
|
||||
chunk_ids = (
|
||||
batch.result_ids[start:start + chunk_size]
|
||||
if batch.result_ids
|
||||
else [f"env_{idx:03d}" for idx in range(start, start + chunk_size)]
|
||||
)
|
||||
chunk_env = build_alfworld_env(
|
||||
env_num=chunk_size,
|
||||
eval_dataset=batch.eval_dataset,
|
||||
seed=batch.seed + start,
|
||||
is_train=batch.is_train,
|
||||
specific_gamefiles=chunk_gamefiles,
|
||||
)
|
||||
try:
|
||||
chunk_results = run_alfworld_batch(
|
||||
env_manager=chunk_env,
|
||||
skill_content=skill_content,
|
||||
max_steps=self.max_steps,
|
||||
out_root=out_dir,
|
||||
max_api_workers=min(self.max_api_workers, chunk_size),
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
result_ids=chunk_ids,
|
||||
)
|
||||
finally:
|
||||
self._close_env(chunk_env)
|
||||
all_results.extend(chunk_results)
|
||||
return all_results
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
)
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
if not self.use_deep_reflect:
|
||||
return []
|
||||
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
selected_items = self.select_representative_items(
|
||||
results,
|
||||
results,
|
||||
n_failures=self.deep_reflect_failures,
|
||||
n_successes=self.deep_reflect_successes,
|
||||
seed=random_seed,
|
||||
)
|
||||
if not selected_items:
|
||||
return []
|
||||
|
||||
selected_ids = {str(item["id"]) for item in selected_items}
|
||||
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
|
||||
selected_examples = self.attach_reference_context(selected_results, selected_items)
|
||||
|
||||
field_counts: dict[str, int] = {}
|
||||
selected_metadata: list[dict] = []
|
||||
for item in selected_items:
|
||||
meta = self.get_reference_metadata(item)
|
||||
for field in meta["fields"]:
|
||||
field_counts[field] = field_counts.get(field, 0) + 1
|
||||
selected_metadata.append({
|
||||
"id": str(item["id"]),
|
||||
"task_type": str(item.get("task_type") or "alfworld"),
|
||||
"gamefile": str(item.get("gamefile") or ""),
|
||||
"reference_fields": meta["fields"],
|
||||
"reference_preview": meta["preview"],
|
||||
})
|
||||
|
||||
deep_dir = os.path.join(out_dir, "deep_reflect")
|
||||
rollout_dir = os.path.join(deep_dir, "rollout")
|
||||
patches_dir = os.path.join(deep_dir, "patches")
|
||||
os.makedirs(deep_dir, exist_ok=True)
|
||||
field_summary = ", ".join(
|
||||
f"{field}({count}/{len(selected_items)})"
|
||||
for field, count in sorted(field_counts.items())
|
||||
) or "none"
|
||||
print(
|
||||
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
|
||||
f"reference_fields={field_summary}"
|
||||
)
|
||||
probe = generate_deep_probe_instruction(
|
||||
skill_content=skill_content,
|
||||
items=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
system_prompt=self.get_deep_probe_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
output_requirements=[
|
||||
"- Some trajectories may include a hidden Reference block. Use it to target the student's latent subgoal, missing precondition, or next-step intent, but do not reveal or paraphrase that reference to the student.",
|
||||
"- The instruction must request a brief diagnostic readout inside the existing <think>...</think> block.",
|
||||
"- The student must still output exactly one admissible action inside <action>...</action>.",
|
||||
"- Do not ask for exhaustive inventories, full plans, or long chain-of-thought.",
|
||||
"- The instruction text should be ready to append directly to the student's prompt.",
|
||||
],
|
||||
)
|
||||
if not probe:
|
||||
return []
|
||||
|
||||
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
**probe,
|
||||
"reference_summary": {
|
||||
"selected_count": len(selected_items),
|
||||
"field_counts": field_counts,
|
||||
},
|
||||
"selected_examples": selected_metadata,
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
gamefiles = [str(item.get("gamefile") or "") for item in selected_items]
|
||||
if any(not gamefile for gamefile in gamefiles):
|
||||
return []
|
||||
eval_dataset, is_train = self._infer_dataset_from_gamefile(gamefiles[0])
|
||||
deep_env = ALFWorldBatchRun(
|
||||
env_num=len(selected_items),
|
||||
eval_dataset=eval_dataset,
|
||||
seed=random_seed or 42,
|
||||
is_train=is_train,
|
||||
specific_gamefiles=gamefiles,
|
||||
workers=min(self.workers, max(len(selected_items), 1)),
|
||||
result_ids=[str(item["id"]) for item in selected_items],
|
||||
)
|
||||
deep_results = self._run_batch(
|
||||
deep_env,
|
||||
skill_content=skill_content,
|
||||
out_dir=rollout_dir,
|
||||
diagnostic_mode=True,
|
||||
diagnostic_instruction=probe["probe_instruction"],
|
||||
)
|
||||
deep_results = self.attach_reference_context(deep_results, selected_items)
|
||||
return run_minibatch_reflect(
|
||||
results=deep_results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=os.path.join(rollout_dir, "predictions"),
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return list(TASKS)
|
||||
123
skillopt/envs/alfworld/dataloader.py
Normal file
123
skillopt/envs/alfworld/dataloader.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""ALFWorld task dataloader."""
|
||||
from __future__ import annotations
|
||||
|
||||
from skillopt.datasets.base import BatchSpec, SplitDataLoader
|
||||
|
||||
|
||||
class ALFWorldDataLoader(SplitDataLoader):
|
||||
"""ALFWorld batch planner.
|
||||
|
||||
In split_dir mode, batches are fixed gamefile items so ablations differ
|
||||
only in how the same training set is batched.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "split_dir",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
train_size: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
self.train_size_override = int(train_size or 0)
|
||||
|
||||
@staticmethod
|
||||
def _metadata_for_items(items: list[dict], split: str, phase: str) -> dict:
|
||||
gamefiles = [str(item.get("gamefile") or "") for item in items]
|
||||
if any(not gamefile for gamefile in gamefiles):
|
||||
raise ValueError("ALFWorld split items must contain non-empty gamefile paths.")
|
||||
eval_dataset = "train"
|
||||
is_train = phase == "train"
|
||||
first = gamefiles[0] if gamefiles else ""
|
||||
if "/valid_seen/" in first:
|
||||
eval_dataset = "eval_in_distribution"
|
||||
is_train = False
|
||||
elif "/valid_unseen/" in first:
|
||||
eval_dataset = "eval_out_of_distribution"
|
||||
is_train = False
|
||||
return {
|
||||
"eval_dataset": eval_dataset,
|
||||
"is_train": is_train,
|
||||
"gamefiles": gamefiles,
|
||||
"result_ids": [str(item.get("id") or idx) for idx, item in enumerate(items)],
|
||||
}
|
||||
|
||||
def get_train_size(self) -> int:
|
||||
if self.train_size_override > 0:
|
||||
return self.train_size_override
|
||||
return super().get_train_size()
|
||||
|
||||
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
|
||||
batch = super().build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
items = list(batch.payload or [])
|
||||
batch.metadata.update(self._metadata_for_items(items, "train", "train"))
|
||||
return BatchSpec(
|
||||
phase="train",
|
||||
split="train",
|
||||
seed=seed,
|
||||
batch_size=len(items),
|
||||
payload=items,
|
||||
metadata=batch.metadata,
|
||||
)
|
||||
|
||||
def plan_train_epoch(
|
||||
self,
|
||||
*,
|
||||
epoch: int,
|
||||
steps_per_epoch: int,
|
||||
accumulation: int,
|
||||
batch_size: int,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> list[BatchSpec]:
|
||||
batches = super().plan_train_epoch(
|
||||
epoch=epoch,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
accumulation=accumulation,
|
||||
batch_size=batch_size,
|
||||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
for batch in batches:
|
||||
items = list(batch.payload or [])
|
||||
batch.metadata.update(self._metadata_for_items(items, "train", "train"))
|
||||
return batches
|
||||
|
||||
def build_eval_batch(
|
||||
self,
|
||||
env_num: int,
|
||||
split: str,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> BatchSpec:
|
||||
batch = super().build_eval_batch(
|
||||
env_num=env_num,
|
||||
split=split,
|
||||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
items = list(batch.payload or [])
|
||||
batch.metadata.update(self._metadata_for_items(items, split, "eval"))
|
||||
return BatchSpec(
|
||||
phase="eval",
|
||||
split=split,
|
||||
seed=seed,
|
||||
batch_size=len(items),
|
||||
payload=items,
|
||||
metadata=batch.metadata,
|
||||
)
|
||||
55
skillopt/envs/alfworld/prompts/analyst_error.md
Normal file
55
skillopt/envs/alfworld/prompts/analyst_error.md
Normal file
@@ -0,0 +1,55 @@
|
||||
You are an expert failure-analysis agent for ALFWorld embodied household tasks.
|
||||
|
||||
You will be given MULTIPLE failed agent trajectories from a single minibatch
|
||||
and the current skill document.
|
||||
Your job is to identify the most important COMMON failure patterns across
|
||||
the batch and propose a concise set of skill edits.
|
||||
|
||||
## ALFWorld Task Types
|
||||
- pick_and_place: Put object in/on a receptacle
|
||||
- pick_two_obj_and_place: Put two instances of an object in/on a receptacle
|
||||
- look_at_obj_in_light: Examine an object under a desklamp
|
||||
- pick_heat_then_place_in_recep: Heat an object and put it in/on a receptacle
|
||||
- pick_cool_then_place_in_recep: Cool an object and put it in/on a receptacle
|
||||
- pick_clean_then_place_in_recep: Clean an object and put it in/on a receptacle
|
||||
|
||||
## Failure Type Categories
|
||||
- **navigation_loop**: the agent revisits the same locations repeatedly without progress
|
||||
- **missed_object**: the agent fails to pick up a visible/reachable goal object
|
||||
- **wrong_sequence**: the agent performs actions in the wrong order (e.g., placing before transforming)
|
||||
- **premature_stop**: the agent stops or gets stuck before completing all goal conditions
|
||||
- **action_loop**: the agent repeats the same action without advancing
|
||||
- **appliance_error**: the agent misuses or skips an appliance (microwave, fridge, sink)
|
||||
- **rule_missing**: the skill lacks a relevant rule for this situation
|
||||
- **rule_wrong**: an existing skill rule is misleading or incorrect
|
||||
- **rule_ignored**: the skill has the right rule but the agent did not follow it
|
||||
- **other**: none of the above
|
||||
|
||||
## Analysis Process
|
||||
1. Read ALL trajectories in the minibatch.
|
||||
2. Identify the most prevalent, systematic failure patterns across them.
|
||||
3. For each pattern, classify its failure type.
|
||||
4. Propose skill edits that address the COMMON patterns — not individual edge cases.
|
||||
5. Edits must be generalizable; do not hardcode task-specific values.
|
||||
6. Only patch gaps in the skill — do not duplicate existing content.
|
||||
|
||||
You will be told the maximum number of edits (the budget L). Produce AT MOST L edits,
|
||||
focusing on the highest-impact patterns. You may produce fewer if warranted.
|
||||
|
||||
Respond ONLY with a valid JSON object (no markdown fences, no extra text):
|
||||
{
|
||||
"batch_size": <number of trajectories analysed>,
|
||||
"failure_summary": [
|
||||
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
|
||||
],
|
||||
"patch": {
|
||||
"reasoning": "<why these edits address the batch's common failures>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown to add at end of skill>"},
|
||||
{"op": "insert_after", "target": "<exact heading/text to insert after>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<exact text to replace>", "content": "<replacement>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
Only include edits that are needed. "edits" can be an empty list if no patch is warranted.
|
||||
33
skillopt/envs/alfworld/prompts/analyst_success.md
Normal file
33
skillopt/envs/alfworld/prompts/analyst_success.md
Normal file
@@ -0,0 +1,33 @@
|
||||
You are an expert success-pattern analyst for AI agents operating in ALFWorld,
|
||||
a text-based embodied household environment.
|
||||
|
||||
You will be given MULTIPLE successful agent trajectories from a single minibatch
|
||||
and the current skill document. Your job is to identify generalizable behavior
|
||||
patterns that are COMMON across the batch and worth encoding in the skill.
|
||||
|
||||
## Rules
|
||||
- Only propose patches for patterns NOT already covered in the skill.
|
||||
- Focus on patterns that appear across MULTIPLE trajectories in the batch.
|
||||
- Be concise. Patterns must generalize beyond specific tasks.
|
||||
- Prefer reinforcing existing sections over adding new top-level sections.
|
||||
- If the agents' success involved efficient exploration or smart appliance usage,
|
||||
consider reinforcing that in the patch.
|
||||
|
||||
You will be told the maximum number of edits (the budget L). Produce AT MOST L edits,
|
||||
focusing on the most broadly applicable patterns. You may produce fewer if warranted.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number of trajectories analysed>,
|
||||
"success_patterns": ["<pattern 1>", "<pattern 2>"],
|
||||
"patch": {
|
||||
"reasoning": "<why these patterns are worth encoding>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
"edits" may be empty if the skill already covers all observed patterns.
|
||||
35
skillopt/envs/alfworld/prompts/deep_probe.md
Normal file
35
skillopt/envs/alfworld/prompts/deep_probe.md
Normal file
@@ -0,0 +1,35 @@
|
||||
You are an expert diagnostic-probe designer for ALFWorld embodied tasks.
|
||||
|
||||
You will design one short diagnostic instruction to append to the student's prompt
|
||||
for a handful of representative ALFWorld trajectories.
|
||||
|
||||
The goal is to expose whether the student has the right intermediate subgoal,
|
||||
object/receptacle state, and next-step intention without substantially changing
|
||||
the current scaffold.
|
||||
|
||||
## Hard Constraints
|
||||
1. Do NOT substantially change the student's existing action-selection scaffold.
|
||||
2. Do NOT prescribe a brand-new planner or long multi-step policy.
|
||||
3. Do NOT ask for exhaustive search over all objects or all admissible actions.
|
||||
4. Keep the diagnostic readout brief and place it inside the existing <think>...</think> block.
|
||||
5. The student must still output exactly one admissible action inside <action>...</action>.
|
||||
6. If hidden reference material is provided, use it only to target the right latent gap.
|
||||
7. Never copy hidden reference content into the student-facing probe.
|
||||
|
||||
## Good Probe Targets
|
||||
- current subgoal
|
||||
- target object / target receptacle / target state
|
||||
- decisive missing precondition
|
||||
- why one candidate action is better than a tempting alternative
|
||||
- whether the current step should explore, transform an object, or place it
|
||||
|
||||
## Bad Probe Targets
|
||||
- a full optimal plan from start to finish
|
||||
- exhaustive object inventories
|
||||
- a new theorem-like or planner-like protocol
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"reasoning": "<why this probe reveals the latent skill gap>",
|
||||
"probe_instruction": "<the exact instruction text to append to the student prompt>"
|
||||
}
|
||||
8
skillopt/envs/alfworld/prompts/rollout_no_history.md
Normal file
8
skillopt/envs/alfworld/prompts/rollout_no_history.md
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
You are an expert agent operating in the ALFRED Embodied Environment.
|
||||
Your current observation is: {current_observation}
|
||||
Your admissible actions of the current situation are: [{admissible_actions}].
|
||||
|
||||
Now it's your turn to take an action.
|
||||
You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within <think> </think> tags.
|
||||
Once you've finished your reasoning, you should choose an admissible action for current step and present it within <action> </action> tags.
|
||||
9
skillopt/envs/alfworld/prompts/rollout_with_history.md
Normal file
9
skillopt/envs/alfworld/prompts/rollout_with_history.md
Normal file
@@ -0,0 +1,9 @@
|
||||
|
||||
You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
|
||||
Prior to this step, you have already taken {step_count} step(s). Below are the most recent {history_length} observations and the corresponding actions you took: {action_history}
|
||||
You are now at step {current_step} and your current observation is: {current_observation}
|
||||
Your admissible actions of the current situation are: [{admissible_actions}].
|
||||
|
||||
Now it's your turn to take an action.
|
||||
You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within <think> </think> tags.
|
||||
Once you've finished your reasoning, you should choose an admissible action for current step and present it within <action> </action> tags.
|
||||
16
skillopt/envs/alfworld/prompts/rollout_with_memory.md
Normal file
16
skillopt/envs/alfworld/prompts/rollout_with_memory.md
Normal file
@@ -0,0 +1,16 @@
|
||||
|
||||
You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
|
||||
|
||||
## Retrieved Relevant Experience
|
||||
|
||||
{retrieved_memories}
|
||||
|
||||
## Current Progress
|
||||
|
||||
Prior to this step, you have already taken {step_count} step(s). Below are the most recent {history_length} observations and the corresponding actions you took: {action_history}
|
||||
You are now at step {current_step} and your current observation is: {current_observation}
|
||||
Your admissible actions of the current situation are: [{admissible_actions}].
|
||||
|
||||
Now it's your turn to take an action.
|
||||
You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within <think> </think> tags.
|
||||
Once you've finished your reasoning, you should choose an admissible action for current step and present it within <action> </action> tags.
|
||||
4
skillopt/envs/alfworld/reflect.py
Normal file
4
skillopt/envs/alfworld/reflect.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""ALFWorld Reflect stage.
|
||||
|
||||
Prompts are now loaded from .md files by the base adapter.
|
||||
"""
|
||||
359
skillopt/envs/alfworld/rollout.py
Normal file
359
skillopt/envs/alfworld/rollout.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""ALFWorld rollout module for ReflACT.
|
||||
|
||||
Provides:
|
||||
- build_alfworld_env(): build ALFWorld environment (wraps vendored SkillRL env)
|
||||
- run_alfworld_batch(): run a batch of ALFWorld episodes in parallel
|
||||
- TASKS: list of ALFWorld task types
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import concurrent.futures
|
||||
import numpy as np
|
||||
|
||||
from skillopt.model import chat_student
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||
|
||||
TASKS = [
|
||||
"pick_and_place",
|
||||
"pick_two_obj_and_place",
|
||||
"look_at_obj_in_light",
|
||||
"pick_heat_then_place_in_recep",
|
||||
"pick_cool_then_place_in_recep",
|
||||
"pick_clean_then_place_in_recep",
|
||||
]
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_task_type(gamefile: str) -> str:
|
||||
for task in TASKS:
|
||||
if task in gamefile:
|
||||
return task
|
||||
return "other"
|
||||
|
||||
|
||||
def _extract_action(model_response: str) -> str | None:
|
||||
match = re.search(r"<action>(.*?)</action>", model_response, re.DOTALL)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
def _extract_think(model_response: str) -> str | None:
|
||||
match = re.search(r"<think>(.*?)</think>", model_response, re.DOTALL)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
def _build_skill_prompt(skill_content: str) -> str:
|
||||
"""Build the skill section to inject into the agent's system prompt."""
|
||||
if not skill_content or not skill_content.strip():
|
||||
return ""
|
||||
return (
|
||||
"\n\n## Skill Knowledge\n"
|
||||
"Below is a skill document with learned strategies. "
|
||||
"Use these guidelines to inform your decisions:\n\n"
|
||||
f"{skill_content}\n"
|
||||
)
|
||||
|
||||
|
||||
def _append_diagnostic_instruction(prompt: str, diagnostic_instruction: str) -> str:
|
||||
if not diagnostic_instruction or not diagnostic_instruction.strip():
|
||||
return prompt
|
||||
return f"{prompt}\n\n## Training Readout\n{diagnostic_instruction.strip()}\n"
|
||||
|
||||
|
||||
# ── Environment builder ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def build_alfworld_env(
|
||||
env_num: int,
|
||||
eval_dataset: str = "eval_out_of_distribution",
|
||||
seed: int = 42,
|
||||
is_train: bool = False,
|
||||
specific_gamefiles: list[str] | None = None,
|
||||
):
|
||||
"""Build ALFWorld environment manager.
|
||||
|
||||
Args:
|
||||
env_num: number of parallel environments
|
||||
eval_dataset: 'eval_in_distribution' or 'eval_out_of_distribution' or train
|
||||
seed: random seed
|
||||
is_train: whether to use training set
|
||||
|
||||
Returns:
|
||||
env_manager: AlfWorldEnvironmentManager instance
|
||||
"""
|
||||
from omegaconf import OmegaConf
|
||||
from functools import partial
|
||||
|
||||
from skillopt.envs.alfworld.vendor.alfworld_envs import build_alfworld_envs
|
||||
from skillopt.envs.alfworld.vendor.alfworld_projection import alfworld_projection
|
||||
from skillopt.envs.alfworld.vendor.env_manager import AlfWorldEnvironmentManager
|
||||
|
||||
HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
alf_config_path = os.path.join(HERE, "vendor", "config_tw.yaml")
|
||||
env_kwargs = {"eval_dataset": eval_dataset}
|
||||
|
||||
envs = build_alfworld_envs(
|
||||
alf_config_path,
|
||||
seed=seed,
|
||||
env_num=env_num,
|
||||
group_n=1,
|
||||
is_train=is_train,
|
||||
env_kwargs=env_kwargs,
|
||||
resources_per_worker=None,
|
||||
gamefiles=specific_gamefiles,
|
||||
)
|
||||
|
||||
config = OmegaConf.create(
|
||||
{
|
||||
"env": {
|
||||
"history_length": 2,
|
||||
"env_name": "alfworld/AlfredTWEnv",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
projection_f = partial(alfworld_projection)
|
||||
env_manager = AlfWorldEnvironmentManager(envs, projection_f, config)
|
||||
return env_manager
|
||||
|
||||
|
||||
# ── Batch rollout ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def run_alfworld_batch(
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
max_steps: int = 50,
|
||||
out_root: str = "",
|
||||
max_api_workers: int = 8,
|
||||
temperature: float = 0.4,
|
||||
max_completion_tokens: int = 2048,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
result_ids: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
"""Run a batch of ALFWorld episodes.
|
||||
|
||||
Returns a list of result dicts compatible with SkillReflection v2 pipeline:
|
||||
[
|
||||
{
|
||||
"id": "<env_idx>_<gamefile_hash>",
|
||||
"hard": 0 or 1,
|
||||
"soft": 0.0 or 1.0,
|
||||
"n_turns": <int>,
|
||||
"fail_reason": "<str>",
|
||||
"agent_ok": True,
|
||||
"task_type": "<str>",
|
||||
"gamefile": "<str>",
|
||||
"task_description": "<str>",
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Also saves conversation.json per environment in out_root/predictions/<task_id>/
|
||||
"""
|
||||
skill_prompt = _build_skill_prompt(skill_content)
|
||||
|
||||
obs, infos = env_manager.reset({})
|
||||
env_num = len(obs["text"])
|
||||
env_dones = [False] * env_num
|
||||
overall_success = [False] * env_num
|
||||
|
||||
# Build per-env metadata
|
||||
env_meta: list[dict] = []
|
||||
for i in range(env_num):
|
||||
gamefile = infos[i].get("extra.gamefile", "") if isinstance(infos[i], dict) else ""
|
||||
task_type = _get_task_type(gamefile)
|
||||
# Extract task description from initial observation
|
||||
task_desc = ""
|
||||
anchor_text = obs["anchor"][i] if "anchor" in obs else ""
|
||||
task_start = anchor_text.find("Your task is to: ")
|
||||
if task_start != -1:
|
||||
task_desc = anchor_text[task_start + len("Your task is to: "):].strip()
|
||||
|
||||
env_meta.append({
|
||||
"gamefile": gamefile,
|
||||
"task_type": task_type,
|
||||
"task_description": task_desc,
|
||||
})
|
||||
|
||||
# Per-env conversation records
|
||||
conversations: list[list[dict]] = [[] for _ in range(env_num)]
|
||||
|
||||
for step_idx in range(max_steps):
|
||||
if all(env_dones):
|
||||
break
|
||||
|
||||
active_indices = [i for i in range(env_num) if not env_dones[i]]
|
||||
|
||||
# Build prompts with skill injection
|
||||
prompts: dict[int, str] = {}
|
||||
for i in active_indices:
|
||||
prompt = obs["text"][i]
|
||||
if skill_prompt:
|
||||
# Inject skill before the action instruction
|
||||
prompt = skill_prompt + "\n" + prompt
|
||||
if diagnostic_mode and diagnostic_instruction.strip():
|
||||
prompt = _append_diagnostic_instruction(prompt, diagnostic_instruction)
|
||||
prompts[i] = prompt
|
||||
|
||||
# Call API in parallel
|
||||
actions = ["None"] * env_num
|
||||
action_timeout = 180
|
||||
|
||||
def call_api(idx):
|
||||
try:
|
||||
response, _ = chat_student(
|
||||
system="You are an expert agent operating in the ALFRED Embodied Environment.",
|
||||
user=prompts[idx],
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
timeout=120,
|
||||
)
|
||||
response = (response or "").strip()
|
||||
if not response:
|
||||
return idx, "<think>empty model response</think><action>look</action>"
|
||||
if _extract_action(response) is None:
|
||||
return idx, "<think>missing action tag</think><action>look</action>"
|
||||
return idx, response
|
||||
except Exception as e:
|
||||
return idx, "<think>error</think><action>look</action>"
|
||||
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_api_workers)
|
||||
try:
|
||||
futures = {executor.submit(call_api, i): i for i in active_indices}
|
||||
started_at = {future: time.time() for future in futures}
|
||||
pending_futs = set(futures)
|
||||
while pending_futs:
|
||||
done, _ = concurrent.futures.wait(
|
||||
pending_futs,
|
||||
timeout=5,
|
||||
return_when=concurrent.futures.FIRST_COMPLETED,
|
||||
)
|
||||
now = time.time()
|
||||
timed_out = [
|
||||
future for future in pending_futs - done
|
||||
if now - started_at[future] >= action_timeout
|
||||
]
|
||||
for future in done:
|
||||
pending_futs.remove(future)
|
||||
try:
|
||||
idx, response = future.result()
|
||||
except Exception: # noqa: BLE001
|
||||
idx = futures[future]
|
||||
response = "<think>error</think><action>look</action>"
|
||||
actions[idx] = response
|
||||
for future in timed_out:
|
||||
pending_futs.remove(future)
|
||||
idx = futures[future]
|
||||
actions[idx] = "<think>api timeout</think><action>look</action>"
|
||||
finally:
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
# Save model responses before stepping
|
||||
model_responses = {i: actions[i] for i in active_indices}
|
||||
|
||||
# Step environment
|
||||
obs, rewards, dones, infos = env_manager.step(actions)
|
||||
|
||||
# Record trajectory
|
||||
for i in active_indices:
|
||||
step_record = {
|
||||
"step": step_idx,
|
||||
"action": _extract_action(model_responses[i]),
|
||||
"reasoning": _extract_think(model_responses[i]),
|
||||
"model_response": model_responses[i],
|
||||
"env_feedback": obs["anchor"][i] if "anchor" in obs else "",
|
||||
"reward": float(rewards[i]),
|
||||
"done": bool(dones[i]),
|
||||
}
|
||||
conversations[i].append(step_record)
|
||||
|
||||
# Update done status
|
||||
for i in range(env_num):
|
||||
if env_dones[i]:
|
||||
continue
|
||||
if dones[i]:
|
||||
env_dones[i] = True
|
||||
won = bool(infos[i].get("won", False))
|
||||
overall_success[i] = won
|
||||
|
||||
# Build results and save conversations
|
||||
results: list[dict] = []
|
||||
pred_dir = os.path.join(out_root, "predictions") if out_root else ""
|
||||
|
||||
for i in range(env_num):
|
||||
gamefile = env_meta[i]["gamefile"]
|
||||
task_type = env_meta[i]["task_type"]
|
||||
task_desc = env_meta[i]["task_description"]
|
||||
n_turns = len(conversations[i])
|
||||
won = overall_success[i]
|
||||
|
||||
# Generate stable task ID from env index and gamefile
|
||||
task_id = str(result_ids[i]) if result_ids and i < len(result_ids) else f"env_{i:03d}"
|
||||
|
||||
fail_reason = ""
|
||||
if not won:
|
||||
if not env_dones[i]:
|
||||
fail_reason = f"Timeout after {max_steps} steps"
|
||||
else:
|
||||
fail_reason = "Episode ended without completing the task"
|
||||
|
||||
result = {
|
||||
"id": task_id,
|
||||
"hard": 1 if won else 0,
|
||||
"soft": 1.0 if won else 0.0,
|
||||
"n_turns": n_turns,
|
||||
"fail_reason": fail_reason,
|
||||
"agent_ok": True, # ALFWorld agent always runs OK (no crash)
|
||||
"task_type": task_type,
|
||||
"gamefile": gamefile,
|
||||
"task_description": task_desc,
|
||||
"instruction_type": task_type, # for compatibility with v2 pipeline
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
# Save conversation
|
||||
if pred_dir:
|
||||
conv_dir = os.path.join(pred_dir, task_id)
|
||||
os.makedirs(conv_dir, exist_ok=True)
|
||||
with open(os.path.join(conv_dir, "conversation.json"), "w") as f:
|
||||
json.dump(conversations[i], f, ensure_ascii=False, indent=2)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ── Item loading (for compatibility with split_three_way) ────────────────────
|
||||
|
||||
|
||||
def load_alfworld_items(
|
||||
eval_dataset: str,
|
||||
env_num: int,
|
||||
seed: int = 42,
|
||||
is_train: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Create pseudo-item dicts for ALFWorld environments.
|
||||
|
||||
Since ALFWorld doesn't have a static JSON dataset like SpreadsheetBench,
|
||||
we create lightweight item dicts that carry enough metadata for the pipeline.
|
||||
The actual environment is built dynamically.
|
||||
|
||||
Returns:
|
||||
List of dicts with "id" keys, one per environment slot.
|
||||
"""
|
||||
items = []
|
||||
for i in range(env_num):
|
||||
items.append({
|
||||
"id": f"env_{i:03d}",
|
||||
"eval_dataset": eval_dataset,
|
||||
"env_index": i,
|
||||
})
|
||||
return items
|
||||
45
skillopt/envs/alfworld/skills/initial.md
Normal file
45
skillopt/envs/alfworld/skills/initial.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# ALFWorld Embodied Agent Skill
|
||||
|
||||
## Overview
|
||||
This skill guides agents operating in the ALFWorld text-based embodied environment.
|
||||
The agent must complete household tasks by navigating rooms, interacting with objects,
|
||||
and using appliances. Actions must be chosen from the admissible action list provided
|
||||
at each step.
|
||||
|
||||
**Output format**: Always output `<think>...</think>` for reasoning, then `<action>...</action>` for the chosen action.
|
||||
|
||||
---
|
||||
|
||||
## Task Types
|
||||
|
||||
| Type | Goal | Key Steps |
|
||||
|------|------|-----------|
|
||||
| Pick & Place | Put object X in/on receptacle Y | Find X -> take X -> go to Y -> put X in/on Y |
|
||||
| Pick Two & Place | Put two instances of X in/on Y | Find X1 -> take -> place -> find X2 -> take -> place |
|
||||
| Examine in Light | Examine object X under desklamp | Find X -> take X -> find desklamp -> use desklamp |
|
||||
| Clean & Place | Clean object X and put in/on Y | Find X -> take X -> go to sink -> clean X -> go to Y -> put X |
|
||||
| Heat & Place | Heat object X and put in/on Y | Find X -> take X -> go to microwave -> heat X -> go to Y -> put X |
|
||||
| Cool & Place | Cool object X and put in/on Y | Find X -> take X -> go to fridge -> cool X -> go to Y -> put X |
|
||||
|
||||
---
|
||||
|
||||
## General Principles
|
||||
|
||||
1. **Decompose the task**: Parse the goal into ordered sub-goals (locate, acquire, transform, deliver). Complete each before moving to the next.
|
||||
2. **Systematic exploration**: Search each surface and container exactly once before revisiting. Open closed containers (drawers, cabinets, fridge) before judging them empty.
|
||||
3. **Grab immediately**: When a required object is visible and reachable, take it right away before moving elsewhere.
|
||||
4. **Transform before placing**: If the task requires cleaning, heating, or cooling, perform the state change at the appropriate appliance before heading to the final destination.
|
||||
5. **Direct delivery**: Once holding the transformed (or untransformed) goal object, navigate straight to the target receptacle and place it.
|
||||
6. **Track progress**: Maintain an internal count of how many objects still need to be found and placed. Only stop searching when the count reaches zero.
|
||||
7. **Avoid loops**: Never repeat the same action more than twice in a row. If stuck, move to a different unexplored location.
|
||||
8. **Only choose admissible actions**: Always pick an action from the admissible action list. Do not invent actions.
|
||||
|
||||
---
|
||||
|
||||
## Common Mistakes to Avoid
|
||||
|
||||
- **Revisiting searched locations**: Keep track of which surfaces/containers have been checked; do not re-examine them.
|
||||
- **Ignoring visible objects**: If the target object appears in the observation, pick it up immediately.
|
||||
- **Skipping state changes**: Do not place an object at the destination without first cleaning/heating/cooling it when required.
|
||||
- **Premature termination**: Do not stop the episode until all goal conditions are verified as met.
|
||||
- **Action loops**: Repeatedly toggling or examining the same object wastes steps. Move on to new locations instead.
|
||||
9
skillopt/envs/alfworld/vendor/__init__.py
vendored
Normal file
9
skillopt/envs/alfworld/vendor/__init__.py
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Vendored ALFWorld environment runtime.
|
||||
|
||||
Minimal subset of SkillRL's agent_system package needed to run
|
||||
ALFWorld environments with ReflACT. Original source:
|
||||
https://github.com/NTU-LANTERN/SkillRL (Apache-2.0 License)
|
||||
"""
|
||||
from .alfworld_envs import AlfworldEnvs, build_alfworld_envs
|
||||
from .alfworld_projection import alfworld_projection
|
||||
from .env_manager import AlfWorldEnvironmentManager
|
||||
221
skillopt/envs/alfworld/vendor/alfworld_envs.py
vendored
Normal file
221
skillopt/envs/alfworld/vendor/alfworld_envs.py
vendored
Normal file
@@ -0,0 +1,221 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/environments/env_package/alfworld/envs.py
|
||||
# Modified: imports use pip-installed alfworld package instead of vendored copy.
|
||||
|
||||
import os
|
||||
import multiprocessing as mp
|
||||
import traceback
|
||||
import yaml
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from alfworld.agents.environment import get_environment
|
||||
|
||||
|
||||
def load_config_file(path):
|
||||
assert os.path.exists(path), f"Invalid config file: {path}"
|
||||
with open(path) as reader:
|
||||
config = yaml.safe_load(reader)
|
||||
return config
|
||||
|
||||
|
||||
def compute_reward(info, multi_modal=False):
|
||||
if multi_modal:
|
||||
reward = 10.0 * float(info['won']) + float(info['goal_condition_success_rate'])
|
||||
else:
|
||||
reward = 10.0 * float(info['won'])
|
||||
return reward
|
||||
|
||||
|
||||
class AlfworldWorker:
|
||||
"""Stateful worker that holds one ALFWorld sub-environment."""
|
||||
|
||||
def __init__(self, config, seed, base_env, gamefile=None):
|
||||
if gamefile:
|
||||
base_env.game_files = [gamefile]
|
||||
if hasattr(base_env, "num_games"):
|
||||
base_env.num_games = 1
|
||||
self.env = base_env.init_env(batch_size=1)
|
||||
self.env.seed(seed)
|
||||
|
||||
def step(self, action):
|
||||
actions = [action]
|
||||
obs, scores, dones, infos = self.env.step(actions)
|
||||
infos['observation_text'] = obs
|
||||
return obs, scores, dones, infos
|
||||
|
||||
def reset(self):
|
||||
obs, infos = self.env.reset()
|
||||
infos['observation_text'] = obs
|
||||
return obs, infos
|
||||
|
||||
|
||||
def _worker_loop(cmd_q, result_q, config, seed, is_train, eval_dataset, gamefile):
|
||||
"""Run one ALFWorld environment in a child process."""
|
||||
try:
|
||||
env_type = config['env']['type']
|
||||
base_env = get_environment(env_type)(
|
||||
config,
|
||||
train_eval='train' if is_train else eval_dataset,
|
||||
)
|
||||
worker = AlfworldWorker(config, seed, base_env, gamefile)
|
||||
result_q.put((True, "ready"))
|
||||
except BaseException:
|
||||
result_q.put((False, traceback.format_exc()))
|
||||
return
|
||||
|
||||
while True:
|
||||
cmd, payload = cmd_q.get()
|
||||
if cmd == "close":
|
||||
result_q.put((True, None))
|
||||
return
|
||||
try:
|
||||
if cmd == "reset":
|
||||
result = worker.reset()
|
||||
elif cmd == "step":
|
||||
result = worker.step(payload)
|
||||
else:
|
||||
raise ValueError(f"Unknown ALFWorld worker command: {cmd}")
|
||||
result_q.put((True, result))
|
||||
except BaseException:
|
||||
result_q.put((False, traceback.format_exc()))
|
||||
|
||||
|
||||
class _ProcessWorker:
|
||||
"""Small stdlib actor wrapper for one environment process."""
|
||||
|
||||
def __init__(self, ctx, config, seed, is_train, eval_dataset, gamefile=None):
|
||||
self.cmd_q = ctx.Queue(maxsize=1)
|
||||
self.result_q = ctx.Queue(maxsize=1)
|
||||
self.process = ctx.Process(
|
||||
target=_worker_loop,
|
||||
args=(self.cmd_q, self.result_q, config, seed, is_train, eval_dataset, gamefile),
|
||||
)
|
||||
self.process.start()
|
||||
ok, payload = self.result_q.get()
|
||||
if not ok:
|
||||
self.close(kill=True)
|
||||
raise RuntimeError(f"Failed to start ALFWorld worker:\n{payload}")
|
||||
|
||||
def send(self, cmd, payload=None):
|
||||
self.cmd_q.put((cmd, payload))
|
||||
|
||||
def recv(self):
|
||||
ok, payload = self.result_q.get()
|
||||
if not ok:
|
||||
raise RuntimeError(f"ALFWorld worker failed:\n{payload}")
|
||||
return payload
|
||||
|
||||
def close(self, kill=False):
|
||||
if self.process.is_alive() and not kill:
|
||||
try:
|
||||
self.send("close")
|
||||
self.recv()
|
||||
except Exception:
|
||||
kill = True
|
||||
if kill and self.process.is_alive():
|
||||
self.process.terminate()
|
||||
self.process.join(timeout=5)
|
||||
if self.process.is_alive():
|
||||
self.process.kill()
|
||||
self.process.join(timeout=1)
|
||||
self.cmd_q.close()
|
||||
self.result_q.close()
|
||||
|
||||
|
||||
class AlfworldEnvs(gym.Env):
|
||||
"""Vectorized ALFWorld environment using local process workers."""
|
||||
|
||||
def __init__(self, alf_config_path, seed, env_num, group_n,
|
||||
resources_per_worker, is_train=True, env_kwargs=None, gamefiles=None):
|
||||
super().__init__()
|
||||
if env_kwargs is None:
|
||||
env_kwargs = {}
|
||||
|
||||
eval_dataset = env_kwargs.get('eval_dataset', 'eval_in_distribution')
|
||||
config = load_config_file(alf_config_path)
|
||||
env_type = config['env']['type']
|
||||
self.multi_modal = (env_type == 'AlfredThorEnv')
|
||||
self.num_processes = env_num * group_n
|
||||
self.group_n = group_n
|
||||
self.gamefiles = list(gamefiles or [])
|
||||
if self.gamefiles and len(self.gamefiles) != self.num_processes:
|
||||
raise ValueError(
|
||||
f"Expected {self.num_processes} gamefiles, got {len(self.gamefiles)}"
|
||||
)
|
||||
|
||||
start_method = os.environ.get("ALFWORLD_WORKER_START_METHOD") or None
|
||||
ctx = mp.get_context(start_method) if start_method else mp.get_context()
|
||||
self.workers = []
|
||||
for i in range(self.num_processes):
|
||||
worker_gamefile = self.gamefiles[i] if self.gamefiles else None
|
||||
worker = _ProcessWorker(
|
||||
ctx,
|
||||
config,
|
||||
seed + (i // self.group_n),
|
||||
is_train,
|
||||
eval_dataset,
|
||||
worker_gamefile,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
self.prev_admissible_commands = [None for _ in range(self.num_processes)]
|
||||
|
||||
def step(self, actions):
|
||||
assert len(actions) == self.num_processes
|
||||
|
||||
for i, worker in enumerate(self.workers):
|
||||
worker.send("step", actions[i])
|
||||
results = [worker.recv() for worker in self.workers]
|
||||
|
||||
text_obs_list = []
|
||||
rewards_list = []
|
||||
dones_list = []
|
||||
info_list = []
|
||||
|
||||
for i, (obs, scores, dones, info) in enumerate(results):
|
||||
for k in info.keys():
|
||||
info[k] = info[k][0]
|
||||
text_obs_list.append(obs[0])
|
||||
dones_list.append(dones[0])
|
||||
info_list.append(info)
|
||||
self.prev_admissible_commands[i] = info['admissible_commands']
|
||||
rewards_list.append(compute_reward(info, self.multi_modal))
|
||||
|
||||
image_obs_list = None
|
||||
return text_obs_list, image_obs_list, rewards_list, dones_list, info_list
|
||||
|
||||
def reset(self):
|
||||
for worker in self.workers:
|
||||
worker.send("reset")
|
||||
results = [worker.recv() for worker in self.workers]
|
||||
|
||||
text_obs_list = []
|
||||
info_list = []
|
||||
|
||||
for i, (obs, info) in enumerate(results):
|
||||
for k in info.keys():
|
||||
info[k] = info[k][0]
|
||||
text_obs_list.append(obs[0])
|
||||
self.prev_admissible_commands[i] = info['admissible_commands']
|
||||
info_list.append(info)
|
||||
|
||||
image_obs_list = None
|
||||
return text_obs_list, image_obs_list, info_list
|
||||
|
||||
@property
|
||||
def get_admissible_commands(self):
|
||||
return self.prev_admissible_commands
|
||||
|
||||
def close(self):
|
||||
for worker in self.workers:
|
||||
worker.close()
|
||||
|
||||
|
||||
def build_alfworld_envs(alf_config_path, seed, env_num, group_n,
|
||||
resources_per_worker, is_train=True, env_kwargs=None, gamefiles=None):
|
||||
"""Build vectorized ALFWorld environments."""
|
||||
return AlfworldEnvs(
|
||||
alf_config_path, seed, env_num, group_n,
|
||||
resources_per_worker, is_train, env_kwargs, gamefiles,
|
||||
)
|
||||
60
skillopt/envs/alfworld/vendor/alfworld_projection.py
vendored
Normal file
60
skillopt/envs/alfworld/vendor/alfworld_projection.py
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/environments/env_package/alfworld/projection.py
|
||||
|
||||
from typing import List
|
||||
import re
|
||||
|
||||
|
||||
def alfworld_projection(actions: List[str], action_pools: List[List[str]]):
|
||||
"""Process raw model outputs into valid ALFWorld actions.
|
||||
|
||||
Extracts text from ``<action>...</action>`` tags and validates that
|
||||
the response also contains ``<think>...</think>`` tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
actions : list[str]
|
||||
Raw model outputs, one per environment.
|
||||
action_pools : list[list[str]]
|
||||
Admissible action lists per environment (unused but kept for API compat).
|
||||
|
||||
Returns
|
||||
-------
|
||||
actions : list[str]
|
||||
Cleaned action strings.
|
||||
valids : list[int]
|
||||
1 if the action was successfully parsed, 0 otherwise.
|
||||
"""
|
||||
valids = [0] * len(actions)
|
||||
|
||||
for i in range(len(actions)):
|
||||
original_str = actions[i]
|
||||
actions[i] = actions[i].lower()
|
||||
|
||||
start_tag = "<action>"
|
||||
end_tag = "</action>"
|
||||
start_idx = actions[i].find(start_tag)
|
||||
end_idx = actions[i].find(end_tag)
|
||||
try:
|
||||
if start_idx == -1 or end_idx == -1:
|
||||
actions[i] = actions[i][-30:]
|
||||
continue
|
||||
|
||||
extracted_action = actions[i][start_idx + len(start_tag):end_idx].strip().lower()
|
||||
actions[i] = extracted_action
|
||||
valids[i] = 1
|
||||
|
||||
except Exception:
|
||||
actions[i] = actions[i][-30:]
|
||||
|
||||
# Require <think>...</think>
|
||||
think_start_idx = original_str.find("<think>")
|
||||
think_end_idx = original_str.find("</think>")
|
||||
if think_start_idx == -1 or think_end_idx == -1:
|
||||
valids[i] = 0
|
||||
|
||||
# Reject responses containing Chinese characters
|
||||
if re.search(r'[\u4e00-\u9fff]', original_str):
|
||||
valids[i] = 0
|
||||
|
||||
return actions, valids
|
||||
8
skillopt/envs/alfworld/vendor/alfworld_prompts.py
vendored
Normal file
8
skillopt/envs/alfworld/vendor/alfworld_prompts.py
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/environments/prompts/alfworld.py
|
||||
|
||||
from skillopt.prompts import load_prompt
|
||||
|
||||
ALFWORLD_TEMPLATE_NO_HIS = load_prompt("rollout_no_history", env="alfworld")
|
||||
ALFWORLD_TEMPLATE = load_prompt("rollout_with_history", env="alfworld")
|
||||
ALFWORLD_TEMPLATE_WITH_MEMORY = load_prompt("rollout_with_memory", env="alfworld")
|
||||
145
skillopt/envs/alfworld/vendor/config_tw.yaml
vendored
Normal file
145
skillopt/envs/alfworld/vendor/config_tw.yaml
vendored
Normal file
@@ -0,0 +1,145 @@
|
||||
dataset:
|
||||
data_path: '$ALFWORLD_DATA/json_2.1.1/train'
|
||||
eval_id_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_seen' # null/None to disable
|
||||
eval_ood_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_unseen' # null/None to disable
|
||||
num_train_games: -1 # max training games (<=0 indicates full dataset)
|
||||
num_eval_games: -1 # max evaluation games (<=0 indicates full dataset)
|
||||
|
||||
logic:
|
||||
domain: '$ALFWORLD_DATA/logic/alfred.pddl' # PDDL domain file that defines the world dynamics
|
||||
grammar: '$ALFWORLD_DATA/logic/alfred.twl2' # Grammar file that defines the text feedbacks
|
||||
|
||||
env:
|
||||
type: 'AlfredTWEnv' # 'AlfredTWEnv' or 'AlfredThorEnv' or 'AlfredHybrid'
|
||||
# regen_game_files: False # check if game is solvable by expert and save to game.tw-pddl file
|
||||
domain_randomization: False # shuffle Textworld print order and object id nums
|
||||
task_types: [1, 2, 3, 4, 5, 6] # task-type ids: 1 - Pick & Place, 2 - Examine in Light, 3 - Clean & Place, 4 - Heat & Place, 5 - Cool & Place, 6 - Pick Two & Place
|
||||
expert_timeout_steps: 150 # max steps before timeout for expert to solve the task
|
||||
expert_type: "handcoded" # 'handcoded' or 'planner'. Note: the planner is very slow for real-time use
|
||||
goal_desc_human_anns_prob: 0.0 # prob of using human-annotated goal language instead of templated goals (1.0 indicates all human annotations from ALFRED)
|
||||
|
||||
hybrid:
|
||||
start_eps: 100000 # starting episode of hybrid training, tw-only training upto this point
|
||||
thor_prob: 0.5 # prob of AlfredThorEnv during hybrid training
|
||||
eval_mode: "tw" # 'tw' or 'thor' - env used for evaluation during hybrid training
|
||||
|
||||
thor:
|
||||
screen_width: 300 # width of THOR window
|
||||
screen_height: 300 # height of THOR window
|
||||
smooth_nav: False # smooth rotations, looks, and translations during navigation (very slow)
|
||||
save_frames_to_disk: False # save frame PNGs to disk (useful for making videos)
|
||||
save_frames_path: './videos/' # path to save frame PNGs
|
||||
|
||||
controller:
|
||||
type: 'oracle' # 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar' (aka BUTLER)
|
||||
debug: False
|
||||
load_receps: True # load receptacle locations from precomputed dict (if available)
|
||||
|
||||
mask_rcnn:
|
||||
pretrained_model_path: '$ALFWORLD_DATA/detectors/mrcnn.pth'
|
||||
|
||||
general:
|
||||
random_seed: 42
|
||||
use_cuda: True # disable this when running on machine without cuda
|
||||
visdom: False # plot training/eval curves, run with visdom server
|
||||
task: 'alfred'
|
||||
training_method: 'dagger' # 'dqn' or 'dagger'
|
||||
save_path: './training/' # path to save pytorch models
|
||||
observation_pool_capacity: 3 # k-size queue, 0 indicates no observation
|
||||
hide_init_receptacles: False # remove initial observation containing navigable receptacles
|
||||
|
||||
training:
|
||||
batch_size: 10
|
||||
max_episode: 50000
|
||||
smoothing_eps: 0.1
|
||||
optimizer:
|
||||
learning_rate: 0.001
|
||||
clip_grad_norm: 5
|
||||
|
||||
evaluate:
|
||||
run_eval: True
|
||||
batch_size: 10
|
||||
env:
|
||||
type: "AlfredTWEnv"
|
||||
|
||||
checkpoint:
|
||||
report_frequency: 1000 # report every N episode
|
||||
experiment_tag: 'test' # name of experiment
|
||||
load_pretrained: False # during test, enable this so that the agent load your pretrained model
|
||||
load_from_tag: 'not loading anything' # name of pre-trained model to load in save_path
|
||||
|
||||
model:
|
||||
encoder_layers: 1
|
||||
decoder_layers: 1
|
||||
encoder_conv_num: 5
|
||||
block_hidden_dim: 64
|
||||
n_heads: 1
|
||||
dropout: 0.1
|
||||
block_dropout: 0.1
|
||||
recurrent: True
|
||||
|
||||
rl:
|
||||
action_space: "admissible" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'beam_search_choice' or 'exhaustive' (not working)
|
||||
max_target_length: 20 # max token length for seq2seq generation
|
||||
beam_width: 10 # 1 means greedy
|
||||
generate_top_k: 3
|
||||
|
||||
training:
|
||||
max_nb_steps_per_episode: 50 # terminate after this many steps
|
||||
learn_start_from_this_episode: 0 # delay updates until this epsiode
|
||||
target_net_update_frequency: 500 # sync target net with online net per this many epochs
|
||||
|
||||
replay:
|
||||
accumulate_reward_from_final: True
|
||||
count_reward_lambda: 0.0 # 0 to disable
|
||||
novel_object_reward_lambda: 0.0 # 0 to disable
|
||||
discount_gamma_game_reward: 0.9
|
||||
discount_gamma_count_reward: 0.5
|
||||
discount_gamma_novel_object_reward: 0.5
|
||||
replay_memory_capacity: 500000 # adjust this depending on your RAM size
|
||||
replay_memory_priority_fraction: 0.5
|
||||
update_per_k_game_steps: 5
|
||||
replay_batch_size: 64
|
||||
multi_step: 3
|
||||
replay_sample_history_length: 4
|
||||
replay_sample_update_from: 2
|
||||
|
||||
epsilon_greedy:
|
||||
noisy_net: False # if this is true, then epsilon greedy is disabled
|
||||
epsilon_anneal_episodes: 1000 # -1 if not annealing
|
||||
epsilon_anneal_from: 0.3
|
||||
epsilon_anneal_to: 0.1
|
||||
|
||||
dagger:
|
||||
action_space: "generation" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'exhaustive' (not working)
|
||||
max_target_length: 20 # max token length for seq2seq generation
|
||||
beam_width: 10 # 1 means greedy
|
||||
generate_top_k: 5
|
||||
unstick_by_beam_search: False # use beam-search for failed actions, set True during evaluation
|
||||
|
||||
training:
|
||||
max_nb_steps_per_episode: 50 # terminate after this many steps
|
||||
|
||||
fraction_assist:
|
||||
fraction_assist_anneal_episodes: 50000
|
||||
fraction_assist_anneal_from: 1.0
|
||||
fraction_assist_anneal_to: 0.01
|
||||
|
||||
fraction_random:
|
||||
fraction_random_anneal_episodes: 0
|
||||
fraction_random_anneal_from: 0.0
|
||||
fraction_random_anneal_to: 0.0
|
||||
|
||||
replay:
|
||||
replay_memory_capacity: 500000
|
||||
update_per_k_game_steps: 5
|
||||
replay_batch_size: 64
|
||||
replay_sample_history_length: 4
|
||||
replay_sample_update_from: 2
|
||||
|
||||
vision_dagger:
|
||||
model_type: "resnet" # 'resnet' (whole image features) or 'maskrcnn_whole' (whole image MaskRCNN feats) or 'maskrcnn' (top k MaskRCNN detection feats) or 'no_vision' (zero vision input)
|
||||
resnet_fc_dim: 64
|
||||
maskrcnn_top_k_boxes: 10 # top k box features
|
||||
use_exploration_frame_feats: False # append feats from initial exploration (memory intensive!)
|
||||
sequence_aggregation_method: "average" # 'sum' or 'average' or 'rnn'
|
||||
84
skillopt/envs/alfworld/vendor/env_base.py
vendored
Normal file
84
skillopt/envs/alfworld/vendor/env_base.py
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/environments/base.py
|
||||
# Trimmed to only include what ALFWorld needs.
|
||||
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def to_numpy(data):
|
||||
"""Convert data to numpy array."""
|
||||
# Lazy-check for torch.Tensor to avoid hard dependency on torch
|
||||
_torch_tensor = None
|
||||
try:
|
||||
import torch
|
||||
_torch_tensor = torch.Tensor
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if _torch_tensor is not None and isinstance(data, _torch_tensor):
|
||||
data = data.detach().cpu().numpy()
|
||||
elif isinstance(data, np.ndarray):
|
||||
pass
|
||||
elif isinstance(data, (int, float, bool, Tuple, List)):
|
||||
data = np.array(data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type: {type(data)})")
|
||||
return data
|
||||
|
||||
|
||||
class EnvironmentManagerBase:
|
||||
"""Base class for vectorized environment managers.
|
||||
|
||||
Manages a set of parallel environments, handles action projection,
|
||||
observation post-processing, and history tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, envs, projection_f, config):
|
||||
self.envs = envs
|
||||
self.projection_f = projection_f
|
||||
self.config = config
|
||||
|
||||
def reset(self, kwargs) -> Dict[str, Any]:
|
||||
obs, infos = self.envs.reset()
|
||||
return {'text': None, 'image': obs, 'anchor': None}, infos
|
||||
|
||||
def step(self, text_actions: List[str]):
|
||||
actions, valids = self.projection_f(text_actions)
|
||||
next_obs, rewards, dones, infos = self.envs.step(actions)
|
||||
|
||||
next_observations = {
|
||||
'text': None,
|
||||
'image': next_obs,
|
||||
'anchor': None,
|
||||
}
|
||||
for i, info in enumerate(infos):
|
||||
info['is_action_valid'] = to_numpy(valids[i])
|
||||
|
||||
rewards = to_numpy(rewards)
|
||||
dones = to_numpy(dones)
|
||||
return next_observations, rewards, dones, infos
|
||||
|
||||
def close(self) -> None:
|
||||
self.envs.close()
|
||||
|
||||
def success_evaluator(self, *args, **kwargs) -> Dict[str, np.ndarray]:
|
||||
total_infos = kwargs['total_infos']
|
||||
total_batch_list = kwargs['total_batch_list']
|
||||
batch_size = len(total_batch_list)
|
||||
|
||||
success = defaultdict(list)
|
||||
for bs in range(batch_size):
|
||||
self._process_batch(bs, total_batch_list, total_infos, success)
|
||||
assert len(success['success_rate']) == batch_size
|
||||
return {key: np.array(value) for key, value in success.items()}
|
||||
|
||||
def _process_batch(self, batch_idx, total_batch_list, total_infos, success):
|
||||
for i in reversed(range(len(total_batch_list[batch_idx]))):
|
||||
batch_item = total_batch_list[batch_idx][i]
|
||||
if batch_item['active_masks']:
|
||||
info = total_infos[batch_idx][i]
|
||||
won_value = float(info['won'])
|
||||
success['success_rate'].append(won_value)
|
||||
return
|
||||
139
skillopt/envs/alfworld/vendor/env_manager.py
vendored
Normal file
139
skillopt/envs/alfworld/vendor/env_manager.py
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/environments/env_manager.py
|
||||
# Trimmed to only include AlfWorldEnvironmentManager and its helpers.
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
from skillopt.envs.alfworld.vendor.env_base import EnvironmentManagerBase, to_numpy
|
||||
from skillopt.envs.alfworld.vendor.alfworld_prompts import (
|
||||
ALFWORLD_TEMPLATE,
|
||||
ALFWORLD_TEMPLATE_NO_HIS,
|
||||
ALFWORLD_TEMPLATE_WITH_MEMORY,
|
||||
)
|
||||
from skillopt.envs.alfworld.vendor.memory import SimpleMemory
|
||||
|
||||
|
||||
def parse_gamefile(infos):
|
||||
gamefile = []
|
||||
for info in infos:
|
||||
if 'extra.gamefile' in info:
|
||||
gamefile.append(info['extra.gamefile'])
|
||||
else:
|
||||
gamefile.append(None)
|
||||
return gamefile
|
||||
|
||||
|
||||
def set_gamefile(infos, gamefile):
|
||||
for i in range(len(infos)):
|
||||
if 'extra.gamefile' in infos[i]:
|
||||
infos[i]['extra.gamefile'] = gamefile[i]
|
||||
else:
|
||||
infos[i]['extra.gamefile'] = None
|
||||
return infos
|
||||
|
||||
|
||||
class AlfWorldEnvironmentManager(EnvironmentManagerBase):
|
||||
"""Manages parallel ALFWorld environments with observation templating."""
|
||||
|
||||
def __init__(self, envs, projection_f, config):
|
||||
self.memory = SimpleMemory()
|
||||
self.retrieval_memory = None
|
||||
super().__init__(envs, projection_f, config)
|
||||
|
||||
def reset(self, kwargs):
|
||||
text_obs, image_obs, infos = self.envs.reset()
|
||||
self.gamefile = parse_gamefile(infos)
|
||||
self.memory.reset(batch_size=len(text_obs))
|
||||
self.tasks = []
|
||||
self.pre_text_obs = text_obs
|
||||
self.extract_task(text_obs)
|
||||
|
||||
full_text_obs = self.build_text_obs(text_obs, self.envs.get_admissible_commands, init=True)
|
||||
return {'text': full_text_obs, 'image': image_obs, 'anchor': text_obs}, infos
|
||||
|
||||
def step(self, text_actions: List[str]):
|
||||
actions, valids = self.projection_f(text_actions, self.envs.get_admissible_commands)
|
||||
text_obs, image_obs, rewards, dones, infos = self.envs.step(actions)
|
||||
self.memory.store({'text_obs': self.pre_text_obs, 'action': actions})
|
||||
self.pre_text_obs = text_obs
|
||||
|
||||
full_text_obs = self.build_text_obs(text_obs, self.envs.get_admissible_commands)
|
||||
if infos[0].get("extra.gamefile") is None:
|
||||
infos = set_gamefile(infos, self.gamefile)
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
info['is_action_valid'] = to_numpy(valids[i])
|
||||
|
||||
next_observations = {'text': full_text_obs, 'image': image_obs, 'anchor': text_obs}
|
||||
rewards = to_numpy(rewards)
|
||||
dones = to_numpy(dones)
|
||||
return next_observations, rewards, dones, infos
|
||||
|
||||
def extract_task(self, text_obs: List[str]):
|
||||
for obs in text_obs:
|
||||
task_start = obs.find('Your task is to: ')
|
||||
if task_start != -1:
|
||||
self.tasks.append(obs[task_start + len('Your task is to: '):].strip())
|
||||
else:
|
||||
raise ValueError("Task description not found in text observation.")
|
||||
|
||||
def build_text_obs(self, text_obs: List[str], admissible_actions: List[List[str]], init: bool = False) -> List[str]:
|
||||
postprocess_text_obs = []
|
||||
if not init and self.config.env.history_length > 0:
|
||||
memory_contexts, valid_lens = self.memory.fetch(
|
||||
self.config.env.history_length,
|
||||
obs_key="text_obs",
|
||||
action_key="action",
|
||||
)
|
||||
|
||||
for i in range(len(text_obs)):
|
||||
reformatted_admissible_actions = "\n ".join(
|
||||
f"'{s}'" for s in admissible_actions[i] if s != 'help'
|
||||
)
|
||||
|
||||
if init or self.config.env.history_length <= 0:
|
||||
obs = ALFWORLD_TEMPLATE_NO_HIS.format(
|
||||
current_observation=text_obs[i],
|
||||
admissible_actions=reformatted_admissible_actions,
|
||||
)
|
||||
else:
|
||||
obs = ALFWORLD_TEMPLATE.format(
|
||||
task_description=self.tasks[i],
|
||||
step_count=len(self.memory[i]),
|
||||
history_length=valid_lens[i],
|
||||
action_history=memory_contexts[i],
|
||||
current_step=len(self.memory[i]) + 1,
|
||||
current_observation=text_obs[i],
|
||||
admissible_actions=reformatted_admissible_actions,
|
||||
)
|
||||
postprocess_text_obs.append(obs)
|
||||
return postprocess_text_obs
|
||||
|
||||
def _process_batch(self, batch_idx, total_batch_list, total_infos, success):
|
||||
for i in reversed(range(len(total_batch_list[batch_idx]))):
|
||||
batch_item = total_batch_list[batch_idx][i]
|
||||
if batch_item['active_masks']:
|
||||
info = total_infos[batch_idx][i]
|
||||
won_value = float(info['won'])
|
||||
success['success_rate'].append(won_value)
|
||||
|
||||
gamefile = info.get("extra.gamefile")
|
||||
if gamefile:
|
||||
self._process_gamefile(gamefile, won_value, success)
|
||||
return
|
||||
|
||||
def _process_gamefile(self, gamefile, won_value, success):
|
||||
tasks = [
|
||||
"pick_and_place",
|
||||
"pick_two_obj_and_place",
|
||||
"look_at_obj_in_light",
|
||||
"pick_heat_then_place_in_recep",
|
||||
"pick_cool_then_place_in_recep",
|
||||
"pick_clean_then_place_in_recep",
|
||||
]
|
||||
for task in tasks:
|
||||
if task in gamefile:
|
||||
success[f"{task}_success_rate"].append(won_value)
|
||||
break
|
||||
87
skillopt/envs/alfworld/vendor/memory.py
vendored
Normal file
87
skillopt/envs/alfworld/vendor/memory.py
vendored
Normal file
@@ -0,0 +1,87 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/memory/base.py + agent_system/memory/memory.py
|
||||
# Merged into a single file for simplicity.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""Base class for memory management."""
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self, batch_size: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def store(self, record: Dict[str, List[Any]]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fetch(self, step: int):
|
||||
pass
|
||||
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
"""Per-environment history buffer for storing observations and actions."""
|
||||
|
||||
def __init__(self):
|
||||
self._data = None
|
||||
self.keys = None
|
||||
self.batch_size = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._data[idx]
|
||||
|
||||
def reset(self, batch_size: int):
|
||||
if self._data is not None:
|
||||
self._data.clear()
|
||||
self._data = [[] for _ in range(batch_size)]
|
||||
self.batch_size = batch_size
|
||||
self.keys = None
|
||||
|
||||
def store(self, record: Dict[str, List[Any]]):
|
||||
if self.keys is None:
|
||||
self.keys = list(record.keys())
|
||||
assert self.keys == list(record.keys())
|
||||
|
||||
for env_idx in range(self.batch_size):
|
||||
self._data[env_idx].append({k: record[k][env_idx] for k in self.keys})
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
history_length: int,
|
||||
obs_key: str = "text_obs",
|
||||
action_key: str = "action",
|
||||
) -> Tuple[List[str], List[int]]:
|
||||
memory_contexts, valid_lengths = [], []
|
||||
|
||||
for env_idx in range(self.batch_size):
|
||||
recent = self._data[env_idx][-history_length:]
|
||||
valid_len = len(recent)
|
||||
start_idx = len(self._data[env_idx]) - valid_len
|
||||
|
||||
lines = []
|
||||
for j, rec in enumerate(recent):
|
||||
step_num = start_idx + j + 1
|
||||
act = rec[action_key]
|
||||
obs = rec[obs_key]
|
||||
lines.append(
|
||||
f"[Observation {step_num}: '{obs}', Action {step_num}: '{act}']"
|
||||
)
|
||||
|
||||
memory_contexts.append("\n".join(lines))
|
||||
valid_lengths.append(valid_len)
|
||||
|
||||
return memory_contexts, valid_lengths
|
||||
1
skillopt/envs/babyvision/__init__.py
Normal file
1
skillopt/envs/babyvision/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""BabyVision environment package for ReflACT."""
|
||||
267
skillopt/envs/babyvision/adapter.py
Normal file
267
skillopt/envs/babyvision/adapter.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""BabyVision environment adapter for ReflACT."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from skillopt.gradient.deep_probe import generate_deep_probe_instruction
|
||||
from skillopt.datasets.base import BatchSpec
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
from skillopt.envs.base import EnvAdapter
|
||||
from skillopt.envs.babyvision.dataloader import BabyVisionDataLoader
|
||||
from skillopt.envs.babyvision.rollout import run_batch
|
||||
from skillopt.model import get_student_backend
|
||||
|
||||
|
||||
class BabyVisionAdapter(EnvAdapter):
|
||||
"""BabyVision adapter."""
|
||||
|
||||
def build_reference_text(self, item: dict) -> str:
|
||||
cot = str(item.get("cot") or "").strip()
|
||||
if not cot:
|
||||
return ""
|
||||
return f"## Reference CoT\n{cot}"
|
||||
|
||||
def get_reference_metadata(self, item: dict) -> dict:
|
||||
cot = str(item.get("cot") or "").strip()
|
||||
if not cot:
|
||||
return {"fields": [], "preview": ""}
|
||||
return {
|
||||
"fields": ["cot"],
|
||||
"preview": cot[:400],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "ratio",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
max_turns: int = 1,
|
||||
workers: int = 32,
|
||||
analyst_workers: int = 16,
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
image_detail: str = "auto",
|
||||
judge_model: str = "gpt-5.4",
|
||||
judge_max_completion_tokens: int = 256,
|
||||
judge_retries: int = 5,
|
||||
use_deep_reflect: bool = False,
|
||||
deep_reflect_failures: int = 4,
|
||||
deep_reflect_successes: int = 2,
|
||||
) -> None:
|
||||
self.max_turns = max_turns
|
||||
self.workers = workers
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
self.edit_budget = edit_budget
|
||||
self.image_detail = image_detail
|
||||
self.judge_model = judge_model
|
||||
self.judge_max_completion_tokens = judge_max_completion_tokens
|
||||
self.judge_retries = judge_retries
|
||||
self.use_deep_reflect = use_deep_reflect
|
||||
self.deep_reflect_failures = deep_reflect_failures
|
||||
self.deep_reflect_successes = deep_reflect_successes
|
||||
self.dataloader = BabyVisionDataLoader(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
self.dataloader.setup(cfg)
|
||||
|
||||
def get_dataloader(self):
|
||||
return self.dataloader
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
return list(batch.payload or [])
|
||||
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def rollout(
|
||||
self,
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict]:
|
||||
items: list[dict] = env_manager
|
||||
return run_batch(
|
||||
items=items,
|
||||
out_root=out_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
workers=self.workers,
|
||||
image_detail=self.image_detail,
|
||||
judge_model=self.judge_model,
|
||||
judge_max_completion_tokens=self.judge_max_completion_tokens,
|
||||
judge_retries=self.judge_retries,
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
|
||||
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),
|
||||
)
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
if not self.use_deep_reflect:
|
||||
return []
|
||||
|
||||
env_manager = kwargs.get("env_manager")
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
codex_backend = get_student_backend() == "codex_exec"
|
||||
selected_items = self.select_representative_items(
|
||||
results,
|
||||
env_manager if isinstance(env_manager, list) else None,
|
||||
n_failures=self.deep_reflect_failures,
|
||||
n_successes=self.deep_reflect_successes,
|
||||
seed=random_seed,
|
||||
)
|
||||
if not selected_items:
|
||||
return []
|
||||
selected_ids = {str(item["id"]) for item in selected_items}
|
||||
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
|
||||
selected_examples = self.attach_reference_context(selected_results, selected_items)
|
||||
if codex_backend:
|
||||
selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir)
|
||||
selected_metadata = []
|
||||
cot_count = 0
|
||||
for item in selected_items:
|
||||
meta = self.get_reference_metadata(item)
|
||||
if meta["fields"]:
|
||||
cot_count += 1
|
||||
selected_metadata.append({
|
||||
"id": str(item["id"]),
|
||||
"task_type": str(item.get("subtype") or item.get("task_type") or "babyvision"),
|
||||
"reference_fields": meta["fields"],
|
||||
"reference_preview": meta["preview"],
|
||||
})
|
||||
|
||||
deep_dir = os.path.join(out_dir, "deep_reflect")
|
||||
rollout_dir = os.path.join(deep_dir, "rollout")
|
||||
patches_dir = os.path.join(deep_dir, "patches")
|
||||
os.makedirs(deep_dir, exist_ok=True)
|
||||
print(
|
||||
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
|
||||
f"reference_fields=cot({cot_count}/{len(selected_items)})"
|
||||
)
|
||||
probe = generate_deep_probe_instruction(
|
||||
skill_content=skill_content,
|
||||
items=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
)
|
||||
if not probe:
|
||||
return []
|
||||
diagnostic_trace_context_by_id = None
|
||||
if codex_backend:
|
||||
selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target(
|
||||
selected_items=selected_items,
|
||||
selected_examples=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
probe=probe,
|
||||
)
|
||||
probe_record = {
|
||||
**probe,
|
||||
"reference_summary": {
|
||||
"selected_count": len(selected_items),
|
||||
"field_counts": {
|
||||
"cot": cot_count,
|
||||
},
|
||||
},
|
||||
"selected_examples": selected_metadata,
|
||||
}
|
||||
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(probe_record, f, ensure_ascii=False, indent=2)
|
||||
deep_results = run_batch(
|
||||
items=selected_items,
|
||||
out_root=rollout_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
workers=min(self.workers, max(len(selected_items), 1)),
|
||||
image_detail=self.image_detail,
|
||||
judge_model=self.judge_model,
|
||||
judge_max_completion_tokens=self.judge_max_completion_tokens,
|
||||
judge_retries=self.judge_retries,
|
||||
diagnostic_mode=True,
|
||||
diagnostic_instruction=probe["probe_instruction"],
|
||||
diagnostic_trace_context_by_id=diagnostic_trace_context_by_id,
|
||||
)
|
||||
deep_results = self.attach_reference_context(deep_results, selected_items)
|
||||
return run_minibatch_reflect(
|
||||
results=deep_results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=os.path.join(rollout_dir, "predictions"),
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return self.dataloader.get_task_types()
|
||||
214
skillopt/envs/babyvision/dataloader.py
Normal file
214
skillopt/envs/babyvision/dataloader.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""BabyVision task dataloader."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from skillopt.datasets.base import SplitDataLoader
|
||||
|
||||
|
||||
# ── Raw data loading utilities (for preprocessing / standalone eval) ─────
|
||||
|
||||
_CHOICE_LABELS = ["A", "B", "C", "D", "E", "F", "G"]
|
||||
|
||||
|
||||
def _iter_jsonl(path: str) -> list[dict]:
|
||||
items: list[dict] = []
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
items.append(json.loads(line))
|
||||
return items
|
||||
|
||||
|
||||
def _normalize_ans_type(raw: Any, options: list[dict], choice_answer: Any) -> str:
|
||||
text = str(raw or "").strip().lower()
|
||||
if text in {"choice", "multiple_choice", "mcq", "option"}:
|
||||
return "choice"
|
||||
if text in {"blank", "open", "open_ended", "fill_blank", "short_answer"}:
|
||||
return "blank"
|
||||
if options or choice_answer not in (None, "", []):
|
||||
return "choice"
|
||||
return "blank"
|
||||
|
||||
|
||||
def _coerce_options(raw: Any) -> list[dict]:
|
||||
options: list[dict] = []
|
||||
if isinstance(raw, list):
|
||||
for idx, item in enumerate(raw):
|
||||
if isinstance(item, dict):
|
||||
text = str(item.get("text") or item.get("content") or item.get("option") or "").strip()
|
||||
label = str(item.get("label") or _CHOICE_LABELS[idx]).strip()
|
||||
else:
|
||||
text = str(item).strip()
|
||||
label = _CHOICE_LABELS[idx]
|
||||
if text:
|
||||
options.append({"label": label, "text": text})
|
||||
elif isinstance(raw, dict):
|
||||
for idx, (key, value) in enumerate(raw.items()):
|
||||
text = str(value).strip()
|
||||
if text:
|
||||
options.append({"label": str(key).strip() or _CHOICE_LABELS[idx], "text": text})
|
||||
return options
|
||||
|
||||
|
||||
def _normalize_choice_answer(choice_answer: Any, options: list[dict]) -> dict[str, str]:
|
||||
if not options:
|
||||
return {"label": "", "text": ""}
|
||||
|
||||
if isinstance(choice_answer, dict):
|
||||
label = str(choice_answer.get("label") or "").strip().upper()
|
||||
text = str(choice_answer.get("text") or "").strip()
|
||||
for option in options:
|
||||
if label and option["label"].strip().upper() == label:
|
||||
return {"label": option["label"], "text": option["text"]}
|
||||
if text and option["text"] == text:
|
||||
return {"label": option["label"], "text": option["text"]}
|
||||
|
||||
if isinstance(choice_answer, int):
|
||||
idx = choice_answer
|
||||
if 0 <= idx < len(options):
|
||||
return dict(options[idx])
|
||||
if 1 <= idx <= len(options):
|
||||
return dict(options[idx - 1])
|
||||
|
||||
text = str(choice_answer or "").strip()
|
||||
label = text.upper().rstrip(".):")
|
||||
for option in options:
|
||||
if option["label"].strip().upper() == label:
|
||||
return dict(option)
|
||||
if option["text"] == text:
|
||||
return dict(option)
|
||||
|
||||
return {"label": "", "text": ""}
|
||||
|
||||
|
||||
def _coerce_blank_answers(raw: Any) -> list[str]:
|
||||
if isinstance(raw, list):
|
||||
return [str(item).strip() for item in raw if str(item).strip()]
|
||||
if raw is None:
|
||||
return []
|
||||
text = str(raw).strip()
|
||||
return [text] if text else []
|
||||
|
||||
|
||||
def load_items(data_path: str) -> list[dict]:
|
||||
"""Load and normalise BabyVision items from a directory or JSONL file."""
|
||||
if not data_path:
|
||||
raise ValueError("BabyVision requires data_path pointing to a local dataset directory or meta_data.jsonl.")
|
||||
|
||||
if os.path.isdir(data_path):
|
||||
meta_path = os.path.join(data_path, "meta_data.jsonl")
|
||||
image_root = os.path.join(data_path, "images")
|
||||
else:
|
||||
meta_path = data_path
|
||||
image_root = os.path.join(os.path.dirname(data_path), "images")
|
||||
|
||||
if not os.path.exists(meta_path):
|
||||
raise ValueError(
|
||||
"BabyVision expected a meta_data.jsonl file. "
|
||||
f"Could not find: {meta_path}"
|
||||
)
|
||||
|
||||
raw_items = _iter_jsonl(meta_path)
|
||||
items: list[dict] = []
|
||||
for idx, raw in enumerate(raw_items):
|
||||
options = _coerce_options(raw.get("options") or raw.get("choices") or raw.get("choiceOptions"))
|
||||
ans_type = _normalize_ans_type(raw.get("ansType"), options, raw.get("choiceAns"))
|
||||
correct_choice = _normalize_choice_answer(raw.get("choiceAns"), options)
|
||||
blank_answers = _coerce_blank_answers(raw.get("blankAns"))
|
||||
|
||||
image_name = str(
|
||||
raw.get("image")
|
||||
or raw.get("image_path")
|
||||
or raw.get("image_file")
|
||||
or raw.get("img")
|
||||
or ""
|
||||
).strip()
|
||||
if not image_name:
|
||||
continue
|
||||
image_path = image_name if os.path.isabs(image_name) else os.path.join(image_root, image_name)
|
||||
if not os.path.exists(image_path):
|
||||
alt = os.path.join(os.path.dirname(meta_path), image_name)
|
||||
if os.path.exists(alt):
|
||||
image_path = alt
|
||||
else:
|
||||
continue
|
||||
|
||||
task_id = str(raw.get("taskId") or raw.get("id") or idx + 1)
|
||||
task_type = str(raw.get("type") or raw.get("taskType") or "unknown").strip() or "unknown"
|
||||
subtype = str(raw.get("subtype") or raw.get("subType") or task_type).strip() or task_type
|
||||
question = str(raw.get("question") or raw.get("query") or "").strip()
|
||||
if not question:
|
||||
continue
|
||||
|
||||
if ans_type == "choice" and not correct_choice["label"]:
|
||||
continue
|
||||
if ans_type != "choice" and not blank_answers:
|
||||
continue
|
||||
|
||||
items.append({
|
||||
"id": task_id,
|
||||
"task_type": task_type,
|
||||
"subtype": subtype,
|
||||
"question": question,
|
||||
"image_path": os.path.abspath(image_path),
|
||||
"ans_type": ans_type,
|
||||
"choices": options,
|
||||
"correct_choice": correct_choice,
|
||||
"blank_answers": blank_answers,
|
||||
"cot": str(raw.get("coT") or raw.get("cot") or "").strip(),
|
||||
"source_path": os.path.abspath(meta_path),
|
||||
})
|
||||
|
||||
if not items:
|
||||
raise ValueError(f"No valid BabyVision items loaded from {data_path}")
|
||||
return items
|
||||
|
||||
|
||||
# ── Dataloader ───────────────────────────────────────────────────────────
|
||||
|
||||
class BabyVisionDataLoader(SplitDataLoader):
|
||||
"""BabyVision dataloader."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "ratio",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
self._task_types: list[str] = []
|
||||
|
||||
def load_raw_items(self, data_path: str) -> list[dict]:
|
||||
return load_items(data_path)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
all_items = self.train_items + self.val_items + self.test_items
|
||||
task_types = {
|
||||
item.get("subtype") or item.get("task_type") or "unknown"
|
||||
for item in all_items
|
||||
}
|
||||
self._task_types = sorted(task_types)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return list(self._task_types)
|
||||
160
skillopt/envs/babyvision/evaluator.py
Normal file
160
skillopt/envs/babyvision/evaluator.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""BabyVision evaluation helpers using the official-style LLM judge."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import string
|
||||
|
||||
import regex
|
||||
|
||||
from skillopt.model import chat_with_deployment
|
||||
from skillopt.prompts import load_prompt
|
||||
|
||||
_EVAL_MODE = "babyvision_judge_v2_official_style"
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
text = str(text).strip().lower()
|
||||
text = "".join(ch for ch in text if ch not in string.punctuation)
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def extract_boxed_answer(text: str | None) -> str | None:
|
||||
"""Extract the final answer using the official BabyVision rule."""
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
pattern = r'\\boxed\{((?:[^{}]|{(?:[^{}]|{.*})*})*)\}'
|
||||
matches = regex.findall(pattern, text)
|
||||
if matches:
|
||||
return matches[-1]
|
||||
|
||||
pattern_alt = r'<\|begin_of_box\|>(.*?)<\|end_of_box\|>'
|
||||
matches_alt = regex.findall(pattern_alt, text)
|
||||
if matches_alt:
|
||||
return matches_alt[-1].strip()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _token_f1(prediction: str, gold: str) -> float:
|
||||
pred_tokens = normalize_text(prediction).split()
|
||||
gold_tokens = normalize_text(gold).split()
|
||||
if not pred_tokens and not gold_tokens:
|
||||
return 1.0
|
||||
if not pred_tokens or not gold_tokens:
|
||||
return 0.0
|
||||
pred_set = {}
|
||||
gold_set = {}
|
||||
for tok in pred_tokens:
|
||||
pred_set[tok] = pred_set.get(tok, 0) + 1
|
||||
for tok in gold_tokens:
|
||||
gold_set[tok] = gold_set.get(tok, 0) + 1
|
||||
common = 0
|
||||
for tok, count in pred_set.items():
|
||||
common += min(count, gold_set.get(tok, 0))
|
||||
if common == 0:
|
||||
return 0.0
|
||||
precision = common / len(pred_tokens)
|
||||
recall = common / len(gold_tokens)
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
|
||||
def _format_choices(choices: list[dict]) -> str:
|
||||
return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices)
|
||||
|
||||
|
||||
def _judge_answer(
|
||||
*,
|
||||
item: dict,
|
||||
prediction_text: str,
|
||||
extracted_answer: str,
|
||||
judge_model: str,
|
||||
max_completion_tokens: int,
|
||||
retries: int,
|
||||
) -> dict:
|
||||
if item["ans_type"] == "choice":
|
||||
ground_truth = str(item["correct_choice"]["label"])
|
||||
else:
|
||||
if len(item["blank_answers"]) == 1:
|
||||
ground_truth = item["blank_answers"][0]
|
||||
else:
|
||||
ground_truth = " | ".join(item["blank_answers"])
|
||||
|
||||
question = str(item["question"])
|
||||
if item["ans_type"] == "choice" and item.get("choices"):
|
||||
question = f"{question}\nChoices:\n{_format_choices(item['choices'])}"
|
||||
|
||||
raw, _ = chat_with_deployment(
|
||||
deployment=judge_model,
|
||||
system="You are a careful and strict evaluator.",
|
||||
user=load_prompt("judge", env="babyvision").format(
|
||||
question=question,
|
||||
groundtruth=ground_truth,
|
||||
modeloutput=extracted_answer,
|
||||
),
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=retries,
|
||||
stage="babyvision_judge",
|
||||
)
|
||||
judge_response_clean = str(raw).strip().lower()
|
||||
if "true" in judge_response_clean:
|
||||
correct = True
|
||||
elif "false" in judge_response_clean:
|
||||
correct = False
|
||||
else:
|
||||
correct = False
|
||||
return {
|
||||
"raw": raw,
|
||||
"correct": correct,
|
||||
"reason": judge_response_clean,
|
||||
"matched_gold": ground_truth if correct else "",
|
||||
}
|
||||
|
||||
|
||||
def evaluate_item(
|
||||
*,
|
||||
item: dict,
|
||||
prediction_text: str,
|
||||
judge_model: str,
|
||||
max_completion_tokens: int = 256,
|
||||
retries: int = 5,
|
||||
) -> dict:
|
||||
answer = extract_boxed_answer(prediction_text)
|
||||
judge = _judge_answer(
|
||||
item=item,
|
||||
prediction_text=prediction_text,
|
||||
extracted_answer=answer,
|
||||
judge_model=judge_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=retries,
|
||||
)
|
||||
hard = 1.0 if judge["correct"] else 0.0
|
||||
|
||||
result = {
|
||||
"evaluation_mode": _EVAL_MODE,
|
||||
"predicted_answer": answer,
|
||||
"em": hard,
|
||||
"f1": hard,
|
||||
"sub_em": hard,
|
||||
"judge_model": judge_model,
|
||||
"judge_raw": judge["raw"],
|
||||
"judge_reason": judge["reason"],
|
||||
"matched_gold": judge["matched_gold"],
|
||||
}
|
||||
|
||||
if item["ans_type"] == "choice":
|
||||
result["predicted_label"] = str(answer or "").strip().upper().rstrip(".):")
|
||||
result["predicted_text"] = ""
|
||||
result["correct_label"] = str(item["correct_choice"].get("label") or "")
|
||||
result["correct_text"] = str(item["correct_choice"].get("text") or "")
|
||||
else:
|
||||
result["gold_answers"] = list(item["blank_answers"])
|
||||
best_f1 = 0.0
|
||||
for gold in item["blank_answers"]:
|
||||
best_f1 = max(best_f1, _token_f1(str(answer or ""), gold))
|
||||
result["string_f1"] = best_f1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def evaluation_mode() -> str:
|
||||
return _EVAL_MODE
|
||||
36
skillopt/envs/babyvision/prompts/analyst_error.md
Normal file
36
skillopt/envs/babyvision/prompts/analyst_error.md
Normal file
@@ -0,0 +1,36 @@
|
||||
You are an expert failure-analysis agent for child-level visual reasoning tasks.
|
||||
|
||||
You will be given MULTIPLE failed BabyVision trajectories from a minibatch and the current skill document.
|
||||
Each trajectory includes the text prompt, the model answer, and the evaluation result.
|
||||
You do not have direct access to raw pixel content during reflection, so focus on general reasoning,
|
||||
option-selection, and visual-question-answering behaviors that can be improved through prompting.
|
||||
|
||||
## Failure Type Categories
|
||||
- **visual_detail_miss**: the agent likely overlooked a salient visual attribute, relation, count, or object state
|
||||
- **option_mismatch**: the agent selected the wrong option despite relevant evidence likely being present
|
||||
- **instruction_slip**: the agent ignored output format or answered too vaguely
|
||||
- **answer_granularity**: the agent gave an answer that was too broad, too narrow, or mismatched the expected specificity
|
||||
- **other**: none of the above
|
||||
|
||||
## Rules
|
||||
1. Focus on patterns recurring across the minibatch.
|
||||
2. Prefer reusable behaviors for inspecting images and grounding answers in visible evidence.
|
||||
3. Do not memorize dataset-specific answers.
|
||||
4. Only patch gaps not already covered by the current skill.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number>,
|
||||
"failure_summary": [
|
||||
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
|
||||
],
|
||||
"patch": {
|
||||
"reasoning": "<why these edits address the common failures>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
25
skillopt/envs/babyvision/prompts/analyst_success.md
Normal file
25
skillopt/envs/babyvision/prompts/analyst_success.md
Normal file
@@ -0,0 +1,25 @@
|
||||
You are an expert success-pattern analyst for child-level visual reasoning tasks.
|
||||
|
||||
You will be given MULTIPLE successful BabyVision trajectories from a minibatch and the current skill document.
|
||||
Identify generalizable behavior patterns that help the agent inspect the image carefully and answer at the right level of specificity.
|
||||
|
||||
## Rules
|
||||
- Focus on broadly useful visual QA behaviors.
|
||||
- Prefer patterns about systematic image inspection, comparing options, and concise grounded answers.
|
||||
- Do not add dataset-specific facts.
|
||||
- "edits" may be empty if the skill already captures the useful patterns.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number>,
|
||||
"success_patterns": ["<pattern 1>", "<pattern 2>"],
|
||||
"patch": {
|
||||
"reasoning": "<why these patterns matter>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
25
skillopt/envs/babyvision/prompts/deep_probe.md
Normal file
25
skillopt/envs/babyvision/prompts/deep_probe.md
Normal file
@@ -0,0 +1,25 @@
|
||||
You are an expert diagnostic-probe designer for BabyVision-style visual reasoning tasks.
|
||||
|
||||
You will be shown representative trajectories, the current student skill, and the student's original prompt context.
|
||||
Design one SMALL diagnostic instruction that exposes the student's intermediate visual judgment without materially changing the original scaffold.
|
||||
|
||||
## Hard Constraints
|
||||
1. Do NOT substantially change the original scaffold.
|
||||
2. Do NOT prescribe a new step-by-step solving method.
|
||||
3. You MAY ask for a short structured list of a few intermediate conclusions, candidate cues, or counted units, as long as it stays close to the original scaffold.
|
||||
4. Do NOT ask for exhaustive listing of all cells, all objects, or a full chain-of-thought.
|
||||
5. Ask only for a short readout that reveals the student's current latent state.
|
||||
6. Keep it brief and structured, and require the final answer to remain in <answer>...</answer>.
|
||||
|
||||
## Good Probe Targets
|
||||
- top answer and runner-up
|
||||
- decisive visual cue
|
||||
- suspicious region or compared objects
|
||||
- counting unit or formatting interpretation
|
||||
- 2-4 short intermediate conclusions that directly support the final answer
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"reasoning": "<why this probe is informative>",
|
||||
"probe_instruction": "<the exact instruction text to append to the student prompt>"
|
||||
}
|
||||
35
skillopt/envs/babyvision/prompts/judge.md
Normal file
35
skillopt/envs/babyvision/prompts/judge.md
Normal file
@@ -0,0 +1,35 @@
|
||||
You are a careful and strict evaluator. You will be given:
|
||||
|
||||
1. **Question**
|
||||
2. **Ground Truth Answer** (correct answer)
|
||||
3. **Model Output** (answer from another model)
|
||||
|
||||
**Your goal:** Determine if the Model Output **accurately matches** the Ground Truth Answer in meaning.
|
||||
|
||||
* Matching means: the facts, entities, and key details are equivalent, even if phrasing differs.
|
||||
* Not matching means: the Model Output is wrong, incomplete, contains extra incorrect facts, or changes the meaning.
|
||||
|
||||
**Process (internal reasoning):**
|
||||
|
||||
1. Read and understand the Question, Ground Truth Answer, and Model Output.
|
||||
2. Ignore small wording differences, formatting, or synonyms.
|
||||
3. If all factual content matches, conclude `1`. Otherwise, conclude `0`.
|
||||
|
||||
**Important:**
|
||||
|
||||
* Think through your decision step-by-step **internally** before responding.
|
||||
* In your final output, return **only** True or False, with no extra text or explanation.
|
||||
|
||||
**Output format:**
|
||||
|
||||
True
|
||||
|
||||
or
|
||||
|
||||
False
|
||||
|
||||
**Input:**
|
||||
|
||||
Question: {question},
|
||||
Ground Truth Answer: {groundtruth},
|
||||
Model Output: {modeloutput}
|
||||
13
skillopt/envs/babyvision/prompts/rollout_system.md
Normal file
13
skillopt/envs/babyvision/prompts/rollout_system.md
Normal file
@@ -0,0 +1,13 @@
|
||||
You are an expert visual reasoning agent solving child-level image understanding tasks.
|
||||
|
||||
{skill_section}## Task Format
|
||||
You will receive one image and one question about it.
|
||||
Inspect the image carefully before answering. Ground the answer in visible evidence.
|
||||
|
||||
## Answer Format
|
||||
Think step by step, then provide your final answer in \boxed{{Answer}} format.
|
||||
- For multiple-choice questions, output only the single choice label, such as \boxed{{A}}.
|
||||
- For open questions, output only a short final answer inside \boxed{{...}}.
|
||||
|
||||
Example:
|
||||
\boxed{{B}}
|
||||
4
skillopt/envs/babyvision/reflect.py
Normal file
4
skillopt/envs/babyvision/reflect.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""BabyVision Reflect stage.
|
||||
|
||||
Prompts are now loaded from .md files by the base adapter.
|
||||
"""
|
||||
467
skillopt/envs/babyvision/rollout.py
Normal file
467
skillopt/envs/babyvision/rollout.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""BabyVision rollout — multimodal visual QA with image input."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from skillopt.envs.babyvision.evaluator import evaluate_item, evaluation_mode, extract_boxed_answer
|
||||
from skillopt.model import chat_student_messages, get_student_backend, is_student_exec_backend
|
||||
from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
|
||||
from skillopt.prompts import load_prompt
|
||||
|
||||
def _build_system(skill_content: str) -> str:
|
||||
if skill_content.strip():
|
||||
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
|
||||
else:
|
||||
skill_section = ""
|
||||
return load_prompt("rollout_system", env="babyvision").format(skill_section=skill_section)
|
||||
|
||||
|
||||
def _format_choices(choices: list[dict]) -> str:
|
||||
return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices)
|
||||
|
||||
|
||||
def _build_user_text(
|
||||
item: dict,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> str:
|
||||
parts = []
|
||||
if diagnostic_trace_context.strip():
|
||||
parts.append(
|
||||
"## Previous Codex Trace Snapshot\n"
|
||||
"This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n"
|
||||
f"{diagnostic_trace_context.strip()}"
|
||||
)
|
||||
parts.append(f"## Question\n{item['question']}")
|
||||
if item["ans_type"] == "choice":
|
||||
parts.append(f"## Choices\n{_format_choices(item['choices'])}")
|
||||
parts.append("Answer using the single correct option label in \\boxed{...}.")
|
||||
else:
|
||||
parts.append("Answer with a short phrase in \\boxed{...}.")
|
||||
if diagnostic_mode and diagnostic_instruction.strip():
|
||||
parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _image_to_data_uri(path: str) -> str:
|
||||
mime = mimetypes.guess_type(path)[0] or "image/png"
|
||||
with open(path, "rb") as f:
|
||||
encoded = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:{mime};base64,{encoded}"
|
||||
|
||||
|
||||
def _build_messages(
|
||||
item: dict,
|
||||
skill_content: str,
|
||||
image_detail: str,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> tuple[list[dict], str, str]:
|
||||
system = _build_system(skill_content)
|
||||
user_text = _build_user_text(
|
||||
item,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
image_url = {
|
||||
"url": _image_to_data_uri(item["image_path"]),
|
||||
}
|
||||
if image_detail and image_detail != "auto":
|
||||
image_url["detail"] = image_detail
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_text},
|
||||
{"type": "image_url", "image_url": image_url},
|
||||
],
|
||||
},
|
||||
]
|
||||
return messages, system, user_text
|
||||
|
||||
|
||||
def _build_codex_skill(skill_content: str) -> str:
|
||||
return render_skill_md(
|
||||
skill_content,
|
||||
description="Dynamic ReflACT skill for solving the current BabyVision visual reasoning question.",
|
||||
preamble=(
|
||||
"Use this skill when answering the current visual reasoning question.\n"
|
||||
"Inspect the attached image carefully and return the final answer in \\boxed{...}."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _run_codex_once(
|
||||
*,
|
||||
pred_dir: str,
|
||||
item: dict,
|
||||
skill_content: str,
|
||||
model: str,
|
||||
timeout: int,
|
||||
image_detail: str,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
previous_response: str = "",
|
||||
) -> tuple[str, str, str, str]:
|
||||
user_text = _build_user_text(
|
||||
item,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
task_parts = [user_text]
|
||||
if previous_response:
|
||||
task_parts.append(
|
||||
"## Previous Attempt\n"
|
||||
f"{previous_response}\n\n"
|
||||
"Review the same image and question carefully. If needed, correct the answer."
|
||||
)
|
||||
task_text = "\n\n".join(task_parts)
|
||||
skill_md = _build_codex_skill(skill_content)
|
||||
work_dir = os.path.join(pred_dir, "codex_exec")
|
||||
prepare_workspace(
|
||||
work_dir=work_dir,
|
||||
skill_md=skill_md,
|
||||
task_text=task_text,
|
||||
images=[item["image_path"]],
|
||||
)
|
||||
prompt = (
|
||||
"Use the `skillopt-student` skill available in this workspace.\n"
|
||||
"Read `task.md`, inspect the attached image, and answer the question.\n"
|
||||
"Return the final answer in \\boxed{...}."
|
||||
)
|
||||
final_message, raw = run_student_exec(
|
||||
work_dir=work_dir,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
images=[item["image_path"]],
|
||||
)
|
||||
return final_message or raw, raw, skill_md, task_text
|
||||
|
||||
|
||||
def process_one(
|
||||
item: dict,
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
image_detail: str = "auto",
|
||||
judge_model: str = "gpt-5.4",
|
||||
judge_max_completion_tokens: int = 256,
|
||||
judge_retries: int = 5,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> dict:
|
||||
item_id = str(item["id"])
|
||||
result = {
|
||||
"id": item_id,
|
||||
"question": item["question"],
|
||||
"task_type": item.get("subtype") or item.get("task_type") or "babyvision",
|
||||
"task_description": item["question"],
|
||||
"hard": 0,
|
||||
"soft": 0.0,
|
||||
"predicted_answer": "",
|
||||
"predicted_label": "",
|
||||
"predicted_text": "",
|
||||
"response": "",
|
||||
"fail_reason": "",
|
||||
"agent_ok": False,
|
||||
"n_turns": 0,
|
||||
"image_path": item["image_path"],
|
||||
"ans_type": item["ans_type"],
|
||||
"evaluation_mode": evaluation_mode(),
|
||||
"judge_model": judge_model,
|
||||
}
|
||||
if item["ans_type"] == "choice":
|
||||
result["correct_label"] = item["correct_choice"]["label"]
|
||||
result["correct_text"] = item["correct_choice"]["text"]
|
||||
else:
|
||||
result["gold_answers"] = item["blank_answers"]
|
||||
|
||||
try:
|
||||
pred_dir = os.path.join(out_root, "predictions", item_id)
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
|
||||
if is_student_exec_backend():
|
||||
from skillopt.model import azure_openai as _llm
|
||||
|
||||
response = ""
|
||||
conversation: list[dict] = [
|
||||
{"role": "user", "content": f"{item['question']}\n\n[image] {os.path.basename(item['image_path'])}"}
|
||||
]
|
||||
system_prompt = ""
|
||||
user_text = ""
|
||||
for turn in range(max_turns):
|
||||
response, raw, system_prompt, user_text = _run_codex_once(
|
||||
pred_dir=pred_dir,
|
||||
item=item,
|
||||
skill_content=skill_content,
|
||||
model=_llm.STUDENT_DEPLOYMENT,
|
||||
timeout=120,
|
||||
image_detail=image_detail,
|
||||
diagnostic_mode=diagnostic_mode if turn == 0 else False,
|
||||
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
|
||||
diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "",
|
||||
previous_response=response if turn > 0 else "",
|
||||
)
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": response})
|
||||
if extract_boxed_answer(response) is not None:
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation) - 1
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system_prompt)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user_text)
|
||||
|
||||
eval_result = evaluate_item(
|
||||
item=item,
|
||||
prediction_text=response,
|
||||
judge_model=judge_model,
|
||||
max_completion_tokens=judge_max_completion_tokens,
|
||||
retries=judge_retries,
|
||||
)
|
||||
result["evaluation_mode"] = eval_result["evaluation_mode"]
|
||||
result["judge_raw"] = eval_result["judge_raw"]
|
||||
result["judge_reason"] = eval_result["judge_reason"]
|
||||
result["matched_gold"] = eval_result["matched_gold"]
|
||||
if item["ans_type"] == "choice":
|
||||
result["predicted_label"] = eval_result["predicted_label"]
|
||||
result["predicted_text"] = eval_result["predicted_text"]
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"judge=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
|
||||
f"but expected '{eval_result['correct_label']}' ({eval_result['judge_reason']})"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted label: {eval_result['predicted_label']!r}\n"
|
||||
f"Predicted text: {eval_result['predicted_text']!r}\n"
|
||||
f"Correct label: {eval_result['correct_label']!r}\n"
|
||||
f"Correct text: {eval_result['correct_text']!r}\n"
|
||||
f"Judge correct: {eval_result['em']}\n"
|
||||
f"Judge reason: {eval_result['judge_reason']}"
|
||||
)
|
||||
else:
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"judge=0: predicted '{eval_result['predicted_answer']}' "
|
||||
f"but expected {item['blank_answers']} ({eval_result['judge_reason']})"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
|
||||
f"Gold answers: {item['blank_answers']!r}\n"
|
||||
f"Judge correct: {eval_result['em']}\n"
|
||||
f"Judge reason: {eval_result['judge_reason']}\n"
|
||||
f"String F1: {eval_result.get('string_f1', 0.0):.4f}"
|
||||
)
|
||||
conversation.append({"role": "system", "content": eval_detail})
|
||||
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(conversation, f, ensure_ascii=False, indent=2)
|
||||
return result
|
||||
|
||||
messages, system_prompt, user_text = _build_messages(
|
||||
item,
|
||||
skill_content,
|
||||
image_detail,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
response = ""
|
||||
conversation: list[dict] = [
|
||||
{"role": "user", "content": f"{user_text}\n\n[image] {os.path.basename(item['image_path'])}"}
|
||||
]
|
||||
|
||||
for turn in range(max_turns):
|
||||
if turn == 0:
|
||||
resp_text, _ = chat_student_messages(
|
||||
messages=messages,
|
||||
max_completion_tokens=768,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
)
|
||||
else:
|
||||
refinement_text = (
|
||||
f"Your previous answer was:\n{response}\n\n"
|
||||
"Review the same image and question carefully. "
|
||||
"If needed, correct your answer. Output the final answer in \\boxed{...}."
|
||||
)
|
||||
refinement_messages = [
|
||||
messages[0],
|
||||
messages[1],
|
||||
{"role": "assistant", "content": response},
|
||||
{"role": "user", "content": refinement_text},
|
||||
]
|
||||
resp_text, _ = chat_student_messages(
|
||||
messages=refinement_messages,
|
||||
max_completion_tokens=512,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
)
|
||||
response = resp_text
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
|
||||
if extract_boxed_answer(resp_text) is not None:
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation) - 1
|
||||
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system_prompt)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user_text)
|
||||
|
||||
eval_result = evaluate_item(
|
||||
item=item,
|
||||
prediction_text=response,
|
||||
judge_model=judge_model,
|
||||
max_completion_tokens=judge_max_completion_tokens,
|
||||
retries=judge_retries,
|
||||
)
|
||||
result["evaluation_mode"] = eval_result["evaluation_mode"]
|
||||
result["judge_raw"] = eval_result["judge_raw"]
|
||||
result["judge_reason"] = eval_result["judge_reason"]
|
||||
result["matched_gold"] = eval_result["matched_gold"]
|
||||
|
||||
if item["ans_type"] == "choice":
|
||||
result["predicted_label"] = eval_result["predicted_label"]
|
||||
result["predicted_text"] = eval_result["predicted_text"]
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"judge=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
|
||||
f"but expected '{eval_result['correct_label']}' ({eval_result['judge_reason']})"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted label: {eval_result['predicted_label']!r}\n"
|
||||
f"Predicted text: {eval_result['predicted_text']!r}\n"
|
||||
f"Correct label: {eval_result['correct_label']!r}\n"
|
||||
f"Correct text: {eval_result['correct_text']!r}\n"
|
||||
f"Judge correct: {eval_result['em']}\n"
|
||||
f"Judge reason: {eval_result['judge_reason']}"
|
||||
)
|
||||
else:
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"judge=0: predicted '{eval_result['predicted_answer']}' "
|
||||
f"but expected {item['blank_answers']} ({eval_result['judge_reason']})"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
|
||||
f"Gold answers: {item['blank_answers']!r}\n"
|
||||
f"Judge correct: {eval_result['em']}\n"
|
||||
f"Judge reason: {eval_result['judge_reason']}\n"
|
||||
f"String F1: {eval_result.get('string_f1', 0.0):.4f}"
|
||||
)
|
||||
|
||||
conversation.append({"role": "system", "content": eval_detail})
|
||||
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(conversation, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
result["fail_reason"] = f"error: {e}"
|
||||
return result
|
||||
|
||||
|
||||
def run_batch(
|
||||
items: list[dict],
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
workers: int = 32,
|
||||
image_detail: str = "auto",
|
||||
judge_model: str = "gpt-5.4",
|
||||
judge_max_completion_tokens: int = 256,
|
||||
judge_retries: int = 5,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context_by_id: dict[str, str] | None = None,
|
||||
) -> list[dict]:
|
||||
results_path = os.path.join(out_root, "results.jsonl")
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
|
||||
expected_eval_mode = evaluation_mode()
|
||||
done_ids: set[str] = set()
|
||||
existing: list[dict] = []
|
||||
rewrite_results = False
|
||||
if os.path.exists(results_path):
|
||||
with open(results_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
try:
|
||||
row = json.loads(line)
|
||||
if row.get("evaluation_mode") != expected_eval_mode:
|
||||
rewrite_results = True
|
||||
continue
|
||||
done_ids.add(str(row["id"]))
|
||||
existing.append(row)
|
||||
except Exception:
|
||||
rewrite_results = True
|
||||
|
||||
pending = [item for item in items if str(item["id"]) not in done_ids]
|
||||
if not pending and not rewrite_results:
|
||||
return existing
|
||||
|
||||
results = list(existing)
|
||||
file_mode = "w" if rewrite_results else "a"
|
||||
with open(results_path, file_mode, encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex:
|
||||
if rewrite_results:
|
||||
for row in existing:
|
||||
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
futs = {
|
||||
ex.submit(
|
||||
process_one,
|
||||
item,
|
||||
out_root,
|
||||
skill_content,
|
||||
max_turns=max_turns,
|
||||
image_detail=image_detail,
|
||||
judge_model=judge_model,
|
||||
judge_max_completion_tokens=judge_max_completion_tokens,
|
||||
judge_retries=judge_retries,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""),
|
||||
): item
|
||||
for item in pending
|
||||
}
|
||||
for fut in as_completed(futs):
|
||||
row = fut.result()
|
||||
results.append(row)
|
||||
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
outf.flush()
|
||||
return results
|
||||
18
skillopt/envs/babyvision/skills/initial.md
Normal file
18
skillopt/envs/babyvision/skills/initial.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# BabyVision Visual QA Heuristics
|
||||
|
||||
## Image Inspection
|
||||
- First identify the main objects, their attributes, and their spatial relations before answering.
|
||||
- If the question involves counting, compare all relevant instances carefully instead of stopping after the first match.
|
||||
- If the question asks about color, size, position, or action, verify the specific visible evidence for that attribute.
|
||||
|
||||
## Multiple Choice
|
||||
- Compare every option against the visible image evidence before deciding.
|
||||
- Prefer the option that matches the image exactly; reject options that are only partially true or too vague.
|
||||
- When two options are close, check the smallest discriminating visual detail.
|
||||
|
||||
## Open Answers
|
||||
- Answer with the shortest phrase that is fully supported by the image.
|
||||
- Match the expected level of specificity: not broader than the image evidence, not narrower than the question asks.
|
||||
|
||||
## Final Answer
|
||||
- Output only the final answer inside <answer>...</answer>.
|
||||
396
skillopt/envs/base.py
Normal file
396
skillopt/envs/base.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""ReflACT environment adapter — abstract interface.
|
||||
|
||||
To connect ReflACT to a new environment (benchmark, simulator, etc.),
|
||||
implement a subclass of :class:`EnvAdapter` with environment-specific
|
||||
rollout and reflection logic.
|
||||
|
||||
Example::
|
||||
|
||||
class MyBenchAdapter(EnvAdapter):
|
||||
def build_train_env(self, batch_size, seed, **kw):
|
||||
return MyEnvManager(split="train", n=batch_size, seed=seed)
|
||||
|
||||
def build_eval_env(self, env_num, split, seed, **kw):
|
||||
return MyEnvManager(split=split, n=env_num, seed=seed)
|
||||
|
||||
def rollout(self, env_manager, skill_content, out_dir, **kw):
|
||||
# Run episodes, return [{"id": ..., "hard": 0/1, "soft": 0.0-1.0, ...}]
|
||||
...
|
||||
|
||||
def reflect(self, results, skill_content, out_dir, **kw):
|
||||
# Analyze trajectories, return list of patch dicts
|
||||
...
|
||||
|
||||
def get_task_types(self):
|
||||
return ["task_a", "task_b"]
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import random
|
||||
|
||||
from skillopt.datasets.base import BaseDataLoader, BatchSpec
|
||||
from skillopt.model.codex_harness import extract_codex_trace_prefix, format_codex_trace_steps, parse_codex_raw
|
||||
from skillopt.prompts import load_prompt
|
||||
|
||||
|
||||
class EnvAdapter(ABC):
|
||||
"""Abstract adapter for connecting ReflACT to any environment.
|
||||
|
||||
Subclasses must implement all abstract methods. The ReflACT trainer
|
||||
calls these methods at the appropriate pipeline stages.
|
||||
"""
|
||||
|
||||
# ── Lifecycle hooks ────────────────────────────────────────────────────
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
"""Called once by the trainer before the training loop begins.
|
||||
|
||||
Override to perform one-time initialization that requires the full
|
||||
config (e.g., data loading, split creation). Default is a no-op.
|
||||
"""
|
||||
self._cfg = dict(cfg)
|
||||
|
||||
def get_dataloader(self) -> BaseDataLoader | None:
|
||||
"""Return the task dataloader used by this adapter, if any."""
|
||||
return None
|
||||
|
||||
def requires_ray(self) -> bool:
|
||||
"""Return whether this adapter requires Ray runtime initialization."""
|
||||
return False
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
"""Optional deeper diagnostic reflection pass.
|
||||
|
||||
Default behavior is a no-op. Dataset-backed adapters may override this
|
||||
to re-query the student on a small representative subset of the current
|
||||
batch using minimally-perturbed diagnostic prompts that expose
|
||||
intermediate reasoning state.
|
||||
"""
|
||||
return []
|
||||
|
||||
def build_reference_text(self, item: dict) -> str:
|
||||
"""Return hidden reference material for deep reflection, if any."""
|
||||
return str(item.get("reference_text") or "").strip()
|
||||
|
||||
def get_reference_metadata(self, item: dict) -> dict:
|
||||
"""Return structured metadata about hidden reference material."""
|
||||
reference_text = self.build_reference_text(item)
|
||||
if not reference_text:
|
||||
return {"fields": [], "preview": ""}
|
||||
return {
|
||||
"fields": ["reference_text"],
|
||||
"preview": reference_text[:400],
|
||||
}
|
||||
|
||||
def get_codex_deep_probe_prompt(self) -> str | None:
|
||||
env_name = getattr(self, "_cfg", {}).get("env_name")
|
||||
return load_prompt("deep_probe_codex", env=env_name)
|
||||
|
||||
def attach_codex_probe_context(
|
||||
self,
|
||||
results: list[dict],
|
||||
prediction_dir: str,
|
||||
) -> list[dict]:
|
||||
"""Attach compact Codex step metadata for codex-aware deep reflection."""
|
||||
enriched: list[dict] = []
|
||||
for row in results:
|
||||
merged = dict(row)
|
||||
tid = str(row.get("id"))
|
||||
raw_path = os.path.join(prediction_dir, tid, "codex_raw.txt")
|
||||
if os.path.exists(raw_path):
|
||||
with open(raw_path, encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
parsed = parse_codex_raw(raw)
|
||||
merged["codex_probe_trace_steps"] = format_codex_trace_steps(raw)
|
||||
merged["codex_probe_step_count"] = len(parsed["steps"])
|
||||
enriched.append(merged)
|
||||
return enriched
|
||||
|
||||
def resolve_codex_probe_target(
|
||||
self,
|
||||
*,
|
||||
selected_items: list[dict],
|
||||
selected_examples: list[dict],
|
||||
prediction_dir: str,
|
||||
probe: dict,
|
||||
) -> tuple[list[dict], dict[str, str] | None, dict]:
|
||||
"""Resolve the teacher-selected codex probe target and raw trace prefix."""
|
||||
target_id = str(probe.get("probe_target_id", "")).strip()
|
||||
selected_id_set = {str(item["id"]) for item in selected_items}
|
||||
if target_id not in selected_id_set:
|
||||
target_id = str(selected_items[0]["id"])
|
||||
target_item = next(item for item in selected_items if str(item["id"]) == target_id)
|
||||
target_result = next(
|
||||
(row for row in selected_examples if str(row.get("id")) == target_id),
|
||||
None,
|
||||
)
|
||||
max_probe_step = int((target_result or {}).get("codex_probe_step_count", 0))
|
||||
default_probe_step = max_probe_step - 1 if max_probe_step > 1 else max_probe_step
|
||||
probe_after_step = int(probe.get("probe_after_step", default_probe_step))
|
||||
if max_probe_step > 0:
|
||||
probe_after_step = max(0, min(probe_after_step, max_probe_step))
|
||||
else:
|
||||
probe_after_step = 0
|
||||
raw_path = os.path.join(prediction_dir, target_id, "codex_raw.txt")
|
||||
trace_prefix = ""
|
||||
if os.path.exists(raw_path):
|
||||
with open(raw_path, encoding="utf-8") as f:
|
||||
trace_prefix = extract_codex_trace_prefix(f.read(), after_step=probe_after_step)
|
||||
updated_probe = dict(probe)
|
||||
updated_probe["probe_target_id"] = target_id
|
||||
updated_probe["probe_after_step"] = probe_after_step
|
||||
return [target_item], {target_id: trace_prefix}, updated_probe
|
||||
|
||||
def attach_reference_context(
|
||||
self,
|
||||
results: list[dict],
|
||||
items: list[dict] | None,
|
||||
) -> list[dict]:
|
||||
"""Attach environment-specific hidden reference text to result dicts."""
|
||||
if not results or not items:
|
||||
return list(results)
|
||||
|
||||
item_by_id = {
|
||||
str(item.get("id")): item
|
||||
for item in items
|
||||
if isinstance(item, dict) and item.get("id") is not None
|
||||
}
|
||||
enriched: list[dict] = []
|
||||
for row in results:
|
||||
merged = dict(row)
|
||||
item = item_by_id.get(str(row.get("id")))
|
||||
if item:
|
||||
reference_text = self.build_reference_text(item)
|
||||
if reference_text:
|
||||
merged["reference_text"] = reference_text
|
||||
enriched.append(merged)
|
||||
return enriched
|
||||
|
||||
def select_representative_items(
|
||||
self,
|
||||
results: list[dict],
|
||||
items: list[dict] | None,
|
||||
*,
|
||||
n_failures: int,
|
||||
n_successes: int,
|
||||
seed: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Select a small diverse subset of current-batch items by outcome."""
|
||||
if not items:
|
||||
return []
|
||||
|
||||
item_by_id = {
|
||||
str(item.get("id")): item
|
||||
for item in items
|
||||
if isinstance(item, dict) and item.get("id") is not None
|
||||
}
|
||||
failures = [
|
||||
(result, item_by_id[str(result.get("id"))])
|
||||
for result in results
|
||||
if not result.get("hard") and str(result.get("id")) in item_by_id
|
||||
]
|
||||
successes = [
|
||||
(result, item_by_id[str(result.get("id"))])
|
||||
for result in results
|
||||
if result.get("hard") and str(result.get("id")) in item_by_id
|
||||
]
|
||||
|
||||
rng = random.Random(seed)
|
||||
|
||||
def _pick(pool: list[tuple[dict, dict]], quota: int) -> list[dict]:
|
||||
if quota <= 0 or not pool:
|
||||
return []
|
||||
shuffled = list(pool)
|
||||
rng.shuffle(shuffled)
|
||||
|
||||
picked_ids: set[str] = set()
|
||||
picked: list[dict] = []
|
||||
seen_types: set[str] = set()
|
||||
|
||||
for result, item in shuffled:
|
||||
task_type = str(result.get("task_type") or item.get("task_type") or item.get("subtype") or "unknown")
|
||||
item_id = str(item["id"])
|
||||
if task_type in seen_types or item_id in picked_ids:
|
||||
continue
|
||||
picked.append(item)
|
||||
picked_ids.add(item_id)
|
||||
seen_types.add(task_type)
|
||||
if len(picked) >= quota:
|
||||
return picked
|
||||
|
||||
for _, item in shuffled:
|
||||
item_id = str(item["id"])
|
||||
if item_id in picked_ids:
|
||||
continue
|
||||
picked.append(item)
|
||||
picked_ids.add(item_id)
|
||||
if len(picked) >= quota:
|
||||
break
|
||||
return picked
|
||||
|
||||
selected = _pick(failures, n_failures)
|
||||
selected_ids = {str(item["id"]) for item in selected}
|
||||
selected.extend(
|
||||
item for item in _pick(successes, n_successes)
|
||||
if str(item["id"]) not in selected_ids
|
||||
)
|
||||
return selected
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
"""Build an environment manager or item list from a :class:`BatchSpec`.
|
||||
|
||||
Default behavior preserves the legacy adapter API by routing training
|
||||
batches through :meth:`build_train_env` and evaluation batches through
|
||||
:meth:`build_eval_env`.
|
||||
"""
|
||||
if batch.phase == "train":
|
||||
return self.build_train_env(batch_size=batch.batch_size, seed=batch.seed, **kwargs)
|
||||
return self.build_eval_env(
|
||||
env_num=batch.batch_size,
|
||||
split=batch.split,
|
||||
seed=batch.seed,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
"""Build a training environment manager.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
An environment manager that can be passed to :meth:`rollout`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
"""Build an evaluation environment manager.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_num : int
|
||||
Number of evaluation environments.
|
||||
split : str
|
||||
Dataset split (e.g. ``"valid_seen"``, ``"valid_unseen"``).
|
||||
seed : int
|
||||
Random seed for reproducibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
An environment manager that can be passed to :meth:`rollout`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def rollout(
|
||||
self,
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict]:
|
||||
"""Run a batch of episodes using the current skill.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict]
|
||||
Each dict conforms to :class:`~skillopt.types.RolloutResult`:
|
||||
must have ``"id"`` (str), ``"hard"`` (0/1), ``"soft"``
|
||||
(float 0-1). May include env-specific fields.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
"""Analyze rollout results and produce patches.
|
||||
|
||||
Each returned dict conforms to :class:`~skillopt.types.RawPatch`:
|
||||
``"patch"`` (with ``"edits"`` list) + ``"source_type"``
|
||||
(``"failure"`` or ``"success"``).
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict | None]
|
||||
Raw analyst outputs; ``None`` entries are filtered out.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_task_types(self) -> list[str]:
|
||||
"""Return the list of task type names for this environment."""
|
||||
|
||||
# ── Prompt configuration (two-level priority) ────────────────────────
|
||||
#
|
||||
# Priority: env-specific prompt file > generic default prompt file.
|
||||
#
|
||||
# Prompts are loaded from ``.md`` files via ``load_prompt(name, env)``:
|
||||
# 1. ``skillopt/envs/<env>/prompts/<name>.md`` (env-specific)
|
||||
# 2. ``skillopt/prompts/<name>.md`` (generic fallback)
|
||||
#
|
||||
# Subclasses can still override ``get_*_prompt()`` for full control.
|
||||
|
||||
@property
|
||||
def _env_name(self) -> str:
|
||||
"""Derive the env directory name from this adapter's module path."""
|
||||
# e.g. "skillopt.envs.searchqa.adapter" → "searchqa"
|
||||
module = type(self).__module__
|
||||
parts = module.split(".")
|
||||
if len(parts) >= 3 and parts[-3] == "envs":
|
||||
return parts[-2]
|
||||
return ""
|
||||
|
||||
def _load_env_prompt(self, name: str) -> str | None:
|
||||
"""Load a prompt with env-specific override. Returns None if not found."""
|
||||
try:
|
||||
return load_prompt(name, env=self._env_name)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def get_error_minibatch_prompt(self) -> str | None:
|
||||
update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch")
|
||||
raw_mode = str(update_mode).strip().lower()
|
||||
if raw_mode in {"full_rewrite", "full_rewrite_minibatch", "minibatch_full_rewrite", "skill_rewrite_minibatch"}:
|
||||
prompt = self._load_env_prompt("analyst_error_full_rewrite")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
if raw_mode in {"rewrite", "rewrite_from_suggestions", "suggestions", "rewrite_suggestions"}:
|
||||
prompt = self._load_env_prompt("analyst_error_rewrite")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
return self._load_env_prompt("analyst_error")
|
||||
|
||||
def get_success_minibatch_prompt(self) -> str | None:
|
||||
update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch")
|
||||
raw_mode = str(update_mode).strip().lower()
|
||||
if raw_mode in {"full_rewrite", "full_rewrite_minibatch", "minibatch_full_rewrite", "skill_rewrite_minibatch"}:
|
||||
prompt = self._load_env_prompt("analyst_success_full_rewrite")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
if raw_mode in {"rewrite", "rewrite_from_suggestions", "suggestions", "rewrite_suggestions"}:
|
||||
prompt = self._load_env_prompt("analyst_success_rewrite")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
return self._load_env_prompt("analyst_success")
|
||||
|
||||
def get_deep_probe_prompt(self) -> str | None:
|
||||
return self._load_env_prompt("deep_probe")
|
||||
|
||||
def get_meta_reflect_prompt(self) -> str | None:
|
||||
update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch")
|
||||
if str(update_mode).strip().lower() == "rewrite_from_suggestions":
|
||||
prompt = self._load_env_prompt("meta_reflect_rewrite")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
return self._load_env_prompt("meta_reflect")
|
||||
114
skillopt/envs/deep_reflect.py
Normal file
114
skillopt/envs/deep_reflect.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable
|
||||
|
||||
from skillopt.gradient.deep_probe import generate_deep_probe_instruction
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
|
||||
|
||||
def run_no_reference_deep_reflect(
|
||||
adapter: Any,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
*,
|
||||
env_manager: Any = None,
|
||||
prediction_dir: str | None = None,
|
||||
random_seed: int | None = None,
|
||||
step_buffer_context: str = "",
|
||||
output_requirements: list[str] | None = None,
|
||||
metadata_builder: Callable[[dict], dict] | None = None,
|
||||
) -> list[dict | None]:
|
||||
"""Run teacher-designed diagnostic probing without hidden references."""
|
||||
if not getattr(adapter, "use_deep_reflect", False):
|
||||
return []
|
||||
if not isinstance(env_manager, list):
|
||||
return []
|
||||
|
||||
prediction_dir = prediction_dir or os.path.join(out_dir, "predictions")
|
||||
selected_items = adapter.select_representative_items(
|
||||
results,
|
||||
env_manager,
|
||||
n_failures=getattr(adapter, "deep_reflect_failures", 4),
|
||||
n_successes=getattr(adapter, "deep_reflect_successes", 2),
|
||||
seed=random_seed,
|
||||
)
|
||||
if not selected_items:
|
||||
return []
|
||||
|
||||
selected_ids = {str(item["id"]) for item in selected_items}
|
||||
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
|
||||
if metadata_builder is None:
|
||||
selected_metadata = [
|
||||
{
|
||||
"id": str(item.get("id")),
|
||||
"task_type": str(item.get("task_type") or item.get("topic") or "unknown"),
|
||||
"question_preview": str(item.get("question") or "")[:200],
|
||||
}
|
||||
for item in selected_items
|
||||
]
|
||||
else:
|
||||
selected_metadata = [metadata_builder(item) for item in selected_items]
|
||||
|
||||
deep_dir = os.path.join(out_dir, "deep_reflect")
|
||||
rollout_dir = os.path.join(deep_dir, "rollout")
|
||||
patches_dir = os.path.join(deep_dir, "patches")
|
||||
os.makedirs(deep_dir, exist_ok=True)
|
||||
print(
|
||||
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
|
||||
"mode=no_reference_probe"
|
||||
)
|
||||
|
||||
probe = generate_deep_probe_instruction(
|
||||
skill_content=skill_content,
|
||||
items=selected_results,
|
||||
prediction_dir=prediction_dir,
|
||||
system_prompt=adapter.get_deep_probe_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
output_requirements=output_requirements,
|
||||
)
|
||||
if not probe:
|
||||
return []
|
||||
|
||||
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
**probe,
|
||||
"reference_summary": {
|
||||
"mode": "no_reference_probe",
|
||||
"selected_count": len(selected_items),
|
||||
},
|
||||
"selected_examples": selected_metadata,
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
deep_results = adapter.rollout(
|
||||
selected_items,
|
||||
skill_content,
|
||||
rollout_dir,
|
||||
diagnostic_mode=True,
|
||||
diagnostic_instruction=probe["probe_instruction"],
|
||||
)
|
||||
return run_minibatch_reflect(
|
||||
results=deep_results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=os.path.join(rollout_dir, "predictions"),
|
||||
patches_dir=patches_dir,
|
||||
workers=getattr(adapter, "analyst_workers", 8),
|
||||
failure_only=getattr(adapter, "failure_only", False),
|
||||
minibatch_size=getattr(adapter, "minibatch_size", 8),
|
||||
edit_budget=getattr(adapter, "edit_budget", 4),
|
||||
random_seed=random_seed,
|
||||
error_system=adapter.get_error_minibatch_prompt(),
|
||||
success_system=adapter.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
update_mode=getattr(getattr(adapter, "_cfg", {}), "get", lambda *_: "patch")(
|
||||
"skill_update_mode",
|
||||
"patch",
|
||||
),
|
||||
)
|
||||
1
skillopt/envs/docvqa/__init__.py
Normal file
1
skillopt/envs/docvqa/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""DocVQA environment package for ReflACT."""
|
||||
153
skillopt/envs/docvqa/adapter.py
Normal file
153
skillopt/envs/docvqa/adapter.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from skillopt.datasets.base import BatchSpec
|
||||
from skillopt.envs.base import EnvAdapter
|
||||
from skillopt.envs.deep_reflect import run_no_reference_deep_reflect
|
||||
from skillopt.envs.docvqa.dataloader import DocVQADataLoader
|
||||
from skillopt.envs.docvqa.rollout import run_batch
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
|
||||
|
||||
class DocVQAAdapter(EnvAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "split_dir",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
max_turns: int = 1,
|
||||
exec_timeout: int = 120,
|
||||
workers: int = 16,
|
||||
analyst_workers: int = 16,
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
exec_timeout: int = 600,
|
||||
image_detail: str = "auto",
|
||||
use_deep_reflect: bool = False,
|
||||
deep_reflect_failures: int = 4,
|
||||
deep_reflect_successes: int = 2,
|
||||
) -> None:
|
||||
self.max_turns = max_turns
|
||||
self.exec_timeout = exec_timeout
|
||||
self.workers = workers
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
self.edit_budget = edit_budget
|
||||
self.exec_timeout = exec_timeout
|
||||
self.image_detail = image_detail
|
||||
self.use_deep_reflect = use_deep_reflect
|
||||
self.deep_reflect_failures = deep_reflect_failures
|
||||
self.deep_reflect_successes = deep_reflect_successes
|
||||
self.dataloader = DocVQADataLoader(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
self.dataloader.setup(cfg)
|
||||
|
||||
def get_dataloader(self):
|
||||
return self.dataloader
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
return list(batch.payload or [])
|
||||
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def rollout(self, env_manager, skill_content: str, out_dir: str, **kwargs) -> list[dict]:
|
||||
items: list[dict] = env_manager
|
||||
return run_batch(
|
||||
items=items,
|
||||
out_root=out_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
exec_timeout=self.exec_timeout,
|
||||
workers=self.workers,
|
||||
image_detail=self.image_detail,
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
|
||||
task_timeout=self.exec_timeout,
|
||||
)
|
||||
|
||||
def reflect(self, results: list[dict], skill_content: str, out_dir: str, **kwargs) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
return run_no_reference_deep_reflect(
|
||||
self,
|
||||
results,
|
||||
skill_content,
|
||||
out_dir,
|
||||
env_manager=kwargs.get("env_manager"),
|
||||
prediction_dir=kwargs.get("prediction_dir"),
|
||||
random_seed=kwargs.get("random_seed"),
|
||||
step_buffer_context=kwargs.get("step_buffer_context", ""),
|
||||
output_requirements=[
|
||||
"- There is no hidden reference block. Use only the document image prompt, student output, and evaluation result to infer what intermediate state is worth probing.",
|
||||
"- The instruction must explicitly request a short <analysis>...</analysis> block before the final <answer>...</answer>.",
|
||||
"- The readout should focus on visual region, field/table/figure label, OCR text read, candidate answer, and answer-format normalization.",
|
||||
"- Do not ask for exhaustive transcription or a full chain-of-thought.",
|
||||
"- The instruction text should be ready to append directly to the student's prompt.",
|
||||
],
|
||||
metadata_builder=lambda item: {
|
||||
"id": str(item.get("id")),
|
||||
"task_type": str(item.get("task_type") or "docvqa"),
|
||||
"question_preview": str(item.get("question") or "")[:200],
|
||||
"image_path": item.get("image_path", ""),
|
||||
"docId": item.get("docId", ""),
|
||||
"page": item.get("ucsf_document_page_no", ""),
|
||||
},
|
||||
)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
seen: list[str] = []
|
||||
for item in self.dataloader.train_items + self.dataloader.val_items + self.dataloader.test_items:
|
||||
task_type = str(item.get("task_type") or "docvqa")
|
||||
if task_type not in seen:
|
||||
seen.append(task_type)
|
||||
return seen or ["docvqa"]
|
||||
61
skillopt/envs/docvqa/dataloader.py
Normal file
61
skillopt/envs/docvqa/dataloader.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
from skillopt.datasets.base import SplitDataLoader
|
||||
|
||||
|
||||
def _parse_answers(raw: str) -> list[str]:
|
||||
text = str(raw or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
parsed = ast.literal_eval(text)
|
||||
except Exception:
|
||||
return [text]
|
||||
if isinstance(parsed, list):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()]
|
||||
return [str(parsed).strip()]
|
||||
|
||||
|
||||
def _extract_document_path(question: str) -> tuple[str, str]:
|
||||
marker = "document_path:"
|
||||
if marker not in question:
|
||||
return question.strip(), ""
|
||||
main, tail = question.split(marker, 1)
|
||||
return main.strip(), tail.strip()
|
||||
|
||||
|
||||
def _normalize_row(row: dict[str, str]) -> dict:
|
||||
question_text, document_path = _extract_document_path(str(row.get("question") or ""))
|
||||
answers = _parse_answers(row.get("answer") or row.get("ground_truth") or "")
|
||||
image_path = str(row.get("image_path") or document_path or "").strip()
|
||||
task_type = str(row.get("topic") or row.get("category") or "docvqa").strip() or "docvqa"
|
||||
return {
|
||||
"id": str(row.get("questionId") or row.get("id") or "").strip(),
|
||||
"question": question_text,
|
||||
"answer": answers[0] if answers else "",
|
||||
"answers": answers,
|
||||
"task_type": task_type,
|
||||
"subtask": task_type,
|
||||
"image_paths": [image_path] if image_path else [],
|
||||
"image_path": image_path,
|
||||
"questionId": str(row.get("questionId") or "").strip(),
|
||||
"docId": str(row.get("docId") or "").strip(),
|
||||
"ucsf_document_id": str(row.get("ucsf_document_id") or "").strip(),
|
||||
"ucsf_document_page_no": str(row.get("ucsf_document_page_no") or "").strip(),
|
||||
"source_split": str(row.get("source_split") or "").strip(),
|
||||
}
|
||||
|
||||
|
||||
class DocVQADataLoader(SplitDataLoader):
|
||||
def load_split_items(self, split_path: str) -> list[dict]:
|
||||
path = Path(split_path)
|
||||
csv_files = sorted(path.glob("*.csv"))
|
||||
if not csv_files:
|
||||
raise FileNotFoundError(f"No .csv file found in {split_path}")
|
||||
with csv_files[0].open(encoding="utf-8", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
return [_normalize_row(row) for row in reader]
|
||||
113
skillopt/envs/docvqa/evaluator.py
Normal file
113
skillopt/envs/docvqa/evaluator.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
DEFAULT_ANLS_THRESHOLD = 0.5
|
||||
|
||||
|
||||
def _normalize_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
text = str(value).strip().lower()
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def _levenshtein_distance(a: str, b: str) -> int:
|
||||
if a == b:
|
||||
return 0
|
||||
if not a:
|
||||
return len(b)
|
||||
if not b:
|
||||
return len(a)
|
||||
if len(a) > len(b):
|
||||
a, b = b, a
|
||||
previous = list(range(len(b) + 1))
|
||||
for i, char_a in enumerate(a, start=1):
|
||||
current = [i]
|
||||
for j, char_b in enumerate(b, start=1):
|
||||
insert_cost = current[j - 1] + 1
|
||||
delete_cost = previous[j] + 1
|
||||
replace_cost = previous[j - 1] + (char_a != char_b)
|
||||
current.append(min(insert_cost, delete_cost, replace_cost))
|
||||
previous = current
|
||||
return previous[-1]
|
||||
|
||||
|
||||
def _score_single_answer(predicted: Any, target: Any, threshold: float) -> float:
|
||||
predicted_norm = _normalize_text(predicted)
|
||||
target_norm = _normalize_text(target)
|
||||
if not predicted_norm and not target_norm:
|
||||
return 1.0
|
||||
if not predicted_norm or not target_norm:
|
||||
return 0.0
|
||||
distance = _levenshtein_distance(predicted_norm, target_norm)
|
||||
normalized_distance = distance / max(len(predicted_norm), len(target_norm))
|
||||
if normalized_distance >= threshold:
|
||||
return 0.0
|
||||
return 1.0 - normalized_distance
|
||||
|
||||
|
||||
def _extract_answer_strings(raw: Any) -> list[str]:
|
||||
if raw is None:
|
||||
return [""]
|
||||
if isinstance(raw, str):
|
||||
text = raw.strip()
|
||||
if not text:
|
||||
return [""]
|
||||
parsed = None
|
||||
if text[0] in "[{":
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
parsed = ast.literal_eval(text)
|
||||
except (ValueError, SyntaxError):
|
||||
parsed = None
|
||||
if parsed is None:
|
||||
return [text]
|
||||
return _extract_answer_strings(parsed)
|
||||
if isinstance(raw, dict):
|
||||
for key in ("answers", "ground_truth", "answer"):
|
||||
if key in raw:
|
||||
return _extract_answer_strings(raw[key])
|
||||
return [str(raw)]
|
||||
if isinstance(raw, Iterable) and not isinstance(raw, (bytes, bytearray)):
|
||||
answers: list[str] = []
|
||||
for item in raw:
|
||||
if isinstance(item, dict):
|
||||
for key in ("text", "answer", "value"):
|
||||
if key in item:
|
||||
answers.extend(_extract_answer_strings(item[key]))
|
||||
break
|
||||
else:
|
||||
answers.append(str(item))
|
||||
continue
|
||||
answers.append(str(item))
|
||||
return answers or [""]
|
||||
return [str(raw)]
|
||||
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
lower = text.lower()
|
||||
start = lower.rfind("<answer>")
|
||||
end = lower.rfind("</answer>")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return text[start + len("<answer>"):end].strip()
|
||||
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
||||
return lines[-1] if lines else text.strip()
|
||||
|
||||
|
||||
def evaluate(prediction_text: str, gold_answers: Any) -> dict:
|
||||
answer = extract_answer(prediction_text)
|
||||
answers = _extract_answer_strings(gold_answers)
|
||||
score = 0.0
|
||||
for target in answers:
|
||||
score = max(score, _score_single_answer(answer, target, DEFAULT_ANLS_THRESHOLD))
|
||||
return {
|
||||
"anls": score,
|
||||
"predicted_answer": answer,
|
||||
"gold_answers": answers,
|
||||
}
|
||||
35
skillopt/envs/docvqa/prompts/analyst_error.md
Normal file
35
skillopt/envs/docvqa/prompts/analyst_error.md
Normal file
@@ -0,0 +1,35 @@
|
||||
You are an expert failure-analysis agent for visual document question answering tasks.
|
||||
|
||||
You will be given MULTIPLE failed DocVQA trajectories from a single minibatch and the current skill document. Each trajectory includes the model response and an evaluation result scored with ANLS against one or more acceptable answers.
|
||||
|
||||
Your job is to identify the most important COMMON failure patterns across the batch and propose concise skill edits.
|
||||
|
||||
## Failure Type Categories
|
||||
- evidence_miss: the model overlooked the relevant visible region or line
|
||||
- near_match_confusion: the model selected a nearby but incorrect text span
|
||||
- normalization_error: the answer differed mainly in formatting, spacing, punctuation, or minor text normalization
|
||||
- reading_error: the model misread the document content
|
||||
- other: none of the above
|
||||
|
||||
## Rules
|
||||
- Focus on common, reusable reading and extraction behaviors.
|
||||
- Do not hardcode image-specific answers.
|
||||
- Prefer concise edits that improve evidence selection and exact span extraction.
|
||||
|
||||
Respond ONLY with a valid JSON object (no markdown fences, no extra text):
|
||||
{
|
||||
"batch_size": <number of trajectories analysed>,
|
||||
"failure_summary": [
|
||||
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
|
||||
],
|
||||
"patch": {
|
||||
"reasoning": "<why these edits address the batch's common failures>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown to add at end of skill>"},
|
||||
{"op": "insert_after", "target": "<exact heading/text to insert after>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<exact text to replace>", "content": "<replacement>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
Only include edits that are needed. "edits" can be an empty list if no patch is warranted.
|
||||
24
skillopt/envs/docvqa/prompts/analyst_success.md
Normal file
24
skillopt/envs/docvqa/prompts/analyst_success.md
Normal file
@@ -0,0 +1,24 @@
|
||||
You are an expert success-pattern analyst for visual document question answering tasks.
|
||||
|
||||
You will be given MULTIPLE successful DocVQA trajectories from a single minibatch and the current skill document. Your job is to identify common visual reading and exact-answer extraction behaviors worth encoding in the skill.
|
||||
|
||||
## Rules
|
||||
- Focus on patterns shared across multiple successful trajectories.
|
||||
- Reinforce reusable behaviors like locating the right region, copying exact spans, and preferring the shortest exact answer over paraphrase.
|
||||
- Only propose patches for patterns not already captured by the current skill.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number of trajectories analysed>,
|
||||
"success_patterns": ["<pattern 1>", "<pattern 2>"],
|
||||
"patch": {
|
||||
"reasoning": "<why these patterns are worth encoding>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
"edits" may be empty if the skill already covers all observed patterns.
|
||||
12
skillopt/envs/docvqa/prompts/rollout_system.md
Normal file
12
skillopt/envs/docvqa/prompts/rollout_system.md
Normal file
@@ -0,0 +1,12 @@
|
||||
You are an expert visual document question answering agent.
|
||||
|
||||
{skill_section}You will receive a document image and a question about the document.
|
||||
Read the visual evidence carefully and answer concisely.
|
||||
|
||||
Rules:
|
||||
- Ground the answer in the visible document content.
|
||||
- Prefer exact spans, numbers, dates, and names from the document.
|
||||
- Do not invent content that is not visible.
|
||||
- If multiple near-matches exist, choose the one best supported by the document.
|
||||
|
||||
Return the final answer inside <answer>...</answer>.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user