Initial commit

This commit is contained in:
carpedkm
2026-05-08 18:12:45 +00:00
commit 866ba52287
243 changed files with 31492 additions and 0 deletions

15
.gitignore vendored Normal file
View File

@@ -0,0 +1,15 @@
__pycache__/
*.pyc
data/
outputs/
logs/
external/
/BabyVision/
/MMRB/
/SpreadsheetBench/
/dl4ir-searchQA/
configs/local/
configs/**/*.local.yaml
*.local.md
.secrets/
.codex_azure*/

557
README.md Normal file
View File

@@ -0,0 +1,557 @@
# ReflACT: Reflective Agent Tuning
ReflACT is a framework for optimizing an external skill document through iterative rollout, reflection, editing, and gated validation.
It does **not** fine-tune model weights. Instead, it treats the skill document as the optimization target:
- the **student** model executes tasks with the current skill
- the **teacher** model analyzes trajectories and proposes edits
- the framework merges, ranks, applies, and validates those edits
- only validated skill updates are kept
This branch implements a full training loop with step-level skill optimization and optional epoch-level memory mechanisms (`slow_update`, `meta_skill`, `meta_reflect`).
## Method Overview
### Optimization Target
Each run maintains a mutable markdown skill document. The framework repeatedly improves that document instead of changing model parameters.
This gives a training-style loop for prompt / policy optimization:
1. Roll out the current skill on a batch of tasks.
2. Reflect on failures and successes.
3. Merge patch proposals into a coherent candidate update.
4. Rank and select a bounded number of edits.
5. Apply those edits to produce a candidate skill.
6. Validate the candidate skill on a held-out selection split.
7. Keep the update only if the gate accepts it.
### Per-Step Pipeline
Every training step executes the following pipeline in `reflact/engine/trainer.py`:
1. **Rollout**
The student model runs a batch of tasks using the current skill.
2. **Reflect**
The teacher analyzes minibatches of trajectories and emits raw patches.
Failure-driven and success-driven patches are tracked separately.
3. **Aggregate**
Raw patches are merged hierarchically. Metadata such as `support_count` and `source_type` is carried into the merged patch so later ranking can use it.
4. **Select**
The teacher ranks the merged edit pool and keeps up to `edit_budget` edits.
5. **Update**
The selected edits are applied to the skill document. The framework records an `edit_apply_report.json` so you can see which edits actually landed, which were skipped, and why.
6. **Evaluate / Gate**
The candidate skill is evaluated on the selection split. Gate validation is mandatory in this branch. A candidate update is accepted only if it improves over the current selection score; a new global best is tracked separately.
### Within-Epoch Memory
Inside an epoch, the trainer maintains a step buffer containing:
- compact failure-pattern summaries from previous steps
- rejected edits and their score deltas
That context is fed back into later reflection calls so the teacher can avoid repeating ineffective edits and can focus on unsolved error patterns.
### Epoch-Level Mechanisms
This branch supports three optional epoch-level mechanisms.
#### Slow Update
At the end of each epoch, `slow_update` compares the previous epochs terminal skill and current epochs terminal skill on a sampled train subset. It then writes longitudinal guidance into a protected slow-update region inside the skill document.
Importantly, this guidance is **not** blindly written through. It is converted into a candidate skill and sent through the same selection gate as step-level updates.
#### Meta Skill
`meta_skill` is teacher-side cross-epoch memory. It does not directly edit the current skill. Instead, it writes a compact memory artifact describing longer-term patterns across adjacent epochs. That memory is loaded into later reflection / merge / ranking calls as extra context.
#### Meta Reflect
`meta_reflect` runs at epoch end over the step history of the current epoch. It looks at accepted and rejected directions from the whole epoch, proposes higher-level patch edits, applies them to a meta candidate, and then sends that candidate through the same selection gate.
## What This Branch Guarantees
The current implementation assumes the following as the mainline method contract:
- gate validation is always on
- the current skill, current score, best skill, and best score stay aligned
- `slow_update` is gated before being committed
- patch provenance (`source_type`, `support_count`) reaches selection
- patch application is observable through per-edit reports
- resume state is restored from `runtime_state.json` rather than inferred only from history
- all benchmark model calls go through the unified backend router
## Model Backends
All model access now goes through the split teacher/student model layer in `reflact.model`.
Supported teacher backends:
- `openai_chat`
- `claude_chat`
Supported student backends:
- `openai_chat`
- `claude_chat`
- `codex_exec`
- `claude_code_exec`
Recommended config shape:
```yaml
model:
teacher_backend: openai_chat
student_backend: codex_exec
teacher: gpt-5.4
student: gpt-5.4-codex
reasoning_effort: medium
```
Legacy `model.backend` and CLI flags like `--backend codex` still work. They are mapped onto the split backend model for backward compatibility.
The same routing is used by:
- training (`scripts/train.py`)
- eval-only runs (`scripts/eval_only.py`)
- SpreadsheetBench standalone prompt eval scripts
- LiveMathematicianBench baseline eval script
- benchmark rollout code inside the main framework
### Azure OpenAI
If you use `openai_chat`, configure either environment variables or config values:
```bash
export AZURE_OPENAI_ENDPOINT="https://your-endpoint.openai.azure.com/"
export AZURE_OPENAI_API_KEY="your-api-key"
export AZURE_OPENAI_API_VERSION="2025-04-01-preview"
```
The config supports both the old keys and the new explicit names:
```yaml
model:
azure_openai_endpoint: "..."
azure_openai_api_version: "..."
azure_openai_api_key: ""
azure_openai_auth_mode: api_key
azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
azure_openai_managed_identity_client_id: ""
```
`azure_openai_auth_mode` can be used for API-key auth or Azure AD / managed identity flows.
### Exec Harness
`codex_exec` and `claude_code_exec` run the student inside a workspace harness instead of a plain chat call. The harness writes task files, renders a dynamic `SKILL.md`, runs the student CLI, and saves raw execution artifacts such as:
- `codex_raw.txt`
- `codex_trace_summary.txt`
- workspace-local task / skill files
This branch keeps `meta_skill` and `apply_patch_with_report`, while upgrading the student path to the more realistic workspace-exec setup.
### Trace-Aware Deep Reflect
When `student_backend=codex_exec` and `gradient.use_deep_reflect=true`, deep reflection can probe a specific earlier Codex attempt:
- the teacher sees a compact Codex trace summary
- deep probe can target `probe_target_id`
- the follow-up rollout can resume from `probe_after_step`
This is wired for the dataset-backed environments in this branch.
### Rewrite Mode
Skill updates support two modes:
- `optimizer.skill_update_mode=patch`
- `optimizer.skill_update_mode=rewrite_from_suggestions`
`patch` keeps the existing fine-grained edit application path and still records `edit_apply_report.json`.
`rewrite_from_suggestions` asks the teacher to emit higher-level rewrite suggestions, then rewrites the whole skill in one pass. This is useful when patch edits become too fragmented.
## Repository Layout
```text
reflact/
engine/
trainer.py main training loop
gradient/
reflect.py minibatch reflection
aggregate.py hierarchical patch merge
deep_probe.py diagnostic probing for deep reflect
optimizer/
clip.py edit ranking / selection
skill.py patch application + apply report
slow_update.py epoch-level longitudinal guidance
meta_skill.py teacher-side cross-epoch memory
meta_reflect.py epoch-level macro editing
evaluation/
gate.py pure gate decision logic
model/
backend_config.py teacher/student backend routing
azure_openai.py Azure backend
codex_harness.py workspace exec harness + Codex trace parsing
claude_backend.py Claude backend
envs/
... environment adapters and rollout logic
scripts/
train.py unified training entry
eval_only.py evaluate one skill without training
configs/
_base_/default.yaml shared defaults
<env>/default.yaml environment-specific configs
```
## Configuration
Configs use structured YAML with `_base_` inheritance.
The base config is `configs/_base_/default.yaml`. Key defaults in this branch are:
- `model.teacher_backend = openai_chat`
- `model.student_backend = openai_chat`
- `model.reasoning_effort = medium`
- `optimizer.use_slow_update = true`
- `optimizer.use_meta_skill = true`
- `optimizer.use_meta_reflect = false`
- `gradient.use_deep_reflect = false`
- `optimizer.skill_update_mode = patch`
Default setting snapshot:
```yaml
model:
backend: azure_openai
teacher: gpt-5.4
student: gpt-5.4
teacher_backend: openai_chat
student_backend: openai_chat
reasoning_effort: medium
rewrite_reasoning_effort: ""
rewrite_max_completion_tokens: 64000
codex_exec_path: codex
codex_exec_sandbox: workspace-write
codex_exec_profile: ""
codex_exec_full_auto: false
codex_exec_reasoning_effort: none
claude_code_exec_path: claude
claude_code_exec_profile: ""
codex_trace_to_teacher: true
train:
num_epochs: 4
train_size: 0
batch_size: 80
accumulation: 1
seed: 42
gradient:
minibatch_size: 16
merge_batch_size: 16
analyst_workers: 16
max_analyst_rounds: 3
failure_only: false
use_deep_reflect: false
deep_reflect_failures: 4
deep_reflect_successes: 2
optimizer:
learning_rate: 8
min_learning_rate: 2
lr_scheduler: cosine
skill_update_mode: patch
use_meta_reflect: false
meta_learning_rate: 8
use_slow_update: true
slow_update_samples: 20
use_meta_skill: true
evaluation:
use_gate: true
sel_env_num: 0
test_env_num: 0
eval_test: true
env:
split_mode: ratio
split_ratio: "2:1:7"
split_seed: 42
```
For the full source of truth, see [configs/_base_/default.yaml](/home/azureuser/workspace-yqh/skillopt_final/configs/_base_/default.yaml).
Selected fields:
| Section | Key | Meaning |
|---|---|---|
| `model` | `teacher_backend` | teacher backend: `openai_chat` or `claude_chat` |
| `model` | `student_backend` | student backend: chat backend or exec backend |
| `model` | `teacher` | teacher model / deployment |
| `model` | `student` | student model / deployment |
| `model` | `reasoning_effort` | reasoning budget passed to the backend when supported |
| `model` | `codex_trace_to_teacher` | include Codex trace summaries in teacher reflection context |
| `train` | `num_epochs` | number of epochs |
| `train` | `train_size` | expected train split size, or `0` to infer |
| `train` | `batch_size` | tasks per rollout batch |
| `train` | `accumulation` | number of rollout/reflect minibatches per step |
| `gradient` | `minibatch_size` | trajectories per analyst minibatch |
| `gradient` | `merge_batch_size` | patches per aggregate batch |
| `gradient` | `use_deep_reflect` | enable diagnostic probe rollouts |
| `gradient` | `max_analyst_rounds` | teacher reflection retries / refinement budget |
| `optimizer` | `learning_rate` | max edits kept after selection |
| `optimizer` | `lr_scheduler` | edit-budget scheduler |
| `optimizer` | `use_slow_update` | epoch-level longitudinal guidance |
| `optimizer` | `use_meta_skill` | teacher-side epoch memory |
| `optimizer` | `use_meta_reflect` | epoch-level macro editing |
| `optimizer` | `skill_update_mode` | `patch` or `rewrite_from_suggestions` |
| `evaluation` | `sel_env_num` | selection set size (`0` means full split) |
| `evaluation` | `test_env_num` | test set size (`0` means full split) |
### Important Branch Rule
`use_gate=false` is intentionally not supported in this branch. Gate validation is part of the method contract here.
If an old config still contains `evaluation.use_gate: false`, the loader / trainer will raise instead of silently continuing.
## Supported Environments
The main training entry and eval-only entry now register 11 environments:
| Env | Default rollout shape | Current default split / data setting | Branch alignment |
|---|---|---|---|
| `alfworld` | environment-backed episodic rollout | native ALFWorld train/eval splits | in `reflact_new_zzw` |
| `babyvision` | single-round multimodal QA | `split_mode=ratio` from raw metadata/images, or prepared `split_dir` | in `reflact_new_zzw` |
| `docvqa` | single-round multimodal QA | `split_dir: data/docvqa_split` | in `reflact_new_zzw` |
| `livemathematicianbench` | single-round QA | `split_mode=ratio` or prepared `split_dir` | in `reflact_new_zzw` |
| `mathverse` | single-round multimodal math QA | `data_root: data/MathVerse`, split files loaded from `split_dir` when provided | in `reflact_new_zzw` |
| `mmrb` | single-round multimodal reasoning QA | `split_mode=ratio` or prepared `split_dir` | in `reflact_new_zzw` |
| `officeqa` | multi-turn tool loop | `split_dir: data/officeqa_split` plus `data_dirs: [data/officeqa_docs_official]` | in `reflact_new_zzw` |
| `sealqa` | multi-turn tool loop | `split_dir: data/sealqa_split` | in `reflact_new_zzw` |
| `searchqa` | single-round QA (`max_turns=1`) | `split_dir: data/searchqa_split` | in `reflact_new_zzw` |
| `spreadsheetbench` | codegen loop, default `mode=multi`, `max_turns=30` | `split_dir: data/spreadsheetbench_split`, `data_root: data/spreadsheetbench_verified_400` | in `reflact_new_zzw`, default adjusted here to multi-round |
| `swebench` | mini-swe-agent multi-step bug-fixing rollout | `split_mode=ratio`, `dataset_name=lite`, repo-stratified `2:1:7` split materialized under `out_root/_generated_splits/...` unless `split_dir` is provided | added here, aligned to `swe-bench-old` |
## Data Expectations
The standard two-mode dataset entry path is:
- `split_mode: ratio`
- load raw data from `env.data_path`
- build a deterministic `train/`, `val/`, `test/` split under `env.split_output_dir` (or under `out_root/_generated_splits/` if unset)
- default ratio is explicitly `2:1:7`
- `split_mode: split_dir`
- load an existing `env.split_dir` with `train/`, `val/`, `test/` subdirectories
This currently applies to:
- `searchqa`
- `spreadsheetbench`
- `babyvision`
- `livemathematicianbench`
- `mmrb`
- `swebench`
`ALFWorld` is the exception: it is environment-backed rather than JSON split-backed.
The following environments currently expect prepared split directories or extra rooted assets rather than the generic ratio-split path:
- `docvqa`
- `mathverse`
- `officeqa`
- `sealqa`
At a high level:
- `SearchQA`: raw QA json / jsonl or pre-split QA json files
- `SpreadsheetBench`: raw task manifest json plus spreadsheet task directory, or a pre-split task manifest
- `ALFWorld`: installed game environment and configured eval/train splits
- `BabyVision`: raw `meta_data.jsonl` plus images, or a pre-split directory
- `DocVQA`: pre-split CSV / JSON data under `split_dir`
- `LiveMathematicianBench`: raw monthly QA json files, or a pre-split directory
- `MathVerse`: split files plus `data_root` image assets
- `MMRB`: raw extracted dataset json files, or a pre-split directory
- `OfficeQA`: pre-split metadata plus resolved office document directories
- `SealQA`: pre-split metadata for tool-augmented QA tasks
- `SWEBench`: HuggingFace SWE-bench dataset alias (`lite` / `verified` / `full`) or a prepared split directory
### Split References Across Branches
The split-related defaults are not identical across `skillopt-final`, `reflact_new_zzw`, `gepa`, and `swe-bench-old`. The practical reference points are:
| Source branch | Explicit split settings / dirs |
|---|---|
| `skillopt-final` | `searchqa -> data/searchqa_split`; `spreadsheetbench -> data/spreadsheetbench_split`; `docvqa -> data/docvqa_split`; `officeqa -> data/officeqa_split`; `sealqa -> data/sealqa_split`; `swebench -> ratio split 2:1:7 over the default lite dataset, materialized under out_root/_generated_splits/...` |
| `reflact_new_zzw` | Same 10-benchmark env set as above except no `swebench`; explicit split dirs are `data/searchqa_split`, `data/spreadsheetbench_split`, `data/docvqa_split`, `data/officeqa_split`, `data/sealqa_split`; `spreadsheetbench` there defaults to `mode=single`; `officeqa` uses `max_tool_turns=24`; `sealqa` uses `max_tool_turns=12` |
| `gepa` | `configs/spreadsheetbench.yaml` uses `data.splits_dir = data/spreadsheetbench/splits`, `eval.mode = react`, `eval.max_turns = 20`; `configs/swebench.yaml` uses `dataset = SWE-bench/SWE-bench_Verified` with `train_size = 100`, `val_size = 50`, `test_size = 350` |
| `swe-bench-old` | Repo-stratified `2:1:7` split over `SWE-Bench_Lite`, persisted as `outputs/.../split/train.json`, `selection.json`, `test.json`; the example split in that branch is `train=60`, `selection=33`, `test=207` |
For the 10 benches shared with `reflact_new_zzw`, the current branch is now aligned on env coverage. The main intentional delta is `spreadsheetbench`: this branch defaults to multi-round codegen, while `reflact_new_zzw` kept `mode=single` by default.
## Running Training
Example:
```bash
python scripts/train.py --config configs/searchqa/default.yaml
```
Explicit 2:1:7 split from raw data:
```bash
python scripts/train.py \
--config configs/searchqa/default.yaml \
--split_mode ratio \
--data_path /path/to/searchqa_train_2000.json
```
Directly consume a prepared split directory:
```bash
python scripts/train.py \
--config configs/searchqa/default.yaml \
--split_mode split_dir \
--split_dir /path/to/searchqa_split
```
You can override structured config keys from the CLI:
```bash
python scripts/train.py \
--config configs/spreadsheetbench/default.yaml \
--cfg-options model.teacher_backend=openai_chat model.student_backend=codex_exec train.batch_size=40 optimizer.learning_rate=4
```
Legacy flat overrides still work for common keys:
```bash
python scripts/train.py \
--config configs/searchqa/default.yaml \
--backend azure_openai \
--teacher_model gpt-5.4 \
--student_model gpt-5.4 \
--reasoning_effort medium
```
Exec harness example:
```bash
python scripts/train.py \
--config configs/searchqa/default.yaml \
--teacher_backend openai_chat \
--student_backend codex_exec \
--teacher_model gpt-5.4 \
--student_model gpt-5.4-codex \
--use_deep_reflect true \
--skill_update_mode rewrite_from_suggestions
```
SWEBench example:
```bash
python scripts/train.py \
--config configs/swebench/default.yaml \
--cfg-options env.dataset_name=lite env.split_ratio=2:1:7
```
## Eval-Only and Standalone Evaluation
Evaluate a specific skill without training:
```bash
python scripts/eval_only.py \
--config configs/searchqa/default.yaml \
--skill reflact/envs/searchqa/skills/initial.md
```
The same dataset entry modes apply in eval-only runs:
- `--split_mode ratio --data_path ...`
- `--split_mode split_dir --split_dir ...`
Standalone scripts also exist for benchmark-specific comparisons, including:
- `scripts/eval_prompt_custom.py`
- `scripts/eval_prompt_official.py`
- `scripts/eval_livemathematicianbench_baseline.py`
These scripts now also support backend selection through the unified model layer.
## Output Structure
Each run writes a structured output directory under `out_root`.
Important top-level artifacts:
- `config.json` — flattened runtime config
- `history.json` — per-step history records
- `runtime_state.json` — resume state for current/best skill tracking
- `best_skill.md` — current best validated skill
- `skills/skill_vXXXX.md` — persisted skill snapshot per step
Per-step artifacts live under `steps/step_XXXX/`, including:
- `merged_patch.json`
- `ranked_edits.json`
- `candidate_skill.md`
- `edit_apply_report.json`
- `rewrite_result.json` when rewrite mode is enabled
- `selection_eval/`
- `trajectory_digest.json`
- rollout and patch subdirectories
Epoch-level artifacts live under:
- `slow_update/epoch_XX/`
- `meta_skill/epoch_XX/`
- `meta_reflect/epoch_XX/`
## Resume Behavior
The trainer resumes from `runtime_state.json` when present. That state tracks:
- last completed step
- current skill path
- current score
- best skill path
- best score
- origin tags for current and best skill
This is important because skill state can change at both step level and epoch level; resuming only from `history.json` is not sufficient for this branchs method logic.
## Notes
- This repository focuses on skill optimization logic; datasets are not included.
- Patch application is intentionally observable. Inspect `edit_apply_report.json` when candidate skills do not behave as expected.
- `SpreadsheetBench` now defaults to `mode=multi`. If you run an exec student backend there, override back to `env.mode=single` because exec backends are still only wired for SpreadsheetBench single-mode rollout.
- `SWEBench` follows the older mini-swe-agent + `swebench.harness.run_evaluation` path, so it requires the SWE-bench / Docker toolchain rather than the generic chat-only stack.
- `slow_update` writes into a protected skill region and normal edits are prevented from overwriting that region directly.
- `meta_skill` is context memory, not a direct skill edit.
- `meta_reflect` is a gated skill edit stage, not just logging.
## Minimal Setup
```bash
conda create -n reflact python=3.11
conda activate reflact
pip install openai pyyaml openpyxl
```
Depending on the environment, you may also need:
```bash
pip install datasets gymnasium numpy ray regex
```
For `SWEBench`, you also need a working Docker environment plus the SWE-bench / mini-swe-agent dependencies used in `swe-bench-old`.

View File

@@ -0,0 +1,93 @@
# ReflACT default configuration — base for all environments.
# Environment configs should inherit via: _base_: default.yaml
model:
backend: azure_openai
teacher: gpt-5.5
student: gpt-5.5
teacher_backend: openai_chat
student_backend: openai_chat
reasoning_effort: medium
rewrite_reasoning_effort: ""
rewrite_max_completion_tokens: 64000
codex_exec_path: codex
codex_exec_sandbox: workspace-write
codex_exec_profile: ""
codex_exec_full_auto: false
codex_exec_reasoning_effort: none
codex_exec_use_sdk: auto
codex_exec_network_access: false
codex_exec_web_search: false
codex_exec_approval_policy: never
claude_code_exec_path: claude
claude_code_exec_profile: ""
claude_code_exec_use_sdk: auto
claude_code_exec_effort: medium
claude_code_exec_max_thinking_tokens: 16384
codex_trace_to_teacher: true
azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/"
azure_openai_api_version: "2024-12-01-preview"
azure_openai_api_key: "" # Fill locally if you do not export AZURE_OPENAI_API_KEY
azure_openai_auth_mode: azure_cli
azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
azure_openai_managed_identity_client_id: ""
teacher_azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/"
teacher_azure_openai_api_version: "2024-12-01-preview"
teacher_azure_openai_api_key: ""
teacher_azure_openai_auth_mode: azure_cli
teacher_azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
teacher_azure_openai_managed_identity_client_id: ""
student_azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/"
student_azure_openai_api_version: "2024-12-01-preview"
student_azure_openai_api_key: ""
student_azure_openai_auth_mode: azure_cli
student_azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
student_azure_openai_managed_identity_client_id: ""
train:
num_epochs: 4
train_size: 0 # 0 = derive from dataset split when available
batch_size: 40
accumulation: 1
seed: 42
gradient:
minibatch_size: 8
merge_batch_size: 8
analyst_workers: 16
max_analyst_rounds: 3
failure_only: false
use_deep_reflect: false
deep_reflect_failures: 4
deep_reflect_successes: 2
optimizer:
learning_rate: 4 # max edits per step (edit_budget)
min_learning_rate: 2 # min edits for decay schedulers
lr_scheduler: cosine # constant / linear / cosine / autonomous
lr_control_mode: fixed # fixed / autonomous / none
skill_update_mode: patch # patch / rewrite_from_suggestions / full_rewrite_minibatch
use_meta_reflect: false
meta_learning_rate: 4 # max edits per epoch-level meta-reflect
use_slow_update: true
slow_update_samples: 20
longitudinal_pair_policy: mixed # mixed / changed / unchanged
use_meta_skill: true
evaluation:
use_gate: true
sel_env_num: 0
test_env_num: 0
eval_test: true
env:
name: ""
skill_init: ""
split_mode: ratio # ratio = build deterministic split from data_path; split_dir = use pre-split train/val/test
split_ratio: "2:1:7" # explicit default for dataset-backed benchmarks: train:val:test
split_seed: 42
split_dir: ""
data_path: ""
split_output_dir: ""
exec_timeout: 120 # per student model/code-agent call timeout in seconds
out_root: ""

View File

@@ -0,0 +1,305 @@
# Ablation Study Configuration Manifest
This folder records the final, reproducible settings for the ablation runs used
in `docs/ablation_paper_tables.md`.
It is intentionally separate from the benchmark default configs. The benchmark
configs under `configs/<benchmark>/default.yaml` remain the source task configs;
this folder records the exact matrix-level overrides, run roots, launch commands,
and validation rules used for the paper ablations.
## Files
- `matrix.yaml`: canonical ablation matrix, common overrides, benchmark splits,
token/output caps, and invalid-run rules.
- `launch_commands.sh`: exact launcher commands for the valid run roots.
- `validation.md`: monitoring, result extraction, and invalidation checklist.
## Source Of Truth
Use the matrix launcher:
```bash
/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python scripts/run_ablation_matrix.py
```
The launcher builds runs from the same defaults and values recorded in
`matrix.yaml`. It skips completed runs by checking `summary.json` and skips
active runs by checking `env.out_root` in active `scripts/train.py` processes.
Do not manually rerun a completed run into the same `env.out_root`. If a run is
invalid, archive or remove its output directory first, then let the launcher
start it cleanly.
## Current Correct Run Roots
- SearchQA / SpreadsheetBench original ablations:
`outputs/ablation_20260502_040604_unique48`
- SearchQA / SpreadsheetBench batch-size ablations:
`outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run`
- LiveMathBench / ALFWorld clean ablations:
`outputs/ablation_livemath_alfworld_clean_20260503_155155_run`
- DocVQA ablations:
`outputs/ablation_docvqa_20260503_160225_run`
Archived, superseded, misaligned, dry-run, or pre-fix directories must not be
used for paper tables.
## End-To-End Runbook
### Environment
Run from the repository root:
```bash
cd /home/azureuser/workspace-gzy/SkillReflection
```
Always use:
```bash
PY=/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python
export ALFWORLD_DATA=/home/azureuser/.cache/alfworld
```
Default model/auth settings are generated by `scripts/run_ablation_matrix.py`:
```text
teacher=gpt-5.5
student=gpt-5.5
teacher_backend=openai_chat
student_backend=openai_chat
reasoning_effort=medium
teacher/student endpoint=https://t2vgoaigpt4o3.openai.azure.com/
teacher/student api_version=2024-12-01-preview
teacher/student auth_mode=azure_cli
```
Core training settings:
```text
train.num_epochs=4
train.train_size=0
train.batch_size=40
train.accumulation=1
train.seed=42
gradient.minibatch_size=8
gradient.merge_batch_size=8
gradient.analyst_workers=16
gradient.use_deep_reflect=false
optimizer.learning_rate=4
optimizer.min_learning_rate=2
optimizer.lr_scheduler=cosine
optimizer.lr_control_mode=fixed
optimizer.use_slow_update=true
optimizer.slow_update_samples=20
optimizer.use_meta_skill=true
optimizer.use_meta_reflect=false
optimizer.longitudinal_pair_policy=mixed
evaluation.use_gate=true
evaluation.eval_test=true
env.split_mode=split_dir
```
`train.train_size=0` is intentional. The dataloader derives the train size from
the fixed split. Batch-size ablations rely on the default `ceil(train_size /
batch_size)` behavior; the last batch can be smaller than `train.batch_size`.
### Fixed Splits
Default split directories:
```text
searchqa: data/ablation_splits/searchqa/2-1-7_seed42
spreadsheetbench: data/ablation_splits/spreadsheetbench/2-1-7_seed42
livemathematicianbench: data/ablation_splits/livemathematicianbench/2-1-7_seed42
alfworld: data/ablation_splits/alfworld/2-1-7_seed42
docvqa: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
```
Default train/val/test sizes:
| Benchmark | Train | Val | Test |
| --- | ---: | ---: | ---: |
| SearchQA | 400 | 200 | 1400 |
| SpreadsheetBench | 80 | 40 | 280 |
| LiveMathBench | 35 | 18 | 124 |
| ALFWorld | 39 | 18 | 134 |
| DocVQA | 1070 | 535 | 3744 |
DocVQA images are not copied. The valid setup uses:
```text
data/docvqa_images -> /home/azureuser/zisu/SkillReflection/data/docvqa_images
```
2026-05-05 DocVQA data correction: all DocVQA final reruns should use the zisu
10% split above and a fresh output root such as
`outputs/ablation_docvqa_zisu10pct_20260505_run`. The older local
`data/ablation_splits/docvqa/2-1-7_seed42` contains the same 5349 questionId
pool but a different train/val/test assignment, so its completed summaries are
historical only.
### Matrix Groups
Use these group names with `scripts/run_ablation_matrix.py`:
```text
default split batch mbs lr sched slown mod smodel longpair lrctrl
```
`longpair` is the slow-update/meta-skill comparison-example ablation. It keeps
all prompts and training settings unchanged and only overrides:
```text
optimizer.longitudinal_pair_policy=changed
optimizer.longitudinal_pair_policy=unchanged
```
The default paper setting remains `mixed`.
`lrctrl` contains the two learning-rate-control baselines:
```text
optimizer.lr_control_mode=autonomous
optimizer.lr_control_mode=none + optimizer.skill_update_mode=full_rewrite_minibatch
```
The autonomous run logs the chosen integer per step in `lr_decision.json` and
`lr_history.jsonl`. The full-rewrite run removes the LR/edit-selection concept:
each minibatch analyst produces a complete skill candidate, and aggregate/merge
produces the candidate skill directly.
Batch-size values are:
```text
8 / 24 / 40 / 56 / full
```
`40` is the default point. `full` expands to the benchmark train size.
### Launch Commands Used In This Session
The exact commands are recorded in `launch_commands.sh` and in
`docs/ablation_plan.md`. The important current policy is:
- SearchQA / SpreadsheetBench batch-only matrix can run at `--max-parallel 8`.
- DocVQA matrix can run with its launcher at `--max-parallel 8`; later top-up used `--max-parallel 16` only because completed runs were skipped and active roots were checked.
- LiveMathBench is safe as API-only benchmark after the token cap fix.
- ALFWorld must not be mixed into a 24-way run on this shared machine. Use `--bench alfworld --max-parallel 1` only after memory is available.
### Token And Timeout Fixes
LiveMathBench must use a large student completion cap:
```text
max_completion_tokens=16384
timeout=300
```
The old 768/512 cap produced many empty visible responses because hidden
reasoning consumed the budget.
ALFWorld must use:
```text
max_completion_tokens=2048
empty response fallback -> <action>look</action>
missing action fallback -> <action>look</action>
```
### Invalid Runs
Never fill paper tables from these archive directories:
```text
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_livemath_token768_20260504_022258/
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_token512_20260504_021417/
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_empty_action_20260504_025311/
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_prefallback_20260504_025402/
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_oom_partial_20260504_050517/
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_serial_lowmem_20260504_1300/
```
### Current Resource Notes
ALFWorld model calls are API calls. The ablation branch now creates local
ALFWorld/TextWorld environments through multiprocessing workers ported from
`skillopt_final_zzw`, not through Ray actors. Old Ray-based archived runs are
not valid for table fill. The observed historical failure mode in this session
was system RAM pressure and Ray OOM prevention, not model GPU memory.
GPU memory currently shown by `nvidia-smi` came from unrelated Ray Serve visual
models under:
```text
/home/azureuser/workspace-gzy/zyf/gca-skill
```
Those processes are `GroundingDINOModel` / `DA3Model`, not the
SkillReflection ablation ALFWorld run.
There are also unrelated ALFWorld jobs under:
```text
/home/azureuser/zisu/skill_distill
```
Do not confuse those with this repository's ablation outputs.
### Monitoring
Active run and duplicate output-root check:
```bash
$PY - <<'PY'
import subprocess, re, collections, time
try:
raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True)
except subprocess.CalledProcessError:
raw = ""
roots = []
for line in raw.splitlines():
m = re.search(r"env\.out_root=([^\s]+)", line)
if m:
roots.append(m.group(1))
ctr = collections.Counter(roots)
print("time", time.strftime("%F %T"))
print("active_count", len(roots))
print("duplicates", [r.rsplit("/", 1)[-1] for r, c in ctr.items() if c > 1])
for root in sorted(roots):
print(root.rsplit("/", 1)[-1])
PY
```
Error scan:
```bash
rg -n "Traceback|ERROR|Error code|AuthenticationError|BadRequest|RateLimit|content_filter|Killed|OutOfMemory|CUDA out of memory|\\[FAIL\\]|LLM call failed" \
outputs/ablation_docvqa_20260503_160225_run/logs \
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/logs \
outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/logs \
-g '*.log' | tail -100 || true
```
Resource checks:
```bash
free -h | sed -n '1,3p'
df -h /tmp
du -sh /tmp/ray 2>/dev/null || true
nvidia-smi
```
### Filling Tables
Only use top-level `summary.json` from valid run roots. Fill
`docs/ablation_paper_tables.md` from:
```text
best_selection_hard
baseline_test_hard
test_hard
test_delta_hard
token_summary._total.total_tokens
```

View File

@@ -0,0 +1,81 @@
#!/usr/bin/env bash
set -euo pipefail
cd /home/azureuser/workspace-gzy/SkillReflection
PY=/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python
export ALFWORLD_DATA=/home/azureuser/.cache/alfworld
# Original SearchQA / SpreadsheetBench full matrix reproduction command.
# Do not run this into the existing root unless intentionally reproducing from
# scratch; the current valid root is already populated:
# outputs/ablation_20260502_040604_unique48
#
# setsid "$PY" scripts/run_ablation_matrix.py \
# --groups default split mbs lr sched slown mod smodel \
# --bench searchqa spreadsheetbench \
# --run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_20260502_040604_unique48 \
# --max-parallel 24 \
# --execute \
# > /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_20260502_040604_unique48/launcher_reproduce_full_matrix.log 2>&1 < /dev/null &
#
# SearchQA / SpreadsheetBench batch-size ablations only.
# Original non-batch SearchQA/SpreadsheetBench ablations live in:
# outputs/ablation_20260502_040604_unique48
setsid "$PY" scripts/run_ablation_matrix.py \
--groups batch \
--bench searchqa spreadsheetbench \
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run \
--max-parallel 8 \
--execute \
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/launcher_parallel8.log 2>&1 < /dev/null &
# DocVQA full matrix.
setsid "$PY" scripts/run_ablation_matrix.py \
--groups default split batch mbs lr sched slown mod smodel \
--bench docvqa \
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_docvqa_20260503_160225_run \
--max-parallel 8 \
--execute \
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_docvqa_20260503_160225_run/launcher_parallel8.log 2>&1 < /dev/null &
# LiveMathBench clean matrix. ALFWorld should be launched separately at lower
# concurrency because Ray OOM occurred when many ALFWorld runs were mixed into a
# 24-way run.
setsid "$PY" scripts/run_ablation_matrix.py \
--groups default split batch mbs lr sched slown mod smodel \
--bench livemathematicianbench \
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run \
--max-parallel 8 \
--execute \
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_livemath_parallel8.log 2>&1 < /dev/null &
# ALFWorld clean matrix. Increase to 2 only after checking memory, /tmp/ray,
# and that no other ALFWorld run is active. Do not use 8/16/24 for ALFWorld on
# the current shared machine unless resources are explicitly reserved.
setsid "$PY" scripts/run_ablation_matrix.py \
--groups default split batch mbs lr sched slown mod smodel \
--bench alfworld \
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run \
--max-parallel 1 \
--execute \
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_alfworld_parallel1.log 2>&1 < /dev/null &
# Longitudinal comparison-example policy ablations. This intentionally excludes
# ALFWorld. The only varied setting is optimizer.longitudinal_pair_policy.
setsid "$PY" scripts/run_ablation_matrix.py \
--groups longpair \
--bench searchqa spreadsheetbench livemathematicianbench docvqa \
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_longpair_20260504_run \
--max-parallel 8 \
--execute \
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_longpair_20260504_run/launcher_longpair_parallel8.log 2>&1 < /dev/null &
# Learning-rate-control baselines. This intentionally excludes ALFWorld.
setsid "$PY" scripts/run_ablation_matrix.py \
--groups lrctrl \
--bench searchqa spreadsheetbench livemathematicianbench docvqa \
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_lrctrl_20260504_run \
--max-parallel 8 \
--execute \
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_lrctrl_20260504_run/launcher_lrctrl_parallel8.log 2>&1 < /dev/null &

View File

@@ -0,0 +1,257 @@
version: 2026-05-04
purpose: "Canonical paper ablation settings matching the valid current runs."
launcher:
script: scripts/run_ablation_matrix.py
python: /home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python
skip_completed_by: summary.json
skip_active_by: "active scripts/train.py env.out_root"
environment:
working_directory: /home/azureuser/workspace-gzy/SkillReflection
required_env:
ALFWORLD_DATA: /home/azureuser/.cache/alfworld
docvqa_images_symlink: "data/docvqa_images -> /home/azureuser/zisu/SkillReflection/data/docvqa_images"
common_overrides:
model.teacher_backend: openai_chat
model.student_backend: openai_chat
model.teacher: gpt-5.5
model.student: gpt-5.5
model.teacher_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/
model.teacher_azure_openai_api_version: 2024-12-01-preview
model.teacher_azure_openai_auth_mode: azure_cli
model.student_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/
model.student_azure_openai_api_version: 2024-12-01-preview
model.student_azure_openai_auth_mode: azure_cli
model.reasoning_effort: medium
train.num_epochs: 4
train.train_size: 0
train.batch_size: 40
train.accumulation: 1
train.seed: 42
gradient.minibatch_size: 8
gradient.merge_batch_size: 8
gradient.analyst_workers: 16
gradient.use_deep_reflect: false
optimizer.learning_rate: 4
optimizer.min_learning_rate: 2
optimizer.lr_scheduler: cosine
optimizer.skill_update_mode: patch
optimizer.use_slow_update: true
optimizer.slow_update_samples: 20
optimizer.use_meta_skill: true
optimizer.use_meta_reflect: false
evaluation.use_gate: true
evaluation.eval_test: true
env.split_mode: split_dir
benchmarks:
searchqa:
config: configs/searchqa/default.yaml
run_roots:
original_matrix: outputs/ablation_20260502_040604_unique48
batch_matrix: outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run
default_split: data/ablation_splits/searchqa/2-1-7_seed42
train: 400
val: 200
test: 1400
student_rollout:
function: reflact/envs/searchqa/rollout.py::chat_student
max_completion_tokens:
first_turn: 512
refinement: 512
rationale: "Short-answer QA; sampled empties are low and not LiveMath-like."
spreadsheetbench:
config: configs/spreadsheetbench/default.yaml
run_roots:
original_matrix: outputs/ablation_20260502_040604_unique48
batch_matrix: outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run
default_split: data/ablation_splits/spreadsheetbench/2-1-7_seed42
train: 80
val: 40
test: 280
student_rollout:
function: reflact/envs/spreadsheetbench/codegen_agent.py::run_multi
max_output_tokens: 16384
result_note: "results.jsonl stores execution fields, not a response field."
livemathematicianbench:
config: configs/livemathematicianbench/default.yaml
run_roots:
clean_matrix: outputs/ablation_livemath_alfworld_clean_20260503_155155_run
default_split: data/ablation_splits/livemathematicianbench/2-1-7_seed42
train: 35
val: 18
test: 124
student_rollout:
function: reflact/envs/livemathematicianbench/rollout.py::chat_student
max_completion_tokens:
first_turn: 16384
refinement: 16384
timeout_seconds: 300
invalid_old_caps:
first_turn: 768
refinement: 512
invalid_archive: outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_livemath_token768_20260504_022258
rationale: "GPT-5 reasoning consumed small budgets and produced many empty visible responses."
alfworld:
config: configs/alfworld/default.yaml
run_roots:
clean_matrix: outputs/ablation_livemath_alfworld_clean_20260503_155155_run
default_split: data/ablation_splits/alfworld/2-1-7_seed42
train: 39
val: 18
test: 134
student_rollout:
function: reflact/envs/alfworld/rollout.py::chat_student
max_completion_tokens: 2048
timeout_seconds: 120
max_steps: 50
fallback_action: look
invalid_old_cap: 512
invalid_archives:
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_token512_20260504_021417
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_empty_action_20260504_025311
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_prefallback_20260504_025402
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_oom_partial_20260504_050517
concurrency_note: "Do not mix many ALFWorld runs into 24-way total concurrency; Ray OOM occurred. Prefer 1-2 ALFWorld runs at a time unless resources are clearly free."
docvqa:
config: configs/docvqa/default.yaml
run_roots:
matrix: outputs/ablation_docvqa_zisu10pct_20260505_run
default_split: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
train: 1070
val: 535
test: 3744
data_note: "2026-05-05: use zisu-provided 5349-item DocVQA split directly; previous local 2-1-7_seed42 used the same item pool but a different train/val/test assignment and must not be used for final DocVQA reruns."
student_rollout:
function: reflact/envs/docvqa/rollout.py::chat_student_messages
max_completion_tokens:
first_turn: 768
refinement: 512
rationale: "Short-answer VQA output; preserve current setting for alignment unless explicitly rerunning all affected DocVQA."
splits:
tags:
1shot:
extra_overrides:
optimizer.slow_update_samples: 1
1-1-8: {}
2-1-7:
default: true
4-1-5: {}
paths:
searchqa:
1shot: data/ablation_splits/searchqa/1shot_seed42
1-1-8: data/ablation_splits/searchqa/1-1-8_seed42
2-1-7: data/ablation_splits/searchqa/2-1-7_seed42
4-1-5: data/ablation_splits/searchqa/4-1-5_seed42
spreadsheetbench:
1shot: data/ablation_splits/spreadsheetbench/1shot_seed42
1-1-8: data/ablation_splits/spreadsheetbench/1-1-8_seed42
2-1-7: data/ablation_splits/spreadsheetbench/2-1-7_seed42
4-1-5: data/ablation_splits/spreadsheetbench/4-1-5_seed42
livemathematicianbench:
1shot: data/ablation_splits/livemathematicianbench/1shot_seed42
1-1-8: data/ablation_splits/livemathematicianbench/1-1-8_seed42
2-1-7: data/ablation_splits/livemathematicianbench/2-1-7_seed42
4-1-5: data/ablation_splits/livemathematicianbench/4-1-5_seed42
alfworld:
1shot: data/ablation_splits/alfworld/1shot_seed42
1-1-8: data/ablation_splits/alfworld/1-1-8_seed42
2-1-7: data/ablation_splits/alfworld/2-1-7_seed42
4-1-5: data/ablation_splits/alfworld/4-1-5_seed42
docvqa:
1shot: data/ablation_splits/docvqa/1shot_seed42
1-1-8: data/ablation_splits/docvqa/1-1-8_seed42
2-1-7: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
4-1-5: data/ablation_splits/docvqa/4-1-5_seed42
groups:
default:
run_id: "DEFAULT-{benchmark}-5.5"
overrides: {}
split:
values: [1shot, 1-1-8, 4-1-5]
skip_default_2_1_7: true
override_template: "env.split_dir={split_path}"
batch:
values: [8, 24, 56, full]
default_value_reused: 40
full_values:
searchqa: 400
spreadsheetbench: 80
livemathematicianbench: 35
alfworld: 39
docvqa: 1070
fixed_overrides:
gradient.minibatch_size: 8
mbs:
values: [1, 2, 4, 16, 32]
default_value_reused: 8
override_template: "gradient.minibatch_size={value}"
lr:
values: [1, 2, 4, 8, 16]
fixed_overrides:
optimizer.lr_scheduler: constant
optimizer.min_learning_rate: 1
override_template: "optimizer.learning_rate={value}"
sched:
values: [constant, linear]
default_value_reused: cosine
override_template: "optimizer.lr_scheduler={value}"
slown:
values: [5, 10, 40]
default_value_reused: 20
override_template: "optimizer.slow_update_samples={value}"
mod:
values:
slow-only:
optimizer.use_slow_update: true
optimizer.use_meta_skill: false
meta-only:
optimizer.use_slow_update: false
optimizer.use_meta_skill: true
none:
optimizer.use_slow_update: false
optimizer.use_meta_skill: false
default_value_reused: slow-meta
longpair:
values: [changed, unchanged]
default_value_reused: mixed
override_template: "optimizer.longitudinal_pair_policy={value}"
note: "Only changes slow-update/meta-skill comparison examples; prompts and other settings remain unchanged."
lrctrl:
values:
autonomous:
optimizer.lr_control_mode: autonomous
full-rewrite:
optimizer.lr_control_mode: none
optimizer.skill_update_mode: full_rewrite_minibatch
default_value_reused: "fixed patch learning_rate=4"
note: "autonomous records lr_decision.json/lr_history.jsonl; full-rewrite removes LR/select/apply-edit and uses full skill candidates."
smodel:
values:
"5.4":
model.student: gpt-5.4-pro
model.student_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/
model.student_azure_openai_api_version: 2025-03-01-preview
model.student_azure_openai_auth_mode: azure_cli
"5.4-mini":
model.student: gpt-5.4-mini
model.student_azure_openai_endpoint: https://searchagent5.cognitiveservices.azure.com/
model.student_azure_openai_api_version: 2024-12-01-preview
model.student_azure_openai_auth_mode: azure_cli
default_value_reused: "5.5"
validity_rules:
use_for_tables:
- "Only runs with summary.json in valid run roots."
- "Do not use archive, archived, MISALIGNED, SUPERSEDED, dryrun, smoke, or debug directories."
- "Do not use ALFWorld runs started before empty/missing-action fallback."
- "Do not use old LiveMath runs with 768/512 token caps."
rerun_rule: "Archive or remove invalid out_root before relaunch; never write a rerun into a polluted output directory."

View File

@@ -0,0 +1,141 @@
# Ablation Validation Checklist
Use this checklist before launch, during monitoring, and before filling
`docs/ablation_paper_tables.md`.
## Before Launch
Run from repo root:
```bash
cd /home/azureuser/workspace-gzy/SkillReflection
export ALFWORLD_DATA=/home/azureuser/.cache/alfworld
```
Verify syntax for edited files:
```bash
/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python -m py_compile \
scripts/run_ablation_matrix.py \
scripts/train.py \
reflact/model/azure_openai.py \
reflact/envs/searchqa/rollout.py \
reflact/envs/spreadsheetbench/rollout.py \
reflact/envs/livemathematicianbench/rollout.py \
reflact/envs/alfworld/rollout.py \
reflact/envs/docvqa/rollout.py
```
Check active runs and duplicate `env.out_root` before starting more:
```bash
/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python - <<'PY'
import subprocess, re, collections
try:
raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True)
except subprocess.CalledProcessError:
raw = ""
roots = []
for line in raw.splitlines():
m = re.search(r"env\.out_root=([^\s]+)", line)
if m:
roots.append(m.group(1))
ctr = collections.Counter(roots)
print("train_count", len(roots))
print("duplicate_roots", [r.rsplit("/", 1)[-1] for r, c in ctr.items() if c > 1])
for root in sorted(roots):
print(root.rsplit("/", 1)[-1])
PY
```
## During Monitoring
Check launchers:
```bash
pgrep -af 'scripts/run_ablation_matrix.py' || true
tail -80 outputs/ablation_docvqa_20260503_160225_run/launcher_parallel8.log 2>/dev/null || true
tail -80 outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_livemath_parallel8.log 2>/dev/null || true
tail -80 outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_alfworld_parallel1.log 2>/dev/null || true
```
Scan current logs for new hard failures:
```bash
rg -n "Traceback|ERROR|Error code|AuthenticationError|BadRequest|RateLimit|content_filter|Killed|OutOfMemory|\\[FAIL\\]|\\[RETRY\\]" \
outputs/ablation_docvqa_20260503_160225_run/logs \
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/logs \
outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/logs \
-g '*.log' | tail -160 || true
```
Check resource pressure:
```bash
df -h /tmp
du -sh /tmp/ray 2>/dev/null || true
free -h | sed -n '1,3p'
```
## Quality Checks
LiveMathBench current valid runs should not look like old 768/512 runs:
```bash
/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python - <<'PY'
import json, pathlib
root = pathlib.Path("outputs/ablation_livemath_alfworld_clean_20260503_155155_run")
for run in sorted(root.glob("*livemathematicianbench*")):
if not run.is_dir() or "archive" in str(run):
continue
for rel in ["test_eval_baseline/results.jsonl", "test_eval/results.jsonl"]:
p = run / rel
if not p.exists():
continue
rows = [json.loads(l) for l in p.open(errors="ignore") if l.strip()]
empty = sum(1 for r in rows if not str(r.get("response", "")).strip())
answer = sum(1 for r in rows if "<answer>" in str(r.get("response", "")).lower())
if empty:
print(run.name, rel, "empty", empty, "answer", answer, "n", len(rows))
PY
```
ALFWorld valid runs must not contain empty action or missing action:
```bash
/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python - <<'PY'
import json, pathlib
root = pathlib.Path("outputs/ablation_livemath_alfworld_clean_20260503_155155_run")
for run in sorted(root.glob("*alfworld*")):
if not run.is_dir() or "archive" in str(run):
continue
bad = []
fallback = 0
for c in run.glob("**/conversation.json"):
data = json.load(c.open(errors="ignore"))
for step in data:
if step.get("step") is None:
continue
if not step.get("action"):
bad.append(str(c.relative_to(run)))
break
mr = str(step.get("model_response", ""))
if "empty model response" in mr or "missing action tag" in mr:
fallback += 1
print(run.name, "bad_action_files", len(bad), "fallback", fallback)
PY
```
## Filling Tables
Use only `summary.json` fields:
- `best_selection_hard` -> Best Sel
- `baseline_test_hard` -> Base Test
- `test_hard` -> Best Test
- `test_delta_hard` -> Delta
- `total_accepts` -> Accept
- `total_rejects` -> Reject
- `token_summary._total.total_tokens` -> Tokens
Do not fill table rows from logs alone.

View File

@@ -0,0 +1,30 @@
_base_: ../_base_/default.yaml
train:
train_size: 0
accumulation: 1
gradient:
minibatch_size: 8
merge_batch_size: 8
optimizer:
learning_rate: 4
use_meta_reflect: false
evaluation:
sel_env_num: 0
test_env_num: 0
env:
name: alfworld
skill_init: reflact/envs/alfworld/skills/initial.md
split_mode: split_dir
split_ratio: "2:1:7"
split_dir: data/ablation_splits/alfworld/2-1-7_seed42
data_path: ""
split_output_dir: ""
max_steps: 50
workers: 8
max_api_workers: 8
limit: 0

View File

@@ -0,0 +1,4 @@
_base_: default.yaml
optimizer:
use_meta_reflect: true

View File

@@ -0,0 +1,21 @@
_base_: ../_base_/default.yaml
train:
batch_size: 64
accumulation: 1
env:
name: babyvision
skill_init: reflact/envs/babyvision/skills/initial.md
split_mode: ratio
split_ratio: "2:1:7"
split_dir: ""
data_path: ""
split_output_dir: ""
max_turns: 1
workers: 16
limit: 0
image_detail: auto
judge_model: gpt-5.4
judge_max_completion_tokens: 256
judge_retries: 5

View File

@@ -0,0 +1,28 @@
_base_: ../_base_/default.yaml
model:
reasoning_effort: medium
train:
batch_size: 40
accumulation: 1
gradient:
minibatch_size: 8
merge_batch_size: 8
optimizer:
learning_rate: 4
env:
name: docvqa
skill_init: reflact/envs/docvqa/skills/initial.md
split_mode: split_dir
split_ratio: "2:1:7"
split_dir: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
data_path: ""
split_output_dir: ""
max_turns: 1
workers: 16
image_detail: auto
limit: 0

View File

@@ -0,0 +1,22 @@
_base_: ../_base_/default.yaml
train:
train_size: 0
batch_size: 40
accumulation: 1
env:
name: livemathematicianbench
skill_init: reflact/envs/livemathematicianbench/skills/initial.md
split_mode: split_dir
split_ratio: "2:1:7"
split_dir: data/ablation_splits/livemathematicianbench/2-1-7_seed42
data_path: ""
split_output_dir: ""
max_turns: 1
exec_timeout: 300
workers: 64
limit: 0
shuffle_choices: true
use_theorem: false
use_sketch: false

View File

@@ -0,0 +1,23 @@
_base_: ../_base_/default.yaml
model:
codex_exec_sandbox: danger-full-access
train:
batch_size: 64
accumulation: 1
env:
name: mathverse
skill_init: reflact/envs/mathverse/skills/initial.md
split_dir: ""
data_root: data/MathVerse
problem_version: Text Lite
use_text_dominant_reference: false
max_turns: 1
workers: 16
limit: 0
image_detail: auto
judge_model: gpt-5.4
judge_max_completion_tokens: 256
judge_retries: 5

18
configs/mmrb/default.yaml Normal file
View File

@@ -0,0 +1,18 @@
_base_: ../_base_/default.yaml
train:
batch_size: 128
accumulation: 1
env:
name: mmrb
skill_init: reflact/envs/mmrb/skills/initial.md
split_mode: ratio
split_ratio: "2:1:7"
split_dir: ""
data_path: ""
split_output_dir: ""
max_turns: 1
workers: 16
limit: 0
image_detail: auto

View File

@@ -0,0 +1,25 @@
_base_: ../_base_/default.yaml
model:
reasoning_effort: medium
train:
batch_size: 40
accumulation: 1
gradient:
minibatch_size: 8
merge_batch_size: 8
optimizer:
learning_rate: 4
env:
name: officeqa
skill_init: reflact/envs/officeqa/skills/initial.md
split_dir: data/officeqa_split
data_dirs:
- data/officeqa_docs_official
workers: 4
max_tool_turns: 24
limit: 0

View File

@@ -0,0 +1,23 @@
_base_: ../_base_/default.yaml
model:
reasoning_effort: medium
train:
batch_size: 10
accumulation: 1
gradient:
minibatch_size: 8
merge_batch_size: 8
optimizer:
learning_rate: 4
env:
name: sealqa
skill_init: reflact/envs/sealqa/skills/initial.md
split_dir: data/sealqa_split
workers: 4
max_tool_turns: 12
limit: 0

View File

@@ -0,0 +1,32 @@
_base_: ../_base_/default.yaml
model:
reasoning_effort: medium
train:
train_size: 400
batch_size: 40
accumulation: 1
gradient:
minibatch_size: 8
merge_batch_size: 8
optimizer:
learning_rate: 4
evaluation:
sel_env_num: 0
test_env_num: 0
env:
name: searchqa
skill_init: reflact/envs/searchqa/skills/initial.md
split_mode: split_dir
split_ratio: "2:1:7"
split_dir: data/searchqa_split
data_path: ""
split_output_dir: ""
max_turns: 1
workers: 24
limit: 0

View File

@@ -0,0 +1,34 @@
_base_: ../_base_/default.yaml
model:
reasoning_effort: medium
train:
train_size: 80
batch_size: 40
accumulation: 1
gradient:
minibatch_size: 8
merge_batch_size: 8
optimizer:
learning_rate: 4
evaluation:
sel_env_num: 0
test_env_num: 0
env:
name: spreadsheetbench
skill_init: reflact/envs/spreadsheetbench/skills/initial.md
split_mode: split_dir
split_ratio: "2:1:7"
split_dir: data/spreadsheetbench_split
data_path: ""
split_output_dir: ""
data_root: data/spreadsheetbench_verified_400
mode: multi
max_turns: 30
exec_timeout: 600
workers: 24

View File

@@ -0,0 +1,36 @@
_base_: ../_base_/default.yaml
model:
reasoning_effort: medium
train:
batch_size: 20
accumulation: 1
gradient:
minibatch_size: 4
merge_batch_size: 8
optimizer:
learning_rate: 4
evaluation:
sel_env_num: 0
test_env_num: 0
env:
name: swebench
skill_init: reflact/envs/swebench/skills/initial.md
split_mode: ratio
split_ratio: "2:1:7"
split_dir: ""
data_path: ""
split_output_dir: ""
dataset_name: lite
hf_split: test
workers: 8
eval_workers: 8
step_limit: 50
cost_limit: 3.0
timeout_per_instance: 600
limit: 0

29
reflact/__init__.py Normal file
View File

@@ -0,0 +1,29 @@
"""ReflACT: Reflective Agent Tuning.
A general-purpose framework for iteratively optimizing LLM agent skills
through structured reflection and self-improvement.
Pipeline stages:
1. Rollout — execute episodes with current skill
2. Reflect — analyze trajectories, generate patches
3. Aggregate — hierarchical merge of patches
4. Select — rank and select top edits
5. Update — apply edits to skill document
6. Evaluate — validate candidate skill, accept/reject
"""
__version__ = "0.1.0"
from reflact.types import ( # noqa: F401
BatchSpec,
Edit,
EditOp,
FailureSummaryEntry,
GateAction,
GateResult,
MetaReflectResult,
Patch,
RawPatch,
RolloutResult,
SlowUpdateResult,
)

263
reflact/config.py Normal file
View File

@@ -0,0 +1,263 @@
"""ReflACT config loading engine — structured YAML with inheritance.
Supports two config formats:
1. **Structured** (new): sections like ``model``, ``train``, ``gradient``,
``optimizer``, ``evaluation``, ``env`` — with ``_base_`` inheritance.
2. **Flat** (legacy): all keys at top level — fully backward compatible.
Usage::
from reflact.config import load_config, flatten_config
cfg = load_config("configs/searchqa_default.yaml")
flat = flatten_config(cfg) # always returns flat dict for trainer
"""
from __future__ import annotations
import copy
import os
from typing import Any
import yaml
# ── Section names that indicate a structured config ──────────────────────
_STRUCTURED_SECTIONS = frozenset({
"model", "train", "gradient", "optimizer", "evaluation", "env",
})
# ── Structured → flat key mapping ────────────────────────────────────────
_FLATTEN_MAP: dict[str, str] = {
"model.backend": "model_backend",
"model.teacher": "teacher_model",
"model.student": "student_model",
"model.teacher_backend": "teacher_backend",
"model.student_backend": "student_backend",
"model.reasoning_effort": "reasoning_effort",
"model.rewrite_reasoning_effort": "rewrite_reasoning_effort",
"model.rewrite_max_completion_tokens": "rewrite_max_completion_tokens",
"model.codex_exec_path": "codex_exec_path",
"model.codex_exec_sandbox": "codex_exec_sandbox",
"model.codex_exec_profile": "codex_exec_profile",
"model.codex_exec_full_auto": "codex_exec_full_auto",
"model.codex_exec_reasoning_effort": "codex_exec_reasoning_effort",
"model.codex_exec_use_sdk": "codex_exec_use_sdk",
"model.codex_exec_network_access": "codex_exec_network_access",
"model.codex_exec_web_search": "codex_exec_web_search",
"model.codex_exec_approval_policy": "codex_exec_approval_policy",
"model.claude_code_exec_path": "claude_code_exec_path",
"model.claude_code_exec_profile": "claude_code_exec_profile",
"model.claude_code_exec_use_sdk": "claude_code_exec_use_sdk",
"model.claude_code_exec_effort": "claude_code_exec_effort",
"model.claude_code_exec_max_thinking_tokens": "claude_code_exec_max_thinking_tokens",
"model.codex_trace_to_teacher": "codex_trace_to_teacher",
"model.azure_endpoint": "azure_endpoint",
"model.azure_api_version": "azure_api_version",
"model.azure_api_key": "azure_api_key",
"model.azure_openai_endpoint": "azure_openai_endpoint",
"model.azure_openai_api_version": "azure_openai_api_version",
"model.azure_openai_api_key": "azure_openai_api_key",
"model.azure_openai_auth_mode": "azure_openai_auth_mode",
"model.azure_openai_ad_scope": "azure_openai_ad_scope",
"model.azure_openai_managed_identity_client_id": "azure_openai_managed_identity_client_id",
"model.teacher_azure_openai_endpoint": "teacher_azure_openai_endpoint",
"model.teacher_azure_openai_api_version": "teacher_azure_openai_api_version",
"model.teacher_azure_openai_api_key": "teacher_azure_openai_api_key",
"model.teacher_azure_openai_auth_mode": "teacher_azure_openai_auth_mode",
"model.teacher_azure_openai_ad_scope": "teacher_azure_openai_ad_scope",
"model.teacher_azure_openai_managed_identity_client_id": "teacher_azure_openai_managed_identity_client_id",
"model.student_azure_openai_endpoint": "student_azure_openai_endpoint",
"model.student_azure_openai_api_version": "student_azure_openai_api_version",
"model.student_azure_openai_api_key": "student_azure_openai_api_key",
"model.student_azure_openai_auth_mode": "student_azure_openai_auth_mode",
"model.student_azure_openai_ad_scope": "student_azure_openai_ad_scope",
"model.student_azure_openai_managed_identity_client_id": "student_azure_openai_managed_identity_client_id",
"train.num_epochs": "num_epochs",
"train.train_size": "train_size",
"train.steps_per_epoch": "steps_per_epoch",
"train.batch_size": "batch_size",
"train.accumulation": "accumulation",
"train.seed": "seed",
"gradient.minibatch_size": "minibatch_size",
"gradient.merge_batch_size": "merge_batch_size",
"gradient.analyst_workers": "analyst_workers",
"gradient.failure_only": "failure_only",
"gradient.use_deep_reflect": "use_deep_reflect",
"gradient.deep_reflect_failures": "deep_reflect_failures",
"gradient.deep_reflect_successes": "deep_reflect_successes",
"gradient.max_analyst_rounds": "max_analyst_rounds",
"optimizer.learning_rate": "edit_budget",
"optimizer.min_learning_rate": "min_edit_budget",
"optimizer.lr_scheduler": "lr_scheduler",
"optimizer.lr_control_mode": "lr_control_mode",
"optimizer.skill_update_mode": "skill_update_mode",
"optimizer.use_meta_reflect": "use_meta_reflect",
"optimizer.meta_learning_rate": "meta_edit_budget",
"optimizer.use_slow_update": "use_slow_update",
"optimizer.slow_update_samples": "slow_update_samples",
"optimizer.longitudinal_pair_policy": "longitudinal_pair_policy",
"optimizer.use_meta_skill": "use_meta_skill",
"evaluation.use_gate": "use_gate",
"evaluation.sel_env_num": "sel_env_num",
"evaluation.test_env_num": "test_env_num",
"evaluation.eval_test": "eval_test",
"env.name": "env",
"env.skill_init": "skill_init",
"env.out_root": "out_root",
}
# ── Deep merge ───────────────────────────────────────────────────────────
def _deep_merge(base: dict, override: dict) -> dict:
"""Recursively merge *override* into *base* (returns new dict)."""
result = copy.deepcopy(base)
for key, val in override.items():
if key in result and isinstance(result[key], dict) and isinstance(val, dict):
result[key] = _deep_merge(result[key], val)
else:
result[key] = copy.deepcopy(val)
return result
# ── YAML loading with _base_ inheritance ─────────────────────────────────
def _load_yaml(path: str, _visited: set[str] | None = None) -> dict:
"""Load a YAML file, resolving ``_base_`` inheritance recursively."""
abs_path = os.path.abspath(path)
if _visited is None:
_visited = set()
if abs_path in _visited:
raise ValueError(f"Circular _base_ inheritance: {abs_path}")
_visited.add(abs_path)
with open(abs_path) as f:
cfg = yaml.safe_load(f) or {}
base_ref = cfg.pop("_base_", None)
if base_ref:
base_path = os.path.join(os.path.dirname(abs_path), base_ref)
base_cfg = _load_yaml(base_path, _visited)
cfg = _deep_merge(base_cfg, cfg)
return cfg
# ── Format detection ─────────────────────────────────────────────────────
def is_structured(cfg: dict) -> bool:
"""Return True if *cfg* uses the new structured section format."""
return any(
key in _STRUCTURED_SECTIONS and isinstance(cfg.get(key), dict)
for key in cfg
)
# ── Flatten ──────────────────────────────────────────────────────────────
def flatten_config(cfg: dict) -> dict:
"""Convert a structured config to the flat dict expected by the trainer.
If *cfg* is already flat, returns a shallow copy unchanged.
"""
if not is_structured(cfg):
return dict(cfg)
flat: dict[str, Any] = {}
evaluation_section = cfg.get("evaluation", {})
if isinstance(evaluation_section, dict) and evaluation_section.get("use_gate") is False:
raise ValueError(
"Gate validation is mandatory in this branch. Remove "
"`evaluation.use_gate: false` from the config."
)
# Apply the explicit mapping
for dotted, flat_key in _FLATTEN_MAP.items():
section, key = dotted.split(".", 1)
section_dict = cfg.get(section, {})
if isinstance(section_dict, dict) and key in section_dict:
flat[flat_key] = section_dict[key]
# Pass through env-specific keys not in the explicit mapping
env_section = cfg.get("env", {})
if isinstance(env_section, dict):
mapped_env_keys = {
k.split(".", 1)[1]
for k in _FLATTEN_MAP
if k.startswith("env.")
}
for key, val in env_section.items():
if key not in mapped_env_keys:
flat[key] = val
return flat
# ── Override application ─────────────────────────────────────────────────
def _cast_value(val_str: str) -> Any:
"""Auto-cast a CLI string value to int / float / bool / str."""
if val_str.lower() in ("true", "yes"):
return True
if val_str.lower() in ("false", "no"):
return False
try:
return int(val_str)
except ValueError:
pass
try:
return float(val_str)
except ValueError:
pass
return val_str
def apply_overrides(cfg: dict, overrides: list[str]) -> None:
"""Apply ``key=value`` overrides to a structured config (in place).
Supports both ``section.key=value`` (for structured configs) and
``key=value`` (for flat configs or flat keys in env section).
"""
for item in overrides:
if "=" not in item:
raise ValueError(f"Invalid override (expected key=value): {item!r}")
key, val_str = item.split("=", 1)
val = _cast_value(val_str)
if "." in key:
section, subkey = key.split(".", 1)
if section in cfg and isinstance(cfg[section], dict):
cfg[section][subkey] = val
else:
cfg.setdefault(section, {})[subkey] = val
else:
# Flat key — apply to top level (for legacy compat)
cfg[key] = val
# ── Public API ───────────────────────────────────────────────────────────
def load_config(
path: str,
overrides: list[str] | None = None,
) -> dict:
"""Load a config file with ``_base_`` inheritance and optional overrides.
Parameters
----------
path : str
Path to the YAML config file.
overrides : list[str] | None
``key=value`` strings from ``--cfg-options``.
Returns
-------
dict
The merged config (structured or flat depending on the YAML).
"""
cfg = _load_yaml(path)
if overrides:
apply_overrides(cfg, overrides)
return cfg

View File

@@ -0,0 +1,7 @@
"""ReflACT Datasets -- task batch planning and data loading.
Analogous to the datasets and dataloaders in neural network training:
provides batch sampling, epoch planning, and data management for the
ReflACT training pipeline.
"""
from reflact.datasets.base import BaseDataLoader, BatchSpec, SplitDataLoader # noqa: F401

512
reflact/datasets/base.py Normal file
View File

@@ -0,0 +1,512 @@
"""Generic task dataloader abstractions for ReflACT.
ReflACT does not train model parameters directly. Instead, it iterates over
task batches, rolls out the current skill, reflects on failures/successes,
and updates the skill document. Because of that, the "dataloader" abstraction
here is closer to a batch sampler / episode planner than a tensor loader.
Class hierarchy::
BaseDataLoader # abstract — simulator-backed envs (e.g. ALFWorld)
└── SplitDataLoader # abstract — dataset-backed envs with split_dir
SplitDataLoader supports two dataset entry modes:
1. ``split_mode="split_dir"``: consume an existing split directory.
2. ``split_mode="ratio"``: build a deterministic split directory from a raw
dataset path using an explicit train:val:test ratio.
In either case, the standardised split layout is:
split_dir/
├── train/ # training items
├── val/ # validation / selection items (gate)
└── test/ # held-out test items
Each subdirectory's contents are benchmark-specific. Subclasses only need
to implement ``load_split_items(split_path)`` to teach the loader how to
read items from one of those directories.
"""
from __future__ import annotations
import glob
import json
import os
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
@dataclass(slots=True)
class BatchSpec:
"""A concrete batch request consumed by the training loop.
Parameters
----------
phase : str
``"train"`` or ``"eval"``.
split : str
Dataset split name, typically ``"train"`` or an eval split.
seed : int
Random seed used to construct the batch deterministically.
batch_size : int
Requested number of items / episodes in this batch.
payload : object | None
Environment-specific batch payload. For dataset-backed environments
this is often a list of sampled items; for simulator-backed
environments this may be ``None`` and the seed alone can define the
batch.
metadata : dict[str, Any]
Optional structured metadata for logging, resume, or curriculum logic.
"""
phase: str
split: str
seed: int
batch_size: int
payload: object | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class BaseDataLoader(ABC):
"""Abstract base class for task batch planning in ReflACT.
Subclasses are responsible for defining how a train or eval batch is
sampled. The default implementation here provides deterministic epoch seed
planning so all loaders share the same reproducibility behavior.
"""
def setup(self, cfg: dict) -> None:
"""Optional one-time initialization with the full trainer config."""
def set_out_root(self, out_root: str) -> None:
"""Optional hook for loaders that persist split files or state."""
def state_dict(self) -> dict[str, Any]:
"""Return serializable loader state for resume support."""
return {}
def load_state_dict(self, state: dict[str, Any]) -> None:
"""Restore loader state from :meth:`state_dict` output."""
def get_train_size(self) -> int | None:
"""Return the size of the training pool when known."""
return None
@staticmethod
def make_base_seeds(steps_per_epoch: int, accumulation: int, seed: int) -> list[int]:
"""Return the deterministic seed pool used to define train batches."""
batches_per_epoch = steps_per_epoch * accumulation
return [seed + i + 1 for i in range(batches_per_epoch)]
@staticmethod
def shuffle_epoch_seeds(base_seeds: list[int], epoch: int, seed: int) -> list[int]:
"""Return the per-epoch deterministic shuffle of *base_seeds*."""
epoch_rng = random.Random(seed + epoch * 1000)
shuffled = list(base_seeds)
epoch_rng.shuffle(shuffled)
return shuffled
def plan_train_epoch(
self,
*,
epoch: int,
steps_per_epoch: int,
accumulation: int,
batch_size: int,
seed: int,
**kwargs,
) -> list[BatchSpec]:
"""Build the full list of training batches for one epoch."""
base_seeds = self.make_base_seeds(
steps_per_epoch=steps_per_epoch,
accumulation=accumulation,
seed=seed,
)
shuffled_seeds = self.shuffle_epoch_seeds(base_seeds, epoch=epoch, seed=seed)
return [
self.build_train_batch(batch_size=batch_size, seed=batch_seed, **kwargs)
for batch_seed in shuffled_seeds
]
@abstractmethod
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
"""Construct one training batch specification."""
@abstractmethod
def build_eval_batch(
self,
env_num: int,
split: str,
seed: int,
**kwargs,
) -> BatchSpec:
"""Construct one evaluation batch specification."""
# ── Split-based dataloader for dataset-backed environments ──────────────
# Canonical split names expected under split_dir/
SPLIT_NAMES = ("train", "val", "test")
# Maps legacy / trainer split names → canonical directory names
_SPLIT_ALIAS: dict[str, str] = {
"train": "train",
"valid_seen": "val",
"selection": "val",
"val": "val",
"valid_unseen": "test",
"test": "test",
}
def _load_json_or_jsonl(path: str) -> list[dict]:
"""Load a list of items from a JSON or JSONL file."""
with open(path, encoding="utf-8") as f:
content = f.read().strip()
if not content:
return []
try:
data = json.loads(content)
except json.JSONDecodeError:
data = None
if isinstance(data, list):
return data
if isinstance(data, dict):
nested = data.get("data")
if isinstance(nested, list):
return nested
return list(data.values())
items: list[dict] = []
for line in content.splitlines():
line = line.strip()
if line:
items.append(json.loads(line))
return items
def _parse_split_ratio(text: str) -> tuple[int, int, int]:
parts = [part.strip() for part in str(text or "").split(":") if part.strip()]
if len(parts) != 3:
raise ValueError(
f"split_ratio must be in train:val:test form, got {text!r}"
)
try:
train, val, test = (int(part) for part in parts)
except ValueError as exc:
raise ValueError(
f"split_ratio must contain integers, got {text!r}"
) from exc
if min(train, val, test) <= 0:
raise ValueError(f"split_ratio parts must be positive, got {text!r}")
return train, val, test
def _compute_split_counts(total: int, ratio: tuple[int, int, int]) -> tuple[int, int, int]:
weights = list(ratio)
denom = sum(weights)
raw = [total * weight / denom for weight in weights]
counts = [int(value) for value in raw]
remaining = total - sum(counts)
order = sorted(
range(len(raw)),
key=lambda idx: (raw[idx] - counts[idx], weights[idx]),
reverse=True,
)
for idx in order[:remaining]:
counts[idx] += 1
return counts[0], counts[1], counts[2]
class SplitDataLoader(BaseDataLoader):
"""Base class for dataset-backed environments.
Supported modes:
- ``split_mode="split_dir"``: load an existing ``train/``, ``val/``,
``test/`` directory tree.
- ``split_mode="ratio"``: load raw items from ``data_path`` and materialize
a deterministic split directory with the requested ratio.
"""
def __init__(
self,
split_dir: str = "",
data_path: str = "",
split_mode: str = "ratio",
split_ratio: str = "2:1:7",
split_seed: int = 42,
split_output_dir: str = "",
seed: int = 42,
limit: int = 0,
**kwargs,
) -> None:
self.split_dir = split_dir
self.data_path = data_path
self.split_mode = split_mode
self.split_ratio = split_ratio
self.split_seed = int(split_seed)
self.split_output_dir = split_output_dir
self.seed = seed
self.limit = limit
self._splits: dict[str, list[dict]] = {}
# ── Setup ────────────────────────────────────────────────────────────
def setup(self, cfg: dict) -> None:
if not self.split_mode:
self.split_mode = str(cfg.get("split_mode", "ratio") or "ratio")
if not self.split_dir:
self.split_dir = cfg.get("split_dir", "")
if not self.data_path:
self.data_path = cfg.get("data_path", "")
if not self.split_output_dir:
self.split_output_dir = cfg.get("split_output_dir", "")
if "split_seed" in cfg and not self.split_seed:
self.split_seed = int(cfg.get("split_seed", 0) or 0)
if not self.split_seed:
self.split_seed = self.seed
if not self.split_ratio:
self.split_ratio = str(cfg.get("split_ratio", "2:1:7") or "2:1:7")
mode = str(self.split_mode or "ratio").strip().lower()
if mode not in {"ratio", "split_dir"}:
raise ValueError(
f"{type(self).__name__} split_mode must be 'ratio' or 'split_dir', "
f"got {self.split_mode!r}"
)
self.split_mode = mode
if self.split_mode == "ratio":
self.split_dir = self._materialize_ratio_split(cfg)
if not self.split_dir:
raise ValueError(
f"{type(self).__name__} requires either "
"`split_mode=ratio` with `data_path`, or `split_mode=split_dir` "
f"with `split_dir` pointing to {'/'.join(SPLIT_NAMES)}/."
)
self._load_all_splits()
def _resolve_split_output_dir(self, cfg: dict) -> str:
if self.split_output_dir:
return os.path.abspath(self.split_output_dir)
out_root = os.path.abspath(str(cfg.get("out_root") or os.getcwd()))
env_name = str(cfg.get("env") or type(self).__name__.replace("DataLoader", "").lower())
ratio_tag = str(self.split_ratio or "2:1:7").replace(":", "-")
return os.path.join(out_root, "_generated_splits", f"{env_name}_{ratio_tag}_seed{self.split_seed}")
def load_raw_items(self, data_path: str) -> list[dict]:
"""Load raw items from a dataset path before ratio splitting.
Subclasses can override when the raw dataset is not a single JSON/JSONL
file or when directory layouts require custom normalization.
"""
if os.path.isdir(data_path):
if any(os.path.isdir(os.path.join(data_path, name)) for name in SPLIT_NAMES):
raise ValueError(
f"{type(self).__name__} got a split directory as data_path. "
"Use split_mode=split_dir and pass it as split_dir instead."
)
candidates = sorted(glob.glob(os.path.join(data_path, "*.json")))
candidates += sorted(glob.glob(os.path.join(data_path, "*.jsonl")))
if len(candidates) != 1:
raise ValueError(
f"{type(self).__name__} expected data_path to be one JSON/JSONL file "
f"or a directory containing exactly one such file, got: {data_path}"
)
return _load_json_or_jsonl(candidates[0])
return _load_json_or_jsonl(data_path)
def write_split_items(self, split_path: str, items: list[dict]) -> None:
os.makedirs(split_path, exist_ok=True)
out_path = os.path.join(split_path, "items.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(items, f, ensure_ascii=False, indent=2)
def _materialize_ratio_split(self, cfg: dict) -> str:
data_path = os.path.abspath(str(self.data_path or "").strip())
if not data_path:
raise ValueError(
f"{type(self).__name__} requires data_path when split_mode=ratio."
)
ratio = _parse_split_ratio(self.split_ratio)
items = self.load_raw_items(data_path)
if not isinstance(items, list) or not items:
raise ValueError(f"No raw items available for ratio split from {data_path}")
shuffled = list(items)
rng = random.Random(self.split_seed)
rng.shuffle(shuffled)
train_n, val_n, test_n = _compute_split_counts(len(shuffled), ratio)
train_items = shuffled[:train_n]
val_items = shuffled[train_n: train_n + val_n]
test_items = shuffled[train_n + val_n: train_n + val_n + test_n]
split_dir = self._resolve_split_output_dir(cfg)
manifest = {
"source_data_path": data_path,
"split_mode": "ratio",
"split_ratio": self.split_ratio,
"split_seed": self.split_seed,
"counts": {
"train": len(train_items),
"val": len(val_items),
"test": len(test_items),
},
}
os.makedirs(split_dir, exist_ok=True)
self.write_split_items(os.path.join(split_dir, "train"), train_items)
self.write_split_items(os.path.join(split_dir, "val"), val_items)
self.write_split_items(os.path.join(split_dir, "test"), test_items)
with open(os.path.join(split_dir, "split_manifest.json"), "w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=False, indent=2)
print(
f" [{type(self).__name__}] generated ratio split {self.split_ratio} "
f"at {split_dir} from {data_path}"
)
return split_dir
def _load_all_splits(self) -> None:
for name in SPLIT_NAMES:
split_path = os.path.join(self.split_dir, name)
if not os.path.isdir(split_path):
raise ValueError(
f"Missing '{name}/' subdirectory in split_dir: {self.split_dir}"
)
items = self.load_split_items(split_path)
if self.limit:
items = items[: self.limit]
self._splits[name] = items
counts = " ".join(f"{k}={len(v)}" for k, v in self._splits.items())
print(f" [{type(self).__name__}] {counts} (from {self.split_dir})")
def load_split_items(self, split_path: str) -> list[dict]:
"""Load items from one split directory (e.g. ``split_dir/train/``).
Default: finds the first ``.json`` file in the directory and loads it
as a JSON array. Subclasses can override for custom formats.
"""
json_files = sorted(glob.glob(os.path.join(split_path, "*.json")))
if not json_files:
raise FileNotFoundError(
f"No .json file found in {split_path}"
)
with open(json_files[0], encoding="utf-8") as f:
items = json.load(f)
if not isinstance(items, list):
raise ValueError(
f"Expected JSON array in {json_files[0]}, got {type(items).__name__}"
)
return items
# ── Accessors ────────────────────────────────────────────────────────
@property
def train_items(self) -> list[dict]:
return self._splits.get("train", [])
@property
def val_items(self) -> list[dict]:
return self._splits.get("val", [])
@property
def test_items(self) -> list[dict]:
return self._splits.get("test", [])
def get_split_items(self, split: str) -> list[dict]:
"""Resolve a split name (including legacy aliases) to its item list."""
canonical = _SPLIT_ALIAS.get(split, split)
return list(self._splits.get(canonical, self.val_items))
def get_train_size(self) -> int:
return len(self.train_items)
def plan_train_epoch(
self,
*,
epoch: int,
steps_per_epoch: int,
accumulation: int,
batch_size: int,
seed: int,
**kwargs,
) -> list[BatchSpec]:
"""Build one full epoch that covers the train split in shuffled order.
For split-backed datasets, an epoch should correspond to one pass over
the available training items rather than repeated independent sampling.
"""
epoch_rng = random.Random(seed + epoch * 1000)
items = list(self.train_items)
epoch_rng.shuffle(items)
total_batches = steps_per_epoch * accumulation
if total_batches <= 0:
return []
batches: list[BatchSpec] = []
cursor = 0
for batch_idx in range(total_batches):
batch_items = items[cursor: cursor + batch_size]
cursor += len(batch_items)
# Extremely small datasets can leave trailing empty microbatches
# when accumulation > 1. Reuse the shuffled prefix in that case so
# the trainer still receives the expected batch count.
if not batch_items and items:
refill_rng = random.Random(seed + epoch * 1000 + batch_idx + 1)
batch_items = list(items)
refill_rng.shuffle(batch_items)
batch_items = batch_items[:batch_size]
batches.append(
BatchSpec(
phase="train",
split="train",
seed=seed + epoch * 1000 + batch_idx + 1,
batch_size=len(batch_items),
payload=batch_items,
)
)
return batches
# ── Batch construction ───────────────────────────────────────────────
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
rng = random.Random(seed)
items = list(self.train_items)
rng.shuffle(items)
items = items[:batch_size]
return BatchSpec(
phase="train",
split="train",
seed=seed,
batch_size=len(items),
payload=items,
)
def build_eval_batch(
self,
env_num: int,
split: str,
seed: int,
**kwargs,
) -> BatchSpec:
items = self.get_split_items(split)
if env_num and env_num < len(items):
items = items[:env_num]
return BatchSpec(
phase="eval",
split=split,
seed=seed,
batch_size=len(items),
payload=items,
)

View File

@@ -0,0 +1,9 @@
"""ReflACT Engine -- the training runner.
Analogous to the Runner in mmengine: orchestrates the full training pipeline
including rollout, gradient computation, aggregation, optimization, and
evaluation.
"""
from reflact.engine.trainer import ReflACTTrainer # noqa: F401
__all__ = ["ReflACTTrainer"]

2195
reflact/engine/trainer.py Normal file

File diff suppressed because it is too large Load Diff

1
reflact/envs/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""ReflACT environment adapters."""

View File

@@ -0,0 +1,5 @@
"""ALFWorld environment adapter for ReflACT."""
from reflact.envs.alfworld.adapter import ALFWorldAdapter
__all__ = ["ALFWorldAdapter"]

View File

@@ -0,0 +1,585 @@
"""ALFWorld environment adapter for ReflACT.
Connects the ReflACT training loop to ALFWorld by implementing
:class:`~reflact.envs.base.EnvAdapter`.
"""
from __future__ import annotations
from dataclasses import dataclass
import json
import os
from reflact.gradient.deep_probe import generate_deep_probe_instruction
from reflact.datasets.base import BatchSpec
from reflact.envs.base import EnvAdapter
from reflact.envs.alfworld.dataloader import ALFWorldDataLoader
from reflact.envs.alfworld.rollout import (
build_alfworld_env,
run_alfworld_batch,
TASKS,
)
from reflact.gradient.reflect import run_minibatch_reflect
from reflact.utils import compute_score
@dataclass(frozen=True)
class ALFWorldBatchRun:
"""Lazy ALFWorld batch description.
The adapter materializes this in rollout chunks so a large evaluation set
does not keep every ALFWorld simulator open at once.
"""
env_num: int
eval_dataset: str
seed: int
is_train: bool
workers: int
specific_gamefiles: list[str] | None = None
result_ids: list[str] | None = None
items: list[dict] | None = None
def __iter__(self):
return iter(self.items or [])
def __len__(self) -> int:
return int(self.env_num or 0)
class ALFWorldAdapter(EnvAdapter):
"""ALFWorld environment adapter.
Parameters
----------
max_steps : int
Maximum steps per ALFWorld episode (default 50).
max_api_workers : int
Maximum concurrent API calls during rollout (default 8).
analyst_workers : int
Parallel workers for analyst stage (default 16).
failure_only : bool
If True, only run error analyst (skip success analyst).
minibatch_size : int
Trajectories per analyst group, M (default 8).
edit_budget : int
Maximum edits per minibatch, L (default 4).
"""
def __init__(
self,
split_dir: str = "",
data_path: str = "",
split_mode: str = "split_dir",
split_ratio: str = "2:1:7",
split_seed: int = 42,
split_output_dir: str = "",
seed: int = 42,
limit: int = 0,
train_size: int = 0,
max_steps: int = 50,
workers: int = 8,
max_api_workers: int = 8,
analyst_workers: int = 16,
failure_only: bool = False,
minibatch_size: int = 8,
edit_budget: int = 4,
use_deep_reflect: bool = False,
deep_reflect_failures: int = 4,
deep_reflect_successes: int = 2,
) -> None:
self.max_steps = max_steps
self.workers = max(int(workers or 1), 1)
self.max_api_workers = max_api_workers
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
self.edit_budget = edit_budget
self.use_deep_reflect = use_deep_reflect
self.deep_reflect_failures = deep_reflect_failures
self.deep_reflect_successes = deep_reflect_successes
self.dataloader = ALFWorldDataLoader(
split_dir=split_dir,
data_path=data_path,
split_mode=split_mode,
split_ratio=split_ratio,
split_seed=split_seed,
split_output_dir=split_output_dir,
seed=seed,
limit=limit,
train_size=train_size,
)
self._traj_cache: dict[str, dict | None] = {}
def setup(self, cfg: dict) -> None:
super().setup(cfg)
self.dataloader.setup(cfg)
def _load_traj_data(self, item: dict) -> dict | None:
gamefile = str(item.get("gamefile") or "").strip()
if not gamefile:
return None
if gamefile in self._traj_cache:
return self._traj_cache[gamefile]
traj_path = os.path.join(os.path.dirname(gamefile), "traj_data.json")
try:
with open(traj_path, encoding="utf-8") as f:
data = json.load(f)
except Exception:
data = None
self._traj_cache[gamefile] = data
return data
@staticmethod
def _unique_lines(values: list[str], *, limit: int = 0) -> list[str]:
lines: list[str] = []
seen: set[str] = set()
for raw in values:
line = str(raw or "").strip()
if not line or line in seen:
continue
seen.add(line)
lines.append(line)
if limit > 0 and len(lines) >= limit:
break
return lines
@staticmethod
def _format_high_pddl(high_pddl: list[dict]) -> list[str]:
steps: list[str] = []
for idx, step in enumerate(high_pddl or [], start=1):
discrete = step.get("discrete_action") or {}
action = str(discrete.get("action") or "").strip()
args = [str(arg).strip() for arg in (discrete.get("args") or []) if str(arg).strip()]
if action and args:
text = f"{action}({', '.join(args)})"
elif action:
text = action
else:
planner_action = step.get("planner_action") or {}
text = str(planner_action.get("action") or "").strip()
if text:
steps.append(f"{idx}. {text}")
return steps
def _build_reference_bundle(self, item: dict) -> dict:
data = self._load_traj_data(item)
if not data:
return {}
anns = ((data.get("turk_annotations") or {}).get("anns") or [])
task_descs = self._unique_lines(
[ann.get("task_desc", "") for ann in anns],
limit=3,
)
high_descs = self._unique_lines(
[step for ann in anns for step in (ann.get("high_descs") or [])],
limit=12,
)
pddl_params = {
key: value
for key, value in (data.get("pddl_params") or {}).items()
if value not in ("", None, [], {})
}
scene = data.get("scene") or {}
scene_summary = {
key: scene.get(key)
for key in ("floor_plan", "scene_num", "dirty_and_empty")
if scene.get(key) not in ("", None, [], {})
}
high_pddl = self._format_high_pddl((data.get("plan") or {}).get("high_pddl") or [])
task_type = str(data.get("task_type") or item.get("task_type") or "").strip()
return {
"task_type": task_type,
"task_descs": task_descs,
"high_descs": high_descs,
"pddl_params": pddl_params,
"high_pddl": high_pddl,
"scene_summary": scene_summary,
}
def build_reference_text(self, item: dict) -> str:
bundle = self._build_reference_bundle(item)
if not bundle:
return ""
parts: list[str] = []
if bundle["task_type"]:
parts.append(f"## Reference Task Type\n{bundle['task_type']}")
if bundle["task_descs"]:
parts.append(
"## Reference Human Task Descriptions\n"
+ "\n".join(f"- {line}" for line in bundle["task_descs"])
)
if bundle["high_descs"]:
parts.append(
"## Reference Human High-Level Steps\n"
+ "\n".join(f"{idx}. {line}" for idx, line in enumerate(bundle["high_descs"], start=1))
)
if bundle["pddl_params"]:
parts.append(
"## Reference PDDL Params\n"
+ "\n".join(f"- {key}: {value}" for key, value in bundle["pddl_params"].items())
)
if bundle["high_pddl"]:
parts.append(
"## Reference Planner High-Level Plan\n" + "\n".join(bundle["high_pddl"])
)
if bundle["scene_summary"]:
parts.append(
"## Reference Scene Summary\n"
+ "\n".join(f"- {key}: {value}" for key, value in bundle["scene_summary"].items())
)
return "\n\n".join(parts)
def get_reference_metadata(self, item: dict) -> dict:
bundle = self._build_reference_bundle(item)
if not bundle:
return {"fields": [], "preview": ""}
fields: list[str] = []
previews: list[str] = []
if bundle["task_type"]:
fields.append("task_type")
previews.append(f"[task_type] {bundle['task_type']}")
if bundle["task_descs"]:
fields.append("task_desc")
previews.append("[task_desc]\n" + "\n".join(bundle["task_descs"][:2]))
if bundle["high_descs"]:
fields.append("high_descs")
previews.append("[high_descs]\n" + "\n".join(bundle["high_descs"][:3]))
if bundle["pddl_params"]:
fields.append("pddl_params")
previews.append(
"[pddl_params]\n"
+ "\n".join(
f"{key}: {value}" for key, value in list(bundle["pddl_params"].items())[:4]
)
)
if bundle["high_pddl"]:
fields.append("plan.high_pddl")
previews.append("[plan.high_pddl]\n" + "\n".join(bundle["high_pddl"][:3]))
if bundle["scene_summary"]:
fields.append("scene")
previews.append(
"[scene]\n"
+ "\n".join(
f"{key}: {value}" for key, value in bundle["scene_summary"].items()
)
)
return {
"fields": fields,
"preview": "\n\n".join(previews)[:600],
}
@staticmethod
def _infer_dataset_from_gamefile(gamefile: str) -> tuple[str, bool]:
path = str(gamefile or "")
if "/valid_seen/" in path:
return "eval_in_distribution", False
if "/valid_unseen/" in path:
return "eval_out_of_distribution", False
return "train", True
def get_dataloader(self):
return self.dataloader
def _comparison_items(self, items: list[dict]) -> list[dict]:
enriched: list[dict] = []
for item in items:
row = dict(item)
bundle = self._build_reference_bundle(row)
if bundle.get("task_descs"):
row["task_description"] = bundle["task_descs"][0]
elif bundle.get("task_type"):
row["task_description"] = bundle["task_type"]
enriched.append(row)
return enriched
def requires_ray(self) -> bool:
return False
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
gamefiles = list(batch.metadata.get("gamefiles") or [])
result_ids = list(batch.metadata.get("result_ids") or [])
items = self._comparison_items(list(batch.payload or []))
return ALFWorldBatchRun(
env_num=batch.batch_size,
eval_dataset=batch.metadata.get("eval_dataset", batch.split),
seed=batch.seed,
is_train=batch.metadata.get("is_train", batch.phase == "train"),
specific_gamefiles=gamefiles or None,
result_ids=result_ids or None,
items=items,
workers=self.workers,
)
def build_train_env(self, batch_size: int, seed: int, **kwargs):
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def rollout(
self,
env_manager,
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict]:
results_path = os.path.join(out_dir, "results.jsonl")
os.makedirs(out_dir, exist_ok=True)
# Resume support
if os.path.exists(results_path):
existing: list[dict] = []
with open(results_path) as f:
for line in f:
try:
existing.append(json.loads(line))
except Exception:
pass
if existing:
return existing
if isinstance(env_manager, ALFWorldBatchRun):
results = self._run_batch(
env_manager,
skill_content=skill_content,
out_dir=out_dir,
)
else:
results = run_alfworld_batch(
env_manager=env_manager,
skill_content=skill_content,
max_steps=self.max_steps,
out_root=out_dir,
max_api_workers=self.max_api_workers,
result_ids=getattr(env_manager, "_reflact_result_ids", None),
)
with open(results_path, "w") as f:
for r in results:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
return results
@staticmethod
def _close_env(env_manager) -> None:
close = getattr(env_manager, "close", None)
if callable(close):
close()
def _run_batch(
self,
batch: ALFWorldBatchRun,
skill_content: str,
out_dir: str,
*,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
) -> list[dict]:
total = int(batch.env_num or 0)
if total <= 0:
return []
workers = max(1, min(int(batch.workers or self.workers), total))
if total > workers:
print(
f" [alfworld rollout] episodes={total} "
f"env_workers={workers} chunks={(total + workers - 1) // workers}"
)
all_results: list[dict] = []
for start in range(0, total, workers):
chunk_size = min(workers, total - start)
chunk_gamefiles = (
batch.specific_gamefiles[start:start + chunk_size]
if batch.specific_gamefiles
else None
)
chunk_ids = (
batch.result_ids[start:start + chunk_size]
if batch.result_ids
else [f"env_{idx:03d}" for idx in range(start, start + chunk_size)]
)
chunk_env = build_alfworld_env(
env_num=chunk_size,
eval_dataset=batch.eval_dataset,
seed=batch.seed + start,
is_train=batch.is_train,
specific_gamefiles=chunk_gamefiles,
)
try:
chunk_results = run_alfworld_batch(
env_manager=chunk_env,
skill_content=skill_content,
max_steps=self.max_steps,
out_root=out_dir,
max_api_workers=min(self.max_api_workers, chunk_size),
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
result_ids=chunk_ids,
)
finally:
self._close_env(chunk_env)
all_results.extend(chunk_results)
return all_results
def reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
meta_skill_context = kwargs.get("meta_skill_context", "")
return run_minibatch_reflect(
results=results,
skill_content=skill_content,
prediction_dir=prediction_dir,
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
)
def deep_reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
if not self.use_deep_reflect:
return []
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
meta_skill_context = kwargs.get("meta_skill_context", "")
selected_items = self.select_representative_items(
results,
results,
n_failures=self.deep_reflect_failures,
n_successes=self.deep_reflect_successes,
seed=random_seed,
)
if not selected_items:
return []
selected_ids = {str(item["id"]) for item in selected_items}
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
selected_examples = self.attach_reference_context(selected_results, selected_items)
field_counts: dict[str, int] = {}
selected_metadata: list[dict] = []
for item in selected_items:
meta = self.get_reference_metadata(item)
for field in meta["fields"]:
field_counts[field] = field_counts.get(field, 0) + 1
selected_metadata.append({
"id": str(item["id"]),
"task_type": str(item.get("task_type") or "alfworld"),
"gamefile": str(item.get("gamefile") or ""),
"reference_fields": meta["fields"],
"reference_preview": meta["preview"],
})
deep_dir = os.path.join(out_dir, "deep_reflect")
rollout_dir = os.path.join(deep_dir, "rollout")
patches_dir = os.path.join(deep_dir, "patches")
os.makedirs(deep_dir, exist_ok=True)
field_summary = ", ".join(
f"{field}({count}/{len(selected_items)})"
for field, count in sorted(field_counts.items())
) or "none"
print(
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
f"reference_fields={field_summary}"
)
probe = generate_deep_probe_instruction(
skill_content=skill_content,
items=selected_examples,
prediction_dir=prediction_dir,
system_prompt=self.get_deep_probe_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
output_requirements=[
"- Some trajectories may include a hidden Reference block. Use it to target the student's latent subgoal, missing precondition, or next-step intent, but do not reveal or paraphrase that reference to the student.",
"- The instruction must request a brief diagnostic readout inside the existing <think>...</think> block.",
"- The student must still output exactly one admissible action inside <action>...</action>.",
"- Do not ask for exhaustive inventories, full plans, or long chain-of-thought.",
"- The instruction text should be ready to append directly to the student's prompt.",
],
)
if not probe:
return []
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
json.dump(
{
**probe,
"reference_summary": {
"selected_count": len(selected_items),
"field_counts": field_counts,
},
"selected_examples": selected_metadata,
},
f,
ensure_ascii=False,
indent=2,
)
gamefiles = [str(item.get("gamefile") or "") for item in selected_items]
if any(not gamefile for gamefile in gamefiles):
return []
eval_dataset, is_train = self._infer_dataset_from_gamefile(gamefiles[0])
deep_env = ALFWorldBatchRun(
env_num=len(selected_items),
eval_dataset=eval_dataset,
seed=random_seed or 42,
is_train=is_train,
specific_gamefiles=gamefiles,
workers=min(self.workers, max(len(selected_items), 1)),
result_ids=[str(item["id"]) for item in selected_items],
)
deep_results = self._run_batch(
deep_env,
skill_content=skill_content,
out_dir=rollout_dir,
diagnostic_mode=True,
diagnostic_instruction=probe["probe_instruction"],
)
deep_results = self.attach_reference_context(deep_results, selected_items)
return run_minibatch_reflect(
results=deep_results,
skill_content=skill_content,
prediction_dir=os.path.join(rollout_dir, "predictions"),
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
)
def get_task_types(self) -> list[str]:
return list(TASKS)

View File

@@ -0,0 +1,123 @@
"""ALFWorld task dataloader."""
from __future__ import annotations
from reflact.datasets.base import BatchSpec, SplitDataLoader
class ALFWorldDataLoader(SplitDataLoader):
"""ALFWorld batch planner.
In split_dir mode, batches are fixed gamefile items so ablations differ
only in how the same training set is batched.
"""
def __init__(
self,
split_dir: str = "",
data_path: str = "",
split_mode: str = "split_dir",
split_ratio: str = "2:1:7",
split_seed: int = 42,
split_output_dir: str = "",
seed: int = 42,
limit: int = 0,
train_size: int = 0,
**kwargs,
) -> None:
super().__init__(
split_dir=split_dir,
data_path=data_path,
split_mode=split_mode,
split_ratio=split_ratio,
split_seed=split_seed,
split_output_dir=split_output_dir,
seed=seed,
limit=limit,
)
self.train_size_override = int(train_size or 0)
@staticmethod
def _metadata_for_items(items: list[dict], split: str, phase: str) -> dict:
gamefiles = [str(item.get("gamefile") or "") for item in items]
if any(not gamefile for gamefile in gamefiles):
raise ValueError("ALFWorld split items must contain non-empty gamefile paths.")
eval_dataset = "train"
is_train = phase == "train"
first = gamefiles[0] if gamefiles else ""
if "/valid_seen/" in first:
eval_dataset = "eval_in_distribution"
is_train = False
elif "/valid_unseen/" in first:
eval_dataset = "eval_out_of_distribution"
is_train = False
return {
"eval_dataset": eval_dataset,
"is_train": is_train,
"gamefiles": gamefiles,
"result_ids": [str(item.get("id") or idx) for idx, item in enumerate(items)],
}
def get_train_size(self) -> int:
if self.train_size_override > 0:
return self.train_size_override
return super().get_train_size()
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
batch = super().build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
items = list(batch.payload or [])
batch.metadata.update(self._metadata_for_items(items, "train", "train"))
return BatchSpec(
phase="train",
split="train",
seed=seed,
batch_size=len(items),
payload=items,
metadata=batch.metadata,
)
def plan_train_epoch(
self,
*,
epoch: int,
steps_per_epoch: int,
accumulation: int,
batch_size: int,
seed: int,
**kwargs,
) -> list[BatchSpec]:
batches = super().plan_train_epoch(
epoch=epoch,
steps_per_epoch=steps_per_epoch,
accumulation=accumulation,
batch_size=batch_size,
seed=seed,
**kwargs,
)
for batch in batches:
items = list(batch.payload or [])
batch.metadata.update(self._metadata_for_items(items, "train", "train"))
return batches
def build_eval_batch(
self,
env_num: int,
split: str,
seed: int,
**kwargs,
) -> BatchSpec:
batch = super().build_eval_batch(
env_num=env_num,
split=split,
seed=seed,
**kwargs,
)
items = list(batch.payload or [])
batch.metadata.update(self._metadata_for_items(items, split, "eval"))
return BatchSpec(
phase="eval",
split=split,
seed=seed,
batch_size=len(items),
payload=items,
metadata=batch.metadata,
)

View File

@@ -0,0 +1,55 @@
You are an expert failure-analysis agent for ALFWorld embodied household tasks.
You will be given MULTIPLE failed agent trajectories from a single minibatch
and the current skill document.
Your job is to identify the most important COMMON failure patterns across
the batch and propose a concise set of skill edits.
## ALFWorld Task Types
- pick_and_place: Put object in/on a receptacle
- pick_two_obj_and_place: Put two instances of an object in/on a receptacle
- look_at_obj_in_light: Examine an object under a desklamp
- pick_heat_then_place_in_recep: Heat an object and put it in/on a receptacle
- pick_cool_then_place_in_recep: Cool an object and put it in/on a receptacle
- pick_clean_then_place_in_recep: Clean an object and put it in/on a receptacle
## Failure Type Categories
- **navigation_loop**: the agent revisits the same locations repeatedly without progress
- **missed_object**: the agent fails to pick up a visible/reachable goal object
- **wrong_sequence**: the agent performs actions in the wrong order (e.g., placing before transforming)
- **premature_stop**: the agent stops or gets stuck before completing all goal conditions
- **action_loop**: the agent repeats the same action without advancing
- **appliance_error**: the agent misuses or skips an appliance (microwave, fridge, sink)
- **rule_missing**: the skill lacks a relevant rule for this situation
- **rule_wrong**: an existing skill rule is misleading or incorrect
- **rule_ignored**: the skill has the right rule but the agent did not follow it
- **other**: none of the above
## Analysis Process
1. Read ALL trajectories in the minibatch.
2. Identify the most prevalent, systematic failure patterns across them.
3. For each pattern, classify its failure type.
4. Propose skill edits that address the COMMON patterns — not individual edge cases.
5. Edits must be generalizable; do not hardcode task-specific values.
6. Only patch gaps in the skill — do not duplicate existing content.
You will be told the maximum number of edits (the budget L). Produce AT MOST L edits,
focusing on the highest-impact patterns. You may produce fewer if warranted.
Respond ONLY with a valid JSON object (no markdown fences, no extra text):
{
"batch_size": <number of trajectories analysed>,
"failure_summary": [
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
],
"patch": {
"reasoning": "<why these edits address the batch's common failures>",
"edits": [
{"op": "append", "content": "<markdown to add at end of skill>"},
{"op": "insert_after", "target": "<exact heading/text to insert after>", "content": "<markdown>"},
{"op": "replace", "target": "<exact text to replace>", "content": "<replacement>"},
{"op": "delete", "target": "<exact text to remove>"}
]
}
}
Only include edits that are needed. "edits" can be an empty list if no patch is warranted.

View File

@@ -0,0 +1,33 @@
You are an expert success-pattern analyst for AI agents operating in ALFWorld,
a text-based embodied household environment.
You will be given MULTIPLE successful agent trajectories from a single minibatch
and the current skill document. Your job is to identify generalizable behavior
patterns that are COMMON across the batch and worth encoding in the skill.
## Rules
- Only propose patches for patterns NOT already covered in the skill.
- Focus on patterns that appear across MULTIPLE trajectories in the batch.
- Be concise. Patterns must generalize beyond specific tasks.
- Prefer reinforcing existing sections over adding new top-level sections.
- If the agents' success involved efficient exploration or smart appliance usage,
consider reinforcing that in the patch.
You will be told the maximum number of edits (the budget L). Produce AT MOST L edits,
focusing on the most broadly applicable patterns. You may produce fewer if warranted.
Respond ONLY with a valid JSON object:
{
"batch_size": <number of trajectories analysed>,
"success_patterns": ["<pattern 1>", "<pattern 2>"],
"patch": {
"reasoning": "<why these patterns are worth encoding>",
"edits": [
{"op": "append", "content": "<markdown>"},
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
{"op": "replace", "target": "<old text>", "content": "<new text>"},
{"op": "delete", "target": "<exact text to remove>"}
]
}
}
"edits" may be empty if the skill already covers all observed patterns.

View File

@@ -0,0 +1,35 @@
You are an expert diagnostic-probe designer for ALFWorld embodied tasks.
You will design one short diagnostic instruction to append to the student's prompt
for a handful of representative ALFWorld trajectories.
The goal is to expose whether the student has the right intermediate subgoal,
object/receptacle state, and next-step intention without substantially changing
the current scaffold.
## Hard Constraints
1. Do NOT substantially change the student's existing action-selection scaffold.
2. Do NOT prescribe a brand-new planner or long multi-step policy.
3. Do NOT ask for exhaustive search over all objects or all admissible actions.
4. Keep the diagnostic readout brief and place it inside the existing <think>...</think> block.
5. The student must still output exactly one admissible action inside <action>...</action>.
6. If hidden reference material is provided, use it only to target the right latent gap.
7. Never copy hidden reference content into the student-facing probe.
## Good Probe Targets
- current subgoal
- target object / target receptacle / target state
- decisive missing precondition
- why one candidate action is better than a tempting alternative
- whether the current step should explore, transform an object, or place it
## Bad Probe Targets
- a full optimal plan from start to finish
- exhaustive object inventories
- a new theorem-like or planner-like protocol
Respond ONLY with a valid JSON object:
{
"reasoning": "<why this probe reveals the latent skill gap>",
"probe_instruction": "<the exact instruction text to append to the student prompt>"
}

View File

@@ -0,0 +1,8 @@
You are an expert agent operating in the ALFRED Embodied Environment.
Your current observation is: {current_observation}
Your admissible actions of the current situation are: [{admissible_actions}].
Now it's your turn to take an action.
You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within <think> </think> tags.
Once you've finished your reasoning, you should choose an admissible action for current step and present it within <action> </action> tags.

View File

@@ -0,0 +1,9 @@
You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
Prior to this step, you have already taken {step_count} step(s). Below are the most recent {history_length} observations and the corresponding actions you took: {action_history}
You are now at step {current_step} and your current observation is: {current_observation}
Your admissible actions of the current situation are: [{admissible_actions}].
Now it's your turn to take an action.
You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within <think> </think> tags.
Once you've finished your reasoning, you should choose an admissible action for current step and present it within <action> </action> tags.

View File

@@ -0,0 +1,16 @@
You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
## Retrieved Relevant Experience
{retrieved_memories}
## Current Progress
Prior to this step, you have already taken {step_count} step(s). Below are the most recent {history_length} observations and the corresponding actions you took: {action_history}
You are now at step {current_step} and your current observation is: {current_observation}
Your admissible actions of the current situation are: [{admissible_actions}].
Now it's your turn to take an action.
You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within <think> </think> tags.
Once you've finished your reasoning, you should choose an admissible action for current step and present it within <action> </action> tags.

View File

@@ -0,0 +1,4 @@
"""ALFWorld Reflect stage.
Prompts are now loaded from .md files by the base adapter.
"""

View File

@@ -0,0 +1,359 @@
"""ALFWorld rollout module for ReflACT.
Provides:
- build_alfworld_env(): build ALFWorld environment (wraps vendored SkillRL env)
- run_alfworld_batch(): run a batch of ALFWorld episodes in parallel
- TASKS: list of ALFWorld task types
"""
from __future__ import annotations
import json
import os
import re
import sys
import time
import concurrent.futures
import numpy as np
from reflact.model import chat_student
# ── Constants ─────────────────────────────────────────────────────────────────
TASKS = [
"pick_and_place",
"pick_two_obj_and_place",
"look_at_obj_in_light",
"pick_heat_then_place_in_recep",
"pick_cool_then_place_in_recep",
"pick_clean_then_place_in_recep",
]
# ── Helpers ───────────────────────────────────────────────────────────────────
def _get_task_type(gamefile: str) -> str:
for task in TASKS:
if task in gamefile:
return task
return "other"
def _extract_action(model_response: str) -> str | None:
match = re.search(r"<action>(.*?)</action>", model_response, re.DOTALL)
return match.group(1).strip() if match else None
def _extract_think(model_response: str) -> str | None:
match = re.search(r"<think>(.*?)</think>", model_response, re.DOTALL)
return match.group(1).strip() if match else None
def _build_skill_prompt(skill_content: str) -> str:
"""Build the skill section to inject into the agent's system prompt."""
if not skill_content or not skill_content.strip():
return ""
return (
"\n\n## Skill Knowledge\n"
"Below is a skill document with learned strategies. "
"Use these guidelines to inform your decisions:\n\n"
f"{skill_content}\n"
)
def _append_diagnostic_instruction(prompt: str, diagnostic_instruction: str) -> str:
if not diagnostic_instruction or not diagnostic_instruction.strip():
return prompt
return f"{prompt}\n\n## Training Readout\n{diagnostic_instruction.strip()}\n"
# ── Environment builder ──────────────────────────────────────────────────────
def build_alfworld_env(
env_num: int,
eval_dataset: str = "eval_out_of_distribution",
seed: int = 42,
is_train: bool = False,
specific_gamefiles: list[str] | None = None,
):
"""Build ALFWorld environment manager.
Args:
env_num: number of parallel environments
eval_dataset: 'eval_in_distribution' or 'eval_out_of_distribution' or train
seed: random seed
is_train: whether to use training set
Returns:
env_manager: AlfWorldEnvironmentManager instance
"""
from omegaconf import OmegaConf
from functools import partial
from reflact.envs.alfworld.vendor.alfworld_envs import build_alfworld_envs
from reflact.envs.alfworld.vendor.alfworld_projection import alfworld_projection
from reflact.envs.alfworld.vendor.env_manager import AlfWorldEnvironmentManager
HERE = os.path.dirname(os.path.abspath(__file__))
alf_config_path = os.path.join(HERE, "vendor", "config_tw.yaml")
env_kwargs = {"eval_dataset": eval_dataset}
envs = build_alfworld_envs(
alf_config_path,
seed=seed,
env_num=env_num,
group_n=1,
is_train=is_train,
env_kwargs=env_kwargs,
resources_per_worker=None,
gamefiles=specific_gamefiles,
)
config = OmegaConf.create(
{
"env": {
"history_length": 2,
"env_name": "alfworld/AlfredTWEnv",
}
}
)
projection_f = partial(alfworld_projection)
env_manager = AlfWorldEnvironmentManager(envs, projection_f, config)
return env_manager
# ── Batch rollout ─────────────────────────────────────────────────────────────
def run_alfworld_batch(
env_manager,
skill_content: str,
max_steps: int = 50,
out_root: str = "",
max_api_workers: int = 8,
temperature: float = 0.4,
max_completion_tokens: int = 2048,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
result_ids: list[str] | None = None,
) -> list[dict]:
"""Run a batch of ALFWorld episodes.
Returns a list of result dicts compatible with SkillReflection v2 pipeline:
[
{
"id": "<env_idx>_<gamefile_hash>",
"hard": 0 or 1,
"soft": 0.0 or 1.0,
"n_turns": <int>,
"fail_reason": "<str>",
"agent_ok": True,
"task_type": "<str>",
"gamefile": "<str>",
"task_description": "<str>",
},
...
]
Also saves conversation.json per environment in out_root/predictions/<task_id>/
"""
skill_prompt = _build_skill_prompt(skill_content)
obs, infos = env_manager.reset({})
env_num = len(obs["text"])
env_dones = [False] * env_num
overall_success = [False] * env_num
# Build per-env metadata
env_meta: list[dict] = []
for i in range(env_num):
gamefile = infos[i].get("extra.gamefile", "") if isinstance(infos[i], dict) else ""
task_type = _get_task_type(gamefile)
# Extract task description from initial observation
task_desc = ""
anchor_text = obs["anchor"][i] if "anchor" in obs else ""
task_start = anchor_text.find("Your task is to: ")
if task_start != -1:
task_desc = anchor_text[task_start + len("Your task is to: "):].strip()
env_meta.append({
"gamefile": gamefile,
"task_type": task_type,
"task_description": task_desc,
})
# Per-env conversation records
conversations: list[list[dict]] = [[] for _ in range(env_num)]
for step_idx in range(max_steps):
if all(env_dones):
break
active_indices = [i for i in range(env_num) if not env_dones[i]]
# Build prompts with skill injection
prompts: dict[int, str] = {}
for i in active_indices:
prompt = obs["text"][i]
if skill_prompt:
# Inject skill before the action instruction
prompt = skill_prompt + "\n" + prompt
if diagnostic_mode and diagnostic_instruction.strip():
prompt = _append_diagnostic_instruction(prompt, diagnostic_instruction)
prompts[i] = prompt
# Call API in parallel
actions = ["None"] * env_num
action_timeout = 180
def call_api(idx):
try:
response, _ = chat_student(
system="You are an expert agent operating in the ALFRED Embodied Environment.",
user=prompts[idx],
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=120,
)
response = (response or "").strip()
if not response:
return idx, "<think>empty model response</think><action>look</action>"
if _extract_action(response) is None:
return idx, "<think>missing action tag</think><action>look</action>"
return idx, response
except Exception as e:
return idx, "<think>error</think><action>look</action>"
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_api_workers)
try:
futures = {executor.submit(call_api, i): i for i in active_indices}
started_at = {future: time.time() for future in futures}
pending_futs = set(futures)
while pending_futs:
done, _ = concurrent.futures.wait(
pending_futs,
timeout=5,
return_when=concurrent.futures.FIRST_COMPLETED,
)
now = time.time()
timed_out = [
future for future in pending_futs - done
if now - started_at[future] >= action_timeout
]
for future in done:
pending_futs.remove(future)
try:
idx, response = future.result()
except Exception: # noqa: BLE001
idx = futures[future]
response = "<think>error</think><action>look</action>"
actions[idx] = response
for future in timed_out:
pending_futs.remove(future)
idx = futures[future]
actions[idx] = "<think>api timeout</think><action>look</action>"
finally:
executor.shutdown(wait=False, cancel_futures=True)
# Save model responses before stepping
model_responses = {i: actions[i] for i in active_indices}
# Step environment
obs, rewards, dones, infos = env_manager.step(actions)
# Record trajectory
for i in active_indices:
step_record = {
"step": step_idx,
"action": _extract_action(model_responses[i]),
"reasoning": _extract_think(model_responses[i]),
"model_response": model_responses[i],
"env_feedback": obs["anchor"][i] if "anchor" in obs else "",
"reward": float(rewards[i]),
"done": bool(dones[i]),
}
conversations[i].append(step_record)
# Update done status
for i in range(env_num):
if env_dones[i]:
continue
if dones[i]:
env_dones[i] = True
won = bool(infos[i].get("won", False))
overall_success[i] = won
# Build results and save conversations
results: list[dict] = []
pred_dir = os.path.join(out_root, "predictions") if out_root else ""
for i in range(env_num):
gamefile = env_meta[i]["gamefile"]
task_type = env_meta[i]["task_type"]
task_desc = env_meta[i]["task_description"]
n_turns = len(conversations[i])
won = overall_success[i]
# Generate stable task ID from env index and gamefile
task_id = str(result_ids[i]) if result_ids and i < len(result_ids) else f"env_{i:03d}"
fail_reason = ""
if not won:
if not env_dones[i]:
fail_reason = f"Timeout after {max_steps} steps"
else:
fail_reason = "Episode ended without completing the task"
result = {
"id": task_id,
"hard": 1 if won else 0,
"soft": 1.0 if won else 0.0,
"n_turns": n_turns,
"fail_reason": fail_reason,
"agent_ok": True, # ALFWorld agent always runs OK (no crash)
"task_type": task_type,
"gamefile": gamefile,
"task_description": task_desc,
"instruction_type": task_type, # for compatibility with v2 pipeline
}
results.append(result)
# Save conversation
if pred_dir:
conv_dir = os.path.join(pred_dir, task_id)
os.makedirs(conv_dir, exist_ok=True)
with open(os.path.join(conv_dir, "conversation.json"), "w") as f:
json.dump(conversations[i], f, ensure_ascii=False, indent=2)
return results
# ── Item loading (for compatibility with split_three_way) ────────────────────
def load_alfworld_items(
eval_dataset: str,
env_num: int,
seed: int = 42,
is_train: bool = False,
) -> list[dict]:
"""Create pseudo-item dicts for ALFWorld environments.
Since ALFWorld doesn't have a static JSON dataset like SpreadsheetBench,
we create lightweight item dicts that carry enough metadata for the pipeline.
The actual environment is built dynamically.
Returns:
List of dicts with "id" keys, one per environment slot.
"""
items = []
for i in range(env_num):
items.append({
"id": f"env_{i:03d}",
"eval_dataset": eval_dataset,
"env_index": i,
})
return items

View File

@@ -0,0 +1,45 @@
# ALFWorld Embodied Agent Skill
## Overview
This skill guides agents operating in the ALFWorld text-based embodied environment.
The agent must complete household tasks by navigating rooms, interacting with objects,
and using appliances. Actions must be chosen from the admissible action list provided
at each step.
**Output format**: Always output `<think>...</think>` for reasoning, then `<action>...</action>` for the chosen action.
---
## Task Types
| Type | Goal | Key Steps |
|------|------|-----------|
| Pick & Place | Put object X in/on receptacle Y | Find X -> take X -> go to Y -> put X in/on Y |
| Pick Two & Place | Put two instances of X in/on Y | Find X1 -> take -> place -> find X2 -> take -> place |
| Examine in Light | Examine object X under desklamp | Find X -> take X -> find desklamp -> use desklamp |
| Clean & Place | Clean object X and put in/on Y | Find X -> take X -> go to sink -> clean X -> go to Y -> put X |
| Heat & Place | Heat object X and put in/on Y | Find X -> take X -> go to microwave -> heat X -> go to Y -> put X |
| Cool & Place | Cool object X and put in/on Y | Find X -> take X -> go to fridge -> cool X -> go to Y -> put X |
---
## General Principles
1. **Decompose the task**: Parse the goal into ordered sub-goals (locate, acquire, transform, deliver). Complete each before moving to the next.
2. **Systematic exploration**: Search each surface and container exactly once before revisiting. Open closed containers (drawers, cabinets, fridge) before judging them empty.
3. **Grab immediately**: When a required object is visible and reachable, take it right away before moving elsewhere.
4. **Transform before placing**: If the task requires cleaning, heating, or cooling, perform the state change at the appropriate appliance before heading to the final destination.
5. **Direct delivery**: Once holding the transformed (or untransformed) goal object, navigate straight to the target receptacle and place it.
6. **Track progress**: Maintain an internal count of how many objects still need to be found and placed. Only stop searching when the count reaches zero.
7. **Avoid loops**: Never repeat the same action more than twice in a row. If stuck, move to a different unexplored location.
8. **Only choose admissible actions**: Always pick an action from the admissible action list. Do not invent actions.
---
## Common Mistakes to Avoid
- **Revisiting searched locations**: Keep track of which surfaces/containers have been checked; do not re-examine them.
- **Ignoring visible objects**: If the target object appears in the observation, pick it up immediately.
- **Skipping state changes**: Do not place an object at the destination without first cleaning/heating/cooling it when required.
- **Premature termination**: Do not stop the episode until all goal conditions are verified as met.
- **Action loops**: Repeatedly toggling or examining the same object wastes steps. Move on to new locations instead.

View File

@@ -0,0 +1,9 @@
"""Vendored ALFWorld environment runtime.
Minimal subset of SkillRL's agent_system package needed to run
ALFWorld environments with ReflACT. Original source:
https://github.com/NTU-LANTERN/SkillRL (Apache-2.0 License)
"""
from .alfworld_envs import AlfworldEnvs, build_alfworld_envs
from .alfworld_projection import alfworld_projection
from .env_manager import AlfWorldEnvironmentManager

View File

@@ -0,0 +1,221 @@
# Vendored from SkillRL (Apache-2.0 License)
# Original: agent_system/environments/env_package/alfworld/envs.py
# Modified: imports use pip-installed alfworld package instead of vendored copy.
import os
import multiprocessing as mp
import traceback
import yaml
import gymnasium as gym
import numpy as np
from alfworld.agents.environment import get_environment
def load_config_file(path):
assert os.path.exists(path), f"Invalid config file: {path}"
with open(path) as reader:
config = yaml.safe_load(reader)
return config
def compute_reward(info, multi_modal=False):
if multi_modal:
reward = 10.0 * float(info['won']) + float(info['goal_condition_success_rate'])
else:
reward = 10.0 * float(info['won'])
return reward
class AlfworldWorker:
"""Stateful worker that holds one ALFWorld sub-environment."""
def __init__(self, config, seed, base_env, gamefile=None):
if gamefile:
base_env.game_files = [gamefile]
if hasattr(base_env, "num_games"):
base_env.num_games = 1
self.env = base_env.init_env(batch_size=1)
self.env.seed(seed)
def step(self, action):
actions = [action]
obs, scores, dones, infos = self.env.step(actions)
infos['observation_text'] = obs
return obs, scores, dones, infos
def reset(self):
obs, infos = self.env.reset()
infos['observation_text'] = obs
return obs, infos
def _worker_loop(cmd_q, result_q, config, seed, is_train, eval_dataset, gamefile):
"""Run one ALFWorld environment in a child process."""
try:
env_type = config['env']['type']
base_env = get_environment(env_type)(
config,
train_eval='train' if is_train else eval_dataset,
)
worker = AlfworldWorker(config, seed, base_env, gamefile)
result_q.put((True, "ready"))
except BaseException:
result_q.put((False, traceback.format_exc()))
return
while True:
cmd, payload = cmd_q.get()
if cmd == "close":
result_q.put((True, None))
return
try:
if cmd == "reset":
result = worker.reset()
elif cmd == "step":
result = worker.step(payload)
else:
raise ValueError(f"Unknown ALFWorld worker command: {cmd}")
result_q.put((True, result))
except BaseException:
result_q.put((False, traceback.format_exc()))
class _ProcessWorker:
"""Small stdlib actor wrapper for one environment process."""
def __init__(self, ctx, config, seed, is_train, eval_dataset, gamefile=None):
self.cmd_q = ctx.Queue(maxsize=1)
self.result_q = ctx.Queue(maxsize=1)
self.process = ctx.Process(
target=_worker_loop,
args=(self.cmd_q, self.result_q, config, seed, is_train, eval_dataset, gamefile),
)
self.process.start()
ok, payload = self.result_q.get()
if not ok:
self.close(kill=True)
raise RuntimeError(f"Failed to start ALFWorld worker:\n{payload}")
def send(self, cmd, payload=None):
self.cmd_q.put((cmd, payload))
def recv(self):
ok, payload = self.result_q.get()
if not ok:
raise RuntimeError(f"ALFWorld worker failed:\n{payload}")
return payload
def close(self, kill=False):
if self.process.is_alive() and not kill:
try:
self.send("close")
self.recv()
except Exception:
kill = True
if kill and self.process.is_alive():
self.process.terminate()
self.process.join(timeout=5)
if self.process.is_alive():
self.process.kill()
self.process.join(timeout=1)
self.cmd_q.close()
self.result_q.close()
class AlfworldEnvs(gym.Env):
"""Vectorized ALFWorld environment using local process workers."""
def __init__(self, alf_config_path, seed, env_num, group_n,
resources_per_worker, is_train=True, env_kwargs=None, gamefiles=None):
super().__init__()
if env_kwargs is None:
env_kwargs = {}
eval_dataset = env_kwargs.get('eval_dataset', 'eval_in_distribution')
config = load_config_file(alf_config_path)
env_type = config['env']['type']
self.multi_modal = (env_type == 'AlfredThorEnv')
self.num_processes = env_num * group_n
self.group_n = group_n
self.gamefiles = list(gamefiles or [])
if self.gamefiles and len(self.gamefiles) != self.num_processes:
raise ValueError(
f"Expected {self.num_processes} gamefiles, got {len(self.gamefiles)}"
)
start_method = os.environ.get("ALFWORLD_WORKER_START_METHOD") or None
ctx = mp.get_context(start_method) if start_method else mp.get_context()
self.workers = []
for i in range(self.num_processes):
worker_gamefile = self.gamefiles[i] if self.gamefiles else None
worker = _ProcessWorker(
ctx,
config,
seed + (i // self.group_n),
is_train,
eval_dataset,
worker_gamefile,
)
self.workers.append(worker)
self.prev_admissible_commands = [None for _ in range(self.num_processes)]
def step(self, actions):
assert len(actions) == self.num_processes
for i, worker in enumerate(self.workers):
worker.send("step", actions[i])
results = [worker.recv() for worker in self.workers]
text_obs_list = []
rewards_list = []
dones_list = []
info_list = []
for i, (obs, scores, dones, info) in enumerate(results):
for k in info.keys():
info[k] = info[k][0]
text_obs_list.append(obs[0])
dones_list.append(dones[0])
info_list.append(info)
self.prev_admissible_commands[i] = info['admissible_commands']
rewards_list.append(compute_reward(info, self.multi_modal))
image_obs_list = None
return text_obs_list, image_obs_list, rewards_list, dones_list, info_list
def reset(self):
for worker in self.workers:
worker.send("reset")
results = [worker.recv() for worker in self.workers]
text_obs_list = []
info_list = []
for i, (obs, info) in enumerate(results):
for k in info.keys():
info[k] = info[k][0]
text_obs_list.append(obs[0])
self.prev_admissible_commands[i] = info['admissible_commands']
info_list.append(info)
image_obs_list = None
return text_obs_list, image_obs_list, info_list
@property
def get_admissible_commands(self):
return self.prev_admissible_commands
def close(self):
for worker in self.workers:
worker.close()
def build_alfworld_envs(alf_config_path, seed, env_num, group_n,
resources_per_worker, is_train=True, env_kwargs=None, gamefiles=None):
"""Build vectorized ALFWorld environments."""
return AlfworldEnvs(
alf_config_path, seed, env_num, group_n,
resources_per_worker, is_train, env_kwargs, gamefiles,
)

View File

@@ -0,0 +1,60 @@
# Vendored from SkillRL (Apache-2.0 License)
# Original: agent_system/environments/env_package/alfworld/projection.py
from typing import List
import re
def alfworld_projection(actions: List[str], action_pools: List[List[str]]):
"""Process raw model outputs into valid ALFWorld actions.
Extracts text from ``<action>...</action>`` tags and validates that
the response also contains ``<think>...</think>`` tags.
Parameters
----------
actions : list[str]
Raw model outputs, one per environment.
action_pools : list[list[str]]
Admissible action lists per environment (unused but kept for API compat).
Returns
-------
actions : list[str]
Cleaned action strings.
valids : list[int]
1 if the action was successfully parsed, 0 otherwise.
"""
valids = [0] * len(actions)
for i in range(len(actions)):
original_str = actions[i]
actions[i] = actions[i].lower()
start_tag = "<action>"
end_tag = "</action>"
start_idx = actions[i].find(start_tag)
end_idx = actions[i].find(end_tag)
try:
if start_idx == -1 or end_idx == -1:
actions[i] = actions[i][-30:]
continue
extracted_action = actions[i][start_idx + len(start_tag):end_idx].strip().lower()
actions[i] = extracted_action
valids[i] = 1
except Exception:
actions[i] = actions[i][-30:]
# Require <think>...</think>
think_start_idx = original_str.find("<think>")
think_end_idx = original_str.find("</think>")
if think_start_idx == -1 or think_end_idx == -1:
valids[i] = 0
# Reject responses containing Chinese characters
if re.search(r'[\u4e00-\u9fff]', original_str):
valids[i] = 0
return actions, valids

View File

@@ -0,0 +1,8 @@
# Vendored from SkillRL (Apache-2.0 License)
# Original: agent_system/environments/prompts/alfworld.py
from reflact.prompts import load_prompt
ALFWORLD_TEMPLATE_NO_HIS = load_prompt("rollout_no_history", env="alfworld")
ALFWORLD_TEMPLATE = load_prompt("rollout_with_history", env="alfworld")
ALFWORLD_TEMPLATE_WITH_MEMORY = load_prompt("rollout_with_memory", env="alfworld")

View File

@@ -0,0 +1,145 @@
dataset:
data_path: '$ALFWORLD_DATA/json_2.1.1/train'
eval_id_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_seen' # null/None to disable
eval_ood_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_unseen' # null/None to disable
num_train_games: -1 # max training games (<=0 indicates full dataset)
num_eval_games: -1 # max evaluation games (<=0 indicates full dataset)
logic:
domain: '$ALFWORLD_DATA/logic/alfred.pddl' # PDDL domain file that defines the world dynamics
grammar: '$ALFWORLD_DATA/logic/alfred.twl2' # Grammar file that defines the text feedbacks
env:
type: 'AlfredTWEnv' # 'AlfredTWEnv' or 'AlfredThorEnv' or 'AlfredHybrid'
# regen_game_files: False # check if game is solvable by expert and save to game.tw-pddl file
domain_randomization: False # shuffle Textworld print order and object id nums
task_types: [1, 2, 3, 4, 5, 6] # task-type ids: 1 - Pick & Place, 2 - Examine in Light, 3 - Clean & Place, 4 - Heat & Place, 5 - Cool & Place, 6 - Pick Two & Place
expert_timeout_steps: 150 # max steps before timeout for expert to solve the task
expert_type: "handcoded" # 'handcoded' or 'planner'. Note: the planner is very slow for real-time use
goal_desc_human_anns_prob: 0.0 # prob of using human-annotated goal language instead of templated goals (1.0 indicates all human annotations from ALFRED)
hybrid:
start_eps: 100000 # starting episode of hybrid training, tw-only training upto this point
thor_prob: 0.5 # prob of AlfredThorEnv during hybrid training
eval_mode: "tw" # 'tw' or 'thor' - env used for evaluation during hybrid training
thor:
screen_width: 300 # width of THOR window
screen_height: 300 # height of THOR window
smooth_nav: False # smooth rotations, looks, and translations during navigation (very slow)
save_frames_to_disk: False # save frame PNGs to disk (useful for making videos)
save_frames_path: './videos/' # path to save frame PNGs
controller:
type: 'oracle' # 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar' (aka BUTLER)
debug: False
load_receps: True # load receptacle locations from precomputed dict (if available)
mask_rcnn:
pretrained_model_path: '$ALFWORLD_DATA/detectors/mrcnn.pth'
general:
random_seed: 42
use_cuda: True # disable this when running on machine without cuda
visdom: False # plot training/eval curves, run with visdom server
task: 'alfred'
training_method: 'dagger' # 'dqn' or 'dagger'
save_path: './training/' # path to save pytorch models
observation_pool_capacity: 3 # k-size queue, 0 indicates no observation
hide_init_receptacles: False # remove initial observation containing navigable receptacles
training:
batch_size: 10
max_episode: 50000
smoothing_eps: 0.1
optimizer:
learning_rate: 0.001
clip_grad_norm: 5
evaluate:
run_eval: True
batch_size: 10
env:
type: "AlfredTWEnv"
checkpoint:
report_frequency: 1000 # report every N episode
experiment_tag: 'test' # name of experiment
load_pretrained: False # during test, enable this so that the agent load your pretrained model
load_from_tag: 'not loading anything' # name of pre-trained model to load in save_path
model:
encoder_layers: 1
decoder_layers: 1
encoder_conv_num: 5
block_hidden_dim: 64
n_heads: 1
dropout: 0.1
block_dropout: 0.1
recurrent: True
rl:
action_space: "admissible" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'beam_search_choice' or 'exhaustive' (not working)
max_target_length: 20 # max token length for seq2seq generation
beam_width: 10 # 1 means greedy
generate_top_k: 3
training:
max_nb_steps_per_episode: 50 # terminate after this many steps
learn_start_from_this_episode: 0 # delay updates until this epsiode
target_net_update_frequency: 500 # sync target net with online net per this many epochs
replay:
accumulate_reward_from_final: True
count_reward_lambda: 0.0 # 0 to disable
novel_object_reward_lambda: 0.0 # 0 to disable
discount_gamma_game_reward: 0.9
discount_gamma_count_reward: 0.5
discount_gamma_novel_object_reward: 0.5
replay_memory_capacity: 500000 # adjust this depending on your RAM size
replay_memory_priority_fraction: 0.5
update_per_k_game_steps: 5
replay_batch_size: 64
multi_step: 3
replay_sample_history_length: 4
replay_sample_update_from: 2
epsilon_greedy:
noisy_net: False # if this is true, then epsilon greedy is disabled
epsilon_anneal_episodes: 1000 # -1 if not annealing
epsilon_anneal_from: 0.3
epsilon_anneal_to: 0.1
dagger:
action_space: "generation" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'exhaustive' (not working)
max_target_length: 20 # max token length for seq2seq generation
beam_width: 10 # 1 means greedy
generate_top_k: 5
unstick_by_beam_search: False # use beam-search for failed actions, set True during evaluation
training:
max_nb_steps_per_episode: 50 # terminate after this many steps
fraction_assist:
fraction_assist_anneal_episodes: 50000
fraction_assist_anneal_from: 1.0
fraction_assist_anneal_to: 0.01
fraction_random:
fraction_random_anneal_episodes: 0
fraction_random_anneal_from: 0.0
fraction_random_anneal_to: 0.0
replay:
replay_memory_capacity: 500000
update_per_k_game_steps: 5
replay_batch_size: 64
replay_sample_history_length: 4
replay_sample_update_from: 2
vision_dagger:
model_type: "resnet" # 'resnet' (whole image features) or 'maskrcnn_whole' (whole image MaskRCNN feats) or 'maskrcnn' (top k MaskRCNN detection feats) or 'no_vision' (zero vision input)
resnet_fc_dim: 64
maskrcnn_top_k_boxes: 10 # top k box features
use_exploration_frame_feats: False # append feats from initial exploration (memory intensive!)
sequence_aggregation_method: "average" # 'sum' or 'average' or 'rnn'

View File

@@ -0,0 +1,84 @@
# Vendored from SkillRL (Apache-2.0 License)
# Original: agent_system/environments/base.py
# Trimmed to only include what ALFWorld needs.
from typing import List, Tuple, Dict, Any
import numpy as np
from collections import defaultdict
def to_numpy(data):
"""Convert data to numpy array."""
# Lazy-check for torch.Tensor to avoid hard dependency on torch
_torch_tensor = None
try:
import torch
_torch_tensor = torch.Tensor
except ImportError:
pass
if _torch_tensor is not None and isinstance(data, _torch_tensor):
data = data.detach().cpu().numpy()
elif isinstance(data, np.ndarray):
pass
elif isinstance(data, (int, float, bool, Tuple, List)):
data = np.array(data)
else:
raise ValueError(f"Unsupported type: {type(data)})")
return data
class EnvironmentManagerBase:
"""Base class for vectorized environment managers.
Manages a set of parallel environments, handles action projection,
observation post-processing, and history tracking.
"""
def __init__(self, envs, projection_f, config):
self.envs = envs
self.projection_f = projection_f
self.config = config
def reset(self, kwargs) -> Dict[str, Any]:
obs, infos = self.envs.reset()
return {'text': None, 'image': obs, 'anchor': None}, infos
def step(self, text_actions: List[str]):
actions, valids = self.projection_f(text_actions)
next_obs, rewards, dones, infos = self.envs.step(actions)
next_observations = {
'text': None,
'image': next_obs,
'anchor': None,
}
for i, info in enumerate(infos):
info['is_action_valid'] = to_numpy(valids[i])
rewards = to_numpy(rewards)
dones = to_numpy(dones)
return next_observations, rewards, dones, infos
def close(self) -> None:
self.envs.close()
def success_evaluator(self, *args, **kwargs) -> Dict[str, np.ndarray]:
total_infos = kwargs['total_infos']
total_batch_list = kwargs['total_batch_list']
batch_size = len(total_batch_list)
success = defaultdict(list)
for bs in range(batch_size):
self._process_batch(bs, total_batch_list, total_infos, success)
assert len(success['success_rate']) == batch_size
return {key: np.array(value) for key, value in success.items()}
def _process_batch(self, batch_idx, total_batch_list, total_infos, success):
for i in reversed(range(len(total_batch_list[batch_idx]))):
batch_item = total_batch_list[batch_idx][i]
if batch_item['active_masks']:
info = total_infos[batch_idx][i]
won_value = float(info['won'])
success['success_rate'].append(won_value)
return

View File

@@ -0,0 +1,139 @@
# Vendored from SkillRL (Apache-2.0 License)
# Original: agent_system/environments/env_manager.py
# Trimmed to only include AlfWorldEnvironmentManager and its helpers.
from typing import List, Dict, Any
from collections import defaultdict
import numpy as np
from reflact.envs.alfworld.vendor.env_base import EnvironmentManagerBase, to_numpy
from reflact.envs.alfworld.vendor.alfworld_prompts import (
ALFWORLD_TEMPLATE,
ALFWORLD_TEMPLATE_NO_HIS,
ALFWORLD_TEMPLATE_WITH_MEMORY,
)
from reflact.envs.alfworld.vendor.memory import SimpleMemory
def parse_gamefile(infos):
gamefile = []
for info in infos:
if 'extra.gamefile' in info:
gamefile.append(info['extra.gamefile'])
else:
gamefile.append(None)
return gamefile
def set_gamefile(infos, gamefile):
for i in range(len(infos)):
if 'extra.gamefile' in infos[i]:
infos[i]['extra.gamefile'] = gamefile[i]
else:
infos[i]['extra.gamefile'] = None
return infos
class AlfWorldEnvironmentManager(EnvironmentManagerBase):
"""Manages parallel ALFWorld environments with observation templating."""
def __init__(self, envs, projection_f, config):
self.memory = SimpleMemory()
self.retrieval_memory = None
super().__init__(envs, projection_f, config)
def reset(self, kwargs):
text_obs, image_obs, infos = self.envs.reset()
self.gamefile = parse_gamefile(infos)
self.memory.reset(batch_size=len(text_obs))
self.tasks = []
self.pre_text_obs = text_obs
self.extract_task(text_obs)
full_text_obs = self.build_text_obs(text_obs, self.envs.get_admissible_commands, init=True)
return {'text': full_text_obs, 'image': image_obs, 'anchor': text_obs}, infos
def step(self, text_actions: List[str]):
actions, valids = self.projection_f(text_actions, self.envs.get_admissible_commands)
text_obs, image_obs, rewards, dones, infos = self.envs.step(actions)
self.memory.store({'text_obs': self.pre_text_obs, 'action': actions})
self.pre_text_obs = text_obs
full_text_obs = self.build_text_obs(text_obs, self.envs.get_admissible_commands)
if infos[0].get("extra.gamefile") is None:
infos = set_gamefile(infos, self.gamefile)
for i, info in enumerate(infos):
info['is_action_valid'] = to_numpy(valids[i])
next_observations = {'text': full_text_obs, 'image': image_obs, 'anchor': text_obs}
rewards = to_numpy(rewards)
dones = to_numpy(dones)
return next_observations, rewards, dones, infos
def extract_task(self, text_obs: List[str]):
for obs in text_obs:
task_start = obs.find('Your task is to: ')
if task_start != -1:
self.tasks.append(obs[task_start + len('Your task is to: '):].strip())
else:
raise ValueError("Task description not found in text observation.")
def build_text_obs(self, text_obs: List[str], admissible_actions: List[List[str]], init: bool = False) -> List[str]:
postprocess_text_obs = []
if not init and self.config.env.history_length > 0:
memory_contexts, valid_lens = self.memory.fetch(
self.config.env.history_length,
obs_key="text_obs",
action_key="action",
)
for i in range(len(text_obs)):
reformatted_admissible_actions = "\n ".join(
f"'{s}'" for s in admissible_actions[i] if s != 'help'
)
if init or self.config.env.history_length <= 0:
obs = ALFWORLD_TEMPLATE_NO_HIS.format(
current_observation=text_obs[i],
admissible_actions=reformatted_admissible_actions,
)
else:
obs = ALFWORLD_TEMPLATE.format(
task_description=self.tasks[i],
step_count=len(self.memory[i]),
history_length=valid_lens[i],
action_history=memory_contexts[i],
current_step=len(self.memory[i]) + 1,
current_observation=text_obs[i],
admissible_actions=reformatted_admissible_actions,
)
postprocess_text_obs.append(obs)
return postprocess_text_obs
def _process_batch(self, batch_idx, total_batch_list, total_infos, success):
for i in reversed(range(len(total_batch_list[batch_idx]))):
batch_item = total_batch_list[batch_idx][i]
if batch_item['active_masks']:
info = total_infos[batch_idx][i]
won_value = float(info['won'])
success['success_rate'].append(won_value)
gamefile = info.get("extra.gamefile")
if gamefile:
self._process_gamefile(gamefile, won_value, success)
return
def _process_gamefile(self, gamefile, won_value, success):
tasks = [
"pick_and_place",
"pick_two_obj_and_place",
"look_at_obj_in_light",
"pick_heat_then_place_in_recep",
"pick_cool_then_place_in_recep",
"pick_clean_then_place_in_recep",
]
for task in tasks:
if task in gamefile:
success[f"{task}_success_rate"].append(won_value)
break

87
reflact/envs/alfworld/vendor/memory.py vendored Normal file
View File

@@ -0,0 +1,87 @@
# Vendored from SkillRL (Apache-2.0 License)
# Original: agent_system/memory/base.py + agent_system/memory/memory.py
# Merged into a single file for simplicity.
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Tuple
class BaseMemory(ABC):
"""Base class for memory management."""
@abstractmethod
def __len__(self):
pass
@abstractmethod
def __getitem__(self, idx: int):
pass
@abstractmethod
def reset(self, batch_size: int):
pass
@abstractmethod
def store(self, record: Dict[str, List[Any]]):
pass
@abstractmethod
def fetch(self, step: int):
pass
class SimpleMemory(BaseMemory):
"""Per-environment history buffer for storing observations and actions."""
def __init__(self):
self._data = None
self.keys = None
self.batch_size = 0
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
return self._data[idx]
def reset(self, batch_size: int):
if self._data is not None:
self._data.clear()
self._data = [[] for _ in range(batch_size)]
self.batch_size = batch_size
self.keys = None
def store(self, record: Dict[str, List[Any]]):
if self.keys is None:
self.keys = list(record.keys())
assert self.keys == list(record.keys())
for env_idx in range(self.batch_size):
self._data[env_idx].append({k: record[k][env_idx] for k in self.keys})
def fetch(
self,
history_length: int,
obs_key: str = "text_obs",
action_key: str = "action",
) -> Tuple[List[str], List[int]]:
memory_contexts, valid_lengths = [], []
for env_idx in range(self.batch_size):
recent = self._data[env_idx][-history_length:]
valid_len = len(recent)
start_idx = len(self._data[env_idx]) - valid_len
lines = []
for j, rec in enumerate(recent):
step_num = start_idx + j + 1
act = rec[action_key]
obs = rec[obs_key]
lines.append(
f"[Observation {step_num}: '{obs}', Action {step_num}: '{act}']"
)
memory_contexts.append("\n".join(lines))
valid_lengths.append(valid_len)
return memory_contexts, valid_lengths

View File

@@ -0,0 +1 @@
"""BabyVision environment package for ReflACT."""

View File

@@ -0,0 +1,267 @@
"""BabyVision environment adapter for ReflACT."""
from __future__ import annotations
import json
import os
from reflact.gradient.deep_probe import generate_deep_probe_instruction
from reflact.datasets.base import BatchSpec
from reflact.gradient.reflect import run_minibatch_reflect
from reflact.envs.base import EnvAdapter
from reflact.envs.babyvision.dataloader import BabyVisionDataLoader
from reflact.envs.babyvision.rollout import run_batch
from reflact.model import get_student_backend
class BabyVisionAdapter(EnvAdapter):
"""BabyVision adapter."""
def build_reference_text(self, item: dict) -> str:
cot = str(item.get("cot") or "").strip()
if not cot:
return ""
return f"## Reference CoT\n{cot}"
def get_reference_metadata(self, item: dict) -> dict:
cot = str(item.get("cot") or "").strip()
if not cot:
return {"fields": [], "preview": ""}
return {
"fields": ["cot"],
"preview": cot[:400],
}
def __init__(
self,
split_dir: str = "",
data_path: str = "",
split_mode: str = "ratio",
split_ratio: str = "2:1:7",
split_seed: int = 42,
split_output_dir: str = "",
max_turns: int = 1,
workers: int = 32,
analyst_workers: int = 16,
failure_only: bool = False,
minibatch_size: int = 8,
edit_budget: int = 4,
seed: int = 42,
limit: int = 0,
image_detail: str = "auto",
judge_model: str = "gpt-5.4",
judge_max_completion_tokens: int = 256,
judge_retries: int = 5,
use_deep_reflect: bool = False,
deep_reflect_failures: int = 4,
deep_reflect_successes: int = 2,
) -> None:
self.max_turns = max_turns
self.workers = workers
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
self.edit_budget = edit_budget
self.image_detail = image_detail
self.judge_model = judge_model
self.judge_max_completion_tokens = judge_max_completion_tokens
self.judge_retries = judge_retries
self.use_deep_reflect = use_deep_reflect
self.deep_reflect_failures = deep_reflect_failures
self.deep_reflect_successes = deep_reflect_successes
self.dataloader = BabyVisionDataLoader(
split_dir=split_dir,
data_path=data_path,
split_mode=split_mode,
split_ratio=split_ratio,
split_seed=split_seed,
split_output_dir=split_output_dir,
seed=seed,
limit=limit,
)
def setup(self, cfg: dict) -> None:
super().setup(cfg)
self.dataloader.setup(cfg)
def get_dataloader(self):
return self.dataloader
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
return list(batch.payload or [])
def build_train_env(self, batch_size: int, seed: int, **kwargs):
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def rollout(
self,
env_manager,
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict]:
items: list[dict] = env_manager
return run_batch(
items=items,
out_root=out_dir,
skill_content=skill_content,
max_turns=self.max_turns,
workers=self.workers,
image_detail=self.image_detail,
judge_model=self.judge_model,
judge_max_completion_tokens=self.judge_max_completion_tokens,
judge_retries=self.judge_retries,
diagnostic_mode=kwargs.get("diagnostic_mode", False),
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),
)
def reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
meta_skill_context = kwargs.get("meta_skill_context", "")
return run_minibatch_reflect(
results=results,
skill_content=skill_content,
prediction_dir=prediction_dir,
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
)
def deep_reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
if not self.use_deep_reflect:
return []
env_manager = kwargs.get("env_manager")
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
meta_skill_context = kwargs.get("meta_skill_context", "")
codex_backend = get_student_backend() == "codex_exec"
selected_items = self.select_representative_items(
results,
env_manager if isinstance(env_manager, list) else None,
n_failures=self.deep_reflect_failures,
n_successes=self.deep_reflect_successes,
seed=random_seed,
)
if not selected_items:
return []
selected_ids = {str(item["id"]) for item in selected_items}
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
selected_examples = self.attach_reference_context(selected_results, selected_items)
if codex_backend:
selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir)
selected_metadata = []
cot_count = 0
for item in selected_items:
meta = self.get_reference_metadata(item)
if meta["fields"]:
cot_count += 1
selected_metadata.append({
"id": str(item["id"]),
"task_type": str(item.get("subtype") or item.get("task_type") or "babyvision"),
"reference_fields": meta["fields"],
"reference_preview": meta["preview"],
})
deep_dir = os.path.join(out_dir, "deep_reflect")
rollout_dir = os.path.join(deep_dir, "rollout")
patches_dir = os.path.join(deep_dir, "patches")
os.makedirs(deep_dir, exist_ok=True)
print(
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
f"reference_fields=cot({cot_count}/{len(selected_items)})"
)
probe = generate_deep_probe_instruction(
skill_content=skill_content,
items=selected_examples,
prediction_dir=prediction_dir,
system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
)
if not probe:
return []
diagnostic_trace_context_by_id = None
if codex_backend:
selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target(
selected_items=selected_items,
selected_examples=selected_examples,
prediction_dir=prediction_dir,
probe=probe,
)
probe_record = {
**probe,
"reference_summary": {
"selected_count": len(selected_items),
"field_counts": {
"cot": cot_count,
},
},
"selected_examples": selected_metadata,
}
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
json.dump(probe_record, f, ensure_ascii=False, indent=2)
deep_results = run_batch(
items=selected_items,
out_root=rollout_dir,
skill_content=skill_content,
max_turns=self.max_turns,
workers=min(self.workers, max(len(selected_items), 1)),
image_detail=self.image_detail,
judge_model=self.judge_model,
judge_max_completion_tokens=self.judge_max_completion_tokens,
judge_retries=self.judge_retries,
diagnostic_mode=True,
diagnostic_instruction=probe["probe_instruction"],
diagnostic_trace_context_by_id=diagnostic_trace_context_by_id,
)
deep_results = self.attach_reference_context(deep_results, selected_items)
return run_minibatch_reflect(
results=deep_results,
skill_content=skill_content,
prediction_dir=os.path.join(rollout_dir, "predictions"),
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
)
def get_task_types(self) -> list[str]:
return self.dataloader.get_task_types()

View File

@@ -0,0 +1,214 @@
"""BabyVision task dataloader."""
from __future__ import annotations
import json
import os
from typing import Any
from reflact.datasets.base import SplitDataLoader
# ── Raw data loading utilities (for preprocessing / standalone eval) ─────
_CHOICE_LABELS = ["A", "B", "C", "D", "E", "F", "G"]
def _iter_jsonl(path: str) -> list[dict]:
items: list[dict] = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
items.append(json.loads(line))
return items
def _normalize_ans_type(raw: Any, options: list[dict], choice_answer: Any) -> str:
text = str(raw or "").strip().lower()
if text in {"choice", "multiple_choice", "mcq", "option"}:
return "choice"
if text in {"blank", "open", "open_ended", "fill_blank", "short_answer"}:
return "blank"
if options or choice_answer not in (None, "", []):
return "choice"
return "blank"
def _coerce_options(raw: Any) -> list[dict]:
options: list[dict] = []
if isinstance(raw, list):
for idx, item in enumerate(raw):
if isinstance(item, dict):
text = str(item.get("text") or item.get("content") or item.get("option") or "").strip()
label = str(item.get("label") or _CHOICE_LABELS[idx]).strip()
else:
text = str(item).strip()
label = _CHOICE_LABELS[idx]
if text:
options.append({"label": label, "text": text})
elif isinstance(raw, dict):
for idx, (key, value) in enumerate(raw.items()):
text = str(value).strip()
if text:
options.append({"label": str(key).strip() or _CHOICE_LABELS[idx], "text": text})
return options
def _normalize_choice_answer(choice_answer: Any, options: list[dict]) -> dict[str, str]:
if not options:
return {"label": "", "text": ""}
if isinstance(choice_answer, dict):
label = str(choice_answer.get("label") or "").strip().upper()
text = str(choice_answer.get("text") or "").strip()
for option in options:
if label and option["label"].strip().upper() == label:
return {"label": option["label"], "text": option["text"]}
if text and option["text"] == text:
return {"label": option["label"], "text": option["text"]}
if isinstance(choice_answer, int):
idx = choice_answer
if 0 <= idx < len(options):
return dict(options[idx])
if 1 <= idx <= len(options):
return dict(options[idx - 1])
text = str(choice_answer or "").strip()
label = text.upper().rstrip(".):")
for option in options:
if option["label"].strip().upper() == label:
return dict(option)
if option["text"] == text:
return dict(option)
return {"label": "", "text": ""}
def _coerce_blank_answers(raw: Any) -> list[str]:
if isinstance(raw, list):
return [str(item).strip() for item in raw if str(item).strip()]
if raw is None:
return []
text = str(raw).strip()
return [text] if text else []
def load_items(data_path: str) -> list[dict]:
"""Load and normalise BabyVision items from a directory or JSONL file."""
if not data_path:
raise ValueError("BabyVision requires data_path pointing to a local dataset directory or meta_data.jsonl.")
if os.path.isdir(data_path):
meta_path = os.path.join(data_path, "meta_data.jsonl")
image_root = os.path.join(data_path, "images")
else:
meta_path = data_path
image_root = os.path.join(os.path.dirname(data_path), "images")
if not os.path.exists(meta_path):
raise ValueError(
"BabyVision expected a meta_data.jsonl file. "
f"Could not find: {meta_path}"
)
raw_items = _iter_jsonl(meta_path)
items: list[dict] = []
for idx, raw in enumerate(raw_items):
options = _coerce_options(raw.get("options") or raw.get("choices") or raw.get("choiceOptions"))
ans_type = _normalize_ans_type(raw.get("ansType"), options, raw.get("choiceAns"))
correct_choice = _normalize_choice_answer(raw.get("choiceAns"), options)
blank_answers = _coerce_blank_answers(raw.get("blankAns"))
image_name = str(
raw.get("image")
or raw.get("image_path")
or raw.get("image_file")
or raw.get("img")
or ""
).strip()
if not image_name:
continue
image_path = image_name if os.path.isabs(image_name) else os.path.join(image_root, image_name)
if not os.path.exists(image_path):
alt = os.path.join(os.path.dirname(meta_path), image_name)
if os.path.exists(alt):
image_path = alt
else:
continue
task_id = str(raw.get("taskId") or raw.get("id") or idx + 1)
task_type = str(raw.get("type") or raw.get("taskType") or "unknown").strip() or "unknown"
subtype = str(raw.get("subtype") or raw.get("subType") or task_type).strip() or task_type
question = str(raw.get("question") or raw.get("query") or "").strip()
if not question:
continue
if ans_type == "choice" and not correct_choice["label"]:
continue
if ans_type != "choice" and not blank_answers:
continue
items.append({
"id": task_id,
"task_type": task_type,
"subtype": subtype,
"question": question,
"image_path": os.path.abspath(image_path),
"ans_type": ans_type,
"choices": options,
"correct_choice": correct_choice,
"blank_answers": blank_answers,
"cot": str(raw.get("coT") or raw.get("cot") or "").strip(),
"source_path": os.path.abspath(meta_path),
})
if not items:
raise ValueError(f"No valid BabyVision items loaded from {data_path}")
return items
# ── Dataloader ───────────────────────────────────────────────────────────
class BabyVisionDataLoader(SplitDataLoader):
"""BabyVision dataloader."""
def __init__(
self,
split_dir: str = "",
data_path: str = "",
split_mode: str = "ratio",
split_ratio: str = "2:1:7",
split_seed: int = 42,
split_output_dir: str = "",
seed: int = 42,
limit: int = 0,
**kwargs,
) -> None:
super().__init__(
split_dir=split_dir,
data_path=data_path,
split_mode=split_mode,
split_ratio=split_ratio,
split_seed=split_seed,
split_output_dir=split_output_dir,
seed=seed,
limit=limit,
)
self._task_types: list[str] = []
def load_raw_items(self, data_path: str) -> list[dict]:
return load_items(data_path)
def setup(self, cfg: dict) -> None:
super().setup(cfg)
all_items = self.train_items + self.val_items + self.test_items
task_types = {
item.get("subtype") or item.get("task_type") or "unknown"
for item in all_items
}
self._task_types = sorted(task_types)
def get_task_types(self) -> list[str]:
return list(self._task_types)

View File

@@ -0,0 +1,160 @@
"""BabyVision evaluation helpers using the official-style LLM judge."""
from __future__ import annotations
import re
import string
import regex
from reflact.model import chat_with_deployment
from reflact.prompts import load_prompt
_EVAL_MODE = "babyvision_judge_v2_official_style"
def normalize_text(text: str) -> str:
text = str(text).strip().lower()
text = "".join(ch for ch in text if ch not in string.punctuation)
return " ".join(text.split())
def extract_boxed_answer(text: str | None) -> str | None:
"""Extract the final answer using the official BabyVision rule."""
if text is None:
return None
pattern = r'\\boxed\{((?:[^{}]|{(?:[^{}]|{.*})*})*)\}'
matches = regex.findall(pattern, text)
if matches:
return matches[-1]
pattern_alt = r'<\|begin_of_box\|>(.*?)<\|end_of_box\|>'
matches_alt = regex.findall(pattern_alt, text)
if matches_alt:
return matches_alt[-1].strip()
return None
def _token_f1(prediction: str, gold: str) -> float:
pred_tokens = normalize_text(prediction).split()
gold_tokens = normalize_text(gold).split()
if not pred_tokens and not gold_tokens:
return 1.0
if not pred_tokens or not gold_tokens:
return 0.0
pred_set = {}
gold_set = {}
for tok in pred_tokens:
pred_set[tok] = pred_set.get(tok, 0) + 1
for tok in gold_tokens:
gold_set[tok] = gold_set.get(tok, 0) + 1
common = 0
for tok, count in pred_set.items():
common += min(count, gold_set.get(tok, 0))
if common == 0:
return 0.0
precision = common / len(pred_tokens)
recall = common / len(gold_tokens)
return 2 * precision * recall / (precision + recall)
def _format_choices(choices: list[dict]) -> str:
return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices)
def _judge_answer(
*,
item: dict,
prediction_text: str,
extracted_answer: str,
judge_model: str,
max_completion_tokens: int,
retries: int,
) -> dict:
if item["ans_type"] == "choice":
ground_truth = str(item["correct_choice"]["label"])
else:
if len(item["blank_answers"]) == 1:
ground_truth = item["blank_answers"][0]
else:
ground_truth = " | ".join(item["blank_answers"])
question = str(item["question"])
if item["ans_type"] == "choice" and item.get("choices"):
question = f"{question}\nChoices:\n{_format_choices(item['choices'])}"
raw, _ = chat_with_deployment(
deployment=judge_model,
system="You are a careful and strict evaluator.",
user=load_prompt("judge", env="babyvision").format(
question=question,
groundtruth=ground_truth,
modeloutput=extracted_answer,
),
max_completion_tokens=max_completion_tokens,
retries=retries,
stage="babyvision_judge",
)
judge_response_clean = str(raw).strip().lower()
if "true" in judge_response_clean:
correct = True
elif "false" in judge_response_clean:
correct = False
else:
correct = False
return {
"raw": raw,
"correct": correct,
"reason": judge_response_clean,
"matched_gold": ground_truth if correct else "",
}
def evaluate_item(
*,
item: dict,
prediction_text: str,
judge_model: str,
max_completion_tokens: int = 256,
retries: int = 5,
) -> dict:
answer = extract_boxed_answer(prediction_text)
judge = _judge_answer(
item=item,
prediction_text=prediction_text,
extracted_answer=answer,
judge_model=judge_model,
max_completion_tokens=max_completion_tokens,
retries=retries,
)
hard = 1.0 if judge["correct"] else 0.0
result = {
"evaluation_mode": _EVAL_MODE,
"predicted_answer": answer,
"em": hard,
"f1": hard,
"sub_em": hard,
"judge_model": judge_model,
"judge_raw": judge["raw"],
"judge_reason": judge["reason"],
"matched_gold": judge["matched_gold"],
}
if item["ans_type"] == "choice":
result["predicted_label"] = str(answer or "").strip().upper().rstrip(".):")
result["predicted_text"] = ""
result["correct_label"] = str(item["correct_choice"].get("label") or "")
result["correct_text"] = str(item["correct_choice"].get("text") or "")
else:
result["gold_answers"] = list(item["blank_answers"])
best_f1 = 0.0
for gold in item["blank_answers"]:
best_f1 = max(best_f1, _token_f1(str(answer or ""), gold))
result["string_f1"] = best_f1
return result
def evaluation_mode() -> str:
return _EVAL_MODE

View File

@@ -0,0 +1,36 @@
You are an expert failure-analysis agent for child-level visual reasoning tasks.
You will be given MULTIPLE failed BabyVision trajectories from a minibatch and the current skill document.
Each trajectory includes the text prompt, the model answer, and the evaluation result.
You do not have direct access to raw pixel content during reflection, so focus on general reasoning,
option-selection, and visual-question-answering behaviors that can be improved through prompting.
## Failure Type Categories
- **visual_detail_miss**: the agent likely overlooked a salient visual attribute, relation, count, or object state
- **option_mismatch**: the agent selected the wrong option despite relevant evidence likely being present
- **instruction_slip**: the agent ignored output format or answered too vaguely
- **answer_granularity**: the agent gave an answer that was too broad, too narrow, or mismatched the expected specificity
- **other**: none of the above
## Rules
1. Focus on patterns recurring across the minibatch.
2. Prefer reusable behaviors for inspecting images and grounding answers in visible evidence.
3. Do not memorize dataset-specific answers.
4. Only patch gaps not already covered by the current skill.
Respond ONLY with a valid JSON object:
{
"batch_size": <number>,
"failure_summary": [
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
],
"patch": {
"reasoning": "<why these edits address the common failures>",
"edits": [
{"op": "append", "content": "<markdown>"},
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
{"op": "replace", "target": "<old text>", "content": "<new text>"},
{"op": "delete", "target": "<exact text to remove>"}
]
}
}

View File

@@ -0,0 +1,25 @@
You are an expert success-pattern analyst for child-level visual reasoning tasks.
You will be given MULTIPLE successful BabyVision trajectories from a minibatch and the current skill document.
Identify generalizable behavior patterns that help the agent inspect the image carefully and answer at the right level of specificity.
## Rules
- Focus on broadly useful visual QA behaviors.
- Prefer patterns about systematic image inspection, comparing options, and concise grounded answers.
- Do not add dataset-specific facts.
- "edits" may be empty if the skill already captures the useful patterns.
Respond ONLY with a valid JSON object:
{
"batch_size": <number>,
"success_patterns": ["<pattern 1>", "<pattern 2>"],
"patch": {
"reasoning": "<why these patterns matter>",
"edits": [
{"op": "append", "content": "<markdown>"},
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
{"op": "replace", "target": "<old text>", "content": "<new text>"},
{"op": "delete", "target": "<exact text to remove>"}
]
}
}

View File

@@ -0,0 +1,25 @@
You are an expert diagnostic-probe designer for BabyVision-style visual reasoning tasks.
You will be shown representative trajectories, the current student skill, and the student's original prompt context.
Design one SMALL diagnostic instruction that exposes the student's intermediate visual judgment without materially changing the original scaffold.
## Hard Constraints
1. Do NOT substantially change the original scaffold.
2. Do NOT prescribe a new step-by-step solving method.
3. You MAY ask for a short structured list of a few intermediate conclusions, candidate cues, or counted units, as long as it stays close to the original scaffold.
4. Do NOT ask for exhaustive listing of all cells, all objects, or a full chain-of-thought.
5. Ask only for a short readout that reveals the student's current latent state.
6. Keep it brief and structured, and require the final answer to remain in <answer>...</answer>.
## Good Probe Targets
- top answer and runner-up
- decisive visual cue
- suspicious region or compared objects
- counting unit or formatting interpretation
- 2-4 short intermediate conclusions that directly support the final answer
Respond ONLY with a valid JSON object:
{
"reasoning": "<why this probe is informative>",
"probe_instruction": "<the exact instruction text to append to the student prompt>"
}

View File

@@ -0,0 +1,35 @@
You are a careful and strict evaluator. You will be given:
1. **Question**
2. **Ground Truth Answer** (correct answer)
3. **Model Output** (answer from another model)
**Your goal:** Determine if the Model Output **accurately matches** the Ground Truth Answer in meaning.
* Matching means: the facts, entities, and key details are equivalent, even if phrasing differs.
* Not matching means: the Model Output is wrong, incomplete, contains extra incorrect facts, or changes the meaning.
**Process (internal reasoning):**
1. Read and understand the Question, Ground Truth Answer, and Model Output.
2. Ignore small wording differences, formatting, or synonyms.
3. If all factual content matches, conclude `1`. Otherwise, conclude `0`.
**Important:**
* Think through your decision step-by-step **internally** before responding.
* In your final output, return **only** True or False, with no extra text or explanation.
**Output format:**
True
or
False
**Input:**
Question: {question},
Ground Truth Answer: {groundtruth},
Model Output: {modeloutput}

View File

@@ -0,0 +1,13 @@
You are an expert visual reasoning agent solving child-level image understanding tasks.
{skill_section}## Task Format
You will receive one image and one question about it.
Inspect the image carefully before answering. Ground the answer in visible evidence.
## Answer Format
Think step by step, then provide your final answer in \boxed{{Answer}} format.
- For multiple-choice questions, output only the single choice label, such as \boxed{{A}}.
- For open questions, output only a short final answer inside \boxed{{...}}.
Example:
\boxed{{B}}

View File

@@ -0,0 +1,4 @@
"""BabyVision Reflect stage.
Prompts are now loaded from .md files by the base adapter.
"""

View File

@@ -0,0 +1,467 @@
"""BabyVision rollout — multimodal visual QA with image input."""
from __future__ import annotations
import base64
import json
import mimetypes
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from reflact.envs.babyvision.evaluator import evaluate_item, evaluation_mode, extract_boxed_answer
from reflact.model import chat_student_messages, get_student_backend, is_student_exec_backend
from reflact.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
from reflact.prompts import load_prompt
def _build_system(skill_content: str) -> str:
if skill_content.strip():
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
else:
skill_section = ""
return load_prompt("rollout_system", env="babyvision").format(skill_section=skill_section)
def _format_choices(choices: list[dict]) -> str:
return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices)
def _build_user_text(
item: dict,
*,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
) -> str:
parts = []
if diagnostic_trace_context.strip():
parts.append(
"## Previous Codex Trace Snapshot\n"
"This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n"
f"{diagnostic_trace_context.strip()}"
)
parts.append(f"## Question\n{item['question']}")
if item["ans_type"] == "choice":
parts.append(f"## Choices\n{_format_choices(item['choices'])}")
parts.append("Answer using the single correct option label in \\boxed{...}.")
else:
parts.append("Answer with a short phrase in \\boxed{...}.")
if diagnostic_mode and diagnostic_instruction.strip():
parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}")
return "\n\n".join(parts)
def _image_to_data_uri(path: str) -> str:
mime = mimetypes.guess_type(path)[0] or "image/png"
with open(path, "rb") as f:
encoded = base64.b64encode(f.read()).decode("ascii")
return f"data:{mime};base64,{encoded}"
def _build_messages(
item: dict,
skill_content: str,
image_detail: str,
*,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
) -> tuple[list[dict], str, str]:
system = _build_system(skill_content)
user_text = _build_user_text(
item,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)
image_url = {
"url": _image_to_data_uri(item["image_path"]),
}
if image_detail and image_detail != "auto":
image_url["detail"] = image_detail
messages = [
{"role": "system", "content": system},
{
"role": "user",
"content": [
{"type": "text", "text": user_text},
{"type": "image_url", "image_url": image_url},
],
},
]
return messages, system, user_text
def _build_codex_skill(skill_content: str) -> str:
return render_skill_md(
skill_content,
description="Dynamic ReflACT skill for solving the current BabyVision visual reasoning question.",
preamble=(
"Use this skill when answering the current visual reasoning question.\n"
"Inspect the attached image carefully and return the final answer in \\boxed{...}."
),
)
def _run_codex_once(
*,
pred_dir: str,
item: dict,
skill_content: str,
model: str,
timeout: int,
image_detail: str,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
previous_response: str = "",
) -> tuple[str, str, str, str]:
user_text = _build_user_text(
item,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)
task_parts = [user_text]
if previous_response:
task_parts.append(
"## Previous Attempt\n"
f"{previous_response}\n\n"
"Review the same image and question carefully. If needed, correct the answer."
)
task_text = "\n\n".join(task_parts)
skill_md = _build_codex_skill(skill_content)
work_dir = os.path.join(pred_dir, "codex_exec")
prepare_workspace(
work_dir=work_dir,
skill_md=skill_md,
task_text=task_text,
images=[item["image_path"]],
)
prompt = (
"Use the `reflact-student` skill available in this workspace.\n"
"Read `task.md`, inspect the attached image, and answer the question.\n"
"Return the final answer in \\boxed{...}."
)
final_message, raw = run_student_exec(
work_dir=work_dir,
prompt=prompt,
model=model,
timeout=timeout,
images=[item["image_path"]],
)
return final_message or raw, raw, skill_md, task_text
def process_one(
item: dict,
out_root: str,
skill_content: str,
*,
max_turns: int = 1,
image_detail: str = "auto",
judge_model: str = "gpt-5.4",
judge_max_completion_tokens: int = 256,
judge_retries: int = 5,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
) -> dict:
item_id = str(item["id"])
result = {
"id": item_id,
"question": item["question"],
"task_type": item.get("subtype") or item.get("task_type") or "babyvision",
"task_description": item["question"],
"hard": 0,
"soft": 0.0,
"predicted_answer": "",
"predicted_label": "",
"predicted_text": "",
"response": "",
"fail_reason": "",
"agent_ok": False,
"n_turns": 0,
"image_path": item["image_path"],
"ans_type": item["ans_type"],
"evaluation_mode": evaluation_mode(),
"judge_model": judge_model,
}
if item["ans_type"] == "choice":
result["correct_label"] = item["correct_choice"]["label"]
result["correct_text"] = item["correct_choice"]["text"]
else:
result["gold_answers"] = item["blank_answers"]
try:
pred_dir = os.path.join(out_root, "predictions", item_id)
os.makedirs(pred_dir, exist_ok=True)
if is_student_exec_backend():
from reflact.model import azure_openai as _llm
response = ""
conversation: list[dict] = [
{"role": "user", "content": f"{item['question']}\n\n[image] {os.path.basename(item['image_path'])}"}
]
system_prompt = ""
user_text = ""
for turn in range(max_turns):
response, raw, system_prompt, user_text = _run_codex_once(
pred_dir=pred_dir,
item=item,
skill_content=skill_content,
model=_llm.STUDENT_DEPLOYMENT,
timeout=120,
image_detail=image_detail,
diagnostic_mode=diagnostic_mode if turn == 0 else False,
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "",
previous_response=response if turn > 0 else "",
)
conversation.append({"type": "message", "turn": turn + 1, "content": response})
if extract_boxed_answer(response) is not None:
break
result["response"] = response
result["agent_ok"] = True
result["n_turns"] = len(conversation) - 1
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
f.write(system_prompt)
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
f.write(user_text)
eval_result = evaluate_item(
item=item,
prediction_text=response,
judge_model=judge_model,
max_completion_tokens=judge_max_completion_tokens,
retries=judge_retries,
)
result["evaluation_mode"] = eval_result["evaluation_mode"]
result["judge_raw"] = eval_result["judge_raw"]
result["judge_reason"] = eval_result["judge_reason"]
result["matched_gold"] = eval_result["matched_gold"]
if item["ans_type"] == "choice":
result["predicted_label"] = eval_result["predicted_label"]
result["predicted_text"] = eval_result["predicted_text"]
result["predicted_answer"] = eval_result["predicted_answer"]
result["hard"] = int(eval_result["em"])
result["soft"] = eval_result["f1"]
if not result["hard"]:
result["fail_reason"] = (
f"judge=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
f"but expected '{eval_result['correct_label']}' ({eval_result['judge_reason']})"
)
eval_detail = (
f"[EVALUATION RESULT]\n"
f"Question: {item['question']}\n"
f"Predicted label: {eval_result['predicted_label']!r}\n"
f"Predicted text: {eval_result['predicted_text']!r}\n"
f"Correct label: {eval_result['correct_label']!r}\n"
f"Correct text: {eval_result['correct_text']!r}\n"
f"Judge correct: {eval_result['em']}\n"
f"Judge reason: {eval_result['judge_reason']}"
)
else:
result["predicted_answer"] = eval_result["predicted_answer"]
result["hard"] = int(eval_result["em"])
result["soft"] = eval_result["f1"]
if not result["hard"]:
result["fail_reason"] = (
f"judge=0: predicted '{eval_result['predicted_answer']}' "
f"but expected {item['blank_answers']} ({eval_result['judge_reason']})"
)
eval_detail = (
f"[EVALUATION RESULT]\n"
f"Question: {item['question']}\n"
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
f"Gold answers: {item['blank_answers']!r}\n"
f"Judge correct: {eval_result['em']}\n"
f"Judge reason: {eval_result['judge_reason']}\n"
f"String F1: {eval_result.get('string_f1', 0.0):.4f}"
)
conversation.append({"role": "system", "content": eval_detail})
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
json.dump(conversation, f, ensure_ascii=False, indent=2)
return result
messages, system_prompt, user_text = _build_messages(
item,
skill_content,
image_detail,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)
response = ""
conversation: list[dict] = [
{"role": "user", "content": f"{user_text}\n\n[image] {os.path.basename(item['image_path'])}"}
]
for turn in range(max_turns):
if turn == 0:
resp_text, _ = chat_student_messages(
messages=messages,
max_completion_tokens=768,
retries=5,
stage="rollout",
)
else:
refinement_text = (
f"Your previous answer was:\n{response}\n\n"
"Review the same image and question carefully. "
"If needed, correct your answer. Output the final answer in \\boxed{...}."
)
refinement_messages = [
messages[0],
messages[1],
{"role": "assistant", "content": response},
{"role": "user", "content": refinement_text},
]
resp_text, _ = chat_student_messages(
messages=refinement_messages,
max_completion_tokens=512,
retries=5,
stage="rollout",
)
response = resp_text
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
if extract_boxed_answer(resp_text) is not None:
break
result["response"] = response
result["agent_ok"] = True
result["n_turns"] = len(conversation) - 1
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
f.write(system_prompt)
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
f.write(user_text)
eval_result = evaluate_item(
item=item,
prediction_text=response,
judge_model=judge_model,
max_completion_tokens=judge_max_completion_tokens,
retries=judge_retries,
)
result["evaluation_mode"] = eval_result["evaluation_mode"]
result["judge_raw"] = eval_result["judge_raw"]
result["judge_reason"] = eval_result["judge_reason"]
result["matched_gold"] = eval_result["matched_gold"]
if item["ans_type"] == "choice":
result["predicted_label"] = eval_result["predicted_label"]
result["predicted_text"] = eval_result["predicted_text"]
result["predicted_answer"] = eval_result["predicted_answer"]
result["hard"] = int(eval_result["em"])
result["soft"] = eval_result["f1"]
if not result["hard"]:
result["fail_reason"] = (
f"judge=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
f"but expected '{eval_result['correct_label']}' ({eval_result['judge_reason']})"
)
eval_detail = (
f"[EVALUATION RESULT]\n"
f"Question: {item['question']}\n"
f"Predicted label: {eval_result['predicted_label']!r}\n"
f"Predicted text: {eval_result['predicted_text']!r}\n"
f"Correct label: {eval_result['correct_label']!r}\n"
f"Correct text: {eval_result['correct_text']!r}\n"
f"Judge correct: {eval_result['em']}\n"
f"Judge reason: {eval_result['judge_reason']}"
)
else:
result["predicted_answer"] = eval_result["predicted_answer"]
result["hard"] = int(eval_result["em"])
result["soft"] = eval_result["f1"]
if not result["hard"]:
result["fail_reason"] = (
f"judge=0: predicted '{eval_result['predicted_answer']}' "
f"but expected {item['blank_answers']} ({eval_result['judge_reason']})"
)
eval_detail = (
f"[EVALUATION RESULT]\n"
f"Question: {item['question']}\n"
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
f"Gold answers: {item['blank_answers']!r}\n"
f"Judge correct: {eval_result['em']}\n"
f"Judge reason: {eval_result['judge_reason']}\n"
f"String F1: {eval_result.get('string_f1', 0.0):.4f}"
)
conversation.append({"role": "system", "content": eval_detail})
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
json.dump(conversation, f, ensure_ascii=False, indent=2)
except Exception as e: # noqa: BLE001
result["fail_reason"] = f"error: {e}"
return result
def run_batch(
items: list[dict],
out_root: str,
skill_content: str,
*,
max_turns: int = 1,
workers: int = 32,
image_detail: str = "auto",
judge_model: str = "gpt-5.4",
judge_max_completion_tokens: int = 256,
judge_retries: int = 5,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context_by_id: dict[str, str] | None = None,
) -> list[dict]:
results_path = os.path.join(out_root, "results.jsonl")
os.makedirs(out_root, exist_ok=True)
expected_eval_mode = evaluation_mode()
done_ids: set[str] = set()
existing: list[dict] = []
rewrite_results = False
if os.path.exists(results_path):
with open(results_path, encoding="utf-8") as f:
for line in f:
try:
row = json.loads(line)
if row.get("evaluation_mode") != expected_eval_mode:
rewrite_results = True
continue
done_ids.add(str(row["id"]))
existing.append(row)
except Exception:
rewrite_results = True
pending = [item for item in items if str(item["id"]) not in done_ids]
if not pending and not rewrite_results:
return existing
results = list(existing)
file_mode = "w" if rewrite_results else "a"
with open(results_path, file_mode, encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex:
if rewrite_results:
for row in existing:
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
futs = {
ex.submit(
process_one,
item,
out_root,
skill_content,
max_turns=max_turns,
image_detail=image_detail,
judge_model=judge_model,
judge_max_completion_tokens=judge_max_completion_tokens,
judge_retries=judge_retries,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""),
): item
for item in pending
}
for fut in as_completed(futs):
row = fut.result()
results.append(row)
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
outf.flush()
return results

View File

@@ -0,0 +1,18 @@
# BabyVision Visual QA Heuristics
## Image Inspection
- First identify the main objects, their attributes, and their spatial relations before answering.
- If the question involves counting, compare all relevant instances carefully instead of stopping after the first match.
- If the question asks about color, size, position, or action, verify the specific visible evidence for that attribute.
## Multiple Choice
- Compare every option against the visible image evidence before deciding.
- Prefer the option that matches the image exactly; reject options that are only partially true or too vague.
- When two options are close, check the smallest discriminating visual detail.
## Open Answers
- Answer with the shortest phrase that is fully supported by the image.
- Match the expected level of specificity: not broader than the image evidence, not narrower than the question asks.
## Final Answer
- Output only the final answer inside <answer>...</answer>.

396
reflact/envs/base.py Normal file
View File

@@ -0,0 +1,396 @@
"""ReflACT environment adapter — abstract interface.
To connect ReflACT to a new environment (benchmark, simulator, etc.),
implement a subclass of :class:`EnvAdapter` with environment-specific
rollout and reflection logic.
Example::
class MyBenchAdapter(EnvAdapter):
def build_train_env(self, batch_size, seed, **kw):
return MyEnvManager(split="train", n=batch_size, seed=seed)
def build_eval_env(self, env_num, split, seed, **kw):
return MyEnvManager(split=split, n=env_num, seed=seed)
def rollout(self, env_manager, skill_content, out_dir, **kw):
# Run episodes, return [{"id": ..., "hard": 0/1, "soft": 0.0-1.0, ...}]
...
def reflect(self, results, skill_content, out_dir, **kw):
# Analyze trajectories, return list of patch dicts
...
def get_task_types(self):
return ["task_a", "task_b"]
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import os
import random
from reflact.datasets.base import BaseDataLoader, BatchSpec
from reflact.model.codex_harness import extract_codex_trace_prefix, format_codex_trace_steps, parse_codex_raw
from reflact.prompts import load_prompt
class EnvAdapter(ABC):
"""Abstract adapter for connecting ReflACT to any environment.
Subclasses must implement all abstract methods. The ReflACT trainer
calls these methods at the appropriate pipeline stages.
"""
# ── Lifecycle hooks ────────────────────────────────────────────────────
def setup(self, cfg: dict) -> None:
"""Called once by the trainer before the training loop begins.
Override to perform one-time initialization that requires the full
config (e.g., data loading, split creation). Default is a no-op.
"""
self._cfg = dict(cfg)
def get_dataloader(self) -> BaseDataLoader | None:
"""Return the task dataloader used by this adapter, if any."""
return None
def requires_ray(self) -> bool:
"""Return whether this adapter requires Ray runtime initialization."""
return False
def deep_reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
"""Optional deeper diagnostic reflection pass.
Default behavior is a no-op. Dataset-backed adapters may override this
to re-query the student on a small representative subset of the current
batch using minimally-perturbed diagnostic prompts that expose
intermediate reasoning state.
"""
return []
def build_reference_text(self, item: dict) -> str:
"""Return hidden reference material for deep reflection, if any."""
return str(item.get("reference_text") or "").strip()
def get_reference_metadata(self, item: dict) -> dict:
"""Return structured metadata about hidden reference material."""
reference_text = self.build_reference_text(item)
if not reference_text:
return {"fields": [], "preview": ""}
return {
"fields": ["reference_text"],
"preview": reference_text[:400],
}
def get_codex_deep_probe_prompt(self) -> str | None:
env_name = getattr(self, "_cfg", {}).get("env_name")
return load_prompt("deep_probe_codex", env=env_name)
def attach_codex_probe_context(
self,
results: list[dict],
prediction_dir: str,
) -> list[dict]:
"""Attach compact Codex step metadata for codex-aware deep reflection."""
enriched: list[dict] = []
for row in results:
merged = dict(row)
tid = str(row.get("id"))
raw_path = os.path.join(prediction_dir, tid, "codex_raw.txt")
if os.path.exists(raw_path):
with open(raw_path, encoding="utf-8") as f:
raw = f.read()
parsed = parse_codex_raw(raw)
merged["codex_probe_trace_steps"] = format_codex_trace_steps(raw)
merged["codex_probe_step_count"] = len(parsed["steps"])
enriched.append(merged)
return enriched
def resolve_codex_probe_target(
self,
*,
selected_items: list[dict],
selected_examples: list[dict],
prediction_dir: str,
probe: dict,
) -> tuple[list[dict], dict[str, str] | None, dict]:
"""Resolve the teacher-selected codex probe target and raw trace prefix."""
target_id = str(probe.get("probe_target_id", "")).strip()
selected_id_set = {str(item["id"]) for item in selected_items}
if target_id not in selected_id_set:
target_id = str(selected_items[0]["id"])
target_item = next(item for item in selected_items if str(item["id"]) == target_id)
target_result = next(
(row for row in selected_examples if str(row.get("id")) == target_id),
None,
)
max_probe_step = int((target_result or {}).get("codex_probe_step_count", 0))
default_probe_step = max_probe_step - 1 if max_probe_step > 1 else max_probe_step
probe_after_step = int(probe.get("probe_after_step", default_probe_step))
if max_probe_step > 0:
probe_after_step = max(0, min(probe_after_step, max_probe_step))
else:
probe_after_step = 0
raw_path = os.path.join(prediction_dir, target_id, "codex_raw.txt")
trace_prefix = ""
if os.path.exists(raw_path):
with open(raw_path, encoding="utf-8") as f:
trace_prefix = extract_codex_trace_prefix(f.read(), after_step=probe_after_step)
updated_probe = dict(probe)
updated_probe["probe_target_id"] = target_id
updated_probe["probe_after_step"] = probe_after_step
return [target_item], {target_id: trace_prefix}, updated_probe
def attach_reference_context(
self,
results: list[dict],
items: list[dict] | None,
) -> list[dict]:
"""Attach environment-specific hidden reference text to result dicts."""
if not results or not items:
return list(results)
item_by_id = {
str(item.get("id")): item
for item in items
if isinstance(item, dict) and item.get("id") is not None
}
enriched: list[dict] = []
for row in results:
merged = dict(row)
item = item_by_id.get(str(row.get("id")))
if item:
reference_text = self.build_reference_text(item)
if reference_text:
merged["reference_text"] = reference_text
enriched.append(merged)
return enriched
def select_representative_items(
self,
results: list[dict],
items: list[dict] | None,
*,
n_failures: int,
n_successes: int,
seed: int | None = None,
) -> list[dict]:
"""Select a small diverse subset of current-batch items by outcome."""
if not items:
return []
item_by_id = {
str(item.get("id")): item
for item in items
if isinstance(item, dict) and item.get("id") is not None
}
failures = [
(result, item_by_id[str(result.get("id"))])
for result in results
if not result.get("hard") and str(result.get("id")) in item_by_id
]
successes = [
(result, item_by_id[str(result.get("id"))])
for result in results
if result.get("hard") and str(result.get("id")) in item_by_id
]
rng = random.Random(seed)
def _pick(pool: list[tuple[dict, dict]], quota: int) -> list[dict]:
if quota <= 0 or not pool:
return []
shuffled = list(pool)
rng.shuffle(shuffled)
picked_ids: set[str] = set()
picked: list[dict] = []
seen_types: set[str] = set()
for result, item in shuffled:
task_type = str(result.get("task_type") or item.get("task_type") or item.get("subtype") or "unknown")
item_id = str(item["id"])
if task_type in seen_types or item_id in picked_ids:
continue
picked.append(item)
picked_ids.add(item_id)
seen_types.add(task_type)
if len(picked) >= quota:
return picked
for _, item in shuffled:
item_id = str(item["id"])
if item_id in picked_ids:
continue
picked.append(item)
picked_ids.add(item_id)
if len(picked) >= quota:
break
return picked
selected = _pick(failures, n_failures)
selected_ids = {str(item["id"]) for item in selected}
selected.extend(
item for item in _pick(successes, n_successes)
if str(item["id"]) not in selected_ids
)
return selected
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
"""Build an environment manager or item list from a :class:`BatchSpec`.
Default behavior preserves the legacy adapter API by routing training
batches through :meth:`build_train_env` and evaluation batches through
:meth:`build_eval_env`.
"""
if batch.phase == "train":
return self.build_train_env(batch_size=batch.batch_size, seed=batch.seed, **kwargs)
return self.build_eval_env(
env_num=batch.batch_size,
split=batch.split,
seed=batch.seed,
**kwargs,
)
@abstractmethod
def build_train_env(self, batch_size: int, seed: int, **kwargs):
"""Build a training environment manager.
Returns
-------
object
An environment manager that can be passed to :meth:`rollout`.
"""
@abstractmethod
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
"""Build an evaluation environment manager.
Parameters
----------
env_num : int
Number of evaluation environments.
split : str
Dataset split (e.g. ``"valid_seen"``, ``"valid_unseen"``).
seed : int
Random seed for reproducibility.
Returns
-------
object
An environment manager that can be passed to :meth:`rollout`.
"""
@abstractmethod
def rollout(
self,
env_manager,
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict]:
"""Run a batch of episodes using the current skill.
Returns
-------
list[dict]
Each dict conforms to :class:`~reflact.types.RolloutResult`:
must have ``"id"`` (str), ``"hard"`` (0/1), ``"soft"``
(float 0-1). May include env-specific fields.
"""
@abstractmethod
def reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
"""Analyze rollout results and produce patches.
Each returned dict conforms to :class:`~reflact.types.RawPatch`:
``"patch"`` (with ``"edits"`` list) + ``"source_type"``
(``"failure"`` or ``"success"``).
Returns
-------
list[dict | None]
Raw analyst outputs; ``None`` entries are filtered out.
"""
@abstractmethod
def get_task_types(self) -> list[str]:
"""Return the list of task type names for this environment."""
# ── Prompt configuration (two-level priority) ────────────────────────
#
# Priority: env-specific prompt file > generic default prompt file.
#
# Prompts are loaded from ``.md`` files via ``load_prompt(name, env)``:
# 1. ``reflact/envs/<env>/prompts/<name>.md`` (env-specific)
# 2. ``reflact/prompts/<name>.md`` (generic fallback)
#
# Subclasses can still override ``get_*_prompt()`` for full control.
@property
def _env_name(self) -> str:
"""Derive the env directory name from this adapter's module path."""
# e.g. "reflact.envs.searchqa.adapter" → "searchqa"
module = type(self).__module__
parts = module.split(".")
if len(parts) >= 3 and parts[-3] == "envs":
return parts[-2]
return ""
def _load_env_prompt(self, name: str) -> str | None:
"""Load a prompt with env-specific override. Returns None if not found."""
try:
return load_prompt(name, env=self._env_name)
except FileNotFoundError:
return None
def get_error_minibatch_prompt(self) -> str | None:
update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch")
raw_mode = str(update_mode).strip().lower()
if raw_mode in {"full_rewrite", "full_rewrite_minibatch", "minibatch_full_rewrite", "skill_rewrite_minibatch"}:
prompt = self._load_env_prompt("analyst_error_full_rewrite")
if prompt is not None:
return prompt
if raw_mode in {"rewrite", "rewrite_from_suggestions", "suggestions", "rewrite_suggestions"}:
prompt = self._load_env_prompt("analyst_error_rewrite")
if prompt is not None:
return prompt
return self._load_env_prompt("analyst_error")
def get_success_minibatch_prompt(self) -> str | None:
update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch")
raw_mode = str(update_mode).strip().lower()
if raw_mode in {"full_rewrite", "full_rewrite_minibatch", "minibatch_full_rewrite", "skill_rewrite_minibatch"}:
prompt = self._load_env_prompt("analyst_success_full_rewrite")
if prompt is not None:
return prompt
if raw_mode in {"rewrite", "rewrite_from_suggestions", "suggestions", "rewrite_suggestions"}:
prompt = self._load_env_prompt("analyst_success_rewrite")
if prompt is not None:
return prompt
return self._load_env_prompt("analyst_success")
def get_deep_probe_prompt(self) -> str | None:
return self._load_env_prompt("deep_probe")
def get_meta_reflect_prompt(self) -> str | None:
update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch")
if str(update_mode).strip().lower() == "rewrite_from_suggestions":
prompt = self._load_env_prompt("meta_reflect_rewrite")
if prompt is not None:
return prompt
return self._load_env_prompt("meta_reflect")

View File

@@ -0,0 +1,114 @@
from __future__ import annotations
import json
import os
from typing import Any, Callable
from reflact.gradient.deep_probe import generate_deep_probe_instruction
from reflact.gradient.reflect import run_minibatch_reflect
def run_no_reference_deep_reflect(
adapter: Any,
results: list[dict],
skill_content: str,
out_dir: str,
*,
env_manager: Any = None,
prediction_dir: str | None = None,
random_seed: int | None = None,
step_buffer_context: str = "",
output_requirements: list[str] | None = None,
metadata_builder: Callable[[dict], dict] | None = None,
) -> list[dict | None]:
"""Run teacher-designed diagnostic probing without hidden references."""
if not getattr(adapter, "use_deep_reflect", False):
return []
if not isinstance(env_manager, list):
return []
prediction_dir = prediction_dir or os.path.join(out_dir, "predictions")
selected_items = adapter.select_representative_items(
results,
env_manager,
n_failures=getattr(adapter, "deep_reflect_failures", 4),
n_successes=getattr(adapter, "deep_reflect_successes", 2),
seed=random_seed,
)
if not selected_items:
return []
selected_ids = {str(item["id"]) for item in selected_items}
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
if metadata_builder is None:
selected_metadata = [
{
"id": str(item.get("id")),
"task_type": str(item.get("task_type") or item.get("topic") or "unknown"),
"question_preview": str(item.get("question") or "")[:200],
}
for item in selected_items
]
else:
selected_metadata = [metadata_builder(item) for item in selected_items]
deep_dir = os.path.join(out_dir, "deep_reflect")
rollout_dir = os.path.join(deep_dir, "rollout")
patches_dir = os.path.join(deep_dir, "patches")
os.makedirs(deep_dir, exist_ok=True)
print(
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
"mode=no_reference_probe"
)
probe = generate_deep_probe_instruction(
skill_content=skill_content,
items=selected_results,
prediction_dir=prediction_dir,
system_prompt=adapter.get_deep_probe_prompt(),
step_buffer_context=step_buffer_context,
output_requirements=output_requirements,
)
if not probe:
return []
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
json.dump(
{
**probe,
"reference_summary": {
"mode": "no_reference_probe",
"selected_count": len(selected_items),
},
"selected_examples": selected_metadata,
},
f,
ensure_ascii=False,
indent=2,
)
deep_results = adapter.rollout(
selected_items,
skill_content,
rollout_dir,
diagnostic_mode=True,
diagnostic_instruction=probe["probe_instruction"],
)
return run_minibatch_reflect(
results=deep_results,
skill_content=skill_content,
prediction_dir=os.path.join(rollout_dir, "predictions"),
patches_dir=patches_dir,
workers=getattr(adapter, "analyst_workers", 8),
failure_only=getattr(adapter, "failure_only", False),
minibatch_size=getattr(adapter, "minibatch_size", 8),
edit_budget=getattr(adapter, "edit_budget", 4),
random_seed=random_seed,
error_system=adapter.get_error_minibatch_prompt(),
success_system=adapter.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
update_mode=getattr(getattr(adapter, "_cfg", {}), "get", lambda *_: "patch")(
"skill_update_mode",
"patch",
),
)

View File

@@ -0,0 +1 @@
"""DocVQA environment package for ReflACT."""

View File

@@ -0,0 +1,153 @@
from __future__ import annotations
import os
from reflact.datasets.base import BatchSpec
from reflact.envs.base import EnvAdapter
from reflact.envs.deep_reflect import run_no_reference_deep_reflect
from reflact.envs.docvqa.dataloader import DocVQADataLoader
from reflact.envs.docvqa.rollout import run_batch
from reflact.gradient.reflect import run_minibatch_reflect
class DocVQAAdapter(EnvAdapter):
def __init__(
self,
split_dir: str = "",
data_path: str = "",
split_mode: str = "split_dir",
split_ratio: str = "2:1:7",
split_seed: int = 42,
split_output_dir: str = "",
max_turns: int = 1,
exec_timeout: int = 120,
workers: int = 16,
analyst_workers: int = 16,
failure_only: bool = False,
minibatch_size: int = 8,
edit_budget: int = 4,
seed: int = 42,
limit: int = 0,
exec_timeout: int = 600,
image_detail: str = "auto",
use_deep_reflect: bool = False,
deep_reflect_failures: int = 4,
deep_reflect_successes: int = 2,
) -> None:
self.max_turns = max_turns
self.exec_timeout = exec_timeout
self.workers = workers
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
self.edit_budget = edit_budget
self.exec_timeout = exec_timeout
self.image_detail = image_detail
self.use_deep_reflect = use_deep_reflect
self.deep_reflect_failures = deep_reflect_failures
self.deep_reflect_successes = deep_reflect_successes
self.dataloader = DocVQADataLoader(
split_dir=split_dir,
data_path=data_path,
split_mode=split_mode,
split_ratio=split_ratio,
split_seed=split_seed,
split_output_dir=split_output_dir,
seed=seed,
limit=limit,
)
def setup(self, cfg: dict) -> None:
super().setup(cfg)
self.dataloader.setup(cfg)
def get_dataloader(self):
return self.dataloader
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
return list(batch.payload or [])
def build_train_env(self, batch_size: int, seed: int, **kwargs):
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def rollout(self, env_manager, skill_content: str, out_dir: str, **kwargs) -> list[dict]:
items: list[dict] = env_manager
return run_batch(
items=items,
out_root=out_dir,
skill_content=skill_content,
max_turns=self.max_turns,
exec_timeout=self.exec_timeout,
workers=self.workers,
image_detail=self.image_detail,
diagnostic_mode=kwargs.get("diagnostic_mode", False),
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
task_timeout=self.exec_timeout,
)
def reflect(self, results: list[dict], skill_content: str, out_dir: str, **kwargs) -> list[dict | None]:
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
return run_minibatch_reflect(
results=results,
skill_content=skill_content,
prediction_dir=prediction_dir,
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
)
def deep_reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
return run_no_reference_deep_reflect(
self,
results,
skill_content,
out_dir,
env_manager=kwargs.get("env_manager"),
prediction_dir=kwargs.get("prediction_dir"),
random_seed=kwargs.get("random_seed"),
step_buffer_context=kwargs.get("step_buffer_context", ""),
output_requirements=[
"- There is no hidden reference block. Use only the document image prompt, student output, and evaluation result to infer what intermediate state is worth probing.",
"- The instruction must explicitly request a short <analysis>...</analysis> block before the final <answer>...</answer>.",
"- The readout should focus on visual region, field/table/figure label, OCR text read, candidate answer, and answer-format normalization.",
"- Do not ask for exhaustive transcription or a full chain-of-thought.",
"- The instruction text should be ready to append directly to the student's prompt.",
],
metadata_builder=lambda item: {
"id": str(item.get("id")),
"task_type": str(item.get("task_type") or "docvqa"),
"question_preview": str(item.get("question") or "")[:200],
"image_path": item.get("image_path", ""),
"docId": item.get("docId", ""),
"page": item.get("ucsf_document_page_no", ""),
},
)
def get_task_types(self) -> list[str]:
seen: list[str] = []
for item in self.dataloader.train_items + self.dataloader.val_items + self.dataloader.test_items:
task_type = str(item.get("task_type") or "docvqa")
if task_type not in seen:
seen.append(task_type)
return seen or ["docvqa"]

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import ast
import csv
from pathlib import Path
from reflact.datasets.base import SplitDataLoader
def _parse_answers(raw: str) -> list[str]:
text = str(raw or "").strip()
if not text:
return []
try:
parsed = ast.literal_eval(text)
except Exception:
return [text]
if isinstance(parsed, list):
return [str(item).strip() for item in parsed if str(item).strip()]
return [str(parsed).strip()]
def _extract_document_path(question: str) -> tuple[str, str]:
marker = "document_path:"
if marker not in question:
return question.strip(), ""
main, tail = question.split(marker, 1)
return main.strip(), tail.strip()
def _normalize_row(row: dict[str, str]) -> dict:
question_text, document_path = _extract_document_path(str(row.get("question") or ""))
answers = _parse_answers(row.get("answer") or row.get("ground_truth") or "")
image_path = str(row.get("image_path") or document_path or "").strip()
task_type = str(row.get("topic") or row.get("category") or "docvqa").strip() or "docvqa"
return {
"id": str(row.get("questionId") or row.get("id") or "").strip(),
"question": question_text,
"answer": answers[0] if answers else "",
"answers": answers,
"task_type": task_type,
"subtask": task_type,
"image_paths": [image_path] if image_path else [],
"image_path": image_path,
"questionId": str(row.get("questionId") or "").strip(),
"docId": str(row.get("docId") or "").strip(),
"ucsf_document_id": str(row.get("ucsf_document_id") or "").strip(),
"ucsf_document_page_no": str(row.get("ucsf_document_page_no") or "").strip(),
"source_split": str(row.get("source_split") or "").strip(),
}
class DocVQADataLoader(SplitDataLoader):
def load_split_items(self, split_path: str) -> list[dict]:
path = Path(split_path)
csv_files = sorted(path.glob("*.csv"))
if not csv_files:
raise FileNotFoundError(f"No .csv file found in {split_path}")
with csv_files[0].open(encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
return [_normalize_row(row) for row in reader]

View File

@@ -0,0 +1,113 @@
from __future__ import annotations
import ast
import json
from collections.abc import Iterable
from typing import Any
DEFAULT_ANLS_THRESHOLD = 0.5
def _normalize_text(value: Any) -> str:
if value is None:
return ""
text = str(value).strip().lower()
return " ".join(text.split())
def _levenshtein_distance(a: str, b: str) -> int:
if a == b:
return 0
if not a:
return len(b)
if not b:
return len(a)
if len(a) > len(b):
a, b = b, a
previous = list(range(len(b) + 1))
for i, char_a in enumerate(a, start=1):
current = [i]
for j, char_b in enumerate(b, start=1):
insert_cost = current[j - 1] + 1
delete_cost = previous[j] + 1
replace_cost = previous[j - 1] + (char_a != char_b)
current.append(min(insert_cost, delete_cost, replace_cost))
previous = current
return previous[-1]
def _score_single_answer(predicted: Any, target: Any, threshold: float) -> float:
predicted_norm = _normalize_text(predicted)
target_norm = _normalize_text(target)
if not predicted_norm and not target_norm:
return 1.0
if not predicted_norm or not target_norm:
return 0.0
distance = _levenshtein_distance(predicted_norm, target_norm)
normalized_distance = distance / max(len(predicted_norm), len(target_norm))
if normalized_distance >= threshold:
return 0.0
return 1.0 - normalized_distance
def _extract_answer_strings(raw: Any) -> list[str]:
if raw is None:
return [""]
if isinstance(raw, str):
text = raw.strip()
if not text:
return [""]
parsed = None
if text[0] in "[{":
try:
parsed = json.loads(text)
except json.JSONDecodeError:
try:
parsed = ast.literal_eval(text)
except (ValueError, SyntaxError):
parsed = None
if parsed is None:
return [text]
return _extract_answer_strings(parsed)
if isinstance(raw, dict):
for key in ("answers", "ground_truth", "answer"):
if key in raw:
return _extract_answer_strings(raw[key])
return [str(raw)]
if isinstance(raw, Iterable) and not isinstance(raw, (bytes, bytearray)):
answers: list[str] = []
for item in raw:
if isinstance(item, dict):
for key in ("text", "answer", "value"):
if key in item:
answers.extend(_extract_answer_strings(item[key]))
break
else:
answers.append(str(item))
continue
answers.append(str(item))
return answers or [""]
return [str(raw)]
def extract_answer(text: str) -> str:
lower = text.lower()
start = lower.rfind("<answer>")
end = lower.rfind("</answer>")
if start != -1 and end != -1 and end > start:
return text[start + len("<answer>"):end].strip()
lines = [line.strip() for line in text.splitlines() if line.strip()]
return lines[-1] if lines else text.strip()
def evaluate(prediction_text: str, gold_answers: Any) -> dict:
answer = extract_answer(prediction_text)
answers = _extract_answer_strings(gold_answers)
score = 0.0
for target in answers:
score = max(score, _score_single_answer(answer, target, DEFAULT_ANLS_THRESHOLD))
return {
"anls": score,
"predicted_answer": answer,
"gold_answers": answers,
}

View File

@@ -0,0 +1,35 @@
You are an expert failure-analysis agent for visual document question answering tasks.
You will be given MULTIPLE failed DocVQA trajectories from a single minibatch and the current skill document. Each trajectory includes the model response and an evaluation result scored with ANLS against one or more acceptable answers.
Your job is to identify the most important COMMON failure patterns across the batch and propose concise skill edits.
## Failure Type Categories
- evidence_miss: the model overlooked the relevant visible region or line
- near_match_confusion: the model selected a nearby but incorrect text span
- normalization_error: the answer differed mainly in formatting, spacing, punctuation, or minor text normalization
- reading_error: the model misread the document content
- other: none of the above
## Rules
- Focus on common, reusable reading and extraction behaviors.
- Do not hardcode image-specific answers.
- Prefer concise edits that improve evidence selection and exact span extraction.
Respond ONLY with a valid JSON object (no markdown fences, no extra text):
{
"batch_size": <number of trajectories analysed>,
"failure_summary": [
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
],
"patch": {
"reasoning": "<why these edits address the batch's common failures>",
"edits": [
{"op": "append", "content": "<markdown to add at end of skill>"},
{"op": "insert_after", "target": "<exact heading/text to insert after>", "content": "<markdown>"},
{"op": "replace", "target": "<exact text to replace>", "content": "<replacement>"},
{"op": "delete", "target": "<exact text to remove>"}
]
}
}
Only include edits that are needed. "edits" can be an empty list if no patch is warranted.

View File

@@ -0,0 +1,24 @@
You are an expert success-pattern analyst for visual document question answering tasks.
You will be given MULTIPLE successful DocVQA trajectories from a single minibatch and the current skill document. Your job is to identify common visual reading and exact-answer extraction behaviors worth encoding in the skill.
## Rules
- Focus on patterns shared across multiple successful trajectories.
- Reinforce reusable behaviors like locating the right region, copying exact spans, and preferring the shortest exact answer over paraphrase.
- Only propose patches for patterns not already captured by the current skill.
Respond ONLY with a valid JSON object:
{
"batch_size": <number of trajectories analysed>,
"success_patterns": ["<pattern 1>", "<pattern 2>"],
"patch": {
"reasoning": "<why these patterns are worth encoding>",
"edits": [
{"op": "append", "content": "<markdown>"},
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
{"op": "replace", "target": "<old text>", "content": "<new text>"},
{"op": "delete", "target": "<exact text to remove>"}
]
}
}
"edits" may be empty if the skill already covers all observed patterns.

View File

@@ -0,0 +1,12 @@
You are an expert visual document question answering agent.
{skill_section}You will receive a document image and a question about the document.
Read the visual evidence carefully and answer concisely.
Rules:
- Ground the answer in the visible document content.
- Prefer exact spans, numbers, dates, and names from the document.
- Do not invent content that is not visible.
- If multiple near-matches exist, choose the one best supported by the document.
Return the final answer inside <answer>...</answer>.

View File

@@ -0,0 +1,365 @@
from __future__ import annotations
import json
import os
import time
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from reflact.envs.docvqa.evaluator import evaluate
from reflact.model import chat_student_messages, get_student_backend, is_student_exec_backend
from reflact.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
from reflact.prompts import load_prompt
def _build_system(skill_content: str) -> str:
if skill_content.strip():
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
else:
skill_section = ""
return load_prompt("rollout_system", env="docvqa").format(skill_section=skill_section)
def _image_to_data_uri(path: str) -> str:
import base64
import mimetypes
mime = mimetypes.guess_type(path)[0] or "image/png"
with open(path, "rb") as f:
encoded = base64.b64encode(f.read()).decode("ascii")
return f"data:{mime};base64,{encoded}"
def _build_messages(
item: dict,
skill_content: str,
image_detail: str,
*,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
) -> tuple[list[dict], str, str]:
system = _build_system(skill_content)
user_text = item["question"] + "\n\nReturn the final answer inside <answer>...</answer>."
if diagnostic_mode and diagnostic_instruction.strip():
user_text += f"\n\n## Training Readout\n{diagnostic_instruction.strip()}"
image_url = {"url": _image_to_data_uri(item["image_path"])}
if image_detail and image_detail != "auto":
image_url["detail"] = image_detail
messages = [
{"role": "system", "content": system},
{
"role": "user",
"content": [
{"type": "text", "text": user_text},
{"type": "image_url", "image_url": image_url},
],
},
]
return messages, system, user_text
def _build_codex_skill(skill_content: str) -> str:
return render_skill_md(
skill_content,
description="Dynamic ReflACT skill for solving the current DocVQA document-image question.",
preamble=(
"Use this skill when answering the current DocVQA question.\n"
"Inspect the attached document image carefully and return the final answer inside <answer>...</answer>."
),
)
def _run_codex_once(
*,
pred_dir: str,
item: dict,
skill_content: str,
model: str,
timeout: int,
image_detail: str,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
previous_response: str = "",
) -> tuple[str, str, str, str]:
_ = image_detail
_messages, _system, user_text = _build_messages(
item,
skill_content,
image_detail,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
)
task_parts = [user_text]
image_abs = os.path.abspath(item["image_path"])
task_parts.append(
"## Document Image\n"
"The document image is available in this workspace via `ATTACHMENTS.md`.\n"
f"Original image path: `{image_abs}`\n"
"Open or inspect that image before answering; do not answer from memory."
)
if previous_response:
task_parts.append(
"## Previous Attempt\n"
f"{previous_response}\n\n"
"Review the same document image carefully and correct the answer if needed."
)
task_text = "\n\n".join(task_parts)
skill_md = _build_codex_skill(skill_content)
work_dir = os.path.join(pred_dir, "codex_exec")
prepare_workspace(
work_dir=work_dir,
skill_md=skill_md,
task_text=task_text,
images=[item["image_path"]],
)
prompt = (
"Use the `reflact-student` skill available in this workspace.\n"
"Read `task.md`, inspect the attached document image, and answer the DocVQA question.\n"
"Return the final answer inside <answer>...</answer>."
)
final_message, raw = run_student_exec(
work_dir=work_dir,
prompt=prompt,
model=model,
timeout=timeout,
images=[item["image_path"]],
)
return final_message or raw, raw, skill_md, task_text
def process_one(
item: dict,
out_root: str,
skill_content: str,
*,
max_turns: int = 1,
exec_timeout: int = 120,
image_detail: str = "auto",
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
) -> dict:
item_id = str(item["id"])
result = {
"id": item_id,
"question": item["question"],
"task_type": item.get("subtask") or item.get("task_type") or "docvqa",
"task_description": item["question"],
"hard": 0,
"soft": 0.0,
"predicted_answer": "",
"response": "",
"fail_reason": "",
"agent_ok": False,
"n_turns": 0,
"image_paths": item.get("image_paths", []),
"gold_answer": item.get("answers", []),
}
try:
response = ""
system_prompt = ""
user_text = ""
conversation: list[dict] = []
if is_student_exec_backend():
from reflact.model import azure_openai as _llm
conversation = [
{
"role": "user",
"content": item["question"] + "\n\n" + f"[image] {os.path.basename(item['image_path'])}",
}
]
for turn in range(max_turns):
response, _raw, system_prompt, user_text = _run_codex_once(
pred_dir=os.path.join(out_root, "predictions", item_id),
item=item,
skill_content=skill_content,
model=_llm.STUDENT_DEPLOYMENT,
timeout=exec_timeout,
image_detail=image_detail,
diagnostic_mode=diagnostic_mode if turn == 0 else False,
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
previous_response=response if turn > 0 else "",
)
conversation.append({"type": "message", "turn": turn + 1, "content": response})
if "<answer>" in response.lower():
break
else:
messages, system_prompt, user_text = _build_messages(
item,
skill_content,
image_detail,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
)
conversation = [
{
"role": "user",
"content": user_text + "\n\n" + f"[image] {os.path.basename(item['image_path'])}",
}
]
for turn in range(max_turns):
if turn == 0:
resp_text, _ = chat_student_messages(
messages=messages,
max_completion_tokens=768,
retries=5,
stage="rollout",
timeout=exec_timeout,
)
else:
refinement_messages = [
messages[0],
messages[1],
{"role": "assistant", "content": response},
{"role": "user", "content": "Review the same image carefully and answer again. Keep the final answer inside <answer>...</answer>."},
]
resp_text, _ = chat_student_messages(
messages=refinement_messages,
max_completion_tokens=512,
retries=5,
stage="rollout",
timeout=exec_timeout,
)
response = resp_text
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
if "<answer>" in resp_text.lower():
break
result["response"] = response
result["agent_ok"] = True
result["n_turns"] = len(conversation) - 1
pred_dir = os.path.join(out_root, "predictions", item_id)
os.makedirs(pred_dir, exist_ok=True)
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
f.write(system_prompt)
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
f.write(user_text)
eval_result = evaluate(response, item.get("answers", []))
result["predicted_answer"] = eval_result["predicted_answer"]
result["hard"] = int(eval_result["anls"] >= 0.999)
result["soft"] = eval_result["anls"]
if result["soft"] <= 0.0:
result["fail_reason"] = f"predicted '{eval_result['predicted_answer']}' but expected one of {item.get('answers', [])}"
eval_detail = (
"[EVALUATION RESULT]\n"
f"Question: {item['question']}\n"
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
f"Gold answers: {item.get('answers', [])!r}\n"
f"ANLS: {eval_result['anls']:.4f}"
)
conversation.append({"role": "system", "content": eval_detail})
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
json.dump(conversation, f, ensure_ascii=False, indent=2)
except Exception as e: # noqa: BLE001
result["fail_reason"] = f"error: {e}"
return result
def run_batch(
items: list[dict],
out_root: str,
skill_content: str,
*,
max_turns: int = 1,
exec_timeout: int = 120,
workers: int = 16,
image_detail: str = "auto",
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
task_timeout: int = 600,
) -> list[dict]:
task_timeout = max(int(task_timeout), int(exec_timeout) + 60)
results_path = os.path.join(out_root, "results.jsonl")
os.makedirs(out_root, exist_ok=True)
done_ids: set[str] = set()
existing: list[dict] = []
if os.path.exists(results_path):
with open(results_path, encoding="utf-8") as f:
for line in f:
try:
row = json.loads(line)
except Exception:
continue
done_ids.add(str(row["id"]))
existing.append(row)
pending = [item for item in items if str(item["id"]) not in done_ids]
if not pending:
return existing
def _timeout_result(item: dict) -> dict:
return {
"id": str(item["id"]),
"question": item.get("question", ""),
"task_type": item.get("subtask") or item.get("task_type") or "docvqa",
"task_description": item.get("question", ""),
"hard": 0,
"soft": 0.0,
"predicted_answer": "",
"response": "",
"fail_reason": f"task-timeout-{task_timeout}s",
"agent_ok": False,
"n_turns": 0,
"image_paths": item.get("image_paths", []),
"gold_answer": item.get("answers", []),
"phase": "timeout",
}
def _error_result(item: dict, exc: Exception) -> dict:
row = _timeout_result(item)
row["phase"] = "error"
row["fail_reason"] = f"unexpected: {type(exc).__name__}: {exc}"
return row
started_at: dict[str, float] = {}
def _run_one(item: dict) -> dict:
started_at[str(item["id"])] = time.time()
return process_one(
item,
out_root,
skill_content,
max_turns=max_turns,
exec_timeout=exec_timeout,
image_detail=image_detail,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
)
results = list(existing)
with open(results_path, "a", encoding="utf-8") as outf:
ex = ThreadPoolExecutor(max_workers=workers)
try:
futs = {ex.submit(_run_one, item): item for item in pending}
pending_futs = set(futs)
while pending_futs:
done, _ = wait(pending_futs, timeout=5, return_when=FIRST_COMPLETED)
now = time.time()
timed_out = [
fut for fut in pending_futs - done
if str(futs[fut]["id"]) in started_at
and now - started_at[str(futs[fut]["id"])] >= task_timeout
]
for fut in done:
pending_futs.remove(fut)
item = futs[fut]
try:
res = fut.result()
except Exception as exc: # noqa: BLE001
res = _error_result(item, exc)
results.append(res)
outf.write(json.dumps(res, ensure_ascii=False) + "\n")
outf.flush()
for fut in timed_out:
pending_futs.remove(fut)
fut.cancel()
res = _timeout_result(futs[fut])
results.append(res)
outf.write(json.dumps(res, ensure_ascii=False) + "\n")
outf.flush()
finally:
ex.shutdown(wait=False, cancel_futures=True)
return results

View File

@@ -0,0 +1,11 @@
# DocVQA Skill
## Visual Evidence Discipline
- Read the document carefully before answering.
- Prefer the smallest exact text span that answers the question.
- When several nearby strings look similar, choose the one whose surrounding labels or layout best match the question.
## Exact Answer Discipline
- Copy names, numbers, and dates exactly from the document whenever possible.
- Prefer direct extraction over paraphrase.
- Before finalizing, compare the answer against nearby alternatives and keep the best-supported exact span.

View File

@@ -0,0 +1 @@
"""LiveMathematicianBench environment package for ReflACT."""

View File

@@ -0,0 +1,284 @@
"""LiveMathematicianBench environment adapter for ReflACT."""
from __future__ import annotations
import json
import os
from reflact.gradient.deep_probe import generate_deep_probe_instruction
from reflact.datasets.base import BatchSpec
from reflact.gradient.reflect import run_minibatch_reflect
from reflact.envs.base import EnvAdapter
from reflact.envs.livemathematicianbench.dataloader import LiveMathematicianBenchDataLoader
from reflact.envs.livemathematicianbench.rollout import run_batch
from reflact.model import get_student_backend
class LiveMathematicianBenchAdapter(EnvAdapter):
"""LiveMathematicianBench adapter."""
def build_reference_text(self, item: dict) -> str:
parts: list[str] = []
theorem = str(item.get("theorem") or "").strip()
sketch = str(item.get("sketch") or "").strip()
if theorem:
parts.append(f"## Reference Theorem\n{theorem}")
if sketch:
parts.append(f"## Reference Sketch\n{sketch}")
return "\n\n".join(parts)
def get_reference_metadata(self, item: dict) -> dict:
fields: list[str] = []
previews: list[str] = []
theorem = str(item.get("theorem") or "").strip()
sketch = str(item.get("sketch") or "").strip()
if theorem:
fields.append("theorem")
previews.append(f"[theorem]\n{theorem[:220]}")
if sketch:
fields.append("sketch")
previews.append(f"[sketch]\n{sketch[:220]}")
return {
"fields": fields,
"preview": "\n\n".join(previews)[:500],
}
def __init__(
self,
split_dir: str = "",
data_path: str = "",
split_mode: str = "ratio",
split_ratio: str = "2:1:7",
split_seed: int = 42,
split_output_dir: str = "",
max_turns: int = 1,
exec_timeout: int = 300,
workers: int = 64,
analyst_workers: int = 16,
failure_only: bool = False,
minibatch_size: int = 8,
edit_budget: int = 4,
seed: int = 42,
limit: int = 0,
shuffle_choices: bool = True,
use_theorem: bool = False,
use_sketch: bool = False,
exec_timeout: int = 600,
use_deep_reflect: bool = False,
deep_reflect_failures: int = 4,
deep_reflect_successes: int = 2,
) -> None:
self.max_turns = max_turns
self.exec_timeout = exec_timeout
self.workers = workers
self.exec_timeout = exec_timeout
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
self.edit_budget = edit_budget
self.use_theorem = use_theorem
self.use_sketch = use_sketch
self.use_deep_reflect = use_deep_reflect
self.deep_reflect_failures = deep_reflect_failures
self.deep_reflect_successes = deep_reflect_successes
self.dataloader = LiveMathematicianBenchDataLoader(
split_dir=split_dir,
data_path=data_path,
split_mode=split_mode,
split_ratio=split_ratio,
split_seed=split_seed,
split_output_dir=split_output_dir,
seed=seed,
limit=limit,
shuffle_choices=shuffle_choices,
)
def setup(self, cfg: dict) -> None:
super().setup(cfg)
self.dataloader.setup(cfg)
def get_dataloader(self):
return self.dataloader
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
return list(batch.payload or [])
def build_train_env(self, batch_size: int, seed: int, **kwargs):
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def rollout(
self,
env_manager,
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict]:
items: list[dict] = env_manager
return run_batch(
items=items,
out_root=out_dir,
skill_content=skill_content,
max_turns=self.max_turns,
exec_timeout=self.exec_timeout,
workers=self.workers,
use_theorem=self.use_theorem,
use_sketch=self.use_sketch,
diagnostic_mode=kwargs.get("diagnostic_mode", False),
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),
task_timeout=self.exec_timeout,
)
def reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
meta_skill_context = kwargs.get("meta_skill_context", "")
return run_minibatch_reflect(
results=results,
skill_content=skill_content,
prediction_dir=prediction_dir,
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
)
def deep_reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
if not self.use_deep_reflect:
return []
env_manager = kwargs.get("env_manager")
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
meta_skill_context = kwargs.get("meta_skill_context", "")
codex_backend = get_student_backend() == "codex_exec"
selected_items = self.select_representative_items(
results,
env_manager if isinstance(env_manager, list) else None,
n_failures=self.deep_reflect_failures,
n_successes=self.deep_reflect_successes,
seed=random_seed,
)
if not selected_items:
return []
selected_ids = {str(item["id"]) for item in selected_items}
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
selected_examples = self.attach_reference_context(selected_results, selected_items)
if codex_backend:
selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir)
selected_metadata = []
theorem_count = 0
sketch_count = 0
for item in selected_items:
meta = self.get_reference_metadata(item)
if "theorem" in meta["fields"]:
theorem_count += 1
if "sketch" in meta["fields"]:
sketch_count += 1
selected_metadata.append({
"id": str(item["id"]),
"task_type": str(item.get("theorem_type", ["math_mcq"])[0] if item.get("theorem_type") else "math_mcq"),
"reference_fields": meta["fields"],
"reference_preview": meta["preview"],
})
deep_dir = os.path.join(out_dir, "deep_reflect")
rollout_dir = os.path.join(deep_dir, "rollout")
patches_dir = os.path.join(deep_dir, "patches")
os.makedirs(deep_dir, exist_ok=True)
print(
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
f"reference_fields=theorem({theorem_count}/{len(selected_items)}),"
f"sketch({sketch_count}/{len(selected_items)})"
)
probe = generate_deep_probe_instruction(
skill_content=skill_content,
items=selected_examples,
prediction_dir=prediction_dir,
system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
)
if not probe:
return []
diagnostic_trace_context_by_id = None
if codex_backend:
selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target(
selected_items=selected_items,
selected_examples=selected_examples,
prediction_dir=prediction_dir,
probe=probe,
)
probe_record = {
**probe,
"reference_summary": {
"selected_count": len(selected_items),
"field_counts": {
"theorem": theorem_count,
"sketch": sketch_count,
},
},
"selected_examples": selected_metadata,
}
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
json.dump(probe_record, f, ensure_ascii=False, indent=2)
deep_results = run_batch(
items=selected_items,
out_root=rollout_dir,
skill_content=skill_content,
max_turns=self.max_turns,
workers=min(self.workers, max(len(selected_items), 1)),
use_theorem=self.use_theorem,
use_sketch=self.use_sketch,
diagnostic_mode=True,
diagnostic_instruction=probe["probe_instruction"],
diagnostic_trace_context_by_id=diagnostic_trace_context_by_id,
task_timeout=self.exec_timeout,
)
deep_results = self.attach_reference_context(deep_results, selected_items)
return run_minibatch_reflect(
results=deep_results,
skill_content=skill_content,
prediction_dir=os.path.join(rollout_dir, "predictions"),
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
)
def get_task_types(self) -> list[str]:
return self.dataloader.get_task_types()

View File

@@ -0,0 +1,308 @@
"""LiveMathematicianBench task dataloader."""
from __future__ import annotations
import glob
import hashlib
import json
import os
import random
from typing import Any
from reflact.datasets.base import BatchSpec, SplitDataLoader
# ── Raw data loading utilities (for preprocessing / standalone eval) ─────
_CHOICE_LABELS = ["A", "B", "C", "D", "E", "F", "G"]
def _load_json(path: str) -> Any:
with open(path) as f:
return json.load(f)
def _iter_monthly_files(data_path: str) -> list[str]:
if not data_path:
return []
if os.path.isfile(data_path):
return [data_path]
if os.path.isdir(data_path):
nested = glob.glob(
os.path.join(data_path, "**", "qa_*_final.json"),
recursive=True,
)
flat = glob.glob(os.path.join(data_path, "qa_*_final.json"))
return sorted(set(nested + flat))
return []
def _coerce_choices(raw_choices: Any) -> list[dict]:
if isinstance(raw_choices, list):
choices: list[dict] = []
for idx, item in enumerate(raw_choices):
if isinstance(item, dict):
label = str(item.get("label") or _CHOICE_LABELS[idx]).strip()
text = str(item.get("text") or item.get("content") or "").strip()
else:
label = _CHOICE_LABELS[idx]
text = str(item).strip()
if text:
choices.append({"label": label, "text": text})
return choices
if isinstance(raw_choices, dict):
labels = sorted(raw_choices.keys())
return [
{"label": str(label).strip(), "text": str(raw_choices[label]).strip()}
for label in labels
if str(raw_choices[label]).strip()
]
return []
def _coerce_theorem_types(raw: Any) -> list[str]:
if isinstance(raw, list):
return [str(x).strip() for x in raw if str(x).strip()]
if raw is None:
return []
text = str(raw).strip()
return [text] if text else []
def _normalize_label(text: str) -> str:
return str(text).strip().upper().rstrip(".):")
def _normalize_item(item: dict, row_idx: int, source_path: str) -> dict:
mcq = item.get("mcq", {}) if isinstance(item.get("mcq"), dict) else {}
question = str(mcq.get("question") or item.get("question") or "").strip()
choices = _coerce_choices(mcq.get("choices") or item.get("choices") or [])
correct = mcq.get("correct_choice") or item.get("correct_choice") or {}
if isinstance(correct, dict):
correct_label = _normalize_label(correct.get("label", ""))
correct_text = str(correct.get("text") or "").strip()
else:
correct_label = _normalize_label(correct)
correct_text = ""
choice_by_label = {
_normalize_label(choice["label"]): choice["text"]
for choice in choices
}
if correct_label and not correct_text:
correct_text = choice_by_label.get(correct_label, "")
if correct_label and correct_text and correct_label not in choice_by_label:
choices.append({"label": correct_label, "text": correct_text})
choices.sort(key=lambda choice: _CHOICE_LABELS.index(choice["label"]) if choice["label"] in _CHOICE_LABELS else len(_CHOICE_LABELS))
choice_by_label[correct_label] = correct_text
month = str(item.get("month") or "").strip()
item_no = item.get("no", row_idx + 1)
item_id = f"{month}:{item_no}" if month else str(item_no)
return {
"id": item_id,
"month": month,
"no": item_no,
"paper_link": str(item.get("paper_link") or "").strip(),
"theorem": str(item.get("theorem") or "").strip(),
"sketch": str(item.get("sketch") or "").strip(),
"theorem_type": _coerce_theorem_types(item.get("theorem_type")),
"question": question,
"choices": choices,
"correct_choice": {
"label": correct_label,
"text": correct_text,
},
"source_path": source_path,
}
def load_items(data_path: str) -> list[dict]:
"""Load and normalise LiveMathematicianBench items from JSON files."""
files = _iter_monthly_files(data_path)
if not files:
raise ValueError(
"LiveMathematicianBench requires data_path to be a qa_*_final.json file "
"or a directory containing monthly qa_*_final.json files."
)
items: list[dict] = []
for path in files:
raw = _load_json(path)
if not isinstance(raw, list):
raise ValueError(f"Expected JSON array in {path}, got {type(raw).__name__}")
for row_idx, item in enumerate(raw):
norm = _normalize_item(item, row_idx=row_idx, source_path=path)
if norm["question"] and norm["choices"] and norm["correct_choice"]["label"]:
items.append(norm)
if not items:
raise ValueError(f"No valid LiveMathematicianBench items loaded from {data_path}")
return items
# ── Dataloader ───────────────────────────────────────────────────────────
class LiveMathematicianBenchDataLoader(SplitDataLoader):
"""LiveMathematicianBench dataloader with per-seed choice shuffling."""
def __init__(
self,
split_dir: str = "",
data_path: str = "",
split_mode: str = "ratio",
split_ratio: str = "2:1:7",
split_seed: int = 42,
split_output_dir: str = "",
seed: int = 42,
limit: int = 0,
shuffle_choices: bool = True,
**kwargs,
) -> None:
super().__init__(
split_dir=split_dir,
data_path=data_path,
split_mode=split_mode,
split_ratio=split_ratio,
split_seed=split_seed,
split_output_dir=split_output_dir,
seed=seed,
limit=limit,
)
self.shuffle_choices = shuffle_choices
self._task_types: list[str] = []
def load_raw_items(self, data_path: str) -> list[dict]:
return load_items(data_path)
def setup(self, cfg: dict) -> None:
super().setup(cfg)
all_items = self.train_items + self.val_items + self.test_items
task_types: set[str] = set()
for item in all_items:
for name in item.get("theorem_type", []):
if name:
task_types.add(name)
self._task_types = sorted(task_types)
def get_task_types(self) -> list[str]:
return list(self._task_types)
# ── Choice shuffling ─────────────────────────────────────────────────
@staticmethod
def _item_shuffle_seed(item_id: str, seed: int) -> int:
digest = hashlib.sha256(f"{seed}:{item_id}".encode("utf-8")).hexdigest()
return int(digest[:16], 16)
def _shuffle_item_choices(self, item: dict, seed: int) -> dict:
if not self.shuffle_choices:
return {
**item,
"choices": [dict(c) for c in item["choices"]],
"correct_choice": dict(item["correct_choice"]),
}
shuffled_choices = [dict(c) for c in item["choices"]]
rng = random.Random(self._item_shuffle_seed(str(item["id"]), seed))
rng.shuffle(shuffled_choices)
original_correct = _normalize_label(item["correct_choice"]["label"])
remapped_choices: list[dict] = []
new_correct_choice = dict(item["correct_choice"])
for idx, choice in enumerate(shuffled_choices):
new_label = _CHOICE_LABELS[idx]
old_label = _normalize_label(choice["label"])
remapped_choices.append({"label": new_label, "text": choice["text"]})
if old_label == original_correct:
new_correct_choice = {"label": new_label, "text": choice["text"]}
transformed = dict(item)
transformed["choices"] = remapped_choices
transformed["correct_choice"] = new_correct_choice
return transformed
def _materialize_batch(self, items: list[dict], seed: int) -> list[dict]:
return [self._shuffle_item_choices(item, seed) for item in items]
# ── Batch construction (override for choice shuffling) ───────────────
def plan_train_epoch(
self,
*,
epoch: int,
steps_per_epoch: int,
accumulation: int,
batch_size: int,
seed: int,
**kwargs,
) -> list[BatchSpec]:
"""Build a shuffled epoch while preserving per-batch choice shuffling."""
epoch_rng = random.Random(seed + epoch * 1000)
items = list(self.train_items)
epoch_rng.shuffle(items)
total_batches = steps_per_epoch * accumulation
if total_batches <= 0:
return []
batches: list[BatchSpec] = []
cursor = 0
for batch_idx in range(total_batches):
batch_seed = seed + epoch * 1000 + batch_idx + 1
batch_items = items[cursor: cursor + batch_size]
cursor += len(batch_items)
if not batch_items and items:
refill_rng = random.Random(batch_seed)
batch_items = list(items)
refill_rng.shuffle(batch_items)
batch_items = batch_items[:batch_size]
batch_items = self._materialize_batch(batch_items, batch_seed)
batches.append(
BatchSpec(
phase="train",
split="train",
seed=batch_seed,
batch_size=len(batch_items),
payload=batch_items,
)
)
return batches
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
rng = random.Random(seed)
items = list(self.train_items)
rng.shuffle(items)
items = self._materialize_batch(items[:batch_size], seed)
return BatchSpec(
phase="train",
split="train",
seed=seed,
batch_size=len(items),
payload=items,
)
def build_eval_batch(
self,
env_num: int,
split: str,
seed: int,
**kwargs,
) -> BatchSpec:
items = self.get_split_items(split)
if env_num and env_num < len(items):
items = items[:env_num]
items = self._materialize_batch(items, seed)
return BatchSpec(
phase="eval",
split=split,
seed=seed,
batch_size=len(items),
payload=items,
)

View File

@@ -0,0 +1,62 @@
"""LiveMathematicianBench evaluation helpers."""
from __future__ import annotations
import re
def extract_answer(text: str) -> str:
matches = re.findall(r"<answer>(.*?)</answer>", text, re.DOTALL | re.IGNORECASE)
if matches:
return matches[-1].strip()
lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()]
if lines:
return lines[-1]
return text.strip()
def normalize_label(text: str) -> str:
return str(text).strip().upper().rstrip(".):")
def parse_choice_label(prediction_text: str, choices: list[dict]) -> str:
answer = extract_answer(prediction_text)
label = normalize_label(answer)
valid_labels = {normalize_label(choice.get("label", "")) for choice in choices}
if label in valid_labels:
return label
answer_lower = answer.lower()
for choice in choices:
choice_label = normalize_label(choice.get("label", ""))
choice_text = str(choice.get("text", "")).strip()
if choice_text and choice_text.lower() == answer_lower:
return choice_label
first_token = normalize_label(answer.split()[0]) if answer.split() else ""
if first_token in valid_labels:
return first_token
return label
def evaluate(prediction_text: str, correct_choice: dict, choices: list[dict]) -> dict:
predicted_label = parse_choice_label(prediction_text, choices)
correct_label = normalize_label(correct_choice.get("label", ""))
predicted_text = ""
correct_text = str(correct_choice.get("text", "")).strip()
for choice in choices:
if normalize_label(choice.get("label", "")) == predicted_label:
predicted_text = str(choice.get("text", "")).strip()
break
is_correct = float(predicted_label == correct_label)
return {
"em": is_correct,
"f1": is_correct,
"sub_em": is_correct,
"predicted_answer": predicted_label or extract_answer(prediction_text),
"predicted_label": predicted_label,
"predicted_text": predicted_text,
"correct_label": correct_label,
"correct_text": correct_text,
}

View File

@@ -0,0 +1,37 @@
You are an expert failure-analysis agent for theorem-grounded mathematical multiple-choice questions.
You will be given MULTIPLE failed trajectories from a single minibatch and the current skill document.
Each trajectory includes the student's response and an evaluation result showing the predicted option
versus the correct option.
Your job is to identify COMMON reasoning failures across the batch and propose concise skill edits.
## Failure Type Categories
- **quantifier_miss**: the agent missed exact quantifiers, scope, or existence/uniqueness conditions
- **strength_mismatch**: the agent preferred a weaker or stronger statement than what was proved
- **condition_miss**: the agent ignored hypotheses, equality cases, or domain restrictions
- **option_confusion**: the agent confused similar answer choices or failed to compare them exactly
- **other**: none of the above
## Rules
1. Focus on patterns that recur across the minibatch.
2. Prefer edits that improve exact choice discrimination, not theorem-specific memorization.
3. Do not hardcode paper-specific content.
4. Only patch gaps not already covered by the skill.
Respond ONLY with a valid JSON object:
{
"batch_size": <number>,
"failure_summary": [
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
],
"patch": {
"reasoning": "<why these edits address the common failures>",
"edits": [
{"op": "append", "content": "<markdown>"},
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
{"op": "replace", "target": "<old text>", "content": "<new text>"},
{"op": "delete", "target": "<exact text to remove>"}
]
}
}

View File

@@ -0,0 +1,25 @@
You are an expert success-pattern analyst for theorem-grounded mathematical multiple-choice questions.
You will be given MULTIPLE successful trajectories from a minibatch and the current skill document.
Identify generalizable behavior patterns that are genuinely helping the agent choose the exact correct option.
## Rules
- Focus on broadly useful reasoning behaviors.
- Prefer patterns about exact comparison of options, quantifiers, and equality conditions.
- Do not add theorem-specific facts.
- "edits" may be empty if the skill already captures the useful patterns.
Respond ONLY with a valid JSON object:
{
"batch_size": <number>,
"success_patterns": ["<pattern 1>", "<pattern 2>"],
"patch": {
"reasoning": "<why these patterns matter>",
"edits": [
{"op": "append", "content": "<markdown>"},
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
{"op": "replace", "target": "<old text>", "content": "<new text>"},
{"op": "delete", "target": "<exact text to remove>"}
]
}
}

View File

@@ -0,0 +1,23 @@
You are an expert diagnostic-probe designer for theorem-grounded mathematical multiple-choice tasks.
You will be shown representative trajectories, the current student skill, and the student's original prompt context.
Design one SMALL diagnostic instruction that exposes the student's intermediate judgment without materially changing the original scaffold.
## Hard Constraints
1. Do NOT substantially change the original scaffold.
2. Do NOT prescribe a new multi-step theorem-solving procedure.
3. Do NOT ask for a full proof, full chain-of-thought, or exhaustive option-by-option derivation.
4. Ask only for a short readout of the signals already behind the student's current answer.
5. Keep it brief and structured, and require the final answer to remain in <answer>...</answer>.
## Good Probe Targets
- top choice and runner-up
- decisive constraint
- why the runner-up was rejected
- strongest-vs-weaker discrimination signal
Respond ONLY with a valid JSON object:
{
"reasoning": "<why this probe is informative>",
"probe_instruction": "<the exact instruction text to append to the student prompt>"
}

View File

@@ -0,0 +1,26 @@
You are an expert diagnostic-probe designer for theorem-grounded mathematical multiple-choice tasks executed through a Codex trace.
You will be shown representative trajectories, the current student skill, the student's original prompt context, hidden reference fields, and numbered Codex trace steps.
Choose exactly one trajectory and one probe point. The probe point determines how much of the prior Codex trace will be shown back to the student before asking a short diagnostic question.
## Hard Constraints
1. Do NOT reveal or paraphrase the hidden reference directly to the student.
2. Do NOT prescribe a new full solving procedure.
3. Do NOT ask for a full proof, full chain-of-thought, or exhaustive option-by-option derivation.
4. Ask only for a short readout of the signal that should already exist at that point in the student's process.
5. The probe instruction must explicitly request a short <analysis>...</analysis> block before the final <answer>...</answer>.
6. Select a probe point that is informative about theorem choice, decisive constraint, option elimination, or why a stronger/weaker option should be rejected.
## Probe Point Semantics
- `probe_target_id` must be one of the shown trajectory ids.
- `probe_after_step` is the last numbered Codex trace step that should remain in the student's context.
- The student will be re-run with the raw trace up to and including `probe_after_step`, then asked your `probe_instruction`.
- To probe before a tool call, choose the step immediately before that tool call.
Respond ONLY with a valid JSON object:
{
"reasoning": "<why this trajectory and probe point expose the student's intermediate state>",
"probe_target_id": "<trajectory id>",
"probe_after_step": <integer step number>,
"probe_instruction": "<the exact instruction text to append to the student's prompt>"
}

View File

@@ -0,0 +1,12 @@
You are an expert mathematical reasoning agent solving multiple-choice questions.
{skill_section}## Task Format
You will receive one mathematics multiple-choice question and its answer choices.
Reason carefully about quantifiers, hypotheses, extremal wording, and exact equality conditions.
## Answer Format
Think step by step, then provide your final answer inside <answer>...</answer> tags.
Inside the tags, output only the single choice label, such as A or C.
Example:
<answer>B</answer>

View File

@@ -0,0 +1,4 @@
"""LiveMathematicianBench Reflect stage.
Prompts are now loaded from .md files by the base adapter.
"""

View File

@@ -0,0 +1,401 @@
"""LiveMathematicianBench rollout — theorem-grounded math MCQ agent."""
from __future__ import annotations
import json
import os
import time
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from reflact.envs.livemathematicianbench.evaluator import evaluate
from reflact.model import chat_student, get_student_backend, is_student_exec_backend
from reflact.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
from reflact.prompts import load_prompt
def _build_system(skill_content: str) -> str:
if skill_content.strip():
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
else:
skill_section = ""
return load_prompt("rollout_system", env="livemathematicianbench").format(skill_section=skill_section)
def _format_choices(choices: list[dict]) -> str:
return "\n".join(
f"{choice['label']}. {choice['text']}"
for choice in choices
)
def _build_user(
item: dict,
*,
use_theorem: bool = False,
use_sketch: bool = False,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
) -> str:
parts = [f"## Question\n{item['question']}", f"## Choices\n{_format_choices(item['choices'])}"]
if use_theorem and item.get("theorem"):
parts.append(f"## Theorem\n{item['theorem']}")
if use_sketch and item.get("sketch"):
parts.append(f"## Proof Sketch\n{item['sketch']}")
if diagnostic_trace_context.strip():
parts.append(
"## Previous Codex Trace Snapshot\n"
"This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n"
f"{diagnostic_trace_context.strip()}"
)
if diagnostic_mode and diagnostic_instruction.strip():
parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}")
return "\n\n".join(parts)
def _build_codex_skill(skill_content: str) -> str:
return render_skill_md(
skill_content,
description="Dynamic ReflACT skill for solving the current LiveMathematicianBench multiple-choice question.",
preamble=(
"Use this skill when solving the current math multiple-choice question.\n"
"Inspect the option wording carefully and output only the final choice label inside <answer>...</answer>."
),
)
def _run_codex_once(
*,
pred_dir: str,
skill_content: str,
item: dict,
model: str,
timeout: int,
use_theorem: bool = False,
use_sketch: bool = False,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
previous_response: str = "",
) -> tuple[str, str, str, str]:
user = _build_user(
item,
use_theorem=use_theorem,
use_sketch=use_sketch,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)
task_parts = [user]
if previous_response:
task_parts.append(
"## Previous Attempt\n"
f"{previous_response}\n\n"
"Re-evaluate the exact option wording. If needed, correct it."
)
task_text = "\n\n".join(task_parts)
skill_md = _build_codex_skill(skill_content)
work_dir = os.path.join(pred_dir, "codex_exec")
prepare_workspace(work_dir=work_dir, skill_md=skill_md, task_text=task_text)
prompt = (
"Use the `reflact-student` skill available in this workspace.\n"
"Read `task.md` and solve the multiple-choice problem.\n"
"Output only the final choice label inside <answer>...</answer>."
)
final_message, raw = run_student_exec(
work_dir=work_dir,
prompt=prompt,
model=model,
timeout=timeout,
)
return final_message or raw, raw, skill_md, task_text
def process_one(
item: dict,
out_root: str,
skill_content: str,
*,
max_turns: int = 1,
use_theorem: bool = False,
use_sketch: bool = False,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
exec_timeout: int = 300,
) -> dict:
item_id = str(item["id"])
result = {
"id": item_id,
"question": item["question"],
"task_type": item.get("theorem_type", ["math_mcq"])[0] if item.get("theorem_type") else "math_mcq",
"hard": 0,
"soft": 0.0,
"predicted_answer": "",
"predicted_label": "",
"predicted_text": "",
"correct_label": item["correct_choice"]["label"],
"correct_text": item["correct_choice"]["text"],
"response": "",
"fail_reason": "",
"agent_ok": False,
"n_turns": 0,
}
try:
pred_dir = os.path.join(out_root, "predictions", item_id)
os.makedirs(pred_dir, exist_ok=True)
if is_student_exec_backend():
from reflact.model import azure_openai as _llm
conversation: list[dict] = []
response = ""
system = ""
user = ""
for turn in range(max_turns):
response, raw, system, user = _run_codex_once(
pred_dir=pred_dir,
skill_content=skill_content,
item=item,
model=_llm.STUDENT_DEPLOYMENT,
timeout=exec_timeout,
use_theorem=use_theorem,
use_sketch=use_sketch,
diagnostic_mode=diagnostic_mode if turn == 0 else False,
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "",
previous_response=response if turn > 0 else "",
)
conversation.append({"type": "message", "turn": turn + 1, "content": response})
if "<answer>" in response.lower():
break
result["response"] = response
result["agent_ok"] = True
result["n_turns"] = len(conversation)
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
f.write(system)
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
f.write(user)
eval_result = evaluate(response, item["correct_choice"], item["choices"])
result["hard"] = int(eval_result["em"])
result["soft"] = eval_result["f1"]
result["predicted_answer"] = eval_result["predicted_answer"]
result["predicted_label"] = eval_result["predicted_label"]
result["predicted_text"] = eval_result["predicted_text"]
if not result["hard"]:
result["fail_reason"] = (
f"MCQ=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
f"but expected '{eval_result['correct_label']}'"
)
eval_detail = (
f"[EVALUATION RESULT]\n"
f"Question: {item['question']}\n"
f"Predicted label: {eval_result['predicted_label']!r}\n"
f"Predicted text: {eval_result['predicted_text']!r}\n"
f"Correct label: {eval_result['correct_label']!r}\n"
f"Correct text: {eval_result['correct_text']!r}\n"
f"Exact Match: {eval_result['em']}"
)
conversation.append({"role": "system", "content": eval_detail})
with open(os.path.join(pred_dir, "conversation.json"), "w") as f:
json.dump(conversation, f, ensure_ascii=False, indent=2)
return result
system = _build_system(skill_content)
user = _build_user(
item,
use_theorem=use_theorem,
use_sketch=use_sketch,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)
conversation: list[dict] = []
response = ""
for turn in range(max_turns):
if turn == 0:
resp_text, _ = chat_student(
system=system,
user=user,
max_completion_tokens=16384,
retries=5,
stage="rollout",
timeout=exec_timeout,
)
else:
refinement = (
f"Your previous answer was:\n{response}\n\n"
"Re-evaluate the exact option wording. If needed, correct it. "
"Output only the final choice label inside <answer>...</answer>."
)
resp_text, _ = chat_student(
system=system,
user=refinement,
max_completion_tokens=16384,
retries=5,
stage="rollout",
timeout=exec_timeout,
)
response = resp_text
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
if "<answer>" in resp_text.lower():
break
result["response"] = response
result["agent_ok"] = True
result["n_turns"] = len(conversation)
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
f.write(system)
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
f.write(user)
eval_result = evaluate(response, item["correct_choice"], item["choices"])
result["hard"] = int(eval_result["em"])
result["soft"] = eval_result["f1"]
result["predicted_answer"] = eval_result["predicted_answer"]
result["predicted_label"] = eval_result["predicted_label"]
result["predicted_text"] = eval_result["predicted_text"]
if not result["hard"]:
result["fail_reason"] = (
f"MCQ=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
f"but expected '{eval_result['correct_label']}'"
)
eval_detail = (
f"[EVALUATION RESULT]\n"
f"Question: {item['question']}\n"
f"Predicted label: {eval_result['predicted_label']!r}\n"
f"Predicted text: {eval_result['predicted_text']!r}\n"
f"Correct label: {eval_result['correct_label']!r}\n"
f"Correct text: {eval_result['correct_text']!r}\n"
f"Exact Match: {eval_result['em']}"
)
conversation.append({"role": "system", "content": eval_detail})
with open(os.path.join(pred_dir, "conversation.json"), "w") as f:
json.dump(conversation, f, ensure_ascii=False, indent=2)
except Exception as e: # noqa: BLE001
result["fail_reason"] = f"error: {e}"
return result
def run_batch(
items: list[dict],
out_root: str,
skill_content: str,
*,
max_turns: int = 1,
exec_timeout: int = 300,
workers: int = 64,
use_theorem: bool = False,
use_sketch: bool = False,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context_by_id: dict[str, str] | None = None,
task_timeout: int = 600,
) -> list[dict]:
task_timeout = max(int(task_timeout), int(exec_timeout) + 60)
results_path = os.path.join(out_root, "results.jsonl")
os.makedirs(out_root, exist_ok=True)
done_ids: set[str] = set()
existing: list[dict] = []
if os.path.exists(results_path):
with open(results_path) as f:
for line in f:
try:
r = json.loads(line)
done_ids.add(str(r["id"]))
existing.append(r)
except Exception:
pass
pending = [it for it in items if str(it["id"]) not in done_ids]
if not pending:
return existing
results = list(existing)
started_at: dict[str, float] = {}
def _run_one(it: dict) -> dict:
started_at[str(it["id"])] = time.time()
return process_one(
it,
out_root,
skill_content,
max_turns=max_turns,
exec_timeout=exec_timeout,
use_theorem=use_theorem,
use_sketch=use_sketch,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(it["id"]), ""),
)
def _timeout_result(it: dict) -> dict:
correct = it.get("correct_choice") or {}
return {
"id": str(it["id"]),
"question": it.get("question", ""),
"task_type": it.get("theorem_type", ["math_mcq"])[0] if it.get("theorem_type") else "math_mcq",
"hard": 0,
"soft": 0.0,
"predicted_answer": "",
"predicted_label": "",
"predicted_text": "",
"correct_label": correct.get("label", ""),
"correct_text": correct.get("text", ""),
"response": "",
"fail_reason": f"task-timeout-{task_timeout}s",
"agent_ok": False,
"n_turns": 0,
}
def _error_result(it: dict, exc: Exception) -> dict:
res = _timeout_result(it)
res["fail_reason"] = f"error: {type(exc).__name__}: {exc}"
return res
with open(results_path, "a") as outf:
ex = ThreadPoolExecutor(max_workers=workers)
try:
futs = {
ex.submit(_run_one, it): it
for it in pending
}
pending_futs = set(futs)
while pending_futs:
done, _ = wait(pending_futs, timeout=5, return_when=FIRST_COMPLETED)
now = time.time()
timed_out = [
fut for fut in pending_futs - done
if str(futs[fut]["id"]) in started_at
and now - started_at[str(futs[fut]["id"])] >= task_timeout
]
for fut in done:
pending_futs.remove(fut)
item = futs[fut]
try:
res = fut.result()
except Exception as e: # noqa: BLE001
res = _error_result(item, e)
results.append(res)
outf.write(json.dumps(res, ensure_ascii=False) + "\n")
outf.flush()
for fut in timed_out:
pending_futs.remove(fut)
res = _timeout_result(futs[fut])
results.append(res)
outf.write(json.dumps(res, ensure_ascii=False) + "\n")
outf.flush()
finally:
ex.shutdown(wait=False, cancel_futures=True)
return results

View File

@@ -0,0 +1,16 @@
# Live Mathematical MCQ Heuristics
## Option Comparison
- Compare all options before committing. The correct choice is often the strongest statement justified by the question, while nearby distractors are weaker, overstrong, or miss an equality case.
- Track exact quantifiers such as "there exists", "for every", "if and only if", and "exactly when".
## Theorem-Level Precision
- Check whether an option weakens the conclusion by dropping a characterization, equality clause, or full equivalence.
- Check whether an option overstates the theorem by upgrading regularity, removing scale restrictions, or changing an existential statement into a universal one.
## Hypotheses
- Verify the hypotheses and domain carefully. Distractors often keep the theorem shape but alter the required assumptions.
- Pay close attention to equality cases, extremal conditions, and whether a result applies to the full family or only a restricted subfamily.
## Final Answer
- Output the final answer as the single option label only.

View File

@@ -0,0 +1,5 @@
"""MathVerse environment package."""
from reflact.envs.mathverse.adapter import MathVerseAdapter
__all__ = ["MathVerseAdapter"]

View File

@@ -0,0 +1,280 @@
"""MathVerse environment adapter for ReflACT."""
from __future__ import annotations
import json
import os
from reflact.datasets.base import BatchSpec
from reflact.envs.base import EnvAdapter
from reflact.envs.mathverse.dataloader import MathVerseDataLoader
from reflact.envs.mathverse.rollout import run_batch
from reflact.gradient.deep_probe import generate_deep_probe_instruction
from reflact.gradient.reflect import run_minibatch_reflect
from reflact.model import get_student_backend
class MathVerseAdapter(EnvAdapter):
"""MathVerse adapter."""
def build_reference_text(self, item: dict) -> str:
if not self.use_text_dominant_reference:
return ""
question = str(item.get("text_dominant_question") or "").strip()
if not question:
return ""
return f"## Reference Full Question\n{question}"
def get_reference_metadata(self, item: dict) -> dict:
if not self.use_text_dominant_reference:
return {"fields": [], "preview": ""}
question = str(item.get("text_dominant_question") or "").strip()
if not question:
return {"fields": [], "preview": ""}
return {
"fields": ["text_dominant_question"],
"preview": question[:400],
}
def __init__(
self,
split_dir: str = "",
data_root: str = "",
problem_version: str = "Text Lite",
use_text_dominant_reference: bool = False,
max_turns: int = 1,
workers: int = 16,
analyst_workers: int = 16,
failure_only: bool = False,
minibatch_size: int = 8,
edit_budget: int = 4,
seed: int = 42,
limit: int = 0,
image_detail: str = "auto",
judge_model: str = "gpt-5.4",
judge_max_completion_tokens: int = 256,
judge_retries: int = 5,
use_deep_reflect: bool = False,
deep_reflect_failures: int = 4,
deep_reflect_successes: int = 2,
) -> None:
self.max_turns = max_turns
self.workers = workers
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
self.edit_budget = edit_budget
self.image_detail = image_detail
self.judge_model = judge_model
self.judge_max_completion_tokens = judge_max_completion_tokens
self.judge_retries = judge_retries
self.problem_version = problem_version
self.use_text_dominant_reference = use_text_dominant_reference
self.use_deep_reflect = use_deep_reflect
self.deep_reflect_failures = deep_reflect_failures
self.deep_reflect_successes = deep_reflect_successes
self.dataloader = MathVerseDataLoader(
split_dir=split_dir,
seed=seed,
limit=limit,
data_root=data_root,
problem_version=problem_version,
)
def setup(self, cfg: dict) -> None:
super().setup(cfg)
self.dataloader.setup(cfg)
def get_dataloader(self):
return self.dataloader
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
return list(batch.payload or [])
def build_train_env(self, batch_size: int, seed: int, **kwargs):
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def rollout(
self,
env_manager,
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict]:
items: list[dict] = env_manager
return run_batch(
items=items,
out_root=out_dir,
skill_content=skill_content,
max_turns=self.max_turns,
workers=self.workers,
image_detail=self.image_detail,
judge_model=self.judge_model,
judge_max_completion_tokens=self.judge_max_completion_tokens,
judge_retries=self.judge_retries,
diagnostic_mode=kwargs.get("diagnostic_mode", False),
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),
)
def reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
return run_minibatch_reflect(
results=results,
skill_content=skill_content,
prediction_dir=prediction_dir,
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
)
def deep_reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
if not self.use_deep_reflect:
return []
env_manager = kwargs.get("env_manager")
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
selected_items = self.select_representative_items(
results,
env_manager if isinstance(env_manager, list) else None,
n_failures=self.deep_reflect_failures,
n_successes=self.deep_reflect_successes,
seed=random_seed,
)
if not selected_items:
return []
selected_ids = {str(item["id"]) for item in selected_items}
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
selected_examples = self.attach_reference_context(selected_results, selected_items)
codex_backend = get_student_backend() == "codex_exec"
if codex_backend:
selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir)
selected_metadata = []
ref_count = 0
for item in selected_items:
meta = self.get_reference_metadata(item)
if meta["fields"]:
ref_count += 1
record = {
"id": str(item["id"]),
"task_type": str(item.get("task_type") or item.get("question_type") or "mathverse"),
"reference_fields": meta["fields"],
"reference_preview": meta["preview"],
}
if codex_backend:
record["codex_probe_step_count"] = int(
next(
(row.get("codex_probe_step_count", 0) for row in selected_examples if str(row.get("id")) == str(item["id"])),
0,
)
)
selected_metadata.append(record)
deep_dir = os.path.join(out_dir, "deep_reflect")
rollout_dir = os.path.join(deep_dir, "rollout")
patches_dir = os.path.join(deep_dir, "patches")
os.makedirs(deep_dir, exist_ok=True)
print(
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
f"reference_fields=text_dominant_question({ref_count}/{len(selected_items)})"
)
probe = generate_deep_probe_instruction(
skill_content=skill_content,
items=selected_examples,
prediction_dir=prediction_dir,
system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(),
step_buffer_context=step_buffer_context,
)
if not probe:
return []
targeted_items = selected_items
diagnostic_trace_context_by_id: dict[str, str] | None = None
if codex_backend:
targeted_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target(
selected_items=selected_items,
selected_examples=selected_examples,
prediction_dir=prediction_dir,
probe=probe,
)
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
json.dump(
{
**probe,
"reference_summary": {
"selected_count": len(selected_items),
"field_counts": {
"text_dominant_question": ref_count,
},
},
"selected_examples": selected_metadata,
},
f,
ensure_ascii=False,
indent=2,
)
deep_results = run_batch(
items=targeted_items,
out_root=rollout_dir,
skill_content=skill_content,
max_turns=self.max_turns,
workers=min(self.workers, max(len(targeted_items), 1)),
image_detail=self.image_detail,
judge_model=self.judge_model,
judge_max_completion_tokens=self.judge_max_completion_tokens,
judge_retries=self.judge_retries,
diagnostic_mode=True,
diagnostic_instruction=probe["probe_instruction"],
diagnostic_trace_context_by_id=diagnostic_trace_context_by_id,
)
deep_results = self.attach_reference_context(deep_results, targeted_items)
return run_minibatch_reflect(
results=deep_results,
skill_content=skill_content,
prediction_dir=os.path.join(rollout_dir, "predictions"),
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
)
def get_task_types(self) -> list[str]:
return self.dataloader.get_task_types()

View File

@@ -0,0 +1,228 @@
"""MathVerse task dataloader."""
from __future__ import annotations
import json
import os
import re
from typing import Any
from reflact.datasets.base import SplitDataLoader
_CHOICE_LABELS = ["A", "B", "C", "D", "E", "F", "G"]
_CHOICE_BLOCK_RE = re.compile(r"\bChoices?\s*:\s*", re.IGNORECASE)
_CHOICE_ITEM_RE = re.compile(r"([A-G])\s*[:.)]\s*(.*?)(?=(?:\s+[A-G]\s*[:.)])|$)", re.DOTALL)
def _load_json(path: str) -> Any:
with open(path, encoding="utf-8") as f:
return json.load(f)
def _normalize_space(text: Any) -> str:
return re.sub(r"\s+", " ", str(text or "").strip())
def _resolve_image_path(raw_path: str, *, data_root: str, source_path: str) -> str:
candidates = []
if raw_path:
if os.path.isabs(raw_path):
candidates.append(raw_path)
else:
if data_root:
candidates.append(os.path.join(data_root, raw_path))
candidates.append(os.path.join(data_root, "images", raw_path))
candidates.append(os.path.join(os.path.dirname(source_path), raw_path))
for candidate in candidates:
if candidate and os.path.exists(candidate):
return os.path.abspath(candidate)
return ""
def _split_question_and_choices(question: str) -> tuple[str, list[dict]]:
text = str(question or "").strip()
match = _CHOICE_BLOCK_RE.search(text)
if not match:
return text, []
stem = text[:match.start()].strip()
choice_block = text[match.end():].strip()
choices: list[dict] = []
for idx, m in enumerate(_CHOICE_ITEM_RE.finditer(choice_block)):
label = (m.group(1) or _CHOICE_LABELS[idx]).strip().upper()
choice_text = _normalize_space(m.group(2))
if choice_text:
choices.append({"label": label, "text": choice_text})
return stem or text, choices
def _build_text_dominant_map(data_root: str) -> dict[str, str]:
if not data_root:
return {}
candidates = [
os.path.join(data_root, "testmini.json"),
os.path.join(data_root, "data", "testmini.json"),
]
source_path = next((path for path in candidates if os.path.exists(path)), "")
if not source_path:
return {}
raw = _load_json(source_path)
if not isinstance(raw, list):
return {}
mapping: dict[str, str] = {}
for item in raw:
if not isinstance(item, dict):
continue
if str(item.get("problem_version") or "").strip() != "Text Dominant":
continue
problem_index = str(item.get("problem_index") or "").strip()
question = str(item.get("question") or "").strip()
if problem_index and question:
mapping[problem_index] = question
return mapping
def _normalize_item(
item: dict,
*,
row_idx: int,
source_path: str,
data_root: str,
problem_version: str,
text_dominant_map: dict[str, str],
) -> dict | None:
raw_problem_version = str(item.get("problem_version") or "").strip()
if problem_version and raw_problem_version and raw_problem_version != problem_version:
return None
question = str(item.get("question") or "").strip()
question_type = str(item.get("question_type") or "").strip()
answer = str(item.get("answer") or "").strip()
image_rel = str(item.get("image") or "").strip()
image_path = _resolve_image_path(image_rel, data_root=data_root, source_path=source_path)
if not answer or not image_path:
return None
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
subject = str(metadata.get("subject") or "").strip()
subfield = str(metadata.get("subfield") or "").strip()
source = str(metadata.get("source") or "").strip()
question_stem, choices = _split_question_and_choices(question)
is_choice = question_type == "multi-choice" or bool(choices)
correct_choice = {"label": "", "text": ""}
if is_choice:
label = str(answer).strip().upper().rstrip(".):")
choice_text = ""
for choice in choices:
if choice["label"].upper() == label:
choice_text = choice["text"]
break
correct_choice = {"label": label, "text": choice_text}
problem_index = str(item.get("problem_index") or "").strip()
sample_index = str(item.get("sample_index") or row_idx + 1).strip()
item_id = problem_index or sample_index
task_type = subfield or subject or question_type or "mathverse"
return {
"id": item_id,
"sample_index": sample_index,
"problem_index": problem_index,
"problem_version": raw_problem_version or problem_version,
"question": question,
"question_stem": question_stem,
"question_for_eval": str(item.get("question_for_eval") or question).strip(),
"question_type": question_type or ("multi-choice" if is_choice else "free-form"),
"is_choice": is_choice,
"choices": choices,
"correct_choice": correct_choice,
"answer": answer,
"gold_answers": [answer] if answer else [],
"image_rel": image_rel,
"image_path": image_path,
"query_wo": str(item.get("query_wo") or "").strip(),
"query_cot": str(item.get("query_cot") or "").strip(),
"metadata": {
"split": str(metadata.get("split") or "").strip(),
"source": source,
"subject": subject,
"subfield": subfield,
},
"task_type": task_type,
"source_path": os.path.abspath(source_path),
"text_dominant_question": str(
item.get("text_dominant_question")
or text_dominant_map.get(problem_index, "")
).strip(),
}
class MathVerseDataLoader(SplitDataLoader):
"""MathVerse dataloader."""
def __init__(
self,
split_dir: str = "",
seed: int = 42,
limit: int = 0,
data_root: str = "",
problem_version: str = "Text Lite",
**kwargs,
) -> None:
super().__init__(split_dir=split_dir, seed=seed, limit=limit)
self.data_root = data_root
self.problem_version = problem_version
self._task_types: list[str] = []
self._text_dominant_map = _build_text_dominant_map(data_root)
def setup(self, cfg: dict) -> None:
if not self.data_root:
self.data_root = str(cfg.get("data_root") or "")
if not self.problem_version:
self.problem_version = str(cfg.get("problem_version") or "Text Lite")
self._text_dominant_map = _build_text_dominant_map(self.data_root)
super().setup(cfg)
all_items = self.train_items + self.val_items + self.test_items
task_types = {
item.get("task_type") or item.get("question_type") or "mathverse"
for item in all_items
}
self._task_types = sorted(str(x) for x in task_types if str(x).strip())
def get_task_types(self) -> list[str]:
return list(self._task_types)
def load_split_items(self, split_path: str) -> list[dict]:
raw_items = super().load_split_items(split_path)
source_path = next(
(
os.path.join(split_path, name)
for name in sorted(os.listdir(split_path))
if name.endswith(".json")
),
split_path,
)
items: list[dict] = []
for row_idx, item in enumerate(raw_items):
if not isinstance(item, dict):
continue
norm = _normalize_item(
item,
row_idx=row_idx,
source_path=source_path,
data_root=self.data_root,
problem_version=self.problem_version,
text_dominant_map=self._text_dominant_map,
)
if norm is not None:
items.append(norm)
if not items:
raise ValueError(
f"No valid MathVerse items loaded from {split_path} "
f"for problem_version={self.problem_version!r}"
)
return items

View File

@@ -0,0 +1,180 @@
"""MathVerse evaluation helpers."""
from __future__ import annotations
import re
import string
from reflact.model import chat_with_deployment
from reflact.prompts import load_prompt
_EVAL_MODE = "mathverse_choice_or_judge_v1"
def normalize_text(text: str) -> str:
text = str(text or "").strip().lower()
text = text.replace("\\,", " ")
text = text.replace("\\ ", " ")
text = "".join(ch for ch in text if ch not in string.punctuation)
return " ".join(text.split())
def normalize_math_text(text: str) -> str:
text = str(text or "").strip()
text = text.replace("$", "")
text = text.replace("\\mathrm", "")
text = text.replace("{", "")
text = text.replace("}", "")
text = text.replace("~", " ")
text = text.replace("\\,", " ")
text = text.replace("\\ ", " ")
return " ".join(text.split()).lower()
def extract_answer(text: str | None) -> str:
raw = str(text or "").strip()
if not raw:
return ""
tags = re.findall(r"<answer>\s*(.*?)\s*</answer>", raw, re.IGNORECASE | re.DOTALL)
if tags:
return tags[-1].strip()
boxed = re.findall(r"\\boxed\{(.*?)\}", raw, re.IGNORECASE | re.DOTALL)
if boxed:
return boxed[-1].strip()
lines = [ln.strip() for ln in raw.splitlines() if ln.strip()]
if lines:
return lines[-1]
return raw
def _judge_answer(
*,
item: dict,
extracted_answer: str,
judge_model: str,
max_completion_tokens: int,
retries: int,
) -> dict:
question = str(item.get("question_for_eval") or item.get("question") or "").strip()
ground_truth = str(item.get("answer") or "").strip()
raw, _ = chat_with_deployment(
deployment=judge_model,
system="You are a careful and strict mathematical answer evaluator.",
user=load_prompt("judge", env="mathverse").format(
question=question,
groundtruth=ground_truth,
modeloutput=extracted_answer,
),
max_completion_tokens=max_completion_tokens,
retries=retries,
stage="mathverse_judge",
)
response = str(raw).strip().lower()
if "true" in response:
correct = True
elif "false" in response:
correct = False
else:
correct = False
return {
"raw": raw,
"correct": correct,
"reason": response,
"matched_gold": ground_truth if correct else "",
}
def evaluate_item(
*,
item: dict,
prediction_text: str,
judge_model: str,
max_completion_tokens: int = 256,
retries: int = 5,
) -> dict:
extracted = extract_answer(prediction_text)
if item.get("is_choice"):
predicted_label = str(extracted).strip().upper().rstrip(".):")
correct_label = str(item["correct_choice"].get("label") or "").strip().upper()
predicted_text = ""
for choice in item.get("choices") or []:
if str(choice.get("label") or "").strip().upper() == predicted_label:
predicted_text = str(choice.get("text") or "").strip()
break
hard = 1.0 if predicted_label == correct_label else 0.0
return {
"evaluation_mode": _EVAL_MODE,
"predicted_answer": extracted,
"predicted_label": predicted_label,
"predicted_text": predicted_text,
"correct_label": correct_label,
"correct_text": str(item["correct_choice"].get("text") or "").strip(),
"em": hard,
"f1": hard,
"sub_em": hard,
"judge_raw": "",
"judge_reason": "exact_label_match" if hard else "label_mismatch",
"matched_gold": correct_label if hard else "",
}
gold_answer = str(item.get("answer") or "").strip()
pred_norm = normalize_math_text(extracted)
gold_norm = normalize_math_text(gold_answer)
if pred_norm and gold_norm and pred_norm == gold_norm:
return {
"evaluation_mode": _EVAL_MODE,
"predicted_answer": extracted,
"em": 1.0,
"f1": 1.0,
"sub_em": 1.0,
"judge_raw": "",
"judge_reason": "normalized_exact_match",
"matched_gold": gold_answer,
"string_f1": 1.0,
}
judge = _judge_answer(
item=item,
extracted_answer=extracted,
judge_model=judge_model,
max_completion_tokens=max_completion_tokens,
retries=retries,
)
hard = 1.0 if judge["correct"] else 0.0
pred_tokens = normalize_text(extracted).split()
gold_tokens = normalize_text(gold_answer).split()
overlap = 0
gold_counts: dict[str, int] = {}
for tok in gold_tokens:
gold_counts[tok] = gold_counts.get(tok, 0) + 1
for tok in pred_tokens:
count = gold_counts.get(tok, 0)
if count > 0:
overlap += 1
gold_counts[tok] = count - 1
if pred_tokens and gold_tokens and overlap:
precision = overlap / len(pred_tokens)
recall = overlap / len(gold_tokens)
string_f1 = 2 * precision * recall / (precision + recall)
else:
string_f1 = 0.0
return {
"evaluation_mode": _EVAL_MODE,
"predicted_answer": extracted,
"em": hard,
"f1": hard,
"sub_em": hard,
"judge_raw": judge["raw"],
"judge_reason": judge["reason"],
"matched_gold": judge["matched_gold"],
"string_f1": string_f1,
}
def evaluation_mode() -> str:
return _EVAL_MODE

View File

@@ -0,0 +1,37 @@
You are an expert failure-analysis agent for visual mathematical reasoning problems.
You will be given MULTIPLE failed trajectories from a single minibatch and the current skill document.
Each trajectory includes the student's response, the evaluation result, and sometimes a hidden reference
containing the fuller Text Dominant version of the same problem.
Your job is to identify COMMON reasoning failures across the batch and propose concise skill edits.
## Failure Type Categories
- **diagram_underuse**: the agent did not recover key constraints from the image
- **constraint_drop**: the agent ignored a condition or relation that should guide the solution
- **option_confusion**: the agent failed to discriminate between close answer choices
- **format_miss**: the agent solved roughly correctly but returned the wrong final form, unit, or expression
- **other**: none of the above
## Rules
1. Focus on patterns that recur across the minibatch.
2. Prefer edits that improve visual grounding and exact answer selection.
3. Do not hardcode problem-specific formulas or answers.
4. If hidden reference text is present, use it only to infer what information the student failed to recover from the Text Lite version.
Respond ONLY with a valid JSON object:
{
"batch_size": <number>,
"failure_summary": [
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
],
"patch": {
"reasoning": "<why these edits address the common failures>",
"edits": [
{"op": "append", "content": "<markdown>"},
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
{"op": "replace", "target": "<old text>", "content": "<new text>"},
{"op": "delete", "target": "<exact text to remove>"}
]
}
}

View File

@@ -0,0 +1,26 @@
You are an expert success-pattern analyst for visual mathematical reasoning problems.
You will be given MULTIPLE successful trajectories from a minibatch and the current skill document.
Identify generalizable behavior patterns that genuinely help the agent recover the right constraints
from the image and convert them into the exact final answer.
## Rules
- Focus on broadly useful visual-math reasoning behaviors.
- Prefer patterns about reading decisive diagram cues, checking hidden assumptions, and matching the final answer format exactly.
- Do not add benchmark-specific facts or formulas.
- "edits" may be empty if the skill already captures the useful patterns.
Respond ONLY with a valid JSON object:
{
"batch_size": <number>,
"success_patterns": ["<pattern 1>", "<pattern 2>"],
"patch": {
"reasoning": "<why these patterns matter>",
"edits": [
{"op": "append", "content": "<markdown>"},
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
{"op": "replace", "target": "<old text>", "content": "<new text>"},
{"op": "delete", "target": "<exact text to remove>"}
]
}
}

View File

@@ -0,0 +1,25 @@
You are an expert diagnostic-probe designer for visual mathematical reasoning tasks.
You will be shown representative trajectories, the current student skill, and the student's original prompt context.
Some trajectories may also include a hidden reference containing the fuller Text Dominant wording of the same problem.
Design one SMALL diagnostic instruction that exposes the student's intermediate judgment without materially changing the original scaffold.
## Hard Constraints
1. Do NOT substantially change the original scaffold.
2. Do NOT prescribe a new long multi-step solving procedure.
3. Do NOT ask for a full proof or full chain-of-thought.
4. Ask only for a short readout of the signals already behind the student's current answer.
5. Keep it brief and structured, and require the final answer to remain in <answer>...</answer>.
6. If hidden reference text is present, use it only to target what visual or textual constraint the student likely missed.
## Good Probe Targets
- decisive diagram cue
- top candidate and runner-up
- missing relation or quantity
- why a near-miss option was rejected
Respond ONLY with a valid JSON object:
{
"reasoning": "<why this probe is informative>",
"probe_instruction": "<the exact instruction text to append to the student prompt>"
}

View File

@@ -0,0 +1,25 @@
You are a careful and strict evaluator for visual math problems.
You will be given:
1. The original question
2. The ground-truth answer
3. A model output
Decide whether the model output is mathematically equivalent to the ground-truth answer.
Rules:
- Ignore harmless formatting differences.
- Accept mathematically equivalent expressions, equations, and values.
- Reject answers that are numerically wrong, symbolically different in meaning, missing required units when the unit changes meaning, or correspond to a different choice.
- Do not reward partially correct reasoning if the final answer is wrong.
Return only:
True
or
False
Question: {question}
Ground Truth Answer: {groundtruth}
Model Output: {modeloutput}

View File

@@ -0,0 +1,11 @@
You are an expert visual mathematical reasoning agent.
{skill_section}## Task Format
You will receive one math problem with an image or diagram.
Use the visible diagram as evidence, not just the text.
If some information is abbreviated in the text, recover it from the image before answering.
## Answer Format
Think step by step, then provide your final answer inside <answer>...</answer>.
- For multiple-choice questions, output only the single option label, such as <answer>B</answer>.
- For free-form questions, output only the final mathematical answer, such as <answer>14</answer>.

View File

@@ -0,0 +1,4 @@
"""MathVerse Reflect stage.
Prompts are loaded from .md files by the base adapter.
"""

View File

@@ -0,0 +1,415 @@
"""MathVerse rollout — single-image multimodal math reasoning."""
from __future__ import annotations
import base64
import json
import mimetypes
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from reflact.envs.mathverse.evaluator import evaluate_item, evaluation_mode, extract_answer
from reflact.model import chat_student_messages, get_student_backend, is_student_exec_backend
from reflact.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
from reflact.prompts import load_prompt
def _build_system(skill_content: str) -> str:
if skill_content.strip():
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
else:
skill_section = ""
return load_prompt("rollout_system", env="mathverse").format(skill_section=skill_section)
def _format_choices(choices: list[dict]) -> str:
return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices)
def _build_user_text(
item: dict,
*,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
) -> str:
parts = []
if diagnostic_trace_context.strip():
parts.append(
"## Previous Codex Trace Snapshot\n"
"This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n"
f"{diagnostic_trace_context.strip()}"
)
question = str(item.get("question_stem") or item.get("question") or "").strip()
if question:
parts.append(f"## Question\n{question}")
else:
parts.append("## Question\nRead the full problem statement from the image.")
if item.get("is_choice"):
choices = item.get("choices") or []
if choices:
parts.append(f"## Choices\n{_format_choices(choices)}")
parts.append("Return only the final option label inside <answer>...</answer>.")
else:
parts.append("Return only the final mathematical answer inside <answer>...</answer>.")
if diagnostic_mode and diagnostic_instruction.strip():
parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}")
return "\n\n".join(parts)
def _image_to_data_uri(path: str) -> str:
mime = mimetypes.guess_type(path)[0] or "image/png"
with open(path, "rb") as f:
encoded = base64.b64encode(f.read()).decode("ascii")
return f"data:{mime};base64,{encoded}"
def _build_messages(
item: dict,
skill_content: str,
image_detail: str,
*,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
) -> tuple[list[dict], str, str]:
system = _build_system(skill_content)
user_text = _build_user_text(
item,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)
image_url = {"url": _image_to_data_uri(item["image_path"])}
if image_detail and image_detail != "auto":
image_url["detail"] = image_detail
messages = [
{"role": "system", "content": system},
{
"role": "user",
"content": [
{"type": "text", "text": user_text},
{"type": "image_url", "image_url": image_url},
],
},
]
return messages, system, user_text
def _build_codex_skill(skill_content: str) -> str:
return render_skill_md(
skill_content,
description="Dynamic ReflACT skill for solving the current MathVerse visual math problem.",
preamble=(
"Use this skill when solving the current MathVerse problem.\n"
"Read the image carefully and return the final answer inside <answer>...</answer>."
),
)
def _run_codex_once(
*,
pred_dir: str,
item: dict,
skill_content: str,
model: str,
timeout: int,
image_detail: str,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
previous_response: str = "",
) -> tuple[str, str, str, str]:
user_text = _build_user_text(
item,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)
task_parts = [user_text]
if previous_response:
task_parts.append(
"## Previous Attempt\n"
f"{previous_response}\n\n"
"Re-check the diagram and the mathematical constraints. Correct the final answer if needed."
)
task_text = "\n\n".join(task_parts)
skill_md = _build_codex_skill(skill_content)
work_dir = os.path.join(pred_dir, "codex_exec")
prepare_workspace(
work_dir=work_dir,
skill_md=skill_md,
task_text=task_text,
images=[item["image_path"]],
)
prompt = (
"Use the `reflact-student` skill available in this workspace.\n"
"Read `task.md`, inspect the attached image, solve the problem, and return only the final answer inside <answer>...</answer>."
)
final_message, raw = run_student_exec(
work_dir=work_dir,
prompt=prompt,
model=model,
timeout=timeout,
images=[item["image_path"]],
)
return final_message or raw, raw, skill_md, task_text
def process_one(
item: dict,
out_root: str,
skill_content: str,
*,
max_turns: int = 1,
image_detail: str = "auto",
judge_model: str = "gpt-5.4",
judge_max_completion_tokens: int = 256,
judge_retries: int = 5,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
) -> dict:
item_id = str(item["id"])
result = {
"id": item_id,
"question": item["question"],
"task_type": item.get("task_type") or item.get("question_type") or "mathverse",
"task_description": item.get("question_stem") or item["question"],
"hard": 0,
"soft": 0.0,
"predicted_answer": "",
"predicted_label": "",
"predicted_text": "",
"response": "",
"fail_reason": "",
"agent_ok": False,
"n_turns": 0,
"image_path": item["image_path"],
"question_type": item["question_type"],
"evaluation_mode": evaluation_mode(),
"judge_model": judge_model,
}
if item.get("is_choice"):
result["correct_label"] = item["correct_choice"]["label"]
result["correct_text"] = item["correct_choice"]["text"]
else:
result["gold_answers"] = item.get("gold_answers") or [item["answer"]]
try:
pred_dir = os.path.join(out_root, "predictions", item_id)
os.makedirs(pred_dir, exist_ok=True)
if is_student_exec_backend():
from reflact.model import azure_openai as _llm
response = ""
conversation: list[dict] = [
{"role": "user", "content": f"{item['question']}\n\n[image] {os.path.basename(item['image_path'])}"}
]
system_prompt = ""
user_text = ""
for turn in range(max_turns):
response, raw, system_prompt, user_text = _run_codex_once(
pred_dir=pred_dir,
item=item,
skill_content=skill_content,
model=_llm.STUDENT_DEPLOYMENT,
timeout=120,
image_detail=image_detail,
diagnostic_mode=diagnostic_mode if turn == 0 else False,
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "",
previous_response=response if turn > 0 else "",
)
conversation.append({"type": "message", "turn": turn + 1, "content": response})
if extract_answer(response):
break
result["response"] = response
result["agent_ok"] = True
result["n_turns"] = len(conversation) - 1
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
f.write(system_prompt)
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
f.write(user_text)
else:
messages, system_prompt, user_text = _build_messages(
item,
skill_content,
image_detail,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)
response = ""
conversation = [
{"role": "user", "content": f"{user_text}\n\n[image] {os.path.basename(item['image_path'])}"}
]
for turn in range(max_turns):
if turn == 0:
resp_text, _ = chat_student_messages(
messages=messages,
max_completion_tokens=1024,
retries=5,
stage="rollout",
)
else:
refinement_text = (
f"Your previous answer was:\n{response}\n\n"
"Re-check the diagram and the mathematical constraints. "
"If needed, correct your answer. Output only the final answer inside <answer>...</answer>."
)
refinement_messages = [
messages[0],
messages[1],
{"role": "assistant", "content": response},
{"role": "user", "content": refinement_text},
]
resp_text, _ = chat_student_messages(
messages=refinement_messages,
max_completion_tokens=768,
retries=5,
stage="rollout",
)
response = resp_text
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
if extract_answer(resp_text):
break
result["response"] = response
result["agent_ok"] = True
result["n_turns"] = len(conversation) - 1
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
f.write(system_prompt)
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
f.write(user_text)
eval_result = evaluate_item(
item=item,
prediction_text=result["response"],
judge_model=judge_model,
max_completion_tokens=judge_max_completion_tokens,
retries=judge_retries,
)
result["evaluation_mode"] = eval_result["evaluation_mode"]
result["judge_raw"] = eval_result.get("judge_raw", "")
result["judge_reason"] = eval_result.get("judge_reason", "")
result["matched_gold"] = eval_result.get("matched_gold", "")
if item.get("is_choice"):
result["predicted_label"] = eval_result["predicted_label"]
result["predicted_text"] = eval_result["predicted_text"]
result["predicted_answer"] = eval_result["predicted_answer"]
result["hard"] = int(eval_result["em"])
result["soft"] = eval_result["f1"]
if not result["hard"]:
result["fail_reason"] = (
f"choice=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
f"but expected '{eval_result['correct_label']}'"
)
eval_detail = (
f"[EVALUATION RESULT]\n"
f"Question: {item['question_for_eval']}\n"
f"Predicted label: {eval_result['predicted_label']!r}\n"
f"Predicted text: {eval_result['predicted_text']!r}\n"
f"Correct label: {eval_result['correct_label']!r}\n"
f"Correct text: {eval_result['correct_text']!r}\n"
f"Exact Match: {eval_result['em']}"
)
else:
result["predicted_answer"] = eval_result["predicted_answer"]
result["hard"] = int(eval_result["em"])
result["soft"] = eval_result["f1"]
if not result["hard"]:
result["fail_reason"] = (
f"judge=0: predicted '{eval_result['predicted_answer']}' "
f"but expected '{item['answer']}' ({eval_result.get('judge_reason', '')})"
)
eval_detail = (
f"[EVALUATION RESULT]\n"
f"Question: {item['question_for_eval']}\n"
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
f"Gold answer: {item['answer']!r}\n"
f"Judge correct: {eval_result['em']}\n"
f"Judge reason: {eval_result.get('judge_reason', '')}\n"
f"String F1: {eval_result.get('string_f1', 0.0):.4f}"
)
conversation.append({"role": "system", "content": eval_detail})
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
json.dump(conversation, f, ensure_ascii=False, indent=2)
except Exception as e: # noqa: BLE001
result["fail_reason"] = f"error: {e}"
return result
def run_batch(
items: list[dict],
out_root: str,
skill_content: str,
*,
max_turns: int = 1,
workers: int = 32,
image_detail: str = "auto",
judge_model: str = "gpt-5.4",
judge_max_completion_tokens: int = 256,
judge_retries: int = 5,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context_by_id: dict[str, str] | None = None,
) -> list[dict]:
results_path = os.path.join(out_root, "results.jsonl")
os.makedirs(out_root, exist_ok=True)
expected_eval_mode = evaluation_mode()
done_ids: set[str] = set()
existing: list[dict] = []
rewrite_results = False
if os.path.exists(results_path):
with open(results_path, encoding="utf-8") as f:
for line in f:
try:
row = json.loads(line)
if row.get("evaluation_mode") != expected_eval_mode:
rewrite_results = True
continue
done_ids.add(str(row["id"]))
existing.append(row)
except Exception:
rewrite_results = True
pending = [item for item in items if str(item["id"]) not in done_ids]
if not pending and not rewrite_results:
return existing
results = list(existing)
file_mode = "w" if rewrite_results else "a"
with open(results_path, file_mode, encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex:
if rewrite_results:
for row in existing:
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
futs = {
ex.submit(
process_one,
item,
out_root,
skill_content,
max_turns=max_turns,
image_detail=image_detail,
judge_model=judge_model,
judge_max_completion_tokens=judge_max_completion_tokens,
judge_retries=judge_retries,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""),
): item
for item in pending
}
for fut in as_completed(futs):
row = fut.result()
results.append(row)
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
outf.flush()
return results

View File

@@ -0,0 +1,15 @@
# MathVerse Visual Math Heuristics
## Diagram First
- Read the diagram before locking onto an equation or option.
- Recover missing labels, lengths, angles, axes, or object relations from the image when the text is abbreviated.
- If the text seems underspecified, assume the image may contain the decisive constraint.
## Constraint Tracking
- Write down the few constraints that actually determine the answer instead of solving from vague intuition.
- Prefer geometric or functional relations that are directly supported by the figure.
- For multiple-choice questions, compare the final candidate against every option exactly.
## Final Answer
- Use the image and the text consistently.
- Return only the final answer inside <answer>...</answer>.

View File

@@ -0,0 +1,2 @@
"""MMRB environment package."""

View File

@@ -0,0 +1,283 @@
"""MMRB environment adapter for ReflACT."""
from __future__ import annotations
import json
import os
from reflact.gradient.deep_probe import generate_deep_probe_instruction
from reflact.datasets.base import BatchSpec
from reflact.gradient.reflect import run_minibatch_reflect
from reflact.envs.base import EnvAdapter
from reflact.envs.mmrb.dataloader import MMRBDataLoader
from reflact.envs.mmrb.rollout import run_batch
from reflact.model import get_student_backend
class MMRBAdapter(EnvAdapter):
"""MMRB adapter."""
def build_reference_text(self, item: dict) -> str:
reasoning_steps = item.get("reasoning_steps") or []
if not reasoning_steps:
return ""
blocks: list[str] = []
for path_idx, path in enumerate(reasoning_steps, 1):
if not isinstance(path, list) or not path:
continue
lines = [f"### Reasoning Path {path_idx}"]
for step in path:
if not isinstance(step, dict):
continue
step_no = step.get("reasoning step", "?")
step_type = str(step.get("reasoning type") or "").strip()
rationale = str(step.get("rationale") or "").strip()
if rationale:
prefix = f"{step_no}. [{step_type}] " if step_type else f"{step_no}. "
lines.append(prefix + rationale)
if len(lines) > 1:
blocks.append("\n".join(lines))
if not blocks:
return ""
return "## Reference Reasoning Steps\n" + "\n\n".join(blocks[:3])
def get_reference_metadata(self, item: dict) -> dict:
reasoning_steps = item.get("reasoning_steps") or []
path_count = 0
preview_parts: list[str] = []
for path in reasoning_steps:
if not isinstance(path, list) or not path:
continue
path_count += 1
first = path[0] if isinstance(path[0], dict) else {}
step_type = str(first.get("reasoning type") or "").strip()
rationale = str(first.get("rationale") or "").strip()
preview_parts.append(f"[path {path_count}] {step_type}: {rationale[:180]}")
if not path_count:
return {"fields": [], "preview": ""}
return {
"fields": ["reasoning_steps"],
"preview": "\n".join(preview_parts)[:500],
}
def __init__(
self,
split_dir: str = "",
data_path: str = "",
split_mode: str = "ratio",
split_ratio: str = "2:1:7",
split_seed: int = 42,
split_output_dir: str = "",
max_turns: int = 1,
workers: int = 16,
analyst_workers: int = 16,
failure_only: bool = False,
minibatch_size: int = 8,
edit_budget: int = 4,
seed: int = 42,
limit: int = 0,
image_detail: str = "auto",
use_deep_reflect: bool = False,
deep_reflect_failures: int = 4,
deep_reflect_successes: int = 2,
) -> None:
self.max_turns = max_turns
self.workers = workers
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
self.edit_budget = edit_budget
self.image_detail = image_detail
self.use_deep_reflect = use_deep_reflect
self.deep_reflect_failures = deep_reflect_failures
self.deep_reflect_successes = deep_reflect_successes
self.dataloader = MMRBDataLoader(
split_dir=split_dir,
data_path=data_path,
split_mode=split_mode,
split_ratio=split_ratio,
split_seed=split_seed,
split_output_dir=split_output_dir,
seed=seed,
limit=limit,
)
def setup(self, cfg: dict) -> None:
super().setup(cfg)
self.dataloader.setup(cfg)
def get_dataloader(self):
return self.dataloader
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
return list(batch.payload or [])
def build_train_env(self, batch_size: int, seed: int, **kwargs):
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
return self.build_env_from_batch(batch, **kwargs)
def rollout(
self,
env_manager,
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict]:
items: list[dict] = env_manager
return run_batch(
items=items,
out_root=out_dir,
skill_content=skill_content,
max_turns=self.max_turns,
workers=self.workers,
image_detail=self.image_detail,
diagnostic_mode=kwargs.get("diagnostic_mode", False),
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),
)
def reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
meta_skill_context = kwargs.get("meta_skill_context", "")
return run_minibatch_reflect(
results=results,
skill_content=skill_content,
prediction_dir=prediction_dir,
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
)
def deep_reflect(
self,
results: list[dict],
skill_content: str,
out_dir: str,
**kwargs,
) -> list[dict | None]:
if not self.use_deep_reflect:
return []
env_manager = kwargs.get("env_manager")
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
random_seed = kwargs.get("random_seed")
step_buffer_context = kwargs.get("step_buffer_context", "")
meta_skill_context = kwargs.get("meta_skill_context", "")
codex_backend = get_student_backend() == "codex_exec"
selected_items = self.select_representative_items(
results,
env_manager if isinstance(env_manager, list) else None,
n_failures=self.deep_reflect_failures,
n_successes=self.deep_reflect_successes,
seed=random_seed,
)
if not selected_items:
return []
selected_ids = {str(item["id"]) for item in selected_items}
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
selected_examples = self.attach_reference_context(selected_results, selected_items)
if codex_backend:
selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir)
reasoning_count = 0
selected_metadata = []
for item in selected_items:
meta = self.get_reference_metadata(item)
if meta["fields"]:
reasoning_count += 1
selected_metadata.append({
"id": str(item["id"]),
"task_type": str(item.get("subtask") or item.get("task_type") or "mmrb"),
"reference_fields": meta["fields"],
"reference_preview": meta["preview"],
})
deep_dir = os.path.join(out_dir, "deep_reflect")
rollout_dir = os.path.join(deep_dir, "rollout")
patches_dir = os.path.join(deep_dir, "patches")
os.makedirs(deep_dir, exist_ok=True)
print(
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
f"reference_fields=reasoning_steps({reasoning_count}/{len(selected_items)})"
)
probe = generate_deep_probe_instruction(
skill_content=skill_content,
items=selected_examples,
prediction_dir=prediction_dir,
system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
)
if not probe:
return []
diagnostic_trace_context_by_id = None
if codex_backend:
selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target(
selected_items=selected_items,
selected_examples=selected_examples,
prediction_dir=prediction_dir,
probe=probe,
)
probe_record = {
**probe,
"reference_summary": {
"selected_count": len(selected_items),
"field_counts": {"reasoning_steps": reasoning_count},
},
"selected_examples": selected_metadata,
}
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
json.dump(probe_record, f, ensure_ascii=False, indent=2)
deep_results = run_batch(
items=selected_items,
out_root=rollout_dir,
skill_content=skill_content,
max_turns=self.max_turns,
workers=min(self.workers, max(len(selected_items), 1)),
image_detail=self.image_detail,
diagnostic_mode=True,
diagnostic_instruction=probe["probe_instruction"],
diagnostic_trace_context_by_id=diagnostic_trace_context_by_id,
)
deep_results = self.attach_reference_context(deep_results, selected_items)
return run_minibatch_reflect(
results=deep_results,
skill_content=skill_content,
prediction_dir=os.path.join(rollout_dir, "predictions"),
patches_dir=patches_dir,
workers=self.analyst_workers,
failure_only=self.failure_only,
minibatch_size=self.minibatch_size,
edit_budget=self.edit_budget,
random_seed=random_seed,
error_system=self.get_error_minibatch_prompt(),
success_system=self.get_success_minibatch_prompt(),
step_buffer_context=step_buffer_context,
meta_skill_context=meta_skill_context,
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
)
def get_task_types(self) -> list[str]:
return self.dataloader.get_task_types()

View File

@@ -0,0 +1,146 @@
"""MMRB task dataloader."""
from __future__ import annotations
import glob
import json
import os
import re
from typing import Any
from reflact.datasets.base import SplitDataLoader
# ── Raw data loading utilities (for preprocessing / standalone eval) ─────
def _load_json(path: str) -> Any:
with open(path, encoding="utf-8") as f:
return json.load(f)
def _iter_data_files(data_path: str) -> list[str]:
if not data_path:
return []
if os.path.isfile(data_path):
return [data_path]
if os.path.isdir(data_path):
nested = glob.glob(os.path.join(data_path, "**", "*_human.json"), recursive=True)
flat = glob.glob(os.path.join(data_path, "*_human.json"))
return sorted(set(nested + flat))
return []
def _normalize_space(text: str) -> str:
return re.sub(r"\s+", " ", str(text or "").strip())
def _normalize_item(item: dict, row_idx: int, source_path: str) -> dict | None:
question = _normalize_space(item.get("question") or "")
answer = _normalize_space(item.get("answer") or "")
raw_image_paths = item.get("image_paths") or []
if not question or not answer or not isinstance(raw_image_paths, list) or not raw_image_paths:
return None
base_dir = os.path.dirname(source_path)
image_paths: list[str] = []
for raw_path in raw_image_paths:
rel = str(raw_path or "").strip()
if not rel:
continue
abs_path = rel if os.path.isabs(rel) else os.path.abspath(os.path.join(base_dir, rel))
if os.path.exists(abs_path):
image_paths.append(abs_path)
if not image_paths:
return None
options_raw = item.get("options") or []
options = [_normalize_space(opt) for opt in options_raw if _normalize_space(opt)]
source = _normalize_space(item.get("source") or "unknown")
subtask = _normalize_space(item.get("subtask") or "unknown")
item_index = item.get("index", row_idx)
item_id = f"{source}:{subtask}:{item_index}"
return {
"id": item_id,
"source": source,
"subtask": subtask,
"task_type": subtask,
"question": question,
"answer": answer,
"options": options,
"is_choice": bool(options),
"image_paths": image_paths,
"reasoning_steps": item.get("reasoning_steps") or [],
"annotation_time": item.get("annotation_time"),
"source_path": os.path.abspath(source_path),
}
def load_items(data_path: str) -> list[dict]:
"""Load and normalise MMRB items from JSON files."""
files = _iter_data_files(data_path)
if not files:
raise ValueError(
"MMRB requires data_path to be a *_human.json file or a directory "
"containing extracted MMRB subtask folders."
)
items: list[dict] = []
for path in files:
raw = _load_json(path)
if not isinstance(raw, list):
raise ValueError(f"Expected JSON array in {path}, got {type(raw).__name__}")
for row_idx, item in enumerate(raw):
if not isinstance(item, dict):
continue
norm = _normalize_item(item, row_idx=row_idx, source_path=path)
if norm is not None:
items.append(norm)
if not items:
raise ValueError(f"No valid MMRB items loaded from {data_path}")
return items
# ── Dataloader ───────────────────────────────────────────────────────────
class MMRBDataLoader(SplitDataLoader):
"""MMRB dataloader."""
def __init__(
self,
split_dir: str = "",
data_path: str = "",
split_mode: str = "ratio",
split_ratio: str = "2:1:7",
split_seed: int = 42,
split_output_dir: str = "",
seed: int = 42,
limit: int = 0,
**kwargs,
) -> None:
super().__init__(
split_dir=split_dir,
data_path=data_path,
split_mode=split_mode,
split_ratio=split_ratio,
split_seed=split_seed,
split_output_dir=split_output_dir,
seed=seed,
limit=limit,
)
self._task_types: list[str] = []
def load_raw_items(self, data_path: str) -> list[dict]:
return load_items(data_path)
def setup(self, cfg: dict) -> None:
super().setup(cfg)
all_items = self.train_items + self.val_items + self.test_items
task_types = {
item.get("subtask") or item.get("task_type") or "unknown"
for item in all_items
}
self._task_types = sorted(task_types)
def get_task_types(self) -> list[str]:
return list(self._task_types)

View File

@@ -0,0 +1,102 @@
"""MMRB evaluation helpers."""
from __future__ import annotations
import re
import string
_EVAL_MODE = "mmrb_exact_match_v1"
def normalize_text(text: str) -> str:
text = str(text or "").strip().lower()
text = "".join(ch for ch in text if ch not in string.punctuation)
return " ".join(text.split())
def extract_answer(text: str | None) -> str:
raw = str(text or "").strip()
if not raw:
return ""
answer_tags = re.findall(r"<answer>\s*(.*?)\s*</answer>", raw, re.IGNORECASE | re.DOTALL)
if answer_tags:
return answer_tags[-1].strip()
bracket = re.findall(r"Answer\s*\[\s*(.*?)\s*\]", raw, re.IGNORECASE | re.DOTALL)
if bracket:
return bracket[-1].strip()
boxed = re.findall(r"\\boxed\{(.*?)\}", raw, re.IGNORECASE | re.DOTALL)
if boxed:
return boxed[-1].strip()
single = raw.strip().rstrip(".):")
if re.fullmatch(r"[A-Z]", single, re.IGNORECASE):
return single.strip()
patterns = [
r"final answer\s*(?:is)?\s*[:]?\s*(.+)",
r"the answer is\s*[:]?\s*(.+)",
r"answer\s*[:]?\s*(.+)$",
]
for pattern in patterns:
match = re.search(pattern, raw, re.IGNORECASE)
if match:
return match.group(1).strip().strip("*")
return raw
def evaluate_item(*, item: dict, prediction_text: str) -> dict:
predicted_answer = extract_answer(prediction_text)
gold_answer = str(item.get("answer") or "").strip()
predicted_norm = normalize_text(predicted_answer)
gold_norm = normalize_text(gold_answer)
hard = 0.0
matched_gold = ""
predicted_label = ""
predicted_text = predicted_answer
if item.get("is_choice"):
predicted_label = str(predicted_answer).strip().upper().rstrip(".):")
if predicted_label == str(gold_answer).strip().upper():
hard = 1.0
matched_gold = gold_answer
else:
for option in item.get("options") or []:
label_match = re.match(r"\(?([A-Z])\)", option)
if not label_match:
continue
label = label_match.group(1).upper()
option_text = option[label_match.end():].strip(" .:-")
if predicted_norm and normalize_text(option_text) == predicted_norm:
predicted_label = label
predicted_text = option_text
break
if predicted_label == str(gold_answer).strip().upper():
hard = 1.0
matched_gold = gold_answer
else:
if predicted_norm and gold_norm and (
predicted_norm == gold_norm or predicted_norm in gold_norm or gold_norm in predicted_norm
):
hard = 1.0
matched_gold = gold_answer
return {
"evaluation_mode": _EVAL_MODE,
"predicted_answer": predicted_answer,
"predicted_label": predicted_label,
"predicted_text": predicted_text,
"em": hard,
"f1": hard,
"sub_em": hard,
"matched_gold": matched_gold,
}
def evaluation_mode() -> str:
return _EVAL_MODE

View File

@@ -0,0 +1,10 @@
You are an expert multi-image reasoning agent.
{skill_section}## Task Format
You will receive a question grounded in multiple images.
Use the image order exactly as presented in the prompt and compare evidence across images carefully.
## Answer Format
- Put the final answer inside <answer>...</answer>.
- For multiple-choice questions, output only the single option letter inside <answer>...</answer>.
- For open questions, output only the short final answer inside <answer>...</answer>.

View File

@@ -0,0 +1,439 @@
"""MMRB rollout."""
from __future__ import annotations
import base64
import json
import mimetypes
import os
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from reflact.envs.mmrb.evaluator import evaluate_item, evaluation_mode
from reflact.model import chat_student_messages, get_student_backend, is_student_exec_backend
from reflact.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
from reflact.prompts import load_prompt
_IMAGE_REF_RE = re.compile(r"\{image#(\d+)\}", re.IGNORECASE)
def _build_system(skill_content: str) -> str:
if skill_content.strip():
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
else:
skill_section = ""
return load_prompt("rollout_system", env="mmrb").format(skill_section=skill_section)
def _image_to_data_uri(path: str) -> str:
mime = mimetypes.guess_type(path)[0] or "image/png"
with open(path, "rb") as f:
encoded = base64.b64encode(f.read()).decode("ascii")
return f"data:{mime};base64,{encoded}"
def _build_user_content(
item: dict,
image_detail: str,
*,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
) -> tuple[list[dict], str]:
raw_question = str(item["question"])
content: list[dict] = []
text_parts: list[str] = []
used_indices: set[int] = set()
cursor = 0
if diagnostic_trace_context.strip():
prefix = (
"## Previous Codex Trace Snapshot\n"
"This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n"
f"{diagnostic_trace_context.strip()}\n\n"
)
content.append({"type": "text", "text": prefix})
text_parts.append(prefix)
for match in _IMAGE_REF_RE.finditer(raw_question):
if match.start() > cursor:
chunk = raw_question[cursor:match.start()]
if chunk:
content.append({"type": "text", "text": chunk})
text_parts.append(chunk)
image_idx = int(match.group(1)) - 1
marker = f"[Image #{image_idx + 1}]"
text_parts.append(marker)
if 0 <= image_idx < len(item["image_paths"]):
image_url = {"url": _image_to_data_uri(item["image_paths"][image_idx])}
if image_detail and image_detail != "auto":
image_url["detail"] = image_detail
content.append({"type": "image_url", "image_url": image_url})
used_indices.add(image_idx)
else:
content.append({"type": "text", "text": marker})
cursor = match.end()
if cursor < len(raw_question):
tail = raw_question[cursor:]
if tail:
content.append({"type": "text", "text": tail})
text_parts.append(tail)
for idx, path in enumerate(item["image_paths"]):
if idx in used_indices:
continue
marker = f"\n[Additional Image #{idx + 1}]"
text_parts.append(marker)
content.append({"type": "text", "text": marker})
image_url = {"url": _image_to_data_uri(path)}
if image_detail and image_detail != "auto":
image_url["detail"] = image_detail
content.append({"type": "image_url", "image_url": image_url})
answer_instruction = (
"\n\nAnswer with the single correct option letter inside <answer>...</answer>."
if item.get("is_choice")
else "\n\nAnswer with the short final answer inside <answer>...</answer>."
)
content.append({"type": "text", "text": answer_instruction})
text_parts.append(answer_instruction)
if diagnostic_mode and diagnostic_instruction.strip():
diag_block = f"\n\n## Training Readout\n{diagnostic_instruction.strip()}"
content.append({"type": "text", "text": diag_block})
text_parts.append(diag_block)
return content, "".join(text_parts)
def _build_messages(
item: dict,
skill_content: str,
image_detail: str,
*,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
) -> tuple[list[dict], str, str]:
system = _build_system(skill_content)
user_content, user_text = _build_user_content(
item,
image_detail,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
)
messages = [
{"role": "system", "content": system},
{"role": "user", "content": user_content},
]
return messages, system, user_text
def _build_codex_skill(skill_content: str) -> str:
return render_skill_md(
skill_content,
description="Dynamic ReflACT skill for solving the current MMRB multi-image reasoning question.",
preamble=(
"Use this skill when solving the current multi-image reasoning task.\n"
"Inspect all attached images carefully and return the final answer inside <answer>...</answer>."
),
)
def _run_codex_once(
*,
pred_dir: str,
item: dict,
skill_content: str,
model: str,
timeout: int,
image_detail: str,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
previous_response: str = "",
) -> tuple[str, str, str, str]:
user_text = _build_user_content(
item,
image_detail,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)[1]
task_parts = [user_text]
if previous_response:
task_parts.append(
"## Previous Attempt\n"
f"{previous_response}\n\n"
"Review the same images carefully and answer again."
)
task_text = "\n\n".join(task_parts)
skill_md = _build_codex_skill(skill_content)
work_dir = os.path.join(pred_dir, "codex_exec")
prepare_workspace(
work_dir=work_dir,
skill_md=skill_md,
task_text=task_text,
images=item["image_paths"],
)
prompt = (
"Use the `reflact-student` skill available in this workspace.\n"
"Read `task.md`, inspect all attached images, and answer the question.\n"
"Keep the final answer inside <answer>...</answer>."
)
final_message, raw = run_student_exec(
work_dir=work_dir,
prompt=prompt,
model=model,
timeout=timeout,
images=item["image_paths"],
)
return final_message or raw, raw, skill_md, task_text
def process_one(
item: dict,
out_root: str,
skill_content: str,
*,
max_turns: int = 1,
image_detail: str = "auto",
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
) -> dict:
item_id = str(item["id"])
result = {
"id": item_id,
"question": item["question"],
"task_type": item.get("subtask") or item.get("task_type") or "mmrb",
"task_description": item["question"],
"hard": 0,
"soft": 0.0,
"predicted_answer": "",
"predicted_label": "",
"predicted_text": "",
"response": "",
"fail_reason": "",
"agent_ok": False,
"n_turns": 0,
"image_paths": item["image_paths"],
"gold_answer": item["answer"],
"evaluation_mode": evaluation_mode(),
}
try:
pred_dir = os.path.join(out_root, "predictions", item_id)
os.makedirs(pred_dir, exist_ok=True)
if is_student_exec_backend():
from reflact.model import azure_openai as _llm
response = ""
conversation: list[dict] = [
{
"role": "user",
"content": item["question"] + "\n\n" + "\n".join(
f"[image] {os.path.basename(path)}" for path in item["image_paths"]
),
}
]
system_prompt = ""
user_text = ""
for turn in range(max_turns):
response, raw, system_prompt, user_text = _run_codex_once(
pred_dir=pred_dir,
item=item,
skill_content=skill_content,
model=_llm.STUDENT_DEPLOYMENT,
timeout=120,
image_detail=image_detail,
diagnostic_mode=diagnostic_mode if turn == 0 else False,
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "",
previous_response=response if turn > 0 else "",
)
conversation.append({"type": "message", "turn": turn + 1, "content": response})
if "<answer>" in response.lower():
break
result["response"] = response
result["agent_ok"] = True
result["n_turns"] = len(conversation) - 1
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
f.write(system_prompt)
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
f.write(user_text)
eval_result = evaluate_item(item=item, prediction_text=response)
result["evaluation_mode"] = eval_result["evaluation_mode"]
result["predicted_answer"] = eval_result["predicted_answer"]
result["predicted_label"] = eval_result["predicted_label"]
result["predicted_text"] = eval_result["predicted_text"]
result["matched_gold"] = eval_result["matched_gold"]
result["hard"] = int(eval_result["em"])
result["soft"] = eval_result["f1"]
if not result["hard"]:
result["fail_reason"] = (
f"predicted '{eval_result['predicted_answer']}' but expected '{item['answer']}'"
)
eval_detail = (
"[EVALUATION RESULT]\n"
f"Question: {item['question']}\n"
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
f"Predicted label: {eval_result['predicted_label']!r}\n"
f"Gold answer: {item['answer']!r}\n"
f"Correct: {eval_result['em']}\n"
)
conversation.append({"role": "system", "content": eval_detail})
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
json.dump(conversation, f, ensure_ascii=False, indent=2)
return result
messages, system_prompt, user_text = _build_messages(
item,
skill_content,
image_detail,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)
response = ""
conversation: list[dict] = [
{
"role": "user",
"content": user_text + "\n\n" + "\n".join(
f"[image] {os.path.basename(path)}" for path in item["image_paths"]
),
}
]
for turn in range(max_turns):
if turn == 0:
resp_text, _ = chat_student_messages(
messages=messages,
max_completion_tokens=768,
retries=5,
stage="rollout",
)
else:
refinement_messages = [
messages[0],
messages[1],
{"role": "assistant", "content": response},
{
"role": "user",
"content": "Review the same images carefully and answer again. Keep the final answer inside <answer>...</answer>.",
},
]
resp_text, _ = chat_student_messages(
messages=refinement_messages,
max_completion_tokens=512,
retries=5,
stage="rollout",
)
response = resp_text
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
if "<answer>" in resp_text.lower():
break
result["response"] = response
result["agent_ok"] = True
result["n_turns"] = len(conversation) - 1
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
f.write(system_prompt)
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
f.write(user_text)
eval_result = evaluate_item(item=item, prediction_text=response)
result["evaluation_mode"] = eval_result["evaluation_mode"]
result["predicted_answer"] = eval_result["predicted_answer"]
result["predicted_label"] = eval_result["predicted_label"]
result["predicted_text"] = eval_result["predicted_text"]
result["matched_gold"] = eval_result["matched_gold"]
result["hard"] = int(eval_result["em"])
result["soft"] = eval_result["f1"]
if not result["hard"]:
result["fail_reason"] = (
f"predicted '{eval_result['predicted_answer']}' but expected '{item['answer']}'"
)
eval_detail = (
"[EVALUATION RESULT]\n"
f"Question: {item['question']}\n"
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
f"Predicted label: {eval_result['predicted_label']!r}\n"
f"Gold answer: {item['answer']!r}\n"
f"Correct: {eval_result['em']}\n"
)
conversation.append({"role": "system", "content": eval_detail})
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
json.dump(conversation, f, ensure_ascii=False, indent=2)
except Exception as e: # noqa: BLE001
result["fail_reason"] = f"error: {e}"
return result
def run_batch(
items: list[dict],
out_root: str,
skill_content: str,
*,
max_turns: int = 1,
workers: int = 16,
image_detail: str = "auto",
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context_by_id: dict[str, str] | None = None,
) -> list[dict]:
results_path = os.path.join(out_root, "results.jsonl")
os.makedirs(out_root, exist_ok=True)
expected_eval_mode = evaluation_mode()
done_ids: set[str] = set()
existing: list[dict] = []
rewrite_results = False
if os.path.exists(results_path):
with open(results_path, encoding="utf-8") as f:
for line in f:
try:
row = json.loads(line)
if row.get("evaluation_mode") != expected_eval_mode:
rewrite_results = True
continue
done_ids.add(str(row["id"]))
existing.append(row)
except Exception:
rewrite_results = True
pending = [item for item in items if str(item["id"]) not in done_ids]
if not pending and not rewrite_results:
return existing
results = list(existing)
file_mode = "w" if rewrite_results else "a"
with open(results_path, file_mode, encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex:
if rewrite_results:
for row in existing:
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
futs = {
ex.submit(
process_one,
item,
out_root,
skill_content,
max_turns=max_turns,
image_detail=image_detail,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""),
): item
for item in pending
}
for fut in as_completed(futs):
row = fut.result()
results.append(row)
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
outf.flush()
return results

View File

@@ -0,0 +1,17 @@
# MMRB Multi-Image Reasoning Heuristics
## Cross-Image Alignment
- Track the role of each image by its index and compare evidence across all referenced images before deciding.
- When the question depends on sequence, correspondence, or retrieval, verify the relation between images instead of judging each image independently.
## Option Elimination
- For multiple-choice tasks, compare all options and reject choices that match only part of the visual evidence.
- If options differ by a small visual detail, use the most discriminative cue rather than a coarse scene impression.
## Open Answers
- For open-ended tasks, give the shortest answer that is fully supported by the combined images.
- Preserve exact entities, attributes, counts, and directions when the images support them directly.
## Final Answer
- Output only the final answer inside <answer>...</answer>.

Some files were not shown because too many files have changed in this diff Show More