mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-05 15:27:51 +08:00
Initial commit
This commit is contained in:
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
data/
|
||||
outputs/
|
||||
logs/
|
||||
external/
|
||||
/BabyVision/
|
||||
/MMRB/
|
||||
/SpreadsheetBench/
|
||||
/dl4ir-searchQA/
|
||||
configs/local/
|
||||
configs/**/*.local.yaml
|
||||
*.local.md
|
||||
.secrets/
|
||||
.codex_azure*/
|
||||
557
README.md
Normal file
557
README.md
Normal file
@@ -0,0 +1,557 @@
|
||||
# ReflACT: Reflective Agent Tuning
|
||||
|
||||
ReflACT is a framework for optimizing an external skill document through iterative rollout, reflection, editing, and gated validation.
|
||||
|
||||
It does **not** fine-tune model weights. Instead, it treats the skill document as the optimization target:
|
||||
|
||||
- the **student** model executes tasks with the current skill
|
||||
- the **teacher** model analyzes trajectories and proposes edits
|
||||
- the framework merges, ranks, applies, and validates those edits
|
||||
- only validated skill updates are kept
|
||||
|
||||
This branch implements a full training loop with step-level skill optimization and optional epoch-level memory mechanisms (`slow_update`, `meta_skill`, `meta_reflect`).
|
||||
|
||||
## Method Overview
|
||||
|
||||
### Optimization Target
|
||||
|
||||
Each run maintains a mutable markdown skill document. The framework repeatedly improves that document instead of changing model parameters.
|
||||
|
||||
This gives a training-style loop for prompt / policy optimization:
|
||||
|
||||
1. Roll out the current skill on a batch of tasks.
|
||||
2. Reflect on failures and successes.
|
||||
3. Merge patch proposals into a coherent candidate update.
|
||||
4. Rank and select a bounded number of edits.
|
||||
5. Apply those edits to produce a candidate skill.
|
||||
6. Validate the candidate skill on a held-out selection split.
|
||||
7. Keep the update only if the gate accepts it.
|
||||
|
||||
### Per-Step Pipeline
|
||||
|
||||
Every training step executes the following pipeline in `reflact/engine/trainer.py`:
|
||||
|
||||
1. **Rollout**
|
||||
The student model runs a batch of tasks using the current skill.
|
||||
|
||||
2. **Reflect**
|
||||
The teacher analyzes minibatches of trajectories and emits raw patches.
|
||||
Failure-driven and success-driven patches are tracked separately.
|
||||
|
||||
3. **Aggregate**
|
||||
Raw patches are merged hierarchically. Metadata such as `support_count` and `source_type` is carried into the merged patch so later ranking can use it.
|
||||
|
||||
4. **Select**
|
||||
The teacher ranks the merged edit pool and keeps up to `edit_budget` edits.
|
||||
|
||||
5. **Update**
|
||||
The selected edits are applied to the skill document. The framework records an `edit_apply_report.json` so you can see which edits actually landed, which were skipped, and why.
|
||||
|
||||
6. **Evaluate / Gate**
|
||||
The candidate skill is evaluated on the selection split. Gate validation is mandatory in this branch. A candidate update is accepted only if it improves over the current selection score; a new global best is tracked separately.
|
||||
|
||||
### Within-Epoch Memory
|
||||
|
||||
Inside an epoch, the trainer maintains a step buffer containing:
|
||||
|
||||
- compact failure-pattern summaries from previous steps
|
||||
- rejected edits and their score deltas
|
||||
|
||||
That context is fed back into later reflection calls so the teacher can avoid repeating ineffective edits and can focus on unsolved error patterns.
|
||||
|
||||
### Epoch-Level Mechanisms
|
||||
|
||||
This branch supports three optional epoch-level mechanisms.
|
||||
|
||||
#### Slow Update
|
||||
|
||||
At the end of each epoch, `slow_update` compares the previous epoch’s terminal skill and current epoch’s terminal skill on a sampled train subset. It then writes longitudinal guidance into a protected slow-update region inside the skill document.
|
||||
|
||||
Importantly, this guidance is **not** blindly written through. It is converted into a candidate skill and sent through the same selection gate as step-level updates.
|
||||
|
||||
#### Meta Skill
|
||||
|
||||
`meta_skill` is teacher-side cross-epoch memory. It does not directly edit the current skill. Instead, it writes a compact memory artifact describing longer-term patterns across adjacent epochs. That memory is loaded into later reflection / merge / ranking calls as extra context.
|
||||
|
||||
#### Meta Reflect
|
||||
|
||||
`meta_reflect` runs at epoch end over the step history of the current epoch. It looks at accepted and rejected directions from the whole epoch, proposes higher-level patch edits, applies them to a meta candidate, and then sends that candidate through the same selection gate.
|
||||
|
||||
## What This Branch Guarantees
|
||||
|
||||
The current implementation assumes the following as the mainline method contract:
|
||||
|
||||
- gate validation is always on
|
||||
- the current skill, current score, best skill, and best score stay aligned
|
||||
- `slow_update` is gated before being committed
|
||||
- patch provenance (`source_type`, `support_count`) reaches selection
|
||||
- patch application is observable through per-edit reports
|
||||
- resume state is restored from `runtime_state.json` rather than inferred only from history
|
||||
- all benchmark model calls go through the unified backend router
|
||||
|
||||
## Model Backends
|
||||
|
||||
All model access now goes through the split teacher/student model layer in `reflact.model`.
|
||||
|
||||
Supported teacher backends:
|
||||
|
||||
- `openai_chat`
|
||||
- `claude_chat`
|
||||
|
||||
Supported student backends:
|
||||
|
||||
- `openai_chat`
|
||||
- `claude_chat`
|
||||
- `codex_exec`
|
||||
- `claude_code_exec`
|
||||
|
||||
Recommended config shape:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
teacher_backend: openai_chat
|
||||
student_backend: codex_exec
|
||||
teacher: gpt-5.4
|
||||
student: gpt-5.4-codex
|
||||
reasoning_effort: medium
|
||||
```
|
||||
|
||||
Legacy `model.backend` and CLI flags like `--backend codex` still work. They are mapped onto the split backend model for backward compatibility.
|
||||
|
||||
The same routing is used by:
|
||||
|
||||
- training (`scripts/train.py`)
|
||||
- eval-only runs (`scripts/eval_only.py`)
|
||||
- SpreadsheetBench standalone prompt eval scripts
|
||||
- LiveMathematicianBench baseline eval script
|
||||
- benchmark rollout code inside the main framework
|
||||
|
||||
### Azure OpenAI
|
||||
|
||||
If you use `openai_chat`, configure either environment variables or config values:
|
||||
|
||||
```bash
|
||||
export AZURE_OPENAI_ENDPOINT="https://your-endpoint.openai.azure.com/"
|
||||
export AZURE_OPENAI_API_KEY="your-api-key"
|
||||
export AZURE_OPENAI_API_VERSION="2025-04-01-preview"
|
||||
```
|
||||
|
||||
The config supports both the old keys and the new explicit names:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
azure_openai_endpoint: "..."
|
||||
azure_openai_api_version: "..."
|
||||
azure_openai_api_key: ""
|
||||
azure_openai_auth_mode: api_key
|
||||
azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
|
||||
azure_openai_managed_identity_client_id: ""
|
||||
```
|
||||
|
||||
`azure_openai_auth_mode` can be used for API-key auth or Azure AD / managed identity flows.
|
||||
|
||||
### Exec Harness
|
||||
|
||||
`codex_exec` and `claude_code_exec` run the student inside a workspace harness instead of a plain chat call. The harness writes task files, renders a dynamic `SKILL.md`, runs the student CLI, and saves raw execution artifacts such as:
|
||||
|
||||
- `codex_raw.txt`
|
||||
- `codex_trace_summary.txt`
|
||||
- workspace-local task / skill files
|
||||
|
||||
This branch keeps `meta_skill` and `apply_patch_with_report`, while upgrading the student path to the more realistic workspace-exec setup.
|
||||
|
||||
### Trace-Aware Deep Reflect
|
||||
|
||||
When `student_backend=codex_exec` and `gradient.use_deep_reflect=true`, deep reflection can probe a specific earlier Codex attempt:
|
||||
|
||||
- the teacher sees a compact Codex trace summary
|
||||
- deep probe can target `probe_target_id`
|
||||
- the follow-up rollout can resume from `probe_after_step`
|
||||
|
||||
This is wired for the dataset-backed environments in this branch.
|
||||
|
||||
### Rewrite Mode
|
||||
|
||||
Skill updates support two modes:
|
||||
|
||||
- `optimizer.skill_update_mode=patch`
|
||||
- `optimizer.skill_update_mode=rewrite_from_suggestions`
|
||||
|
||||
`patch` keeps the existing fine-grained edit application path and still records `edit_apply_report.json`.
|
||||
|
||||
`rewrite_from_suggestions` asks the teacher to emit higher-level rewrite suggestions, then rewrites the whole skill in one pass. This is useful when patch edits become too fragmented.
|
||||
|
||||
## Repository Layout
|
||||
|
||||
```text
|
||||
reflact/
|
||||
engine/
|
||||
trainer.py main training loop
|
||||
gradient/
|
||||
reflect.py minibatch reflection
|
||||
aggregate.py hierarchical patch merge
|
||||
deep_probe.py diagnostic probing for deep reflect
|
||||
optimizer/
|
||||
clip.py edit ranking / selection
|
||||
skill.py patch application + apply report
|
||||
slow_update.py epoch-level longitudinal guidance
|
||||
meta_skill.py teacher-side cross-epoch memory
|
||||
meta_reflect.py epoch-level macro editing
|
||||
evaluation/
|
||||
gate.py pure gate decision logic
|
||||
model/
|
||||
backend_config.py teacher/student backend routing
|
||||
azure_openai.py Azure backend
|
||||
codex_harness.py workspace exec harness + Codex trace parsing
|
||||
claude_backend.py Claude backend
|
||||
envs/
|
||||
... environment adapters and rollout logic
|
||||
scripts/
|
||||
train.py unified training entry
|
||||
eval_only.py evaluate one skill without training
|
||||
configs/
|
||||
_base_/default.yaml shared defaults
|
||||
<env>/default.yaml environment-specific configs
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Configs use structured YAML with `_base_` inheritance.
|
||||
|
||||
The base config is `configs/_base_/default.yaml`. Key defaults in this branch are:
|
||||
|
||||
- `model.teacher_backend = openai_chat`
|
||||
- `model.student_backend = openai_chat`
|
||||
- `model.reasoning_effort = medium`
|
||||
- `optimizer.use_slow_update = true`
|
||||
- `optimizer.use_meta_skill = true`
|
||||
- `optimizer.use_meta_reflect = false`
|
||||
- `gradient.use_deep_reflect = false`
|
||||
- `optimizer.skill_update_mode = patch`
|
||||
|
||||
Default setting snapshot:
|
||||
|
||||
```yaml
|
||||
model:
|
||||
backend: azure_openai
|
||||
teacher: gpt-5.4
|
||||
student: gpt-5.4
|
||||
teacher_backend: openai_chat
|
||||
student_backend: openai_chat
|
||||
reasoning_effort: medium
|
||||
rewrite_reasoning_effort: ""
|
||||
rewrite_max_completion_tokens: 64000
|
||||
codex_exec_path: codex
|
||||
codex_exec_sandbox: workspace-write
|
||||
codex_exec_profile: ""
|
||||
codex_exec_full_auto: false
|
||||
codex_exec_reasoning_effort: none
|
||||
claude_code_exec_path: claude
|
||||
claude_code_exec_profile: ""
|
||||
codex_trace_to_teacher: true
|
||||
|
||||
train:
|
||||
num_epochs: 4
|
||||
train_size: 0
|
||||
batch_size: 80
|
||||
accumulation: 1
|
||||
seed: 42
|
||||
|
||||
gradient:
|
||||
minibatch_size: 16
|
||||
merge_batch_size: 16
|
||||
analyst_workers: 16
|
||||
max_analyst_rounds: 3
|
||||
failure_only: false
|
||||
use_deep_reflect: false
|
||||
deep_reflect_failures: 4
|
||||
deep_reflect_successes: 2
|
||||
|
||||
optimizer:
|
||||
learning_rate: 8
|
||||
min_learning_rate: 2
|
||||
lr_scheduler: cosine
|
||||
skill_update_mode: patch
|
||||
use_meta_reflect: false
|
||||
meta_learning_rate: 8
|
||||
use_slow_update: true
|
||||
slow_update_samples: 20
|
||||
use_meta_skill: true
|
||||
|
||||
evaluation:
|
||||
use_gate: true
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
eval_test: true
|
||||
|
||||
env:
|
||||
split_mode: ratio
|
||||
split_ratio: "2:1:7"
|
||||
split_seed: 42
|
||||
```
|
||||
|
||||
For the full source of truth, see [configs/_base_/default.yaml](/home/azureuser/workspace-yqh/skillopt_final/configs/_base_/default.yaml).
|
||||
|
||||
Selected fields:
|
||||
|
||||
| Section | Key | Meaning |
|
||||
|---|---|---|
|
||||
| `model` | `teacher_backend` | teacher backend: `openai_chat` or `claude_chat` |
|
||||
| `model` | `student_backend` | student backend: chat backend or exec backend |
|
||||
| `model` | `teacher` | teacher model / deployment |
|
||||
| `model` | `student` | student model / deployment |
|
||||
| `model` | `reasoning_effort` | reasoning budget passed to the backend when supported |
|
||||
| `model` | `codex_trace_to_teacher` | include Codex trace summaries in teacher reflection context |
|
||||
| `train` | `num_epochs` | number of epochs |
|
||||
| `train` | `train_size` | expected train split size, or `0` to infer |
|
||||
| `train` | `batch_size` | tasks per rollout batch |
|
||||
| `train` | `accumulation` | number of rollout/reflect minibatches per step |
|
||||
| `gradient` | `minibatch_size` | trajectories per analyst minibatch |
|
||||
| `gradient` | `merge_batch_size` | patches per aggregate batch |
|
||||
| `gradient` | `use_deep_reflect` | enable diagnostic probe rollouts |
|
||||
| `gradient` | `max_analyst_rounds` | teacher reflection retries / refinement budget |
|
||||
| `optimizer` | `learning_rate` | max edits kept after selection |
|
||||
| `optimizer` | `lr_scheduler` | edit-budget scheduler |
|
||||
| `optimizer` | `use_slow_update` | epoch-level longitudinal guidance |
|
||||
| `optimizer` | `use_meta_skill` | teacher-side epoch memory |
|
||||
| `optimizer` | `use_meta_reflect` | epoch-level macro editing |
|
||||
| `optimizer` | `skill_update_mode` | `patch` or `rewrite_from_suggestions` |
|
||||
| `evaluation` | `sel_env_num` | selection set size (`0` means full split) |
|
||||
| `evaluation` | `test_env_num` | test set size (`0` means full split) |
|
||||
|
||||
### Important Branch Rule
|
||||
|
||||
`use_gate=false` is intentionally not supported in this branch. Gate validation is part of the method contract here.
|
||||
|
||||
If an old config still contains `evaluation.use_gate: false`, the loader / trainer will raise instead of silently continuing.
|
||||
|
||||
## Supported Environments
|
||||
|
||||
The main training entry and eval-only entry now register 11 environments:
|
||||
|
||||
| Env | Default rollout shape | Current default split / data setting | Branch alignment |
|
||||
|---|---|---|---|
|
||||
| `alfworld` | environment-backed episodic rollout | native ALFWorld train/eval splits | in `reflact_new_zzw` |
|
||||
| `babyvision` | single-round multimodal QA | `split_mode=ratio` from raw metadata/images, or prepared `split_dir` | in `reflact_new_zzw` |
|
||||
| `docvqa` | single-round multimodal QA | `split_dir: data/docvqa_split` | in `reflact_new_zzw` |
|
||||
| `livemathematicianbench` | single-round QA | `split_mode=ratio` or prepared `split_dir` | in `reflact_new_zzw` |
|
||||
| `mathverse` | single-round multimodal math QA | `data_root: data/MathVerse`, split files loaded from `split_dir` when provided | in `reflact_new_zzw` |
|
||||
| `mmrb` | single-round multimodal reasoning QA | `split_mode=ratio` or prepared `split_dir` | in `reflact_new_zzw` |
|
||||
| `officeqa` | multi-turn tool loop | `split_dir: data/officeqa_split` plus `data_dirs: [data/officeqa_docs_official]` | in `reflact_new_zzw` |
|
||||
| `sealqa` | multi-turn tool loop | `split_dir: data/sealqa_split` | in `reflact_new_zzw` |
|
||||
| `searchqa` | single-round QA (`max_turns=1`) | `split_dir: data/searchqa_split` | in `reflact_new_zzw` |
|
||||
| `spreadsheetbench` | codegen loop, default `mode=multi`, `max_turns=30` | `split_dir: data/spreadsheetbench_split`, `data_root: data/spreadsheetbench_verified_400` | in `reflact_new_zzw`, default adjusted here to multi-round |
|
||||
| `swebench` | mini-swe-agent multi-step bug-fixing rollout | `split_mode=ratio`, `dataset_name=lite`, repo-stratified `2:1:7` split materialized under `out_root/_generated_splits/...` unless `split_dir` is provided | added here, aligned to `swe-bench-old` |
|
||||
|
||||
## Data Expectations
|
||||
|
||||
The standard two-mode dataset entry path is:
|
||||
|
||||
- `split_mode: ratio`
|
||||
- load raw data from `env.data_path`
|
||||
- build a deterministic `train/`, `val/`, `test/` split under `env.split_output_dir` (or under `out_root/_generated_splits/` if unset)
|
||||
- default ratio is explicitly `2:1:7`
|
||||
- `split_mode: split_dir`
|
||||
- load an existing `env.split_dir` with `train/`, `val/`, `test/` subdirectories
|
||||
|
||||
This currently applies to:
|
||||
|
||||
- `searchqa`
|
||||
- `spreadsheetbench`
|
||||
- `babyvision`
|
||||
- `livemathematicianbench`
|
||||
- `mmrb`
|
||||
- `swebench`
|
||||
|
||||
`ALFWorld` is the exception: it is environment-backed rather than JSON split-backed.
|
||||
|
||||
The following environments currently expect prepared split directories or extra rooted assets rather than the generic ratio-split path:
|
||||
|
||||
- `docvqa`
|
||||
- `mathverse`
|
||||
- `officeqa`
|
||||
- `sealqa`
|
||||
|
||||
At a high level:
|
||||
|
||||
- `SearchQA`: raw QA json / jsonl or pre-split QA json files
|
||||
- `SpreadsheetBench`: raw task manifest json plus spreadsheet task directory, or a pre-split task manifest
|
||||
- `ALFWorld`: installed game environment and configured eval/train splits
|
||||
- `BabyVision`: raw `meta_data.jsonl` plus images, or a pre-split directory
|
||||
- `DocVQA`: pre-split CSV / JSON data under `split_dir`
|
||||
- `LiveMathematicianBench`: raw monthly QA json files, or a pre-split directory
|
||||
- `MathVerse`: split files plus `data_root` image assets
|
||||
- `MMRB`: raw extracted dataset json files, or a pre-split directory
|
||||
- `OfficeQA`: pre-split metadata plus resolved office document directories
|
||||
- `SealQA`: pre-split metadata for tool-augmented QA tasks
|
||||
- `SWEBench`: HuggingFace SWE-bench dataset alias (`lite` / `verified` / `full`) or a prepared split directory
|
||||
|
||||
### Split References Across Branches
|
||||
|
||||
The split-related defaults are not identical across `skillopt-final`, `reflact_new_zzw`, `gepa`, and `swe-bench-old`. The practical reference points are:
|
||||
|
||||
| Source branch | Explicit split settings / dirs |
|
||||
|---|---|
|
||||
| `skillopt-final` | `searchqa -> data/searchqa_split`; `spreadsheetbench -> data/spreadsheetbench_split`; `docvqa -> data/docvqa_split`; `officeqa -> data/officeqa_split`; `sealqa -> data/sealqa_split`; `swebench -> ratio split 2:1:7 over the default lite dataset, materialized under out_root/_generated_splits/...` |
|
||||
| `reflact_new_zzw` | Same 10-benchmark env set as above except no `swebench`; explicit split dirs are `data/searchqa_split`, `data/spreadsheetbench_split`, `data/docvqa_split`, `data/officeqa_split`, `data/sealqa_split`; `spreadsheetbench` there defaults to `mode=single`; `officeqa` uses `max_tool_turns=24`; `sealqa` uses `max_tool_turns=12` |
|
||||
| `gepa` | `configs/spreadsheetbench.yaml` uses `data.splits_dir = data/spreadsheetbench/splits`, `eval.mode = react`, `eval.max_turns = 20`; `configs/swebench.yaml` uses `dataset = SWE-bench/SWE-bench_Verified` with `train_size = 100`, `val_size = 50`, `test_size = 350` |
|
||||
| `swe-bench-old` | Repo-stratified `2:1:7` split over `SWE-Bench_Lite`, persisted as `outputs/.../split/train.json`, `selection.json`, `test.json`; the example split in that branch is `train=60`, `selection=33`, `test=207` |
|
||||
|
||||
For the 10 benches shared with `reflact_new_zzw`, the current branch is now aligned on env coverage. The main intentional delta is `spreadsheetbench`: this branch defaults to multi-round codegen, while `reflact_new_zzw` kept `mode=single` by default.
|
||||
|
||||
## Running Training
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python scripts/train.py --config configs/searchqa/default.yaml
|
||||
```
|
||||
|
||||
Explicit 2:1:7 split from raw data:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/searchqa/default.yaml \
|
||||
--split_mode ratio \
|
||||
--data_path /path/to/searchqa_train_2000.json
|
||||
```
|
||||
|
||||
Directly consume a prepared split directory:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/searchqa/default.yaml \
|
||||
--split_mode split_dir \
|
||||
--split_dir /path/to/searchqa_split
|
||||
```
|
||||
|
||||
You can override structured config keys from the CLI:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/spreadsheetbench/default.yaml \
|
||||
--cfg-options model.teacher_backend=openai_chat model.student_backend=codex_exec train.batch_size=40 optimizer.learning_rate=4
|
||||
```
|
||||
|
||||
Legacy flat overrides still work for common keys:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/searchqa/default.yaml \
|
||||
--backend azure_openai \
|
||||
--teacher_model gpt-5.4 \
|
||||
--student_model gpt-5.4 \
|
||||
--reasoning_effort medium
|
||||
```
|
||||
|
||||
Exec harness example:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/searchqa/default.yaml \
|
||||
--teacher_backend openai_chat \
|
||||
--student_backend codex_exec \
|
||||
--teacher_model gpt-5.4 \
|
||||
--student_model gpt-5.4-codex \
|
||||
--use_deep_reflect true \
|
||||
--skill_update_mode rewrite_from_suggestions
|
||||
```
|
||||
|
||||
SWEBench example:
|
||||
|
||||
```bash
|
||||
python scripts/train.py \
|
||||
--config configs/swebench/default.yaml \
|
||||
--cfg-options env.dataset_name=lite env.split_ratio=2:1:7
|
||||
```
|
||||
|
||||
## Eval-Only and Standalone Evaluation
|
||||
|
||||
Evaluate a specific skill without training:
|
||||
|
||||
```bash
|
||||
python scripts/eval_only.py \
|
||||
--config configs/searchqa/default.yaml \
|
||||
--skill reflact/envs/searchqa/skills/initial.md
|
||||
```
|
||||
|
||||
The same dataset entry modes apply in eval-only runs:
|
||||
|
||||
- `--split_mode ratio --data_path ...`
|
||||
- `--split_mode split_dir --split_dir ...`
|
||||
|
||||
Standalone scripts also exist for benchmark-specific comparisons, including:
|
||||
|
||||
- `scripts/eval_prompt_custom.py`
|
||||
- `scripts/eval_prompt_official.py`
|
||||
- `scripts/eval_livemathematicianbench_baseline.py`
|
||||
|
||||
These scripts now also support backend selection through the unified model layer.
|
||||
|
||||
## Output Structure
|
||||
|
||||
Each run writes a structured output directory under `out_root`.
|
||||
|
||||
Important top-level artifacts:
|
||||
|
||||
- `config.json` — flattened runtime config
|
||||
- `history.json` — per-step history records
|
||||
- `runtime_state.json` — resume state for current/best skill tracking
|
||||
- `best_skill.md` — current best validated skill
|
||||
- `skills/skill_vXXXX.md` — persisted skill snapshot per step
|
||||
|
||||
Per-step artifacts live under `steps/step_XXXX/`, including:
|
||||
|
||||
- `merged_patch.json`
|
||||
- `ranked_edits.json`
|
||||
- `candidate_skill.md`
|
||||
- `edit_apply_report.json`
|
||||
- `rewrite_result.json` when rewrite mode is enabled
|
||||
- `selection_eval/`
|
||||
- `trajectory_digest.json`
|
||||
- rollout and patch subdirectories
|
||||
|
||||
Epoch-level artifacts live under:
|
||||
|
||||
- `slow_update/epoch_XX/`
|
||||
- `meta_skill/epoch_XX/`
|
||||
- `meta_reflect/epoch_XX/`
|
||||
|
||||
## Resume Behavior
|
||||
|
||||
The trainer resumes from `runtime_state.json` when present. That state tracks:
|
||||
|
||||
- last completed step
|
||||
- current skill path
|
||||
- current score
|
||||
- best skill path
|
||||
- best score
|
||||
- origin tags for current and best skill
|
||||
|
||||
This is important because skill state can change at both step level and epoch level; resuming only from `history.json` is not sufficient for this branch’s method logic.
|
||||
|
||||
## Notes
|
||||
|
||||
- This repository focuses on skill optimization logic; datasets are not included.
|
||||
- Patch application is intentionally observable. Inspect `edit_apply_report.json` when candidate skills do not behave as expected.
|
||||
- `SpreadsheetBench` now defaults to `mode=multi`. If you run an exec student backend there, override back to `env.mode=single` because exec backends are still only wired for SpreadsheetBench single-mode rollout.
|
||||
- `SWEBench` follows the older mini-swe-agent + `swebench.harness.run_evaluation` path, so it requires the SWE-bench / Docker toolchain rather than the generic chat-only stack.
|
||||
- `slow_update` writes into a protected skill region and normal edits are prevented from overwriting that region directly.
|
||||
- `meta_skill` is context memory, not a direct skill edit.
|
||||
- `meta_reflect` is a gated skill edit stage, not just logging.
|
||||
|
||||
## Minimal Setup
|
||||
|
||||
```bash
|
||||
conda create -n reflact python=3.11
|
||||
conda activate reflact
|
||||
pip install openai pyyaml openpyxl
|
||||
```
|
||||
|
||||
Depending on the environment, you may also need:
|
||||
|
||||
```bash
|
||||
pip install datasets gymnasium numpy ray regex
|
||||
```
|
||||
|
||||
For `SWEBench`, you also need a working Docker environment plus the SWE-bench / mini-swe-agent dependencies used in `swe-bench-old`.
|
||||
93
configs/_base_/default.yaml
Normal file
93
configs/_base_/default.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
# ReflACT default configuration — base for all environments.
|
||||
# Environment configs should inherit via: _base_: default.yaml
|
||||
|
||||
model:
|
||||
backend: azure_openai
|
||||
teacher: gpt-5.5
|
||||
student: gpt-5.5
|
||||
teacher_backend: openai_chat
|
||||
student_backend: openai_chat
|
||||
reasoning_effort: medium
|
||||
rewrite_reasoning_effort: ""
|
||||
rewrite_max_completion_tokens: 64000
|
||||
codex_exec_path: codex
|
||||
codex_exec_sandbox: workspace-write
|
||||
codex_exec_profile: ""
|
||||
codex_exec_full_auto: false
|
||||
codex_exec_reasoning_effort: none
|
||||
codex_exec_use_sdk: auto
|
||||
codex_exec_network_access: false
|
||||
codex_exec_web_search: false
|
||||
codex_exec_approval_policy: never
|
||||
claude_code_exec_path: claude
|
||||
claude_code_exec_profile: ""
|
||||
claude_code_exec_use_sdk: auto
|
||||
claude_code_exec_effort: medium
|
||||
claude_code_exec_max_thinking_tokens: 16384
|
||||
codex_trace_to_teacher: true
|
||||
azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/"
|
||||
azure_openai_api_version: "2024-12-01-preview"
|
||||
azure_openai_api_key: "" # Fill locally if you do not export AZURE_OPENAI_API_KEY
|
||||
azure_openai_auth_mode: azure_cli
|
||||
azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
|
||||
azure_openai_managed_identity_client_id: ""
|
||||
teacher_azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/"
|
||||
teacher_azure_openai_api_version: "2024-12-01-preview"
|
||||
teacher_azure_openai_api_key: ""
|
||||
teacher_azure_openai_auth_mode: azure_cli
|
||||
teacher_azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
|
||||
teacher_azure_openai_managed_identity_client_id: ""
|
||||
student_azure_openai_endpoint: "https://t2vgoaigpt4o3.openai.azure.com/"
|
||||
student_azure_openai_api_version: "2024-12-01-preview"
|
||||
student_azure_openai_api_key: ""
|
||||
student_azure_openai_auth_mode: azure_cli
|
||||
student_azure_openai_ad_scope: "https://cognitiveservices.azure.com/.default"
|
||||
student_azure_openai_managed_identity_client_id: ""
|
||||
|
||||
train:
|
||||
num_epochs: 4
|
||||
train_size: 0 # 0 = derive from dataset split when available
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
seed: 42
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
analyst_workers: 16
|
||||
max_analyst_rounds: 3
|
||||
failure_only: false
|
||||
use_deep_reflect: false
|
||||
deep_reflect_failures: 4
|
||||
deep_reflect_successes: 2
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4 # max edits per step (edit_budget)
|
||||
min_learning_rate: 2 # min edits for decay schedulers
|
||||
lr_scheduler: cosine # constant / linear / cosine / autonomous
|
||||
lr_control_mode: fixed # fixed / autonomous / none
|
||||
skill_update_mode: patch # patch / rewrite_from_suggestions / full_rewrite_minibatch
|
||||
use_meta_reflect: false
|
||||
meta_learning_rate: 4 # max edits per epoch-level meta-reflect
|
||||
use_slow_update: true
|
||||
slow_update_samples: 20
|
||||
longitudinal_pair_policy: mixed # mixed / changed / unchanged
|
||||
use_meta_skill: true
|
||||
|
||||
evaluation:
|
||||
use_gate: true
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
eval_test: true
|
||||
|
||||
env:
|
||||
name: ""
|
||||
skill_init: ""
|
||||
split_mode: ratio # ratio = build deterministic split from data_path; split_dir = use pre-split train/val/test
|
||||
split_ratio: "2:1:7" # explicit default for dataset-backed benchmarks: train:val:test
|
||||
split_seed: 42
|
||||
split_dir: ""
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
exec_timeout: 120 # per student model/code-agent call timeout in seconds
|
||||
out_root: ""
|
||||
305
configs/ablation_study/README.md
Normal file
305
configs/ablation_study/README.md
Normal file
@@ -0,0 +1,305 @@
|
||||
# Ablation Study Configuration Manifest
|
||||
|
||||
This folder records the final, reproducible settings for the ablation runs used
|
||||
in `docs/ablation_paper_tables.md`.
|
||||
|
||||
It is intentionally separate from the benchmark default configs. The benchmark
|
||||
configs under `configs/<benchmark>/default.yaml` remain the source task configs;
|
||||
this folder records the exact matrix-level overrides, run roots, launch commands,
|
||||
and validation rules used for the paper ablations.
|
||||
|
||||
## Files
|
||||
|
||||
- `matrix.yaml`: canonical ablation matrix, common overrides, benchmark splits,
|
||||
token/output caps, and invalid-run rules.
|
||||
- `launch_commands.sh`: exact launcher commands for the valid run roots.
|
||||
- `validation.md`: monitoring, result extraction, and invalidation checklist.
|
||||
|
||||
## Source Of Truth
|
||||
|
||||
Use the matrix launcher:
|
||||
|
||||
```bash
|
||||
/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python scripts/run_ablation_matrix.py
|
||||
```
|
||||
|
||||
The launcher builds runs from the same defaults and values recorded in
|
||||
`matrix.yaml`. It skips completed runs by checking `summary.json` and skips
|
||||
active runs by checking `env.out_root` in active `scripts/train.py` processes.
|
||||
|
||||
Do not manually rerun a completed run into the same `env.out_root`. If a run is
|
||||
invalid, archive or remove its output directory first, then let the launcher
|
||||
start it cleanly.
|
||||
|
||||
## Current Correct Run Roots
|
||||
|
||||
- SearchQA / SpreadsheetBench original ablations:
|
||||
`outputs/ablation_20260502_040604_unique48`
|
||||
- SearchQA / SpreadsheetBench batch-size ablations:
|
||||
`outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run`
|
||||
- LiveMathBench / ALFWorld clean ablations:
|
||||
`outputs/ablation_livemath_alfworld_clean_20260503_155155_run`
|
||||
- DocVQA ablations:
|
||||
`outputs/ablation_docvqa_20260503_160225_run`
|
||||
|
||||
Archived, superseded, misaligned, dry-run, or pre-fix directories must not be
|
||||
used for paper tables.
|
||||
|
||||
## End-To-End Runbook
|
||||
|
||||
### Environment
|
||||
|
||||
Run from the repository root:
|
||||
|
||||
```bash
|
||||
cd /home/azureuser/workspace-gzy/SkillReflection
|
||||
```
|
||||
|
||||
Always use:
|
||||
|
||||
```bash
|
||||
PY=/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python
|
||||
export ALFWORLD_DATA=/home/azureuser/.cache/alfworld
|
||||
```
|
||||
|
||||
Default model/auth settings are generated by `scripts/run_ablation_matrix.py`:
|
||||
|
||||
```text
|
||||
teacher=gpt-5.5
|
||||
student=gpt-5.5
|
||||
teacher_backend=openai_chat
|
||||
student_backend=openai_chat
|
||||
reasoning_effort=medium
|
||||
teacher/student endpoint=https://t2vgoaigpt4o3.openai.azure.com/
|
||||
teacher/student api_version=2024-12-01-preview
|
||||
teacher/student auth_mode=azure_cli
|
||||
```
|
||||
|
||||
Core training settings:
|
||||
|
||||
```text
|
||||
train.num_epochs=4
|
||||
train.train_size=0
|
||||
train.batch_size=40
|
||||
train.accumulation=1
|
||||
train.seed=42
|
||||
gradient.minibatch_size=8
|
||||
gradient.merge_batch_size=8
|
||||
gradient.analyst_workers=16
|
||||
gradient.use_deep_reflect=false
|
||||
optimizer.learning_rate=4
|
||||
optimizer.min_learning_rate=2
|
||||
optimizer.lr_scheduler=cosine
|
||||
optimizer.lr_control_mode=fixed
|
||||
optimizer.use_slow_update=true
|
||||
optimizer.slow_update_samples=20
|
||||
optimizer.use_meta_skill=true
|
||||
optimizer.use_meta_reflect=false
|
||||
optimizer.longitudinal_pair_policy=mixed
|
||||
evaluation.use_gate=true
|
||||
evaluation.eval_test=true
|
||||
env.split_mode=split_dir
|
||||
```
|
||||
|
||||
`train.train_size=0` is intentional. The dataloader derives the train size from
|
||||
the fixed split. Batch-size ablations rely on the default `ceil(train_size /
|
||||
batch_size)` behavior; the last batch can be smaller than `train.batch_size`.
|
||||
|
||||
### Fixed Splits
|
||||
|
||||
Default split directories:
|
||||
|
||||
```text
|
||||
searchqa: data/ablation_splits/searchqa/2-1-7_seed42
|
||||
spreadsheetbench: data/ablation_splits/spreadsheetbench/2-1-7_seed42
|
||||
livemathematicianbench: data/ablation_splits/livemathematicianbench/2-1-7_seed42
|
||||
alfworld: data/ablation_splits/alfworld/2-1-7_seed42
|
||||
docvqa: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
|
||||
```
|
||||
|
||||
Default train/val/test sizes:
|
||||
|
||||
| Benchmark | Train | Val | Test |
|
||||
| --- | ---: | ---: | ---: |
|
||||
| SearchQA | 400 | 200 | 1400 |
|
||||
| SpreadsheetBench | 80 | 40 | 280 |
|
||||
| LiveMathBench | 35 | 18 | 124 |
|
||||
| ALFWorld | 39 | 18 | 134 |
|
||||
| DocVQA | 1070 | 535 | 3744 |
|
||||
|
||||
DocVQA images are not copied. The valid setup uses:
|
||||
|
||||
```text
|
||||
data/docvqa_images -> /home/azureuser/zisu/SkillReflection/data/docvqa_images
|
||||
```
|
||||
|
||||
2026-05-05 DocVQA data correction: all DocVQA final reruns should use the zisu
|
||||
10% split above and a fresh output root such as
|
||||
`outputs/ablation_docvqa_zisu10pct_20260505_run`. The older local
|
||||
`data/ablation_splits/docvqa/2-1-7_seed42` contains the same 5349 questionId
|
||||
pool but a different train/val/test assignment, so its completed summaries are
|
||||
historical only.
|
||||
|
||||
### Matrix Groups
|
||||
|
||||
Use these group names with `scripts/run_ablation_matrix.py`:
|
||||
|
||||
```text
|
||||
default split batch mbs lr sched slown mod smodel longpair lrctrl
|
||||
```
|
||||
|
||||
`longpair` is the slow-update/meta-skill comparison-example ablation. It keeps
|
||||
all prompts and training settings unchanged and only overrides:
|
||||
|
||||
```text
|
||||
optimizer.longitudinal_pair_policy=changed
|
||||
optimizer.longitudinal_pair_policy=unchanged
|
||||
```
|
||||
|
||||
The default paper setting remains `mixed`.
|
||||
|
||||
`lrctrl` contains the two learning-rate-control baselines:
|
||||
|
||||
```text
|
||||
optimizer.lr_control_mode=autonomous
|
||||
optimizer.lr_control_mode=none + optimizer.skill_update_mode=full_rewrite_minibatch
|
||||
```
|
||||
|
||||
The autonomous run logs the chosen integer per step in `lr_decision.json` and
|
||||
`lr_history.jsonl`. The full-rewrite run removes the LR/edit-selection concept:
|
||||
each minibatch analyst produces a complete skill candidate, and aggregate/merge
|
||||
produces the candidate skill directly.
|
||||
|
||||
Batch-size values are:
|
||||
|
||||
```text
|
||||
8 / 24 / 40 / 56 / full
|
||||
```
|
||||
|
||||
`40` is the default point. `full` expands to the benchmark train size.
|
||||
|
||||
### Launch Commands Used In This Session
|
||||
|
||||
The exact commands are recorded in `launch_commands.sh` and in
|
||||
`docs/ablation_plan.md`. The important current policy is:
|
||||
|
||||
- SearchQA / SpreadsheetBench batch-only matrix can run at `--max-parallel 8`.
|
||||
- DocVQA matrix can run with its launcher at `--max-parallel 8`; later top-up used `--max-parallel 16` only because completed runs were skipped and active roots were checked.
|
||||
- LiveMathBench is safe as API-only benchmark after the token cap fix.
|
||||
- ALFWorld must not be mixed into a 24-way run on this shared machine. Use `--bench alfworld --max-parallel 1` only after memory is available.
|
||||
|
||||
### Token And Timeout Fixes
|
||||
|
||||
LiveMathBench must use a large student completion cap:
|
||||
|
||||
```text
|
||||
max_completion_tokens=16384
|
||||
timeout=300
|
||||
```
|
||||
|
||||
The old 768/512 cap produced many empty visible responses because hidden
|
||||
reasoning consumed the budget.
|
||||
|
||||
ALFWorld must use:
|
||||
|
||||
```text
|
||||
max_completion_tokens=2048
|
||||
empty response fallback -> <action>look</action>
|
||||
missing action fallback -> <action>look</action>
|
||||
```
|
||||
|
||||
### Invalid Runs
|
||||
|
||||
Never fill paper tables from these archive directories:
|
||||
|
||||
```text
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_livemath_token768_20260504_022258/
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_token512_20260504_021417/
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_empty_action_20260504_025311/
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_prefallback_20260504_025402/
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_oom_partial_20260504_050517/
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_serial_lowmem_20260504_1300/
|
||||
```
|
||||
|
||||
### Current Resource Notes
|
||||
|
||||
ALFWorld model calls are API calls. The ablation branch now creates local
|
||||
ALFWorld/TextWorld environments through multiprocessing workers ported from
|
||||
`skillopt_final_zzw`, not through Ray actors. Old Ray-based archived runs are
|
||||
not valid for table fill. The observed historical failure mode in this session
|
||||
was system RAM pressure and Ray OOM prevention, not model GPU memory.
|
||||
|
||||
GPU memory currently shown by `nvidia-smi` came from unrelated Ray Serve visual
|
||||
models under:
|
||||
|
||||
```text
|
||||
/home/azureuser/workspace-gzy/zyf/gca-skill
|
||||
```
|
||||
|
||||
Those processes are `GroundingDINOModel` / `DA3Model`, not the
|
||||
SkillReflection ablation ALFWorld run.
|
||||
|
||||
There are also unrelated ALFWorld jobs under:
|
||||
|
||||
```text
|
||||
/home/azureuser/zisu/skill_distill
|
||||
```
|
||||
|
||||
Do not confuse those with this repository's ablation outputs.
|
||||
|
||||
### Monitoring
|
||||
|
||||
Active run and duplicate output-root check:
|
||||
|
||||
```bash
|
||||
$PY - <<'PY'
|
||||
import subprocess, re, collections, time
|
||||
try:
|
||||
raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True)
|
||||
except subprocess.CalledProcessError:
|
||||
raw = ""
|
||||
roots = []
|
||||
for line in raw.splitlines():
|
||||
m = re.search(r"env\.out_root=([^\s]+)", line)
|
||||
if m:
|
||||
roots.append(m.group(1))
|
||||
ctr = collections.Counter(roots)
|
||||
print("time", time.strftime("%F %T"))
|
||||
print("active_count", len(roots))
|
||||
print("duplicates", [r.rsplit("/", 1)[-1] for r, c in ctr.items() if c > 1])
|
||||
for root in sorted(roots):
|
||||
print(root.rsplit("/", 1)[-1])
|
||||
PY
|
||||
```
|
||||
|
||||
Error scan:
|
||||
|
||||
```bash
|
||||
rg -n "Traceback|ERROR|Error code|AuthenticationError|BadRequest|RateLimit|content_filter|Killed|OutOfMemory|CUDA out of memory|\\[FAIL\\]|LLM call failed" \
|
||||
outputs/ablation_docvqa_20260503_160225_run/logs \
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/logs \
|
||||
outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/logs \
|
||||
-g '*.log' | tail -100 || true
|
||||
```
|
||||
|
||||
Resource checks:
|
||||
|
||||
```bash
|
||||
free -h | sed -n '1,3p'
|
||||
df -h /tmp
|
||||
du -sh /tmp/ray 2>/dev/null || true
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
### Filling Tables
|
||||
|
||||
Only use top-level `summary.json` from valid run roots. Fill
|
||||
`docs/ablation_paper_tables.md` from:
|
||||
|
||||
```text
|
||||
best_selection_hard
|
||||
baseline_test_hard
|
||||
test_hard
|
||||
test_delta_hard
|
||||
token_summary._total.total_tokens
|
||||
```
|
||||
81
configs/ablation_study/launch_commands.sh
Executable file
81
configs/ablation_study/launch_commands.sh
Executable file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
cd /home/azureuser/workspace-gzy/SkillReflection
|
||||
|
||||
PY=/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python
|
||||
export ALFWORLD_DATA=/home/azureuser/.cache/alfworld
|
||||
|
||||
# Original SearchQA / SpreadsheetBench full matrix reproduction command.
|
||||
# Do not run this into the existing root unless intentionally reproducing from
|
||||
# scratch; the current valid root is already populated:
|
||||
# outputs/ablation_20260502_040604_unique48
|
||||
#
|
||||
# setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
# --groups default split mbs lr sched slown mod smodel \
|
||||
# --bench searchqa spreadsheetbench \
|
||||
# --run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_20260502_040604_unique48 \
|
||||
# --max-parallel 24 \
|
||||
# --execute \
|
||||
# > /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_20260502_040604_unique48/launcher_reproduce_full_matrix.log 2>&1 < /dev/null &
|
||||
#
|
||||
# SearchQA / SpreadsheetBench batch-size ablations only.
|
||||
# Original non-batch SearchQA/SpreadsheetBench ablations live in:
|
||||
# outputs/ablation_20260502_040604_unique48
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups batch \
|
||||
--bench searchqa spreadsheetbench \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run \
|
||||
--max-parallel 8 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/launcher_parallel8.log 2>&1 < /dev/null &
|
||||
|
||||
# DocVQA full matrix.
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups default split batch mbs lr sched slown mod smodel \
|
||||
--bench docvqa \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_docvqa_20260503_160225_run \
|
||||
--max-parallel 8 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_docvqa_20260503_160225_run/launcher_parallel8.log 2>&1 < /dev/null &
|
||||
|
||||
# LiveMathBench clean matrix. ALFWorld should be launched separately at lower
|
||||
# concurrency because Ray OOM occurred when many ALFWorld runs were mixed into a
|
||||
# 24-way run.
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups default split batch mbs lr sched slown mod smodel \
|
||||
--bench livemathematicianbench \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run \
|
||||
--max-parallel 8 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_livemath_parallel8.log 2>&1 < /dev/null &
|
||||
|
||||
# ALFWorld clean matrix. Increase to 2 only after checking memory, /tmp/ray,
|
||||
# and that no other ALFWorld run is active. Do not use 8/16/24 for ALFWorld on
|
||||
# the current shared machine unless resources are explicitly reserved.
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups default split batch mbs lr sched slown mod smodel \
|
||||
--bench alfworld \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run \
|
||||
--max-parallel 1 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_alfworld_parallel1.log 2>&1 < /dev/null &
|
||||
|
||||
# Longitudinal comparison-example policy ablations. This intentionally excludes
|
||||
# ALFWorld. The only varied setting is optimizer.longitudinal_pair_policy.
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups longpair \
|
||||
--bench searchqa spreadsheetbench livemathematicianbench docvqa \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_longpair_20260504_run \
|
||||
--max-parallel 8 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_longpair_20260504_run/launcher_longpair_parallel8.log 2>&1 < /dev/null &
|
||||
|
||||
# Learning-rate-control baselines. This intentionally excludes ALFWorld.
|
||||
setsid "$PY" scripts/run_ablation_matrix.py \
|
||||
--groups lrctrl \
|
||||
--bench searchqa spreadsheetbench livemathematicianbench docvqa \
|
||||
--run-root /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_lrctrl_20260504_run \
|
||||
--max-parallel 8 \
|
||||
--execute \
|
||||
> /home/azureuser/workspace-gzy/SkillReflection/outputs/ablation_lrctrl_20260504_run/launcher_lrctrl_parallel8.log 2>&1 < /dev/null &
|
||||
257
configs/ablation_study/matrix.yaml
Normal file
257
configs/ablation_study/matrix.yaml
Normal file
@@ -0,0 +1,257 @@
|
||||
version: 2026-05-04
|
||||
purpose: "Canonical paper ablation settings matching the valid current runs."
|
||||
|
||||
launcher:
|
||||
script: scripts/run_ablation_matrix.py
|
||||
python: /home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python
|
||||
skip_completed_by: summary.json
|
||||
skip_active_by: "active scripts/train.py env.out_root"
|
||||
|
||||
environment:
|
||||
working_directory: /home/azureuser/workspace-gzy/SkillReflection
|
||||
required_env:
|
||||
ALFWORLD_DATA: /home/azureuser/.cache/alfworld
|
||||
docvqa_images_symlink: "data/docvqa_images -> /home/azureuser/zisu/SkillReflection/data/docvqa_images"
|
||||
|
||||
common_overrides:
|
||||
model.teacher_backend: openai_chat
|
||||
model.student_backend: openai_chat
|
||||
model.teacher: gpt-5.5
|
||||
model.student: gpt-5.5
|
||||
model.teacher_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/
|
||||
model.teacher_azure_openai_api_version: 2024-12-01-preview
|
||||
model.teacher_azure_openai_auth_mode: azure_cli
|
||||
model.student_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/
|
||||
model.student_azure_openai_api_version: 2024-12-01-preview
|
||||
model.student_azure_openai_auth_mode: azure_cli
|
||||
model.reasoning_effort: medium
|
||||
train.num_epochs: 4
|
||||
train.train_size: 0
|
||||
train.batch_size: 40
|
||||
train.accumulation: 1
|
||||
train.seed: 42
|
||||
gradient.minibatch_size: 8
|
||||
gradient.merge_batch_size: 8
|
||||
gradient.analyst_workers: 16
|
||||
gradient.use_deep_reflect: false
|
||||
optimizer.learning_rate: 4
|
||||
optimizer.min_learning_rate: 2
|
||||
optimizer.lr_scheduler: cosine
|
||||
optimizer.skill_update_mode: patch
|
||||
optimizer.use_slow_update: true
|
||||
optimizer.slow_update_samples: 20
|
||||
optimizer.use_meta_skill: true
|
||||
optimizer.use_meta_reflect: false
|
||||
evaluation.use_gate: true
|
||||
evaluation.eval_test: true
|
||||
env.split_mode: split_dir
|
||||
|
||||
benchmarks:
|
||||
searchqa:
|
||||
config: configs/searchqa/default.yaml
|
||||
run_roots:
|
||||
original_matrix: outputs/ablation_20260502_040604_unique48
|
||||
batch_matrix: outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run
|
||||
default_split: data/ablation_splits/searchqa/2-1-7_seed42
|
||||
train: 400
|
||||
val: 200
|
||||
test: 1400
|
||||
student_rollout:
|
||||
function: reflact/envs/searchqa/rollout.py::chat_student
|
||||
max_completion_tokens:
|
||||
first_turn: 512
|
||||
refinement: 512
|
||||
rationale: "Short-answer QA; sampled empties are low and not LiveMath-like."
|
||||
|
||||
spreadsheetbench:
|
||||
config: configs/spreadsheetbench/default.yaml
|
||||
run_roots:
|
||||
original_matrix: outputs/ablation_20260502_040604_unique48
|
||||
batch_matrix: outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run
|
||||
default_split: data/ablation_splits/spreadsheetbench/2-1-7_seed42
|
||||
train: 80
|
||||
val: 40
|
||||
test: 280
|
||||
student_rollout:
|
||||
function: reflact/envs/spreadsheetbench/codegen_agent.py::run_multi
|
||||
max_output_tokens: 16384
|
||||
result_note: "results.jsonl stores execution fields, not a response field."
|
||||
|
||||
livemathematicianbench:
|
||||
config: configs/livemathematicianbench/default.yaml
|
||||
run_roots:
|
||||
clean_matrix: outputs/ablation_livemath_alfworld_clean_20260503_155155_run
|
||||
default_split: data/ablation_splits/livemathematicianbench/2-1-7_seed42
|
||||
train: 35
|
||||
val: 18
|
||||
test: 124
|
||||
student_rollout:
|
||||
function: reflact/envs/livemathematicianbench/rollout.py::chat_student
|
||||
max_completion_tokens:
|
||||
first_turn: 16384
|
||||
refinement: 16384
|
||||
timeout_seconds: 300
|
||||
invalid_old_caps:
|
||||
first_turn: 768
|
||||
refinement: 512
|
||||
invalid_archive: outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_livemath_token768_20260504_022258
|
||||
rationale: "GPT-5 reasoning consumed small budgets and produced many empty visible responses."
|
||||
|
||||
alfworld:
|
||||
config: configs/alfworld/default.yaml
|
||||
run_roots:
|
||||
clean_matrix: outputs/ablation_livemath_alfworld_clean_20260503_155155_run
|
||||
default_split: data/ablation_splits/alfworld/2-1-7_seed42
|
||||
train: 39
|
||||
val: 18
|
||||
test: 134
|
||||
student_rollout:
|
||||
function: reflact/envs/alfworld/rollout.py::chat_student
|
||||
max_completion_tokens: 2048
|
||||
timeout_seconds: 120
|
||||
max_steps: 50
|
||||
fallback_action: look
|
||||
invalid_old_cap: 512
|
||||
invalid_archives:
|
||||
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_token512_20260504_021417
|
||||
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_empty_action_20260504_025311
|
||||
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_prefallback_20260504_025402
|
||||
- outputs/ablation_livemath_alfworld_clean_20260503_155155_run/archive_alfworld_oom_partial_20260504_050517
|
||||
concurrency_note: "Do not mix many ALFWorld runs into 24-way total concurrency; Ray OOM occurred. Prefer 1-2 ALFWorld runs at a time unless resources are clearly free."
|
||||
|
||||
docvqa:
|
||||
config: configs/docvqa/default.yaml
|
||||
run_roots:
|
||||
matrix: outputs/ablation_docvqa_zisu10pct_20260505_run
|
||||
default_split: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
|
||||
train: 1070
|
||||
val: 535
|
||||
test: 3744
|
||||
data_note: "2026-05-05: use zisu-provided 5349-item DocVQA split directly; previous local 2-1-7_seed42 used the same item pool but a different train/val/test assignment and must not be used for final DocVQA reruns."
|
||||
student_rollout:
|
||||
function: reflact/envs/docvqa/rollout.py::chat_student_messages
|
||||
max_completion_tokens:
|
||||
first_turn: 768
|
||||
refinement: 512
|
||||
rationale: "Short-answer VQA output; preserve current setting for alignment unless explicitly rerunning all affected DocVQA."
|
||||
|
||||
splits:
|
||||
tags:
|
||||
1shot:
|
||||
extra_overrides:
|
||||
optimizer.slow_update_samples: 1
|
||||
1-1-8: {}
|
||||
2-1-7:
|
||||
default: true
|
||||
4-1-5: {}
|
||||
paths:
|
||||
searchqa:
|
||||
1shot: data/ablation_splits/searchqa/1shot_seed42
|
||||
1-1-8: data/ablation_splits/searchqa/1-1-8_seed42
|
||||
2-1-7: data/ablation_splits/searchqa/2-1-7_seed42
|
||||
4-1-5: data/ablation_splits/searchqa/4-1-5_seed42
|
||||
spreadsheetbench:
|
||||
1shot: data/ablation_splits/spreadsheetbench/1shot_seed42
|
||||
1-1-8: data/ablation_splits/spreadsheetbench/1-1-8_seed42
|
||||
2-1-7: data/ablation_splits/spreadsheetbench/2-1-7_seed42
|
||||
4-1-5: data/ablation_splits/spreadsheetbench/4-1-5_seed42
|
||||
livemathematicianbench:
|
||||
1shot: data/ablation_splits/livemathematicianbench/1shot_seed42
|
||||
1-1-8: data/ablation_splits/livemathematicianbench/1-1-8_seed42
|
||||
2-1-7: data/ablation_splits/livemathematicianbench/2-1-7_seed42
|
||||
4-1-5: data/ablation_splits/livemathematicianbench/4-1-5_seed42
|
||||
alfworld:
|
||||
1shot: data/ablation_splits/alfworld/1shot_seed42
|
||||
1-1-8: data/ablation_splits/alfworld/1-1-8_seed42
|
||||
2-1-7: data/ablation_splits/alfworld/2-1-7_seed42
|
||||
4-1-5: data/ablation_splits/alfworld/4-1-5_seed42
|
||||
docvqa:
|
||||
1shot: data/ablation_splits/docvqa/1shot_seed42
|
||||
1-1-8: data/ablation_splits/docvqa/1-1-8_seed42
|
||||
2-1-7: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
|
||||
4-1-5: data/ablation_splits/docvqa/4-1-5_seed42
|
||||
|
||||
groups:
|
||||
default:
|
||||
run_id: "DEFAULT-{benchmark}-5.5"
|
||||
overrides: {}
|
||||
split:
|
||||
values: [1shot, 1-1-8, 4-1-5]
|
||||
skip_default_2_1_7: true
|
||||
override_template: "env.split_dir={split_path}"
|
||||
batch:
|
||||
values: [8, 24, 56, full]
|
||||
default_value_reused: 40
|
||||
full_values:
|
||||
searchqa: 400
|
||||
spreadsheetbench: 80
|
||||
livemathematicianbench: 35
|
||||
alfworld: 39
|
||||
docvqa: 1070
|
||||
fixed_overrides:
|
||||
gradient.minibatch_size: 8
|
||||
mbs:
|
||||
values: [1, 2, 4, 16, 32]
|
||||
default_value_reused: 8
|
||||
override_template: "gradient.minibatch_size={value}"
|
||||
lr:
|
||||
values: [1, 2, 4, 8, 16]
|
||||
fixed_overrides:
|
||||
optimizer.lr_scheduler: constant
|
||||
optimizer.min_learning_rate: 1
|
||||
override_template: "optimizer.learning_rate={value}"
|
||||
sched:
|
||||
values: [constant, linear]
|
||||
default_value_reused: cosine
|
||||
override_template: "optimizer.lr_scheduler={value}"
|
||||
slown:
|
||||
values: [5, 10, 40]
|
||||
default_value_reused: 20
|
||||
override_template: "optimizer.slow_update_samples={value}"
|
||||
mod:
|
||||
values:
|
||||
slow-only:
|
||||
optimizer.use_slow_update: true
|
||||
optimizer.use_meta_skill: false
|
||||
meta-only:
|
||||
optimizer.use_slow_update: false
|
||||
optimizer.use_meta_skill: true
|
||||
none:
|
||||
optimizer.use_slow_update: false
|
||||
optimizer.use_meta_skill: false
|
||||
default_value_reused: slow-meta
|
||||
longpair:
|
||||
values: [changed, unchanged]
|
||||
default_value_reused: mixed
|
||||
override_template: "optimizer.longitudinal_pair_policy={value}"
|
||||
note: "Only changes slow-update/meta-skill comparison examples; prompts and other settings remain unchanged."
|
||||
lrctrl:
|
||||
values:
|
||||
autonomous:
|
||||
optimizer.lr_control_mode: autonomous
|
||||
full-rewrite:
|
||||
optimizer.lr_control_mode: none
|
||||
optimizer.skill_update_mode: full_rewrite_minibatch
|
||||
default_value_reused: "fixed patch learning_rate=4"
|
||||
note: "autonomous records lr_decision.json/lr_history.jsonl; full-rewrite removes LR/select/apply-edit and uses full skill candidates."
|
||||
smodel:
|
||||
values:
|
||||
"5.4":
|
||||
model.student: gpt-5.4-pro
|
||||
model.student_azure_openai_endpoint: https://t2vgoaigpt4o3.openai.azure.com/
|
||||
model.student_azure_openai_api_version: 2025-03-01-preview
|
||||
model.student_azure_openai_auth_mode: azure_cli
|
||||
"5.4-mini":
|
||||
model.student: gpt-5.4-mini
|
||||
model.student_azure_openai_endpoint: https://searchagent5.cognitiveservices.azure.com/
|
||||
model.student_azure_openai_api_version: 2024-12-01-preview
|
||||
model.student_azure_openai_auth_mode: azure_cli
|
||||
default_value_reused: "5.5"
|
||||
|
||||
validity_rules:
|
||||
use_for_tables:
|
||||
- "Only runs with summary.json in valid run roots."
|
||||
- "Do not use archive, archived, MISALIGNED, SUPERSEDED, dryrun, smoke, or debug directories."
|
||||
- "Do not use ALFWorld runs started before empty/missing-action fallback."
|
||||
- "Do not use old LiveMath runs with 768/512 token caps."
|
||||
rerun_rule: "Archive or remove invalid out_root before relaunch; never write a rerun into a polluted output directory."
|
||||
141
configs/ablation_study/validation.md
Normal file
141
configs/ablation_study/validation.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Ablation Validation Checklist
|
||||
|
||||
Use this checklist before launch, during monitoring, and before filling
|
||||
`docs/ablation_paper_tables.md`.
|
||||
|
||||
## Before Launch
|
||||
|
||||
Run from repo root:
|
||||
|
||||
```bash
|
||||
cd /home/azureuser/workspace-gzy/SkillReflection
|
||||
export ALFWORLD_DATA=/home/azureuser/.cache/alfworld
|
||||
```
|
||||
|
||||
Verify syntax for edited files:
|
||||
|
||||
```bash
|
||||
/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python -m py_compile \
|
||||
scripts/run_ablation_matrix.py \
|
||||
scripts/train.py \
|
||||
reflact/model/azure_openai.py \
|
||||
reflact/envs/searchqa/rollout.py \
|
||||
reflact/envs/spreadsheetbench/rollout.py \
|
||||
reflact/envs/livemathematicianbench/rollout.py \
|
||||
reflact/envs/alfworld/rollout.py \
|
||||
reflact/envs/docvqa/rollout.py
|
||||
```
|
||||
|
||||
Check active runs and duplicate `env.out_root` before starting more:
|
||||
|
||||
```bash
|
||||
/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python - <<'PY'
|
||||
import subprocess, re, collections
|
||||
try:
|
||||
raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True)
|
||||
except subprocess.CalledProcessError:
|
||||
raw = ""
|
||||
roots = []
|
||||
for line in raw.splitlines():
|
||||
m = re.search(r"env\.out_root=([^\s]+)", line)
|
||||
if m:
|
||||
roots.append(m.group(1))
|
||||
ctr = collections.Counter(roots)
|
||||
print("train_count", len(roots))
|
||||
print("duplicate_roots", [r.rsplit("/", 1)[-1] for r, c in ctr.items() if c > 1])
|
||||
for root in sorted(roots):
|
||||
print(root.rsplit("/", 1)[-1])
|
||||
PY
|
||||
```
|
||||
|
||||
## During Monitoring
|
||||
|
||||
Check launchers:
|
||||
|
||||
```bash
|
||||
pgrep -af 'scripts/run_ablation_matrix.py' || true
|
||||
tail -80 outputs/ablation_docvqa_20260503_160225_run/launcher_parallel8.log 2>/dev/null || true
|
||||
tail -80 outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_livemath_parallel8.log 2>/dev/null || true
|
||||
tail -80 outputs/ablation_livemath_alfworld_clean_20260503_155155_run/launcher_alfworld_parallel1.log 2>/dev/null || true
|
||||
```
|
||||
|
||||
Scan current logs for new hard failures:
|
||||
|
||||
```bash
|
||||
rg -n "Traceback|ERROR|Error code|AuthenticationError|BadRequest|RateLimit|content_filter|Killed|OutOfMemory|\\[FAIL\\]|\\[RETRY\\]" \
|
||||
outputs/ablation_docvqa_20260503_160225_run/logs \
|
||||
outputs/ablation_livemath_alfworld_clean_20260503_155155_run/logs \
|
||||
outputs/ablation_batch_searchqa_spreadsheet_20260503_153902_run/logs \
|
||||
-g '*.log' | tail -160 || true
|
||||
```
|
||||
|
||||
Check resource pressure:
|
||||
|
||||
```bash
|
||||
df -h /tmp
|
||||
du -sh /tmp/ray 2>/dev/null || true
|
||||
free -h | sed -n '1,3p'
|
||||
```
|
||||
|
||||
## Quality Checks
|
||||
|
||||
LiveMathBench current valid runs should not look like old 768/512 runs:
|
||||
|
||||
```bash
|
||||
/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python - <<'PY'
|
||||
import json, pathlib
|
||||
root = pathlib.Path("outputs/ablation_livemath_alfworld_clean_20260503_155155_run")
|
||||
for run in sorted(root.glob("*livemathematicianbench*")):
|
||||
if not run.is_dir() or "archive" in str(run):
|
||||
continue
|
||||
for rel in ["test_eval_baseline/results.jsonl", "test_eval/results.jsonl"]:
|
||||
p = run / rel
|
||||
if not p.exists():
|
||||
continue
|
||||
rows = [json.loads(l) for l in p.open(errors="ignore") if l.strip()]
|
||||
empty = sum(1 for r in rows if not str(r.get("response", "")).strip())
|
||||
answer = sum(1 for r in rows if "<answer>" in str(r.get("response", "")).lower())
|
||||
if empty:
|
||||
print(run.name, rel, "empty", empty, "answer", answer, "n", len(rows))
|
||||
PY
|
||||
```
|
||||
|
||||
ALFWorld valid runs must not contain empty action or missing action:
|
||||
|
||||
```bash
|
||||
/home/azureuser/workspace-gzy/miniconda3/envs/reflact/bin/python - <<'PY'
|
||||
import json, pathlib
|
||||
root = pathlib.Path("outputs/ablation_livemath_alfworld_clean_20260503_155155_run")
|
||||
for run in sorted(root.glob("*alfworld*")):
|
||||
if not run.is_dir() or "archive" in str(run):
|
||||
continue
|
||||
bad = []
|
||||
fallback = 0
|
||||
for c in run.glob("**/conversation.json"):
|
||||
data = json.load(c.open(errors="ignore"))
|
||||
for step in data:
|
||||
if step.get("step") is None:
|
||||
continue
|
||||
if not step.get("action"):
|
||||
bad.append(str(c.relative_to(run)))
|
||||
break
|
||||
mr = str(step.get("model_response", ""))
|
||||
if "empty model response" in mr or "missing action tag" in mr:
|
||||
fallback += 1
|
||||
print(run.name, "bad_action_files", len(bad), "fallback", fallback)
|
||||
PY
|
||||
```
|
||||
|
||||
## Filling Tables
|
||||
|
||||
Use only `summary.json` fields:
|
||||
|
||||
- `best_selection_hard` -> Best Sel
|
||||
- `baseline_test_hard` -> Base Test
|
||||
- `test_hard` -> Best Test
|
||||
- `test_delta_hard` -> Delta
|
||||
- `total_accepts` -> Accept
|
||||
- `total_rejects` -> Reject
|
||||
- `token_summary._total.total_tokens` -> Tokens
|
||||
|
||||
Do not fill table rows from logs alone.
|
||||
30
configs/alfworld/default.yaml
Normal file
30
configs/alfworld/default.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
train:
|
||||
train_size: 0
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
use_meta_reflect: false
|
||||
|
||||
evaluation:
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
|
||||
env:
|
||||
name: alfworld
|
||||
skill_init: reflact/envs/alfworld/skills/initial.md
|
||||
split_mode: split_dir
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: data/ablation_splits/alfworld/2-1-7_seed42
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_steps: 50
|
||||
workers: 8
|
||||
max_api_workers: 8
|
||||
limit: 0
|
||||
4
configs/alfworld/meta_reflect.yaml
Normal file
4
configs/alfworld/meta_reflect.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
_base_: default.yaml
|
||||
|
||||
optimizer:
|
||||
use_meta_reflect: true
|
||||
21
configs/babyvision/default.yaml
Normal file
21
configs/babyvision/default.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
accumulation: 1
|
||||
|
||||
env:
|
||||
name: babyvision
|
||||
skill_init: reflact/envs/babyvision/skills/initial.md
|
||||
split_mode: ratio
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: ""
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_turns: 1
|
||||
workers: 16
|
||||
limit: 0
|
||||
image_detail: auto
|
||||
judge_model: gpt-5.4
|
||||
judge_max_completion_tokens: 256
|
||||
judge_retries: 5
|
||||
28
configs/docvqa/default.yaml
Normal file
28
configs/docvqa/default.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
env:
|
||||
name: docvqa
|
||||
skill_init: reflact/envs/docvqa/skills/initial.md
|
||||
split_mode: split_dir
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: /home/azureuser/zisu/SkillReflection/data/docvqa/splits
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_turns: 1
|
||||
workers: 16
|
||||
image_detail: auto
|
||||
limit: 0
|
||||
22
configs/livemathematicianbench/default.yaml
Normal file
22
configs/livemathematicianbench/default.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
train:
|
||||
train_size: 0
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
|
||||
env:
|
||||
name: livemathematicianbench
|
||||
skill_init: reflact/envs/livemathematicianbench/skills/initial.md
|
||||
split_mode: split_dir
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: data/ablation_splits/livemathematicianbench/2-1-7_seed42
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_turns: 1
|
||||
exec_timeout: 300
|
||||
workers: 64
|
||||
limit: 0
|
||||
shuffle_choices: true
|
||||
use_theorem: false
|
||||
use_sketch: false
|
||||
23
configs/mathverse/default.yaml
Normal file
23
configs/mathverse/default.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
codex_exec_sandbox: danger-full-access
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
accumulation: 1
|
||||
|
||||
env:
|
||||
name: mathverse
|
||||
skill_init: reflact/envs/mathverse/skills/initial.md
|
||||
split_dir: ""
|
||||
data_root: data/MathVerse
|
||||
problem_version: Text Lite
|
||||
use_text_dominant_reference: false
|
||||
max_turns: 1
|
||||
workers: 16
|
||||
limit: 0
|
||||
image_detail: auto
|
||||
judge_model: gpt-5.4
|
||||
judge_max_completion_tokens: 256
|
||||
judge_retries: 5
|
||||
18
configs/mmrb/default.yaml
Normal file
18
configs/mmrb/default.yaml
Normal file
@@ -0,0 +1,18 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
train:
|
||||
batch_size: 128
|
||||
accumulation: 1
|
||||
|
||||
env:
|
||||
name: mmrb
|
||||
skill_init: reflact/envs/mmrb/skills/initial.md
|
||||
split_mode: ratio
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: ""
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_turns: 1
|
||||
workers: 16
|
||||
limit: 0
|
||||
image_detail: auto
|
||||
25
configs/officeqa/default.yaml
Normal file
25
configs/officeqa/default.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
env:
|
||||
name: officeqa
|
||||
skill_init: reflact/envs/officeqa/skills/initial.md
|
||||
split_dir: data/officeqa_split
|
||||
data_dirs:
|
||||
- data/officeqa_docs_official
|
||||
workers: 4
|
||||
max_tool_turns: 24
|
||||
limit: 0
|
||||
23
configs/sealqa/default.yaml
Normal file
23
configs/sealqa/default.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
batch_size: 10
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
env:
|
||||
name: sealqa
|
||||
skill_init: reflact/envs/sealqa/skills/initial.md
|
||||
split_dir: data/sealqa_split
|
||||
workers: 4
|
||||
max_tool_turns: 12
|
||||
limit: 0
|
||||
32
configs/searchqa/default.yaml
Normal file
32
configs/searchqa/default.yaml
Normal file
@@ -0,0 +1,32 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
train_size: 400
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
evaluation:
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
|
||||
env:
|
||||
name: searchqa
|
||||
skill_init: reflact/envs/searchqa/skills/initial.md
|
||||
split_mode: split_dir
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: data/searchqa_split
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
max_turns: 1
|
||||
workers: 24
|
||||
limit: 0
|
||||
34
configs/spreadsheetbench/default.yaml
Normal file
34
configs/spreadsheetbench/default.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
train_size: 80
|
||||
batch_size: 40
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 8
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
evaluation:
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
|
||||
env:
|
||||
name: spreadsheetbench
|
||||
skill_init: reflact/envs/spreadsheetbench/skills/initial.md
|
||||
split_mode: split_dir
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: data/spreadsheetbench_split
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
data_root: data/spreadsheetbench_verified_400
|
||||
mode: multi
|
||||
max_turns: 30
|
||||
exec_timeout: 600
|
||||
workers: 24
|
||||
36
configs/swebench/default.yaml
Normal file
36
configs/swebench/default.yaml
Normal file
@@ -0,0 +1,36 @@
|
||||
_base_: ../_base_/default.yaml
|
||||
|
||||
model:
|
||||
reasoning_effort: medium
|
||||
|
||||
train:
|
||||
batch_size: 20
|
||||
accumulation: 1
|
||||
|
||||
gradient:
|
||||
minibatch_size: 4
|
||||
merge_batch_size: 8
|
||||
|
||||
optimizer:
|
||||
learning_rate: 4
|
||||
|
||||
evaluation:
|
||||
sel_env_num: 0
|
||||
test_env_num: 0
|
||||
|
||||
env:
|
||||
name: swebench
|
||||
skill_init: reflact/envs/swebench/skills/initial.md
|
||||
split_mode: ratio
|
||||
split_ratio: "2:1:7"
|
||||
split_dir: ""
|
||||
data_path: ""
|
||||
split_output_dir: ""
|
||||
dataset_name: lite
|
||||
hf_split: test
|
||||
workers: 8
|
||||
eval_workers: 8
|
||||
step_limit: 50
|
||||
cost_limit: 3.0
|
||||
timeout_per_instance: 600
|
||||
limit: 0
|
||||
29
reflact/__init__.py
Normal file
29
reflact/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""ReflACT: Reflective Agent Tuning.
|
||||
|
||||
A general-purpose framework for iteratively optimizing LLM agent skills
|
||||
through structured reflection and self-improvement.
|
||||
|
||||
Pipeline stages:
|
||||
1. Rollout — execute episodes with current skill
|
||||
2. Reflect — analyze trajectories, generate patches
|
||||
3. Aggregate — hierarchical merge of patches
|
||||
4. Select — rank and select top edits
|
||||
5. Update — apply edits to skill document
|
||||
6. Evaluate — validate candidate skill, accept/reject
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from reflact.types import ( # noqa: F401
|
||||
BatchSpec,
|
||||
Edit,
|
||||
EditOp,
|
||||
FailureSummaryEntry,
|
||||
GateAction,
|
||||
GateResult,
|
||||
MetaReflectResult,
|
||||
Patch,
|
||||
RawPatch,
|
||||
RolloutResult,
|
||||
SlowUpdateResult,
|
||||
)
|
||||
263
reflact/config.py
Normal file
263
reflact/config.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""ReflACT config loading engine — structured YAML with inheritance.
|
||||
|
||||
Supports two config formats:
|
||||
1. **Structured** (new): sections like ``model``, ``train``, ``gradient``,
|
||||
``optimizer``, ``evaluation``, ``env`` — with ``_base_`` inheritance.
|
||||
2. **Flat** (legacy): all keys at top level — fully backward compatible.
|
||||
|
||||
Usage::
|
||||
|
||||
from reflact.config import load_config, flatten_config
|
||||
|
||||
cfg = load_config("configs/searchqa_default.yaml")
|
||||
flat = flatten_config(cfg) # always returns flat dict for trainer
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
# ── Section names that indicate a structured config ──────────────────────
|
||||
|
||||
_STRUCTURED_SECTIONS = frozenset({
|
||||
"model", "train", "gradient", "optimizer", "evaluation", "env",
|
||||
})
|
||||
|
||||
# ── Structured → flat key mapping ────────────────────────────────────────
|
||||
|
||||
_FLATTEN_MAP: dict[str, str] = {
|
||||
"model.backend": "model_backend",
|
||||
"model.teacher": "teacher_model",
|
||||
"model.student": "student_model",
|
||||
"model.teacher_backend": "teacher_backend",
|
||||
"model.student_backend": "student_backend",
|
||||
"model.reasoning_effort": "reasoning_effort",
|
||||
"model.rewrite_reasoning_effort": "rewrite_reasoning_effort",
|
||||
"model.rewrite_max_completion_tokens": "rewrite_max_completion_tokens",
|
||||
"model.codex_exec_path": "codex_exec_path",
|
||||
"model.codex_exec_sandbox": "codex_exec_sandbox",
|
||||
"model.codex_exec_profile": "codex_exec_profile",
|
||||
"model.codex_exec_full_auto": "codex_exec_full_auto",
|
||||
"model.codex_exec_reasoning_effort": "codex_exec_reasoning_effort",
|
||||
"model.codex_exec_use_sdk": "codex_exec_use_sdk",
|
||||
"model.codex_exec_network_access": "codex_exec_network_access",
|
||||
"model.codex_exec_web_search": "codex_exec_web_search",
|
||||
"model.codex_exec_approval_policy": "codex_exec_approval_policy",
|
||||
"model.claude_code_exec_path": "claude_code_exec_path",
|
||||
"model.claude_code_exec_profile": "claude_code_exec_profile",
|
||||
"model.claude_code_exec_use_sdk": "claude_code_exec_use_sdk",
|
||||
"model.claude_code_exec_effort": "claude_code_exec_effort",
|
||||
"model.claude_code_exec_max_thinking_tokens": "claude_code_exec_max_thinking_tokens",
|
||||
"model.codex_trace_to_teacher": "codex_trace_to_teacher",
|
||||
"model.azure_endpoint": "azure_endpoint",
|
||||
"model.azure_api_version": "azure_api_version",
|
||||
"model.azure_api_key": "azure_api_key",
|
||||
"model.azure_openai_endpoint": "azure_openai_endpoint",
|
||||
"model.azure_openai_api_version": "azure_openai_api_version",
|
||||
"model.azure_openai_api_key": "azure_openai_api_key",
|
||||
"model.azure_openai_auth_mode": "azure_openai_auth_mode",
|
||||
"model.azure_openai_ad_scope": "azure_openai_ad_scope",
|
||||
"model.azure_openai_managed_identity_client_id": "azure_openai_managed_identity_client_id",
|
||||
"model.teacher_azure_openai_endpoint": "teacher_azure_openai_endpoint",
|
||||
"model.teacher_azure_openai_api_version": "teacher_azure_openai_api_version",
|
||||
"model.teacher_azure_openai_api_key": "teacher_azure_openai_api_key",
|
||||
"model.teacher_azure_openai_auth_mode": "teacher_azure_openai_auth_mode",
|
||||
"model.teacher_azure_openai_ad_scope": "teacher_azure_openai_ad_scope",
|
||||
"model.teacher_azure_openai_managed_identity_client_id": "teacher_azure_openai_managed_identity_client_id",
|
||||
"model.student_azure_openai_endpoint": "student_azure_openai_endpoint",
|
||||
"model.student_azure_openai_api_version": "student_azure_openai_api_version",
|
||||
"model.student_azure_openai_api_key": "student_azure_openai_api_key",
|
||||
"model.student_azure_openai_auth_mode": "student_azure_openai_auth_mode",
|
||||
"model.student_azure_openai_ad_scope": "student_azure_openai_ad_scope",
|
||||
"model.student_azure_openai_managed_identity_client_id": "student_azure_openai_managed_identity_client_id",
|
||||
"train.num_epochs": "num_epochs",
|
||||
"train.train_size": "train_size",
|
||||
"train.steps_per_epoch": "steps_per_epoch",
|
||||
"train.batch_size": "batch_size",
|
||||
"train.accumulation": "accumulation",
|
||||
"train.seed": "seed",
|
||||
"gradient.minibatch_size": "minibatch_size",
|
||||
"gradient.merge_batch_size": "merge_batch_size",
|
||||
"gradient.analyst_workers": "analyst_workers",
|
||||
"gradient.failure_only": "failure_only",
|
||||
"gradient.use_deep_reflect": "use_deep_reflect",
|
||||
"gradient.deep_reflect_failures": "deep_reflect_failures",
|
||||
"gradient.deep_reflect_successes": "deep_reflect_successes",
|
||||
"gradient.max_analyst_rounds": "max_analyst_rounds",
|
||||
"optimizer.learning_rate": "edit_budget",
|
||||
"optimizer.min_learning_rate": "min_edit_budget",
|
||||
"optimizer.lr_scheduler": "lr_scheduler",
|
||||
"optimizer.lr_control_mode": "lr_control_mode",
|
||||
"optimizer.skill_update_mode": "skill_update_mode",
|
||||
"optimizer.use_meta_reflect": "use_meta_reflect",
|
||||
"optimizer.meta_learning_rate": "meta_edit_budget",
|
||||
"optimizer.use_slow_update": "use_slow_update",
|
||||
"optimizer.slow_update_samples": "slow_update_samples",
|
||||
"optimizer.longitudinal_pair_policy": "longitudinal_pair_policy",
|
||||
"optimizer.use_meta_skill": "use_meta_skill",
|
||||
"evaluation.use_gate": "use_gate",
|
||||
"evaluation.sel_env_num": "sel_env_num",
|
||||
"evaluation.test_env_num": "test_env_num",
|
||||
"evaluation.eval_test": "eval_test",
|
||||
"env.name": "env",
|
||||
"env.skill_init": "skill_init",
|
||||
"env.out_root": "out_root",
|
||||
}
|
||||
|
||||
|
||||
# ── Deep merge ───────────────────────────────────────────────────────────
|
||||
|
||||
def _deep_merge(base: dict, override: dict) -> dict:
|
||||
"""Recursively merge *override* into *base* (returns new dict)."""
|
||||
result = copy.deepcopy(base)
|
||||
for key, val in override.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(val, dict):
|
||||
result[key] = _deep_merge(result[key], val)
|
||||
else:
|
||||
result[key] = copy.deepcopy(val)
|
||||
return result
|
||||
|
||||
|
||||
# ── YAML loading with _base_ inheritance ─────────────────────────────────
|
||||
|
||||
def _load_yaml(path: str, _visited: set[str] | None = None) -> dict:
|
||||
"""Load a YAML file, resolving ``_base_`` inheritance recursively."""
|
||||
abs_path = os.path.abspath(path)
|
||||
if _visited is None:
|
||||
_visited = set()
|
||||
if abs_path in _visited:
|
||||
raise ValueError(f"Circular _base_ inheritance: {abs_path}")
|
||||
_visited.add(abs_path)
|
||||
|
||||
with open(abs_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
|
||||
base_ref = cfg.pop("_base_", None)
|
||||
if base_ref:
|
||||
base_path = os.path.join(os.path.dirname(abs_path), base_ref)
|
||||
base_cfg = _load_yaml(base_path, _visited)
|
||||
cfg = _deep_merge(base_cfg, cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
# ── Format detection ─────────────────────────────────────────────────────
|
||||
|
||||
def is_structured(cfg: dict) -> bool:
|
||||
"""Return True if *cfg* uses the new structured section format."""
|
||||
return any(
|
||||
key in _STRUCTURED_SECTIONS and isinstance(cfg.get(key), dict)
|
||||
for key in cfg
|
||||
)
|
||||
|
||||
|
||||
# ── Flatten ──────────────────────────────────────────────────────────────
|
||||
|
||||
def flatten_config(cfg: dict) -> dict:
|
||||
"""Convert a structured config to the flat dict expected by the trainer.
|
||||
|
||||
If *cfg* is already flat, returns a shallow copy unchanged.
|
||||
"""
|
||||
if not is_structured(cfg):
|
||||
return dict(cfg)
|
||||
|
||||
flat: dict[str, Any] = {}
|
||||
|
||||
evaluation_section = cfg.get("evaluation", {})
|
||||
if isinstance(evaluation_section, dict) and evaluation_section.get("use_gate") is False:
|
||||
raise ValueError(
|
||||
"Gate validation is mandatory in this branch. Remove "
|
||||
"`evaluation.use_gate: false` from the config."
|
||||
)
|
||||
|
||||
# Apply the explicit mapping
|
||||
for dotted, flat_key in _FLATTEN_MAP.items():
|
||||
section, key = dotted.split(".", 1)
|
||||
section_dict = cfg.get(section, {})
|
||||
if isinstance(section_dict, dict) and key in section_dict:
|
||||
flat[flat_key] = section_dict[key]
|
||||
|
||||
# Pass through env-specific keys not in the explicit mapping
|
||||
env_section = cfg.get("env", {})
|
||||
if isinstance(env_section, dict):
|
||||
mapped_env_keys = {
|
||||
k.split(".", 1)[1]
|
||||
for k in _FLATTEN_MAP
|
||||
if k.startswith("env.")
|
||||
}
|
||||
for key, val in env_section.items():
|
||||
if key not in mapped_env_keys:
|
||||
flat[key] = val
|
||||
|
||||
return flat
|
||||
|
||||
|
||||
# ── Override application ─────────────────────────────────────────────────
|
||||
|
||||
def _cast_value(val_str: str) -> Any:
|
||||
"""Auto-cast a CLI string value to int / float / bool / str."""
|
||||
if val_str.lower() in ("true", "yes"):
|
||||
return True
|
||||
if val_str.lower() in ("false", "no"):
|
||||
return False
|
||||
try:
|
||||
return int(val_str)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(val_str)
|
||||
except ValueError:
|
||||
pass
|
||||
return val_str
|
||||
|
||||
|
||||
def apply_overrides(cfg: dict, overrides: list[str]) -> None:
|
||||
"""Apply ``key=value`` overrides to a structured config (in place).
|
||||
|
||||
Supports both ``section.key=value`` (for structured configs) and
|
||||
``key=value`` (for flat configs or flat keys in env section).
|
||||
"""
|
||||
for item in overrides:
|
||||
if "=" not in item:
|
||||
raise ValueError(f"Invalid override (expected key=value): {item!r}")
|
||||
key, val_str = item.split("=", 1)
|
||||
val = _cast_value(val_str)
|
||||
|
||||
if "." in key:
|
||||
section, subkey = key.split(".", 1)
|
||||
if section in cfg and isinstance(cfg[section], dict):
|
||||
cfg[section][subkey] = val
|
||||
else:
|
||||
cfg.setdefault(section, {})[subkey] = val
|
||||
else:
|
||||
# Flat key — apply to top level (for legacy compat)
|
||||
cfg[key] = val
|
||||
|
||||
|
||||
# ── Public API ───────────────────────────────────────────────────────────
|
||||
|
||||
def load_config(
|
||||
path: str,
|
||||
overrides: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Load a config file with ``_base_`` inheritance and optional overrides.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Path to the YAML config file.
|
||||
overrides : list[str] | None
|
||||
``key=value`` strings from ``--cfg-options``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The merged config (structured or flat depending on the YAML).
|
||||
"""
|
||||
cfg = _load_yaml(path)
|
||||
if overrides:
|
||||
apply_overrides(cfg, overrides)
|
||||
return cfg
|
||||
7
reflact/datasets/__init__.py
Normal file
7
reflact/datasets/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""ReflACT Datasets -- task batch planning and data loading.
|
||||
|
||||
Analogous to the datasets and dataloaders in neural network training:
|
||||
provides batch sampling, epoch planning, and data management for the
|
||||
ReflACT training pipeline.
|
||||
"""
|
||||
from reflact.datasets.base import BaseDataLoader, BatchSpec, SplitDataLoader # noqa: F401
|
||||
512
reflact/datasets/base.py
Normal file
512
reflact/datasets/base.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""Generic task dataloader abstractions for ReflACT.
|
||||
|
||||
ReflACT does not train model parameters directly. Instead, it iterates over
|
||||
task batches, rolls out the current skill, reflects on failures/successes,
|
||||
and updates the skill document. Because of that, the "dataloader" abstraction
|
||||
here is closer to a batch sampler / episode planner than a tensor loader.
|
||||
|
||||
Class hierarchy::
|
||||
|
||||
BaseDataLoader # abstract — simulator-backed envs (e.g. ALFWorld)
|
||||
└── SplitDataLoader # abstract — dataset-backed envs with split_dir
|
||||
|
||||
SplitDataLoader supports two dataset entry modes:
|
||||
|
||||
1. ``split_mode="split_dir"``: consume an existing split directory.
|
||||
2. ``split_mode="ratio"``: build a deterministic split directory from a raw
|
||||
dataset path using an explicit train:val:test ratio.
|
||||
|
||||
In either case, the standardised split layout is:
|
||||
|
||||
split_dir/
|
||||
├── train/ # training items
|
||||
├── val/ # validation / selection items (gate)
|
||||
└── test/ # held-out test items
|
||||
|
||||
Each subdirectory's contents are benchmark-specific. Subclasses only need
|
||||
to implement ``load_split_items(split_path)`` to teach the loader how to
|
||||
read items from one of those directories.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BatchSpec:
|
||||
"""A concrete batch request consumed by the training loop.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
phase : str
|
||||
``"train"`` or ``"eval"``.
|
||||
split : str
|
||||
Dataset split name, typically ``"train"`` or an eval split.
|
||||
seed : int
|
||||
Random seed used to construct the batch deterministically.
|
||||
batch_size : int
|
||||
Requested number of items / episodes in this batch.
|
||||
payload : object | None
|
||||
Environment-specific batch payload. For dataset-backed environments
|
||||
this is often a list of sampled items; for simulator-backed
|
||||
environments this may be ``None`` and the seed alone can define the
|
||||
batch.
|
||||
metadata : dict[str, Any]
|
||||
Optional structured metadata for logging, resume, or curriculum logic.
|
||||
"""
|
||||
|
||||
phase: str
|
||||
split: str
|
||||
seed: int
|
||||
batch_size: int
|
||||
payload: object | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class BaseDataLoader(ABC):
|
||||
"""Abstract base class for task batch planning in ReflACT.
|
||||
|
||||
Subclasses are responsible for defining how a train or eval batch is
|
||||
sampled. The default implementation here provides deterministic epoch seed
|
||||
planning so all loaders share the same reproducibility behavior.
|
||||
"""
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
"""Optional one-time initialization with the full trainer config."""
|
||||
|
||||
def set_out_root(self, out_root: str) -> None:
|
||||
"""Optional hook for loaders that persist split files or state."""
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
"""Return serializable loader state for resume support."""
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state: dict[str, Any]) -> None:
|
||||
"""Restore loader state from :meth:`state_dict` output."""
|
||||
|
||||
def get_train_size(self) -> int | None:
|
||||
"""Return the size of the training pool when known."""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def make_base_seeds(steps_per_epoch: int, accumulation: int, seed: int) -> list[int]:
|
||||
"""Return the deterministic seed pool used to define train batches."""
|
||||
batches_per_epoch = steps_per_epoch * accumulation
|
||||
return [seed + i + 1 for i in range(batches_per_epoch)]
|
||||
|
||||
@staticmethod
|
||||
def shuffle_epoch_seeds(base_seeds: list[int], epoch: int, seed: int) -> list[int]:
|
||||
"""Return the per-epoch deterministic shuffle of *base_seeds*."""
|
||||
epoch_rng = random.Random(seed + epoch * 1000)
|
||||
shuffled = list(base_seeds)
|
||||
epoch_rng.shuffle(shuffled)
|
||||
return shuffled
|
||||
|
||||
def plan_train_epoch(
|
||||
self,
|
||||
*,
|
||||
epoch: int,
|
||||
steps_per_epoch: int,
|
||||
accumulation: int,
|
||||
batch_size: int,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> list[BatchSpec]:
|
||||
"""Build the full list of training batches for one epoch."""
|
||||
base_seeds = self.make_base_seeds(
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
accumulation=accumulation,
|
||||
seed=seed,
|
||||
)
|
||||
shuffled_seeds = self.shuffle_epoch_seeds(base_seeds, epoch=epoch, seed=seed)
|
||||
return [
|
||||
self.build_train_batch(batch_size=batch_size, seed=batch_seed, **kwargs)
|
||||
for batch_seed in shuffled_seeds
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
|
||||
"""Construct one training batch specification."""
|
||||
|
||||
@abstractmethod
|
||||
def build_eval_batch(
|
||||
self,
|
||||
env_num: int,
|
||||
split: str,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> BatchSpec:
|
||||
"""Construct one evaluation batch specification."""
|
||||
|
||||
|
||||
# ── Split-based dataloader for dataset-backed environments ──────────────
|
||||
|
||||
# Canonical split names expected under split_dir/
|
||||
SPLIT_NAMES = ("train", "val", "test")
|
||||
|
||||
# Maps legacy / trainer split names → canonical directory names
|
||||
_SPLIT_ALIAS: dict[str, str] = {
|
||||
"train": "train",
|
||||
"valid_seen": "val",
|
||||
"selection": "val",
|
||||
"val": "val",
|
||||
"valid_unseen": "test",
|
||||
"test": "test",
|
||||
}
|
||||
|
||||
|
||||
def _load_json_or_jsonl(path: str) -> list[dict]:
|
||||
"""Load a list of items from a JSON or JSONL file."""
|
||||
with open(path, encoding="utf-8") as f:
|
||||
content = f.read().strip()
|
||||
if not content:
|
||||
return []
|
||||
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
data = None
|
||||
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if isinstance(data, dict):
|
||||
nested = data.get("data")
|
||||
if isinstance(nested, list):
|
||||
return nested
|
||||
return list(data.values())
|
||||
|
||||
items: list[dict] = []
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
items.append(json.loads(line))
|
||||
return items
|
||||
|
||||
|
||||
def _parse_split_ratio(text: str) -> tuple[int, int, int]:
|
||||
parts = [part.strip() for part in str(text or "").split(":") if part.strip()]
|
||||
if len(parts) != 3:
|
||||
raise ValueError(
|
||||
f"split_ratio must be in train:val:test form, got {text!r}"
|
||||
)
|
||||
try:
|
||||
train, val, test = (int(part) for part in parts)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"split_ratio must contain integers, got {text!r}"
|
||||
) from exc
|
||||
if min(train, val, test) <= 0:
|
||||
raise ValueError(f"split_ratio parts must be positive, got {text!r}")
|
||||
return train, val, test
|
||||
|
||||
|
||||
def _compute_split_counts(total: int, ratio: tuple[int, int, int]) -> tuple[int, int, int]:
|
||||
weights = list(ratio)
|
||||
denom = sum(weights)
|
||||
raw = [total * weight / denom for weight in weights]
|
||||
counts = [int(value) for value in raw]
|
||||
remaining = total - sum(counts)
|
||||
order = sorted(
|
||||
range(len(raw)),
|
||||
key=lambda idx: (raw[idx] - counts[idx], weights[idx]),
|
||||
reverse=True,
|
||||
)
|
||||
for idx in order[:remaining]:
|
||||
counts[idx] += 1
|
||||
return counts[0], counts[1], counts[2]
|
||||
|
||||
|
||||
class SplitDataLoader(BaseDataLoader):
|
||||
"""Base class for dataset-backed environments.
|
||||
|
||||
Supported modes:
|
||||
|
||||
- ``split_mode="split_dir"``: load an existing ``train/``, ``val/``,
|
||||
``test/`` directory tree.
|
||||
- ``split_mode="ratio"``: load raw items from ``data_path`` and materialize
|
||||
a deterministic split directory with the requested ratio.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "ratio",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.split_dir = split_dir
|
||||
self.data_path = data_path
|
||||
self.split_mode = split_mode
|
||||
self.split_ratio = split_ratio
|
||||
self.split_seed = int(split_seed)
|
||||
self.split_output_dir = split_output_dir
|
||||
self.seed = seed
|
||||
self.limit = limit
|
||||
self._splits: dict[str, list[dict]] = {}
|
||||
|
||||
# ── Setup ────────────────────────────────────────────────────────────
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
if not self.split_mode:
|
||||
self.split_mode = str(cfg.get("split_mode", "ratio") or "ratio")
|
||||
if not self.split_dir:
|
||||
self.split_dir = cfg.get("split_dir", "")
|
||||
if not self.data_path:
|
||||
self.data_path = cfg.get("data_path", "")
|
||||
if not self.split_output_dir:
|
||||
self.split_output_dir = cfg.get("split_output_dir", "")
|
||||
if "split_seed" in cfg and not self.split_seed:
|
||||
self.split_seed = int(cfg.get("split_seed", 0) or 0)
|
||||
if not self.split_seed:
|
||||
self.split_seed = self.seed
|
||||
if not self.split_ratio:
|
||||
self.split_ratio = str(cfg.get("split_ratio", "2:1:7") or "2:1:7")
|
||||
|
||||
mode = str(self.split_mode or "ratio").strip().lower()
|
||||
if mode not in {"ratio", "split_dir"}:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} split_mode must be 'ratio' or 'split_dir', "
|
||||
f"got {self.split_mode!r}"
|
||||
)
|
||||
self.split_mode = mode
|
||||
|
||||
if self.split_mode == "ratio":
|
||||
self.split_dir = self._materialize_ratio_split(cfg)
|
||||
if not self.split_dir:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} requires either "
|
||||
"`split_mode=ratio` with `data_path`, or `split_mode=split_dir` "
|
||||
f"with `split_dir` pointing to {'/'.join(SPLIT_NAMES)}/."
|
||||
)
|
||||
self._load_all_splits()
|
||||
|
||||
def _resolve_split_output_dir(self, cfg: dict) -> str:
|
||||
if self.split_output_dir:
|
||||
return os.path.abspath(self.split_output_dir)
|
||||
out_root = os.path.abspath(str(cfg.get("out_root") or os.getcwd()))
|
||||
env_name = str(cfg.get("env") or type(self).__name__.replace("DataLoader", "").lower())
|
||||
ratio_tag = str(self.split_ratio or "2:1:7").replace(":", "-")
|
||||
return os.path.join(out_root, "_generated_splits", f"{env_name}_{ratio_tag}_seed{self.split_seed}")
|
||||
|
||||
def load_raw_items(self, data_path: str) -> list[dict]:
|
||||
"""Load raw items from a dataset path before ratio splitting.
|
||||
|
||||
Subclasses can override when the raw dataset is not a single JSON/JSONL
|
||||
file or when directory layouts require custom normalization.
|
||||
"""
|
||||
if os.path.isdir(data_path):
|
||||
if any(os.path.isdir(os.path.join(data_path, name)) for name in SPLIT_NAMES):
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} got a split directory as data_path. "
|
||||
"Use split_mode=split_dir and pass it as split_dir instead."
|
||||
)
|
||||
candidates = sorted(glob.glob(os.path.join(data_path, "*.json")))
|
||||
candidates += sorted(glob.glob(os.path.join(data_path, "*.jsonl")))
|
||||
if len(candidates) != 1:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} expected data_path to be one JSON/JSONL file "
|
||||
f"or a directory containing exactly one such file, got: {data_path}"
|
||||
)
|
||||
return _load_json_or_jsonl(candidates[0])
|
||||
return _load_json_or_jsonl(data_path)
|
||||
|
||||
def write_split_items(self, split_path: str, items: list[dict]) -> None:
|
||||
os.makedirs(split_path, exist_ok=True)
|
||||
out_path = os.path.join(split_path, "items.json")
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(items, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def _materialize_ratio_split(self, cfg: dict) -> str:
|
||||
data_path = os.path.abspath(str(self.data_path or "").strip())
|
||||
if not data_path:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} requires data_path when split_mode=ratio."
|
||||
)
|
||||
|
||||
ratio = _parse_split_ratio(self.split_ratio)
|
||||
items = self.load_raw_items(data_path)
|
||||
if not isinstance(items, list) or not items:
|
||||
raise ValueError(f"No raw items available for ratio split from {data_path}")
|
||||
|
||||
shuffled = list(items)
|
||||
rng = random.Random(self.split_seed)
|
||||
rng.shuffle(shuffled)
|
||||
|
||||
train_n, val_n, test_n = _compute_split_counts(len(shuffled), ratio)
|
||||
train_items = shuffled[:train_n]
|
||||
val_items = shuffled[train_n: train_n + val_n]
|
||||
test_items = shuffled[train_n + val_n: train_n + val_n + test_n]
|
||||
|
||||
split_dir = self._resolve_split_output_dir(cfg)
|
||||
manifest = {
|
||||
"source_data_path": data_path,
|
||||
"split_mode": "ratio",
|
||||
"split_ratio": self.split_ratio,
|
||||
"split_seed": self.split_seed,
|
||||
"counts": {
|
||||
"train": len(train_items),
|
||||
"val": len(val_items),
|
||||
"test": len(test_items),
|
||||
},
|
||||
}
|
||||
os.makedirs(split_dir, exist_ok=True)
|
||||
self.write_split_items(os.path.join(split_dir, "train"), train_items)
|
||||
self.write_split_items(os.path.join(split_dir, "val"), val_items)
|
||||
self.write_split_items(os.path.join(split_dir, "test"), test_items)
|
||||
with open(os.path.join(split_dir, "split_manifest.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(manifest, f, ensure_ascii=False, indent=2)
|
||||
print(
|
||||
f" [{type(self).__name__}] generated ratio split {self.split_ratio} "
|
||||
f"at {split_dir} from {data_path}"
|
||||
)
|
||||
return split_dir
|
||||
|
||||
def _load_all_splits(self) -> None:
|
||||
for name in SPLIT_NAMES:
|
||||
split_path = os.path.join(self.split_dir, name)
|
||||
if not os.path.isdir(split_path):
|
||||
raise ValueError(
|
||||
f"Missing '{name}/' subdirectory in split_dir: {self.split_dir}"
|
||||
)
|
||||
items = self.load_split_items(split_path)
|
||||
if self.limit:
|
||||
items = items[: self.limit]
|
||||
self._splits[name] = items
|
||||
|
||||
counts = " ".join(f"{k}={len(v)}" for k, v in self._splits.items())
|
||||
print(f" [{type(self).__name__}] {counts} (from {self.split_dir})")
|
||||
|
||||
def load_split_items(self, split_path: str) -> list[dict]:
|
||||
"""Load items from one split directory (e.g. ``split_dir/train/``).
|
||||
|
||||
Default: finds the first ``.json`` file in the directory and loads it
|
||||
as a JSON array. Subclasses can override for custom formats.
|
||||
"""
|
||||
json_files = sorted(glob.glob(os.path.join(split_path, "*.json")))
|
||||
if not json_files:
|
||||
raise FileNotFoundError(
|
||||
f"No .json file found in {split_path}"
|
||||
)
|
||||
with open(json_files[0], encoding="utf-8") as f:
|
||||
items = json.load(f)
|
||||
if not isinstance(items, list):
|
||||
raise ValueError(
|
||||
f"Expected JSON array in {json_files[0]}, got {type(items).__name__}"
|
||||
)
|
||||
return items
|
||||
|
||||
# ── Accessors ────────────────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def train_items(self) -> list[dict]:
|
||||
return self._splits.get("train", [])
|
||||
|
||||
@property
|
||||
def val_items(self) -> list[dict]:
|
||||
return self._splits.get("val", [])
|
||||
|
||||
@property
|
||||
def test_items(self) -> list[dict]:
|
||||
return self._splits.get("test", [])
|
||||
|
||||
def get_split_items(self, split: str) -> list[dict]:
|
||||
"""Resolve a split name (including legacy aliases) to its item list."""
|
||||
canonical = _SPLIT_ALIAS.get(split, split)
|
||||
return list(self._splits.get(canonical, self.val_items))
|
||||
|
||||
def get_train_size(self) -> int:
|
||||
return len(self.train_items)
|
||||
|
||||
def plan_train_epoch(
|
||||
self,
|
||||
*,
|
||||
epoch: int,
|
||||
steps_per_epoch: int,
|
||||
accumulation: int,
|
||||
batch_size: int,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> list[BatchSpec]:
|
||||
"""Build one full epoch that covers the train split in shuffled order.
|
||||
|
||||
For split-backed datasets, an epoch should correspond to one pass over
|
||||
the available training items rather than repeated independent sampling.
|
||||
"""
|
||||
epoch_rng = random.Random(seed + epoch * 1000)
|
||||
items = list(self.train_items)
|
||||
epoch_rng.shuffle(items)
|
||||
|
||||
total_batches = steps_per_epoch * accumulation
|
||||
if total_batches <= 0:
|
||||
return []
|
||||
|
||||
batches: list[BatchSpec] = []
|
||||
cursor = 0
|
||||
for batch_idx in range(total_batches):
|
||||
batch_items = items[cursor: cursor + batch_size]
|
||||
cursor += len(batch_items)
|
||||
|
||||
# Extremely small datasets can leave trailing empty microbatches
|
||||
# when accumulation > 1. Reuse the shuffled prefix in that case so
|
||||
# the trainer still receives the expected batch count.
|
||||
if not batch_items and items:
|
||||
refill_rng = random.Random(seed + epoch * 1000 + batch_idx + 1)
|
||||
batch_items = list(items)
|
||||
refill_rng.shuffle(batch_items)
|
||||
batch_items = batch_items[:batch_size]
|
||||
|
||||
batches.append(
|
||||
BatchSpec(
|
||||
phase="train",
|
||||
split="train",
|
||||
seed=seed + epoch * 1000 + batch_idx + 1,
|
||||
batch_size=len(batch_items),
|
||||
payload=batch_items,
|
||||
)
|
||||
)
|
||||
|
||||
return batches
|
||||
|
||||
# ── Batch construction ───────────────────────────────────────────────
|
||||
|
||||
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
|
||||
rng = random.Random(seed)
|
||||
items = list(self.train_items)
|
||||
rng.shuffle(items)
|
||||
items = items[:batch_size]
|
||||
return BatchSpec(
|
||||
phase="train",
|
||||
split="train",
|
||||
seed=seed,
|
||||
batch_size=len(items),
|
||||
payload=items,
|
||||
)
|
||||
|
||||
def build_eval_batch(
|
||||
self,
|
||||
env_num: int,
|
||||
split: str,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> BatchSpec:
|
||||
items = self.get_split_items(split)
|
||||
if env_num and env_num < len(items):
|
||||
items = items[:env_num]
|
||||
return BatchSpec(
|
||||
phase="eval",
|
||||
split=split,
|
||||
seed=seed,
|
||||
batch_size=len(items),
|
||||
payload=items,
|
||||
)
|
||||
9
reflact/engine/__init__.py
Normal file
9
reflact/engine/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""ReflACT Engine -- the training runner.
|
||||
|
||||
Analogous to the Runner in mmengine: orchestrates the full training pipeline
|
||||
including rollout, gradient computation, aggregation, optimization, and
|
||||
evaluation.
|
||||
"""
|
||||
from reflact.engine.trainer import ReflACTTrainer # noqa: F401
|
||||
|
||||
__all__ = ["ReflACTTrainer"]
|
||||
2195
reflact/engine/trainer.py
Normal file
2195
reflact/engine/trainer.py
Normal file
File diff suppressed because it is too large
Load Diff
1
reflact/envs/__init__.py
Normal file
1
reflact/envs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""ReflACT environment adapters."""
|
||||
5
reflact/envs/alfworld/__init__.py
Normal file
5
reflact/envs/alfworld/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""ALFWorld environment adapter for ReflACT."""
|
||||
|
||||
from reflact.envs.alfworld.adapter import ALFWorldAdapter
|
||||
|
||||
__all__ = ["ALFWorldAdapter"]
|
||||
585
reflact/envs/alfworld/adapter.py
Normal file
585
reflact/envs/alfworld/adapter.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""ALFWorld environment adapter for ReflACT.
|
||||
|
||||
Connects the ReflACT training loop to ALFWorld by implementing
|
||||
:class:`~reflact.envs.base.EnvAdapter`.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
|
||||
from reflact.gradient.deep_probe import generate_deep_probe_instruction
|
||||
from reflact.datasets.base import BatchSpec
|
||||
from reflact.envs.base import EnvAdapter
|
||||
from reflact.envs.alfworld.dataloader import ALFWorldDataLoader
|
||||
from reflact.envs.alfworld.rollout import (
|
||||
build_alfworld_env,
|
||||
run_alfworld_batch,
|
||||
TASKS,
|
||||
)
|
||||
from reflact.gradient.reflect import run_minibatch_reflect
|
||||
from reflact.utils import compute_score
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ALFWorldBatchRun:
|
||||
"""Lazy ALFWorld batch description.
|
||||
|
||||
The adapter materializes this in rollout chunks so a large evaluation set
|
||||
does not keep every ALFWorld simulator open at once.
|
||||
"""
|
||||
|
||||
env_num: int
|
||||
eval_dataset: str
|
||||
seed: int
|
||||
is_train: bool
|
||||
workers: int
|
||||
specific_gamefiles: list[str] | None = None
|
||||
result_ids: list[str] | None = None
|
||||
items: list[dict] | None = None
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.items or [])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return int(self.env_num or 0)
|
||||
|
||||
|
||||
class ALFWorldAdapter(EnvAdapter):
|
||||
"""ALFWorld environment adapter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_steps : int
|
||||
Maximum steps per ALFWorld episode (default 50).
|
||||
max_api_workers : int
|
||||
Maximum concurrent API calls during rollout (default 8).
|
||||
analyst_workers : int
|
||||
Parallel workers for analyst stage (default 16).
|
||||
failure_only : bool
|
||||
If True, only run error analyst (skip success analyst).
|
||||
minibatch_size : int
|
||||
Trajectories per analyst group, M (default 8).
|
||||
edit_budget : int
|
||||
Maximum edits per minibatch, L (default 4).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "split_dir",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
train_size: int = 0,
|
||||
max_steps: int = 50,
|
||||
workers: int = 8,
|
||||
max_api_workers: int = 8,
|
||||
analyst_workers: int = 16,
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
use_deep_reflect: bool = False,
|
||||
deep_reflect_failures: int = 4,
|
||||
deep_reflect_successes: int = 2,
|
||||
) -> None:
|
||||
self.max_steps = max_steps
|
||||
self.workers = max(int(workers or 1), 1)
|
||||
self.max_api_workers = max_api_workers
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
self.edit_budget = edit_budget
|
||||
self.use_deep_reflect = use_deep_reflect
|
||||
self.deep_reflect_failures = deep_reflect_failures
|
||||
self.deep_reflect_successes = deep_reflect_successes
|
||||
self.dataloader = ALFWorldDataLoader(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
train_size=train_size,
|
||||
)
|
||||
self._traj_cache: dict[str, dict | None] = {}
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
self.dataloader.setup(cfg)
|
||||
|
||||
def _load_traj_data(self, item: dict) -> dict | None:
|
||||
gamefile = str(item.get("gamefile") or "").strip()
|
||||
if not gamefile:
|
||||
return None
|
||||
if gamefile in self._traj_cache:
|
||||
return self._traj_cache[gamefile]
|
||||
|
||||
traj_path = os.path.join(os.path.dirname(gamefile), "traj_data.json")
|
||||
try:
|
||||
with open(traj_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except Exception:
|
||||
data = None
|
||||
self._traj_cache[gamefile] = data
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _unique_lines(values: list[str], *, limit: int = 0) -> list[str]:
|
||||
lines: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for raw in values:
|
||||
line = str(raw or "").strip()
|
||||
if not line or line in seen:
|
||||
continue
|
||||
seen.add(line)
|
||||
lines.append(line)
|
||||
if limit > 0 and len(lines) >= limit:
|
||||
break
|
||||
return lines
|
||||
|
||||
@staticmethod
|
||||
def _format_high_pddl(high_pddl: list[dict]) -> list[str]:
|
||||
steps: list[str] = []
|
||||
for idx, step in enumerate(high_pddl or [], start=1):
|
||||
discrete = step.get("discrete_action") or {}
|
||||
action = str(discrete.get("action") or "").strip()
|
||||
args = [str(arg).strip() for arg in (discrete.get("args") or []) if str(arg).strip()]
|
||||
if action and args:
|
||||
text = f"{action}({', '.join(args)})"
|
||||
elif action:
|
||||
text = action
|
||||
else:
|
||||
planner_action = step.get("planner_action") or {}
|
||||
text = str(planner_action.get("action") or "").strip()
|
||||
if text:
|
||||
steps.append(f"{idx}. {text}")
|
||||
return steps
|
||||
|
||||
def _build_reference_bundle(self, item: dict) -> dict:
|
||||
data = self._load_traj_data(item)
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
anns = ((data.get("turk_annotations") or {}).get("anns") or [])
|
||||
task_descs = self._unique_lines(
|
||||
[ann.get("task_desc", "") for ann in anns],
|
||||
limit=3,
|
||||
)
|
||||
high_descs = self._unique_lines(
|
||||
[step for ann in anns for step in (ann.get("high_descs") or [])],
|
||||
limit=12,
|
||||
)
|
||||
pddl_params = {
|
||||
key: value
|
||||
for key, value in (data.get("pddl_params") or {}).items()
|
||||
if value not in ("", None, [], {})
|
||||
}
|
||||
scene = data.get("scene") or {}
|
||||
scene_summary = {
|
||||
key: scene.get(key)
|
||||
for key in ("floor_plan", "scene_num", "dirty_and_empty")
|
||||
if scene.get(key) not in ("", None, [], {})
|
||||
}
|
||||
high_pddl = self._format_high_pddl((data.get("plan") or {}).get("high_pddl") or [])
|
||||
task_type = str(data.get("task_type") or item.get("task_type") or "").strip()
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task_descs": task_descs,
|
||||
"high_descs": high_descs,
|
||||
"pddl_params": pddl_params,
|
||||
"high_pddl": high_pddl,
|
||||
"scene_summary": scene_summary,
|
||||
}
|
||||
|
||||
def build_reference_text(self, item: dict) -> str:
|
||||
bundle = self._build_reference_bundle(item)
|
||||
if not bundle:
|
||||
return ""
|
||||
|
||||
parts: list[str] = []
|
||||
if bundle["task_type"]:
|
||||
parts.append(f"## Reference Task Type\n{bundle['task_type']}")
|
||||
if bundle["task_descs"]:
|
||||
parts.append(
|
||||
"## Reference Human Task Descriptions\n"
|
||||
+ "\n".join(f"- {line}" for line in bundle["task_descs"])
|
||||
)
|
||||
if bundle["high_descs"]:
|
||||
parts.append(
|
||||
"## Reference Human High-Level Steps\n"
|
||||
+ "\n".join(f"{idx}. {line}" for idx, line in enumerate(bundle["high_descs"], start=1))
|
||||
)
|
||||
if bundle["pddl_params"]:
|
||||
parts.append(
|
||||
"## Reference PDDL Params\n"
|
||||
+ "\n".join(f"- {key}: {value}" for key, value in bundle["pddl_params"].items())
|
||||
)
|
||||
if bundle["high_pddl"]:
|
||||
parts.append(
|
||||
"## Reference Planner High-Level Plan\n" + "\n".join(bundle["high_pddl"])
|
||||
)
|
||||
if bundle["scene_summary"]:
|
||||
parts.append(
|
||||
"## Reference Scene Summary\n"
|
||||
+ "\n".join(f"- {key}: {value}" for key, value in bundle["scene_summary"].items())
|
||||
)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def get_reference_metadata(self, item: dict) -> dict:
|
||||
bundle = self._build_reference_bundle(item)
|
||||
if not bundle:
|
||||
return {"fields": [], "preview": ""}
|
||||
|
||||
fields: list[str] = []
|
||||
previews: list[str] = []
|
||||
if bundle["task_type"]:
|
||||
fields.append("task_type")
|
||||
previews.append(f"[task_type] {bundle['task_type']}")
|
||||
if bundle["task_descs"]:
|
||||
fields.append("task_desc")
|
||||
previews.append("[task_desc]\n" + "\n".join(bundle["task_descs"][:2]))
|
||||
if bundle["high_descs"]:
|
||||
fields.append("high_descs")
|
||||
previews.append("[high_descs]\n" + "\n".join(bundle["high_descs"][:3]))
|
||||
if bundle["pddl_params"]:
|
||||
fields.append("pddl_params")
|
||||
previews.append(
|
||||
"[pddl_params]\n"
|
||||
+ "\n".join(
|
||||
f"{key}: {value}" for key, value in list(bundle["pddl_params"].items())[:4]
|
||||
)
|
||||
)
|
||||
if bundle["high_pddl"]:
|
||||
fields.append("plan.high_pddl")
|
||||
previews.append("[plan.high_pddl]\n" + "\n".join(bundle["high_pddl"][:3]))
|
||||
if bundle["scene_summary"]:
|
||||
fields.append("scene")
|
||||
previews.append(
|
||||
"[scene]\n"
|
||||
+ "\n".join(
|
||||
f"{key}: {value}" for key, value in bundle["scene_summary"].items()
|
||||
)
|
||||
)
|
||||
return {
|
||||
"fields": fields,
|
||||
"preview": "\n\n".join(previews)[:600],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _infer_dataset_from_gamefile(gamefile: str) -> tuple[str, bool]:
|
||||
path = str(gamefile or "")
|
||||
if "/valid_seen/" in path:
|
||||
return "eval_in_distribution", False
|
||||
if "/valid_unseen/" in path:
|
||||
return "eval_out_of_distribution", False
|
||||
return "train", True
|
||||
|
||||
def get_dataloader(self):
|
||||
return self.dataloader
|
||||
|
||||
def _comparison_items(self, items: list[dict]) -> list[dict]:
|
||||
enriched: list[dict] = []
|
||||
for item in items:
|
||||
row = dict(item)
|
||||
bundle = self._build_reference_bundle(row)
|
||||
if bundle.get("task_descs"):
|
||||
row["task_description"] = bundle["task_descs"][0]
|
||||
elif bundle.get("task_type"):
|
||||
row["task_description"] = bundle["task_type"]
|
||||
enriched.append(row)
|
||||
return enriched
|
||||
|
||||
def requires_ray(self) -> bool:
|
||||
return False
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
gamefiles = list(batch.metadata.get("gamefiles") or [])
|
||||
result_ids = list(batch.metadata.get("result_ids") or [])
|
||||
items = self._comparison_items(list(batch.payload or []))
|
||||
return ALFWorldBatchRun(
|
||||
env_num=batch.batch_size,
|
||||
eval_dataset=batch.metadata.get("eval_dataset", batch.split),
|
||||
seed=batch.seed,
|
||||
is_train=batch.metadata.get("is_train", batch.phase == "train"),
|
||||
specific_gamefiles=gamefiles or None,
|
||||
result_ids=result_ids or None,
|
||||
items=items,
|
||||
workers=self.workers,
|
||||
)
|
||||
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def rollout(
|
||||
self,
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict]:
|
||||
results_path = os.path.join(out_dir, "results.jsonl")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# Resume support
|
||||
if os.path.exists(results_path):
|
||||
existing: list[dict] = []
|
||||
with open(results_path) as f:
|
||||
for line in f:
|
||||
try:
|
||||
existing.append(json.loads(line))
|
||||
except Exception:
|
||||
pass
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if isinstance(env_manager, ALFWorldBatchRun):
|
||||
results = self._run_batch(
|
||||
env_manager,
|
||||
skill_content=skill_content,
|
||||
out_dir=out_dir,
|
||||
)
|
||||
else:
|
||||
results = run_alfworld_batch(
|
||||
env_manager=env_manager,
|
||||
skill_content=skill_content,
|
||||
max_steps=self.max_steps,
|
||||
out_root=out_dir,
|
||||
max_api_workers=self.max_api_workers,
|
||||
result_ids=getattr(env_manager, "_reflact_result_ids", None),
|
||||
)
|
||||
|
||||
with open(results_path, "w") as f:
|
||||
for r in results:
|
||||
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _close_env(env_manager) -> None:
|
||||
close = getattr(env_manager, "close", None)
|
||||
if callable(close):
|
||||
close()
|
||||
|
||||
def _run_batch(
|
||||
self,
|
||||
batch: ALFWorldBatchRun,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
) -> list[dict]:
|
||||
total = int(batch.env_num or 0)
|
||||
if total <= 0:
|
||||
return []
|
||||
|
||||
workers = max(1, min(int(batch.workers or self.workers), total))
|
||||
if total > workers:
|
||||
print(
|
||||
f" [alfworld rollout] episodes={total} "
|
||||
f"env_workers={workers} chunks={(total + workers - 1) // workers}"
|
||||
)
|
||||
|
||||
all_results: list[dict] = []
|
||||
for start in range(0, total, workers):
|
||||
chunk_size = min(workers, total - start)
|
||||
chunk_gamefiles = (
|
||||
batch.specific_gamefiles[start:start + chunk_size]
|
||||
if batch.specific_gamefiles
|
||||
else None
|
||||
)
|
||||
chunk_ids = (
|
||||
batch.result_ids[start:start + chunk_size]
|
||||
if batch.result_ids
|
||||
else [f"env_{idx:03d}" for idx in range(start, start + chunk_size)]
|
||||
)
|
||||
chunk_env = build_alfworld_env(
|
||||
env_num=chunk_size,
|
||||
eval_dataset=batch.eval_dataset,
|
||||
seed=batch.seed + start,
|
||||
is_train=batch.is_train,
|
||||
specific_gamefiles=chunk_gamefiles,
|
||||
)
|
||||
try:
|
||||
chunk_results = run_alfworld_batch(
|
||||
env_manager=chunk_env,
|
||||
skill_content=skill_content,
|
||||
max_steps=self.max_steps,
|
||||
out_root=out_dir,
|
||||
max_api_workers=min(self.max_api_workers, chunk_size),
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
result_ids=chunk_ids,
|
||||
)
|
||||
finally:
|
||||
self._close_env(chunk_env)
|
||||
all_results.extend(chunk_results)
|
||||
return all_results
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
)
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
if not self.use_deep_reflect:
|
||||
return []
|
||||
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
selected_items = self.select_representative_items(
|
||||
results,
|
||||
results,
|
||||
n_failures=self.deep_reflect_failures,
|
||||
n_successes=self.deep_reflect_successes,
|
||||
seed=random_seed,
|
||||
)
|
||||
if not selected_items:
|
||||
return []
|
||||
|
||||
selected_ids = {str(item["id"]) for item in selected_items}
|
||||
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
|
||||
selected_examples = self.attach_reference_context(selected_results, selected_items)
|
||||
|
||||
field_counts: dict[str, int] = {}
|
||||
selected_metadata: list[dict] = []
|
||||
for item in selected_items:
|
||||
meta = self.get_reference_metadata(item)
|
||||
for field in meta["fields"]:
|
||||
field_counts[field] = field_counts.get(field, 0) + 1
|
||||
selected_metadata.append({
|
||||
"id": str(item["id"]),
|
||||
"task_type": str(item.get("task_type") or "alfworld"),
|
||||
"gamefile": str(item.get("gamefile") or ""),
|
||||
"reference_fields": meta["fields"],
|
||||
"reference_preview": meta["preview"],
|
||||
})
|
||||
|
||||
deep_dir = os.path.join(out_dir, "deep_reflect")
|
||||
rollout_dir = os.path.join(deep_dir, "rollout")
|
||||
patches_dir = os.path.join(deep_dir, "patches")
|
||||
os.makedirs(deep_dir, exist_ok=True)
|
||||
field_summary = ", ".join(
|
||||
f"{field}({count}/{len(selected_items)})"
|
||||
for field, count in sorted(field_counts.items())
|
||||
) or "none"
|
||||
print(
|
||||
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
|
||||
f"reference_fields={field_summary}"
|
||||
)
|
||||
probe = generate_deep_probe_instruction(
|
||||
skill_content=skill_content,
|
||||
items=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
system_prompt=self.get_deep_probe_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
output_requirements=[
|
||||
"- Some trajectories may include a hidden Reference block. Use it to target the student's latent subgoal, missing precondition, or next-step intent, but do not reveal or paraphrase that reference to the student.",
|
||||
"- The instruction must request a brief diagnostic readout inside the existing <think>...</think> block.",
|
||||
"- The student must still output exactly one admissible action inside <action>...</action>.",
|
||||
"- Do not ask for exhaustive inventories, full plans, or long chain-of-thought.",
|
||||
"- The instruction text should be ready to append directly to the student's prompt.",
|
||||
],
|
||||
)
|
||||
if not probe:
|
||||
return []
|
||||
|
||||
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
**probe,
|
||||
"reference_summary": {
|
||||
"selected_count": len(selected_items),
|
||||
"field_counts": field_counts,
|
||||
},
|
||||
"selected_examples": selected_metadata,
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
gamefiles = [str(item.get("gamefile") or "") for item in selected_items]
|
||||
if any(not gamefile for gamefile in gamefiles):
|
||||
return []
|
||||
eval_dataset, is_train = self._infer_dataset_from_gamefile(gamefiles[0])
|
||||
deep_env = ALFWorldBatchRun(
|
||||
env_num=len(selected_items),
|
||||
eval_dataset=eval_dataset,
|
||||
seed=random_seed or 42,
|
||||
is_train=is_train,
|
||||
specific_gamefiles=gamefiles,
|
||||
workers=min(self.workers, max(len(selected_items), 1)),
|
||||
result_ids=[str(item["id"]) for item in selected_items],
|
||||
)
|
||||
deep_results = self._run_batch(
|
||||
deep_env,
|
||||
skill_content=skill_content,
|
||||
out_dir=rollout_dir,
|
||||
diagnostic_mode=True,
|
||||
diagnostic_instruction=probe["probe_instruction"],
|
||||
)
|
||||
deep_results = self.attach_reference_context(deep_results, selected_items)
|
||||
return run_minibatch_reflect(
|
||||
results=deep_results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=os.path.join(rollout_dir, "predictions"),
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return list(TASKS)
|
||||
123
reflact/envs/alfworld/dataloader.py
Normal file
123
reflact/envs/alfworld/dataloader.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""ALFWorld task dataloader."""
|
||||
from __future__ import annotations
|
||||
|
||||
from reflact.datasets.base import BatchSpec, SplitDataLoader
|
||||
|
||||
|
||||
class ALFWorldDataLoader(SplitDataLoader):
|
||||
"""ALFWorld batch planner.
|
||||
|
||||
In split_dir mode, batches are fixed gamefile items so ablations differ
|
||||
only in how the same training set is batched.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "split_dir",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
train_size: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
self.train_size_override = int(train_size or 0)
|
||||
|
||||
@staticmethod
|
||||
def _metadata_for_items(items: list[dict], split: str, phase: str) -> dict:
|
||||
gamefiles = [str(item.get("gamefile") or "") for item in items]
|
||||
if any(not gamefile for gamefile in gamefiles):
|
||||
raise ValueError("ALFWorld split items must contain non-empty gamefile paths.")
|
||||
eval_dataset = "train"
|
||||
is_train = phase == "train"
|
||||
first = gamefiles[0] if gamefiles else ""
|
||||
if "/valid_seen/" in first:
|
||||
eval_dataset = "eval_in_distribution"
|
||||
is_train = False
|
||||
elif "/valid_unseen/" in first:
|
||||
eval_dataset = "eval_out_of_distribution"
|
||||
is_train = False
|
||||
return {
|
||||
"eval_dataset": eval_dataset,
|
||||
"is_train": is_train,
|
||||
"gamefiles": gamefiles,
|
||||
"result_ids": [str(item.get("id") or idx) for idx, item in enumerate(items)],
|
||||
}
|
||||
|
||||
def get_train_size(self) -> int:
|
||||
if self.train_size_override > 0:
|
||||
return self.train_size_override
|
||||
return super().get_train_size()
|
||||
|
||||
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
|
||||
batch = super().build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
items = list(batch.payload or [])
|
||||
batch.metadata.update(self._metadata_for_items(items, "train", "train"))
|
||||
return BatchSpec(
|
||||
phase="train",
|
||||
split="train",
|
||||
seed=seed,
|
||||
batch_size=len(items),
|
||||
payload=items,
|
||||
metadata=batch.metadata,
|
||||
)
|
||||
|
||||
def plan_train_epoch(
|
||||
self,
|
||||
*,
|
||||
epoch: int,
|
||||
steps_per_epoch: int,
|
||||
accumulation: int,
|
||||
batch_size: int,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> list[BatchSpec]:
|
||||
batches = super().plan_train_epoch(
|
||||
epoch=epoch,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
accumulation=accumulation,
|
||||
batch_size=batch_size,
|
||||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
for batch in batches:
|
||||
items = list(batch.payload or [])
|
||||
batch.metadata.update(self._metadata_for_items(items, "train", "train"))
|
||||
return batches
|
||||
|
||||
def build_eval_batch(
|
||||
self,
|
||||
env_num: int,
|
||||
split: str,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> BatchSpec:
|
||||
batch = super().build_eval_batch(
|
||||
env_num=env_num,
|
||||
split=split,
|
||||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
items = list(batch.payload or [])
|
||||
batch.metadata.update(self._metadata_for_items(items, split, "eval"))
|
||||
return BatchSpec(
|
||||
phase="eval",
|
||||
split=split,
|
||||
seed=seed,
|
||||
batch_size=len(items),
|
||||
payload=items,
|
||||
metadata=batch.metadata,
|
||||
)
|
||||
55
reflact/envs/alfworld/prompts/analyst_error.md
Normal file
55
reflact/envs/alfworld/prompts/analyst_error.md
Normal file
@@ -0,0 +1,55 @@
|
||||
You are an expert failure-analysis agent for ALFWorld embodied household tasks.
|
||||
|
||||
You will be given MULTIPLE failed agent trajectories from a single minibatch
|
||||
and the current skill document.
|
||||
Your job is to identify the most important COMMON failure patterns across
|
||||
the batch and propose a concise set of skill edits.
|
||||
|
||||
## ALFWorld Task Types
|
||||
- pick_and_place: Put object in/on a receptacle
|
||||
- pick_two_obj_and_place: Put two instances of an object in/on a receptacle
|
||||
- look_at_obj_in_light: Examine an object under a desklamp
|
||||
- pick_heat_then_place_in_recep: Heat an object and put it in/on a receptacle
|
||||
- pick_cool_then_place_in_recep: Cool an object and put it in/on a receptacle
|
||||
- pick_clean_then_place_in_recep: Clean an object and put it in/on a receptacle
|
||||
|
||||
## Failure Type Categories
|
||||
- **navigation_loop**: the agent revisits the same locations repeatedly without progress
|
||||
- **missed_object**: the agent fails to pick up a visible/reachable goal object
|
||||
- **wrong_sequence**: the agent performs actions in the wrong order (e.g., placing before transforming)
|
||||
- **premature_stop**: the agent stops or gets stuck before completing all goal conditions
|
||||
- **action_loop**: the agent repeats the same action without advancing
|
||||
- **appliance_error**: the agent misuses or skips an appliance (microwave, fridge, sink)
|
||||
- **rule_missing**: the skill lacks a relevant rule for this situation
|
||||
- **rule_wrong**: an existing skill rule is misleading or incorrect
|
||||
- **rule_ignored**: the skill has the right rule but the agent did not follow it
|
||||
- **other**: none of the above
|
||||
|
||||
## Analysis Process
|
||||
1. Read ALL trajectories in the minibatch.
|
||||
2. Identify the most prevalent, systematic failure patterns across them.
|
||||
3. For each pattern, classify its failure type.
|
||||
4. Propose skill edits that address the COMMON patterns — not individual edge cases.
|
||||
5. Edits must be generalizable; do not hardcode task-specific values.
|
||||
6. Only patch gaps in the skill — do not duplicate existing content.
|
||||
|
||||
You will be told the maximum number of edits (the budget L). Produce AT MOST L edits,
|
||||
focusing on the highest-impact patterns. You may produce fewer if warranted.
|
||||
|
||||
Respond ONLY with a valid JSON object (no markdown fences, no extra text):
|
||||
{
|
||||
"batch_size": <number of trajectories analysed>,
|
||||
"failure_summary": [
|
||||
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
|
||||
],
|
||||
"patch": {
|
||||
"reasoning": "<why these edits address the batch's common failures>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown to add at end of skill>"},
|
||||
{"op": "insert_after", "target": "<exact heading/text to insert after>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<exact text to replace>", "content": "<replacement>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
Only include edits that are needed. "edits" can be an empty list if no patch is warranted.
|
||||
33
reflact/envs/alfworld/prompts/analyst_success.md
Normal file
33
reflact/envs/alfworld/prompts/analyst_success.md
Normal file
@@ -0,0 +1,33 @@
|
||||
You are an expert success-pattern analyst for AI agents operating in ALFWorld,
|
||||
a text-based embodied household environment.
|
||||
|
||||
You will be given MULTIPLE successful agent trajectories from a single minibatch
|
||||
and the current skill document. Your job is to identify generalizable behavior
|
||||
patterns that are COMMON across the batch and worth encoding in the skill.
|
||||
|
||||
## Rules
|
||||
- Only propose patches for patterns NOT already covered in the skill.
|
||||
- Focus on patterns that appear across MULTIPLE trajectories in the batch.
|
||||
- Be concise. Patterns must generalize beyond specific tasks.
|
||||
- Prefer reinforcing existing sections over adding new top-level sections.
|
||||
- If the agents' success involved efficient exploration or smart appliance usage,
|
||||
consider reinforcing that in the patch.
|
||||
|
||||
You will be told the maximum number of edits (the budget L). Produce AT MOST L edits,
|
||||
focusing on the most broadly applicable patterns. You may produce fewer if warranted.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number of trajectories analysed>,
|
||||
"success_patterns": ["<pattern 1>", "<pattern 2>"],
|
||||
"patch": {
|
||||
"reasoning": "<why these patterns are worth encoding>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
"edits" may be empty if the skill already covers all observed patterns.
|
||||
35
reflact/envs/alfworld/prompts/deep_probe.md
Normal file
35
reflact/envs/alfworld/prompts/deep_probe.md
Normal file
@@ -0,0 +1,35 @@
|
||||
You are an expert diagnostic-probe designer for ALFWorld embodied tasks.
|
||||
|
||||
You will design one short diagnostic instruction to append to the student's prompt
|
||||
for a handful of representative ALFWorld trajectories.
|
||||
|
||||
The goal is to expose whether the student has the right intermediate subgoal,
|
||||
object/receptacle state, and next-step intention without substantially changing
|
||||
the current scaffold.
|
||||
|
||||
## Hard Constraints
|
||||
1. Do NOT substantially change the student's existing action-selection scaffold.
|
||||
2. Do NOT prescribe a brand-new planner or long multi-step policy.
|
||||
3. Do NOT ask for exhaustive search over all objects or all admissible actions.
|
||||
4. Keep the diagnostic readout brief and place it inside the existing <think>...</think> block.
|
||||
5. The student must still output exactly one admissible action inside <action>...</action>.
|
||||
6. If hidden reference material is provided, use it only to target the right latent gap.
|
||||
7. Never copy hidden reference content into the student-facing probe.
|
||||
|
||||
## Good Probe Targets
|
||||
- current subgoal
|
||||
- target object / target receptacle / target state
|
||||
- decisive missing precondition
|
||||
- why one candidate action is better than a tempting alternative
|
||||
- whether the current step should explore, transform an object, or place it
|
||||
|
||||
## Bad Probe Targets
|
||||
- a full optimal plan from start to finish
|
||||
- exhaustive object inventories
|
||||
- a new theorem-like or planner-like protocol
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"reasoning": "<why this probe reveals the latent skill gap>",
|
||||
"probe_instruction": "<the exact instruction text to append to the student prompt>"
|
||||
}
|
||||
8
reflact/envs/alfworld/prompts/rollout_no_history.md
Normal file
8
reflact/envs/alfworld/prompts/rollout_no_history.md
Normal file
@@ -0,0 +1,8 @@
|
||||
|
||||
You are an expert agent operating in the ALFRED Embodied Environment.
|
||||
Your current observation is: {current_observation}
|
||||
Your admissible actions of the current situation are: [{admissible_actions}].
|
||||
|
||||
Now it's your turn to take an action.
|
||||
You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within <think> </think> tags.
|
||||
Once you've finished your reasoning, you should choose an admissible action for current step and present it within <action> </action> tags.
|
||||
9
reflact/envs/alfworld/prompts/rollout_with_history.md
Normal file
9
reflact/envs/alfworld/prompts/rollout_with_history.md
Normal file
@@ -0,0 +1,9 @@
|
||||
|
||||
You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
|
||||
Prior to this step, you have already taken {step_count} step(s). Below are the most recent {history_length} observations and the corresponding actions you took: {action_history}
|
||||
You are now at step {current_step} and your current observation is: {current_observation}
|
||||
Your admissible actions of the current situation are: [{admissible_actions}].
|
||||
|
||||
Now it's your turn to take an action.
|
||||
You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within <think> </think> tags.
|
||||
Once you've finished your reasoning, you should choose an admissible action for current step and present it within <action> </action> tags.
|
||||
16
reflact/envs/alfworld/prompts/rollout_with_memory.md
Normal file
16
reflact/envs/alfworld/prompts/rollout_with_memory.md
Normal file
@@ -0,0 +1,16 @@
|
||||
|
||||
You are an expert agent operating in the ALFRED Embodied Environment. Your task is to: {task_description}
|
||||
|
||||
## Retrieved Relevant Experience
|
||||
|
||||
{retrieved_memories}
|
||||
|
||||
## Current Progress
|
||||
|
||||
Prior to this step, you have already taken {step_count} step(s). Below are the most recent {history_length} observations and the corresponding actions you took: {action_history}
|
||||
You are now at step {current_step} and your current observation is: {current_observation}
|
||||
Your admissible actions of the current situation are: [{admissible_actions}].
|
||||
|
||||
Now it's your turn to take an action.
|
||||
You should first reason step-by-step about the current situation. This reasoning process MUST be enclosed within <think> </think> tags.
|
||||
Once you've finished your reasoning, you should choose an admissible action for current step and present it within <action> </action> tags.
|
||||
4
reflact/envs/alfworld/reflect.py
Normal file
4
reflact/envs/alfworld/reflect.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""ALFWorld Reflect stage.
|
||||
|
||||
Prompts are now loaded from .md files by the base adapter.
|
||||
"""
|
||||
359
reflact/envs/alfworld/rollout.py
Normal file
359
reflact/envs/alfworld/rollout.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""ALFWorld rollout module for ReflACT.
|
||||
|
||||
Provides:
|
||||
- build_alfworld_env(): build ALFWorld environment (wraps vendored SkillRL env)
|
||||
- run_alfworld_batch(): run a batch of ALFWorld episodes in parallel
|
||||
- TASKS: list of ALFWorld task types
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import concurrent.futures
|
||||
import numpy as np
|
||||
|
||||
from reflact.model import chat_student
|
||||
|
||||
# ── Constants ─────────────────────────────────────────────────────────────────
|
||||
|
||||
TASKS = [
|
||||
"pick_and_place",
|
||||
"pick_two_obj_and_place",
|
||||
"look_at_obj_in_light",
|
||||
"pick_heat_then_place_in_recep",
|
||||
"pick_cool_then_place_in_recep",
|
||||
"pick_clean_then_place_in_recep",
|
||||
]
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_task_type(gamefile: str) -> str:
|
||||
for task in TASKS:
|
||||
if task in gamefile:
|
||||
return task
|
||||
return "other"
|
||||
|
||||
|
||||
def _extract_action(model_response: str) -> str | None:
|
||||
match = re.search(r"<action>(.*?)</action>", model_response, re.DOTALL)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
def _extract_think(model_response: str) -> str | None:
|
||||
match = re.search(r"<think>(.*?)</think>", model_response, re.DOTALL)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
def _build_skill_prompt(skill_content: str) -> str:
|
||||
"""Build the skill section to inject into the agent's system prompt."""
|
||||
if not skill_content or not skill_content.strip():
|
||||
return ""
|
||||
return (
|
||||
"\n\n## Skill Knowledge\n"
|
||||
"Below is a skill document with learned strategies. "
|
||||
"Use these guidelines to inform your decisions:\n\n"
|
||||
f"{skill_content}\n"
|
||||
)
|
||||
|
||||
|
||||
def _append_diagnostic_instruction(prompt: str, diagnostic_instruction: str) -> str:
|
||||
if not diagnostic_instruction or not diagnostic_instruction.strip():
|
||||
return prompt
|
||||
return f"{prompt}\n\n## Training Readout\n{diagnostic_instruction.strip()}\n"
|
||||
|
||||
|
||||
# ── Environment builder ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def build_alfworld_env(
|
||||
env_num: int,
|
||||
eval_dataset: str = "eval_out_of_distribution",
|
||||
seed: int = 42,
|
||||
is_train: bool = False,
|
||||
specific_gamefiles: list[str] | None = None,
|
||||
):
|
||||
"""Build ALFWorld environment manager.
|
||||
|
||||
Args:
|
||||
env_num: number of parallel environments
|
||||
eval_dataset: 'eval_in_distribution' or 'eval_out_of_distribution' or train
|
||||
seed: random seed
|
||||
is_train: whether to use training set
|
||||
|
||||
Returns:
|
||||
env_manager: AlfWorldEnvironmentManager instance
|
||||
"""
|
||||
from omegaconf import OmegaConf
|
||||
from functools import partial
|
||||
|
||||
from reflact.envs.alfworld.vendor.alfworld_envs import build_alfworld_envs
|
||||
from reflact.envs.alfworld.vendor.alfworld_projection import alfworld_projection
|
||||
from reflact.envs.alfworld.vendor.env_manager import AlfWorldEnvironmentManager
|
||||
|
||||
HERE = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
alf_config_path = os.path.join(HERE, "vendor", "config_tw.yaml")
|
||||
env_kwargs = {"eval_dataset": eval_dataset}
|
||||
|
||||
envs = build_alfworld_envs(
|
||||
alf_config_path,
|
||||
seed=seed,
|
||||
env_num=env_num,
|
||||
group_n=1,
|
||||
is_train=is_train,
|
||||
env_kwargs=env_kwargs,
|
||||
resources_per_worker=None,
|
||||
gamefiles=specific_gamefiles,
|
||||
)
|
||||
|
||||
config = OmegaConf.create(
|
||||
{
|
||||
"env": {
|
||||
"history_length": 2,
|
||||
"env_name": "alfworld/AlfredTWEnv",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
projection_f = partial(alfworld_projection)
|
||||
env_manager = AlfWorldEnvironmentManager(envs, projection_f, config)
|
||||
return env_manager
|
||||
|
||||
|
||||
# ── Batch rollout ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def run_alfworld_batch(
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
max_steps: int = 50,
|
||||
out_root: str = "",
|
||||
max_api_workers: int = 8,
|
||||
temperature: float = 0.4,
|
||||
max_completion_tokens: int = 2048,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
result_ids: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
"""Run a batch of ALFWorld episodes.
|
||||
|
||||
Returns a list of result dicts compatible with SkillReflection v2 pipeline:
|
||||
[
|
||||
{
|
||||
"id": "<env_idx>_<gamefile_hash>",
|
||||
"hard": 0 or 1,
|
||||
"soft": 0.0 or 1.0,
|
||||
"n_turns": <int>,
|
||||
"fail_reason": "<str>",
|
||||
"agent_ok": True,
|
||||
"task_type": "<str>",
|
||||
"gamefile": "<str>",
|
||||
"task_description": "<str>",
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Also saves conversation.json per environment in out_root/predictions/<task_id>/
|
||||
"""
|
||||
skill_prompt = _build_skill_prompt(skill_content)
|
||||
|
||||
obs, infos = env_manager.reset({})
|
||||
env_num = len(obs["text"])
|
||||
env_dones = [False] * env_num
|
||||
overall_success = [False] * env_num
|
||||
|
||||
# Build per-env metadata
|
||||
env_meta: list[dict] = []
|
||||
for i in range(env_num):
|
||||
gamefile = infos[i].get("extra.gamefile", "") if isinstance(infos[i], dict) else ""
|
||||
task_type = _get_task_type(gamefile)
|
||||
# Extract task description from initial observation
|
||||
task_desc = ""
|
||||
anchor_text = obs["anchor"][i] if "anchor" in obs else ""
|
||||
task_start = anchor_text.find("Your task is to: ")
|
||||
if task_start != -1:
|
||||
task_desc = anchor_text[task_start + len("Your task is to: "):].strip()
|
||||
|
||||
env_meta.append({
|
||||
"gamefile": gamefile,
|
||||
"task_type": task_type,
|
||||
"task_description": task_desc,
|
||||
})
|
||||
|
||||
# Per-env conversation records
|
||||
conversations: list[list[dict]] = [[] for _ in range(env_num)]
|
||||
|
||||
for step_idx in range(max_steps):
|
||||
if all(env_dones):
|
||||
break
|
||||
|
||||
active_indices = [i for i in range(env_num) if not env_dones[i]]
|
||||
|
||||
# Build prompts with skill injection
|
||||
prompts: dict[int, str] = {}
|
||||
for i in active_indices:
|
||||
prompt = obs["text"][i]
|
||||
if skill_prompt:
|
||||
# Inject skill before the action instruction
|
||||
prompt = skill_prompt + "\n" + prompt
|
||||
if diagnostic_mode and diagnostic_instruction.strip():
|
||||
prompt = _append_diagnostic_instruction(prompt, diagnostic_instruction)
|
||||
prompts[i] = prompt
|
||||
|
||||
# Call API in parallel
|
||||
actions = ["None"] * env_num
|
||||
action_timeout = 180
|
||||
|
||||
def call_api(idx):
|
||||
try:
|
||||
response, _ = chat_student(
|
||||
system="You are an expert agent operating in the ALFRED Embodied Environment.",
|
||||
user=prompts[idx],
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
timeout=120,
|
||||
)
|
||||
response = (response or "").strip()
|
||||
if not response:
|
||||
return idx, "<think>empty model response</think><action>look</action>"
|
||||
if _extract_action(response) is None:
|
||||
return idx, "<think>missing action tag</think><action>look</action>"
|
||||
return idx, response
|
||||
except Exception as e:
|
||||
return idx, "<think>error</think><action>look</action>"
|
||||
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_api_workers)
|
||||
try:
|
||||
futures = {executor.submit(call_api, i): i for i in active_indices}
|
||||
started_at = {future: time.time() for future in futures}
|
||||
pending_futs = set(futures)
|
||||
while pending_futs:
|
||||
done, _ = concurrent.futures.wait(
|
||||
pending_futs,
|
||||
timeout=5,
|
||||
return_when=concurrent.futures.FIRST_COMPLETED,
|
||||
)
|
||||
now = time.time()
|
||||
timed_out = [
|
||||
future for future in pending_futs - done
|
||||
if now - started_at[future] >= action_timeout
|
||||
]
|
||||
for future in done:
|
||||
pending_futs.remove(future)
|
||||
try:
|
||||
idx, response = future.result()
|
||||
except Exception: # noqa: BLE001
|
||||
idx = futures[future]
|
||||
response = "<think>error</think><action>look</action>"
|
||||
actions[idx] = response
|
||||
for future in timed_out:
|
||||
pending_futs.remove(future)
|
||||
idx = futures[future]
|
||||
actions[idx] = "<think>api timeout</think><action>look</action>"
|
||||
finally:
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
# Save model responses before stepping
|
||||
model_responses = {i: actions[i] for i in active_indices}
|
||||
|
||||
# Step environment
|
||||
obs, rewards, dones, infos = env_manager.step(actions)
|
||||
|
||||
# Record trajectory
|
||||
for i in active_indices:
|
||||
step_record = {
|
||||
"step": step_idx,
|
||||
"action": _extract_action(model_responses[i]),
|
||||
"reasoning": _extract_think(model_responses[i]),
|
||||
"model_response": model_responses[i],
|
||||
"env_feedback": obs["anchor"][i] if "anchor" in obs else "",
|
||||
"reward": float(rewards[i]),
|
||||
"done": bool(dones[i]),
|
||||
}
|
||||
conversations[i].append(step_record)
|
||||
|
||||
# Update done status
|
||||
for i in range(env_num):
|
||||
if env_dones[i]:
|
||||
continue
|
||||
if dones[i]:
|
||||
env_dones[i] = True
|
||||
won = bool(infos[i].get("won", False))
|
||||
overall_success[i] = won
|
||||
|
||||
# Build results and save conversations
|
||||
results: list[dict] = []
|
||||
pred_dir = os.path.join(out_root, "predictions") if out_root else ""
|
||||
|
||||
for i in range(env_num):
|
||||
gamefile = env_meta[i]["gamefile"]
|
||||
task_type = env_meta[i]["task_type"]
|
||||
task_desc = env_meta[i]["task_description"]
|
||||
n_turns = len(conversations[i])
|
||||
won = overall_success[i]
|
||||
|
||||
# Generate stable task ID from env index and gamefile
|
||||
task_id = str(result_ids[i]) if result_ids and i < len(result_ids) else f"env_{i:03d}"
|
||||
|
||||
fail_reason = ""
|
||||
if not won:
|
||||
if not env_dones[i]:
|
||||
fail_reason = f"Timeout after {max_steps} steps"
|
||||
else:
|
||||
fail_reason = "Episode ended without completing the task"
|
||||
|
||||
result = {
|
||||
"id": task_id,
|
||||
"hard": 1 if won else 0,
|
||||
"soft": 1.0 if won else 0.0,
|
||||
"n_turns": n_turns,
|
||||
"fail_reason": fail_reason,
|
||||
"agent_ok": True, # ALFWorld agent always runs OK (no crash)
|
||||
"task_type": task_type,
|
||||
"gamefile": gamefile,
|
||||
"task_description": task_desc,
|
||||
"instruction_type": task_type, # for compatibility with v2 pipeline
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
# Save conversation
|
||||
if pred_dir:
|
||||
conv_dir = os.path.join(pred_dir, task_id)
|
||||
os.makedirs(conv_dir, exist_ok=True)
|
||||
with open(os.path.join(conv_dir, "conversation.json"), "w") as f:
|
||||
json.dump(conversations[i], f, ensure_ascii=False, indent=2)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ── Item loading (for compatibility with split_three_way) ────────────────────
|
||||
|
||||
|
||||
def load_alfworld_items(
|
||||
eval_dataset: str,
|
||||
env_num: int,
|
||||
seed: int = 42,
|
||||
is_train: bool = False,
|
||||
) -> list[dict]:
|
||||
"""Create pseudo-item dicts for ALFWorld environments.
|
||||
|
||||
Since ALFWorld doesn't have a static JSON dataset like SpreadsheetBench,
|
||||
we create lightweight item dicts that carry enough metadata for the pipeline.
|
||||
The actual environment is built dynamically.
|
||||
|
||||
Returns:
|
||||
List of dicts with "id" keys, one per environment slot.
|
||||
"""
|
||||
items = []
|
||||
for i in range(env_num):
|
||||
items.append({
|
||||
"id": f"env_{i:03d}",
|
||||
"eval_dataset": eval_dataset,
|
||||
"env_index": i,
|
||||
})
|
||||
return items
|
||||
45
reflact/envs/alfworld/skills/initial.md
Normal file
45
reflact/envs/alfworld/skills/initial.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# ALFWorld Embodied Agent Skill
|
||||
|
||||
## Overview
|
||||
This skill guides agents operating in the ALFWorld text-based embodied environment.
|
||||
The agent must complete household tasks by navigating rooms, interacting with objects,
|
||||
and using appliances. Actions must be chosen from the admissible action list provided
|
||||
at each step.
|
||||
|
||||
**Output format**: Always output `<think>...</think>` for reasoning, then `<action>...</action>` for the chosen action.
|
||||
|
||||
---
|
||||
|
||||
## Task Types
|
||||
|
||||
| Type | Goal | Key Steps |
|
||||
|------|------|-----------|
|
||||
| Pick & Place | Put object X in/on receptacle Y | Find X -> take X -> go to Y -> put X in/on Y |
|
||||
| Pick Two & Place | Put two instances of X in/on Y | Find X1 -> take -> place -> find X2 -> take -> place |
|
||||
| Examine in Light | Examine object X under desklamp | Find X -> take X -> find desklamp -> use desklamp |
|
||||
| Clean & Place | Clean object X and put in/on Y | Find X -> take X -> go to sink -> clean X -> go to Y -> put X |
|
||||
| Heat & Place | Heat object X and put in/on Y | Find X -> take X -> go to microwave -> heat X -> go to Y -> put X |
|
||||
| Cool & Place | Cool object X and put in/on Y | Find X -> take X -> go to fridge -> cool X -> go to Y -> put X |
|
||||
|
||||
---
|
||||
|
||||
## General Principles
|
||||
|
||||
1. **Decompose the task**: Parse the goal into ordered sub-goals (locate, acquire, transform, deliver). Complete each before moving to the next.
|
||||
2. **Systematic exploration**: Search each surface and container exactly once before revisiting. Open closed containers (drawers, cabinets, fridge) before judging them empty.
|
||||
3. **Grab immediately**: When a required object is visible and reachable, take it right away before moving elsewhere.
|
||||
4. **Transform before placing**: If the task requires cleaning, heating, or cooling, perform the state change at the appropriate appliance before heading to the final destination.
|
||||
5. **Direct delivery**: Once holding the transformed (or untransformed) goal object, navigate straight to the target receptacle and place it.
|
||||
6. **Track progress**: Maintain an internal count of how many objects still need to be found and placed. Only stop searching when the count reaches zero.
|
||||
7. **Avoid loops**: Never repeat the same action more than twice in a row. If stuck, move to a different unexplored location.
|
||||
8. **Only choose admissible actions**: Always pick an action from the admissible action list. Do not invent actions.
|
||||
|
||||
---
|
||||
|
||||
## Common Mistakes to Avoid
|
||||
|
||||
- **Revisiting searched locations**: Keep track of which surfaces/containers have been checked; do not re-examine them.
|
||||
- **Ignoring visible objects**: If the target object appears in the observation, pick it up immediately.
|
||||
- **Skipping state changes**: Do not place an object at the destination without first cleaning/heating/cooling it when required.
|
||||
- **Premature termination**: Do not stop the episode until all goal conditions are verified as met.
|
||||
- **Action loops**: Repeatedly toggling or examining the same object wastes steps. Move on to new locations instead.
|
||||
9
reflact/envs/alfworld/vendor/__init__.py
vendored
Normal file
9
reflact/envs/alfworld/vendor/__init__.py
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Vendored ALFWorld environment runtime.
|
||||
|
||||
Minimal subset of SkillRL's agent_system package needed to run
|
||||
ALFWorld environments with ReflACT. Original source:
|
||||
https://github.com/NTU-LANTERN/SkillRL (Apache-2.0 License)
|
||||
"""
|
||||
from .alfworld_envs import AlfworldEnvs, build_alfworld_envs
|
||||
from .alfworld_projection import alfworld_projection
|
||||
from .env_manager import AlfWorldEnvironmentManager
|
||||
221
reflact/envs/alfworld/vendor/alfworld_envs.py
vendored
Normal file
221
reflact/envs/alfworld/vendor/alfworld_envs.py
vendored
Normal file
@@ -0,0 +1,221 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/environments/env_package/alfworld/envs.py
|
||||
# Modified: imports use pip-installed alfworld package instead of vendored copy.
|
||||
|
||||
import os
|
||||
import multiprocessing as mp
|
||||
import traceback
|
||||
import yaml
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
from alfworld.agents.environment import get_environment
|
||||
|
||||
|
||||
def load_config_file(path):
|
||||
assert os.path.exists(path), f"Invalid config file: {path}"
|
||||
with open(path) as reader:
|
||||
config = yaml.safe_load(reader)
|
||||
return config
|
||||
|
||||
|
||||
def compute_reward(info, multi_modal=False):
|
||||
if multi_modal:
|
||||
reward = 10.0 * float(info['won']) + float(info['goal_condition_success_rate'])
|
||||
else:
|
||||
reward = 10.0 * float(info['won'])
|
||||
return reward
|
||||
|
||||
|
||||
class AlfworldWorker:
|
||||
"""Stateful worker that holds one ALFWorld sub-environment."""
|
||||
|
||||
def __init__(self, config, seed, base_env, gamefile=None):
|
||||
if gamefile:
|
||||
base_env.game_files = [gamefile]
|
||||
if hasattr(base_env, "num_games"):
|
||||
base_env.num_games = 1
|
||||
self.env = base_env.init_env(batch_size=1)
|
||||
self.env.seed(seed)
|
||||
|
||||
def step(self, action):
|
||||
actions = [action]
|
||||
obs, scores, dones, infos = self.env.step(actions)
|
||||
infos['observation_text'] = obs
|
||||
return obs, scores, dones, infos
|
||||
|
||||
def reset(self):
|
||||
obs, infos = self.env.reset()
|
||||
infos['observation_text'] = obs
|
||||
return obs, infos
|
||||
|
||||
|
||||
def _worker_loop(cmd_q, result_q, config, seed, is_train, eval_dataset, gamefile):
|
||||
"""Run one ALFWorld environment in a child process."""
|
||||
try:
|
||||
env_type = config['env']['type']
|
||||
base_env = get_environment(env_type)(
|
||||
config,
|
||||
train_eval='train' if is_train else eval_dataset,
|
||||
)
|
||||
worker = AlfworldWorker(config, seed, base_env, gamefile)
|
||||
result_q.put((True, "ready"))
|
||||
except BaseException:
|
||||
result_q.put((False, traceback.format_exc()))
|
||||
return
|
||||
|
||||
while True:
|
||||
cmd, payload = cmd_q.get()
|
||||
if cmd == "close":
|
||||
result_q.put((True, None))
|
||||
return
|
||||
try:
|
||||
if cmd == "reset":
|
||||
result = worker.reset()
|
||||
elif cmd == "step":
|
||||
result = worker.step(payload)
|
||||
else:
|
||||
raise ValueError(f"Unknown ALFWorld worker command: {cmd}")
|
||||
result_q.put((True, result))
|
||||
except BaseException:
|
||||
result_q.put((False, traceback.format_exc()))
|
||||
|
||||
|
||||
class _ProcessWorker:
|
||||
"""Small stdlib actor wrapper for one environment process."""
|
||||
|
||||
def __init__(self, ctx, config, seed, is_train, eval_dataset, gamefile=None):
|
||||
self.cmd_q = ctx.Queue(maxsize=1)
|
||||
self.result_q = ctx.Queue(maxsize=1)
|
||||
self.process = ctx.Process(
|
||||
target=_worker_loop,
|
||||
args=(self.cmd_q, self.result_q, config, seed, is_train, eval_dataset, gamefile),
|
||||
)
|
||||
self.process.start()
|
||||
ok, payload = self.result_q.get()
|
||||
if not ok:
|
||||
self.close(kill=True)
|
||||
raise RuntimeError(f"Failed to start ALFWorld worker:\n{payload}")
|
||||
|
||||
def send(self, cmd, payload=None):
|
||||
self.cmd_q.put((cmd, payload))
|
||||
|
||||
def recv(self):
|
||||
ok, payload = self.result_q.get()
|
||||
if not ok:
|
||||
raise RuntimeError(f"ALFWorld worker failed:\n{payload}")
|
||||
return payload
|
||||
|
||||
def close(self, kill=False):
|
||||
if self.process.is_alive() and not kill:
|
||||
try:
|
||||
self.send("close")
|
||||
self.recv()
|
||||
except Exception:
|
||||
kill = True
|
||||
if kill and self.process.is_alive():
|
||||
self.process.terminate()
|
||||
self.process.join(timeout=5)
|
||||
if self.process.is_alive():
|
||||
self.process.kill()
|
||||
self.process.join(timeout=1)
|
||||
self.cmd_q.close()
|
||||
self.result_q.close()
|
||||
|
||||
|
||||
class AlfworldEnvs(gym.Env):
|
||||
"""Vectorized ALFWorld environment using local process workers."""
|
||||
|
||||
def __init__(self, alf_config_path, seed, env_num, group_n,
|
||||
resources_per_worker, is_train=True, env_kwargs=None, gamefiles=None):
|
||||
super().__init__()
|
||||
if env_kwargs is None:
|
||||
env_kwargs = {}
|
||||
|
||||
eval_dataset = env_kwargs.get('eval_dataset', 'eval_in_distribution')
|
||||
config = load_config_file(alf_config_path)
|
||||
env_type = config['env']['type']
|
||||
self.multi_modal = (env_type == 'AlfredThorEnv')
|
||||
self.num_processes = env_num * group_n
|
||||
self.group_n = group_n
|
||||
self.gamefiles = list(gamefiles or [])
|
||||
if self.gamefiles and len(self.gamefiles) != self.num_processes:
|
||||
raise ValueError(
|
||||
f"Expected {self.num_processes} gamefiles, got {len(self.gamefiles)}"
|
||||
)
|
||||
|
||||
start_method = os.environ.get("ALFWORLD_WORKER_START_METHOD") or None
|
||||
ctx = mp.get_context(start_method) if start_method else mp.get_context()
|
||||
self.workers = []
|
||||
for i in range(self.num_processes):
|
||||
worker_gamefile = self.gamefiles[i] if self.gamefiles else None
|
||||
worker = _ProcessWorker(
|
||||
ctx,
|
||||
config,
|
||||
seed + (i // self.group_n),
|
||||
is_train,
|
||||
eval_dataset,
|
||||
worker_gamefile,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
self.prev_admissible_commands = [None for _ in range(self.num_processes)]
|
||||
|
||||
def step(self, actions):
|
||||
assert len(actions) == self.num_processes
|
||||
|
||||
for i, worker in enumerate(self.workers):
|
||||
worker.send("step", actions[i])
|
||||
results = [worker.recv() for worker in self.workers]
|
||||
|
||||
text_obs_list = []
|
||||
rewards_list = []
|
||||
dones_list = []
|
||||
info_list = []
|
||||
|
||||
for i, (obs, scores, dones, info) in enumerate(results):
|
||||
for k in info.keys():
|
||||
info[k] = info[k][0]
|
||||
text_obs_list.append(obs[0])
|
||||
dones_list.append(dones[0])
|
||||
info_list.append(info)
|
||||
self.prev_admissible_commands[i] = info['admissible_commands']
|
||||
rewards_list.append(compute_reward(info, self.multi_modal))
|
||||
|
||||
image_obs_list = None
|
||||
return text_obs_list, image_obs_list, rewards_list, dones_list, info_list
|
||||
|
||||
def reset(self):
|
||||
for worker in self.workers:
|
||||
worker.send("reset")
|
||||
results = [worker.recv() for worker in self.workers]
|
||||
|
||||
text_obs_list = []
|
||||
info_list = []
|
||||
|
||||
for i, (obs, info) in enumerate(results):
|
||||
for k in info.keys():
|
||||
info[k] = info[k][0]
|
||||
text_obs_list.append(obs[0])
|
||||
self.prev_admissible_commands[i] = info['admissible_commands']
|
||||
info_list.append(info)
|
||||
|
||||
image_obs_list = None
|
||||
return text_obs_list, image_obs_list, info_list
|
||||
|
||||
@property
|
||||
def get_admissible_commands(self):
|
||||
return self.prev_admissible_commands
|
||||
|
||||
def close(self):
|
||||
for worker in self.workers:
|
||||
worker.close()
|
||||
|
||||
|
||||
def build_alfworld_envs(alf_config_path, seed, env_num, group_n,
|
||||
resources_per_worker, is_train=True, env_kwargs=None, gamefiles=None):
|
||||
"""Build vectorized ALFWorld environments."""
|
||||
return AlfworldEnvs(
|
||||
alf_config_path, seed, env_num, group_n,
|
||||
resources_per_worker, is_train, env_kwargs, gamefiles,
|
||||
)
|
||||
60
reflact/envs/alfworld/vendor/alfworld_projection.py
vendored
Normal file
60
reflact/envs/alfworld/vendor/alfworld_projection.py
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/environments/env_package/alfworld/projection.py
|
||||
|
||||
from typing import List
|
||||
import re
|
||||
|
||||
|
||||
def alfworld_projection(actions: List[str], action_pools: List[List[str]]):
|
||||
"""Process raw model outputs into valid ALFWorld actions.
|
||||
|
||||
Extracts text from ``<action>...</action>`` tags and validates that
|
||||
the response also contains ``<think>...</think>`` tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
actions : list[str]
|
||||
Raw model outputs, one per environment.
|
||||
action_pools : list[list[str]]
|
||||
Admissible action lists per environment (unused but kept for API compat).
|
||||
|
||||
Returns
|
||||
-------
|
||||
actions : list[str]
|
||||
Cleaned action strings.
|
||||
valids : list[int]
|
||||
1 if the action was successfully parsed, 0 otherwise.
|
||||
"""
|
||||
valids = [0] * len(actions)
|
||||
|
||||
for i in range(len(actions)):
|
||||
original_str = actions[i]
|
||||
actions[i] = actions[i].lower()
|
||||
|
||||
start_tag = "<action>"
|
||||
end_tag = "</action>"
|
||||
start_idx = actions[i].find(start_tag)
|
||||
end_idx = actions[i].find(end_tag)
|
||||
try:
|
||||
if start_idx == -1 or end_idx == -1:
|
||||
actions[i] = actions[i][-30:]
|
||||
continue
|
||||
|
||||
extracted_action = actions[i][start_idx + len(start_tag):end_idx].strip().lower()
|
||||
actions[i] = extracted_action
|
||||
valids[i] = 1
|
||||
|
||||
except Exception:
|
||||
actions[i] = actions[i][-30:]
|
||||
|
||||
# Require <think>...</think>
|
||||
think_start_idx = original_str.find("<think>")
|
||||
think_end_idx = original_str.find("</think>")
|
||||
if think_start_idx == -1 or think_end_idx == -1:
|
||||
valids[i] = 0
|
||||
|
||||
# Reject responses containing Chinese characters
|
||||
if re.search(r'[\u4e00-\u9fff]', original_str):
|
||||
valids[i] = 0
|
||||
|
||||
return actions, valids
|
||||
8
reflact/envs/alfworld/vendor/alfworld_prompts.py
vendored
Normal file
8
reflact/envs/alfworld/vendor/alfworld_prompts.py
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/environments/prompts/alfworld.py
|
||||
|
||||
from reflact.prompts import load_prompt
|
||||
|
||||
ALFWORLD_TEMPLATE_NO_HIS = load_prompt("rollout_no_history", env="alfworld")
|
||||
ALFWORLD_TEMPLATE = load_prompt("rollout_with_history", env="alfworld")
|
||||
ALFWORLD_TEMPLATE_WITH_MEMORY = load_prompt("rollout_with_memory", env="alfworld")
|
||||
145
reflact/envs/alfworld/vendor/config_tw.yaml
vendored
Normal file
145
reflact/envs/alfworld/vendor/config_tw.yaml
vendored
Normal file
@@ -0,0 +1,145 @@
|
||||
dataset:
|
||||
data_path: '$ALFWORLD_DATA/json_2.1.1/train'
|
||||
eval_id_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_seen' # null/None to disable
|
||||
eval_ood_data_path: '$ALFWORLD_DATA/json_2.1.1/valid_unseen' # null/None to disable
|
||||
num_train_games: -1 # max training games (<=0 indicates full dataset)
|
||||
num_eval_games: -1 # max evaluation games (<=0 indicates full dataset)
|
||||
|
||||
logic:
|
||||
domain: '$ALFWORLD_DATA/logic/alfred.pddl' # PDDL domain file that defines the world dynamics
|
||||
grammar: '$ALFWORLD_DATA/logic/alfred.twl2' # Grammar file that defines the text feedbacks
|
||||
|
||||
env:
|
||||
type: 'AlfredTWEnv' # 'AlfredTWEnv' or 'AlfredThorEnv' or 'AlfredHybrid'
|
||||
# regen_game_files: False # check if game is solvable by expert and save to game.tw-pddl file
|
||||
domain_randomization: False # shuffle Textworld print order and object id nums
|
||||
task_types: [1, 2, 3, 4, 5, 6] # task-type ids: 1 - Pick & Place, 2 - Examine in Light, 3 - Clean & Place, 4 - Heat & Place, 5 - Cool & Place, 6 - Pick Two & Place
|
||||
expert_timeout_steps: 150 # max steps before timeout for expert to solve the task
|
||||
expert_type: "handcoded" # 'handcoded' or 'planner'. Note: the planner is very slow for real-time use
|
||||
goal_desc_human_anns_prob: 0.0 # prob of using human-annotated goal language instead of templated goals (1.0 indicates all human annotations from ALFRED)
|
||||
|
||||
hybrid:
|
||||
start_eps: 100000 # starting episode of hybrid training, tw-only training upto this point
|
||||
thor_prob: 0.5 # prob of AlfredThorEnv during hybrid training
|
||||
eval_mode: "tw" # 'tw' or 'thor' - env used for evaluation during hybrid training
|
||||
|
||||
thor:
|
||||
screen_width: 300 # width of THOR window
|
||||
screen_height: 300 # height of THOR window
|
||||
smooth_nav: False # smooth rotations, looks, and translations during navigation (very slow)
|
||||
save_frames_to_disk: False # save frame PNGs to disk (useful for making videos)
|
||||
save_frames_path: './videos/' # path to save frame PNGs
|
||||
|
||||
controller:
|
||||
type: 'oracle' # 'oracle' or 'oracle_astar' or 'mrcnn' or 'mrcnn_astar' (aka BUTLER)
|
||||
debug: False
|
||||
load_receps: True # load receptacle locations from precomputed dict (if available)
|
||||
|
||||
mask_rcnn:
|
||||
pretrained_model_path: '$ALFWORLD_DATA/detectors/mrcnn.pth'
|
||||
|
||||
general:
|
||||
random_seed: 42
|
||||
use_cuda: True # disable this when running on machine without cuda
|
||||
visdom: False # plot training/eval curves, run with visdom server
|
||||
task: 'alfred'
|
||||
training_method: 'dagger' # 'dqn' or 'dagger'
|
||||
save_path: './training/' # path to save pytorch models
|
||||
observation_pool_capacity: 3 # k-size queue, 0 indicates no observation
|
||||
hide_init_receptacles: False # remove initial observation containing navigable receptacles
|
||||
|
||||
training:
|
||||
batch_size: 10
|
||||
max_episode: 50000
|
||||
smoothing_eps: 0.1
|
||||
optimizer:
|
||||
learning_rate: 0.001
|
||||
clip_grad_norm: 5
|
||||
|
||||
evaluate:
|
||||
run_eval: True
|
||||
batch_size: 10
|
||||
env:
|
||||
type: "AlfredTWEnv"
|
||||
|
||||
checkpoint:
|
||||
report_frequency: 1000 # report every N episode
|
||||
experiment_tag: 'test' # name of experiment
|
||||
load_pretrained: False # during test, enable this so that the agent load your pretrained model
|
||||
load_from_tag: 'not loading anything' # name of pre-trained model to load in save_path
|
||||
|
||||
model:
|
||||
encoder_layers: 1
|
||||
decoder_layers: 1
|
||||
encoder_conv_num: 5
|
||||
block_hidden_dim: 64
|
||||
n_heads: 1
|
||||
dropout: 0.1
|
||||
block_dropout: 0.1
|
||||
recurrent: True
|
||||
|
||||
rl:
|
||||
action_space: "admissible" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'beam_search_choice' or 'exhaustive' (not working)
|
||||
max_target_length: 20 # max token length for seq2seq generation
|
||||
beam_width: 10 # 1 means greedy
|
||||
generate_top_k: 3
|
||||
|
||||
training:
|
||||
max_nb_steps_per_episode: 50 # terminate after this many steps
|
||||
learn_start_from_this_episode: 0 # delay updates until this epsiode
|
||||
target_net_update_frequency: 500 # sync target net with online net per this many epochs
|
||||
|
||||
replay:
|
||||
accumulate_reward_from_final: True
|
||||
count_reward_lambda: 0.0 # 0 to disable
|
||||
novel_object_reward_lambda: 0.0 # 0 to disable
|
||||
discount_gamma_game_reward: 0.9
|
||||
discount_gamma_count_reward: 0.5
|
||||
discount_gamma_novel_object_reward: 0.5
|
||||
replay_memory_capacity: 500000 # adjust this depending on your RAM size
|
||||
replay_memory_priority_fraction: 0.5
|
||||
update_per_k_game_steps: 5
|
||||
replay_batch_size: 64
|
||||
multi_step: 3
|
||||
replay_sample_history_length: 4
|
||||
replay_sample_update_from: 2
|
||||
|
||||
epsilon_greedy:
|
||||
noisy_net: False # if this is true, then epsilon greedy is disabled
|
||||
epsilon_anneal_episodes: 1000 # -1 if not annealing
|
||||
epsilon_anneal_from: 0.3
|
||||
epsilon_anneal_to: 0.1
|
||||
|
||||
dagger:
|
||||
action_space: "generation" # 'admissible' (candidates from text engine) or 'generation' (seq2seq-style generation) or 'exhaustive' (not working)
|
||||
max_target_length: 20 # max token length for seq2seq generation
|
||||
beam_width: 10 # 1 means greedy
|
||||
generate_top_k: 5
|
||||
unstick_by_beam_search: False # use beam-search for failed actions, set True during evaluation
|
||||
|
||||
training:
|
||||
max_nb_steps_per_episode: 50 # terminate after this many steps
|
||||
|
||||
fraction_assist:
|
||||
fraction_assist_anneal_episodes: 50000
|
||||
fraction_assist_anneal_from: 1.0
|
||||
fraction_assist_anneal_to: 0.01
|
||||
|
||||
fraction_random:
|
||||
fraction_random_anneal_episodes: 0
|
||||
fraction_random_anneal_from: 0.0
|
||||
fraction_random_anneal_to: 0.0
|
||||
|
||||
replay:
|
||||
replay_memory_capacity: 500000
|
||||
update_per_k_game_steps: 5
|
||||
replay_batch_size: 64
|
||||
replay_sample_history_length: 4
|
||||
replay_sample_update_from: 2
|
||||
|
||||
vision_dagger:
|
||||
model_type: "resnet" # 'resnet' (whole image features) or 'maskrcnn_whole' (whole image MaskRCNN feats) or 'maskrcnn' (top k MaskRCNN detection feats) or 'no_vision' (zero vision input)
|
||||
resnet_fc_dim: 64
|
||||
maskrcnn_top_k_boxes: 10 # top k box features
|
||||
use_exploration_frame_feats: False # append feats from initial exploration (memory intensive!)
|
||||
sequence_aggregation_method: "average" # 'sum' or 'average' or 'rnn'
|
||||
84
reflact/envs/alfworld/vendor/env_base.py
vendored
Normal file
84
reflact/envs/alfworld/vendor/env_base.py
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/environments/base.py
|
||||
# Trimmed to only include what ALFWorld needs.
|
||||
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
def to_numpy(data):
|
||||
"""Convert data to numpy array."""
|
||||
# Lazy-check for torch.Tensor to avoid hard dependency on torch
|
||||
_torch_tensor = None
|
||||
try:
|
||||
import torch
|
||||
_torch_tensor = torch.Tensor
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if _torch_tensor is not None and isinstance(data, _torch_tensor):
|
||||
data = data.detach().cpu().numpy()
|
||||
elif isinstance(data, np.ndarray):
|
||||
pass
|
||||
elif isinstance(data, (int, float, bool, Tuple, List)):
|
||||
data = np.array(data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type: {type(data)})")
|
||||
return data
|
||||
|
||||
|
||||
class EnvironmentManagerBase:
|
||||
"""Base class for vectorized environment managers.
|
||||
|
||||
Manages a set of parallel environments, handles action projection,
|
||||
observation post-processing, and history tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, envs, projection_f, config):
|
||||
self.envs = envs
|
||||
self.projection_f = projection_f
|
||||
self.config = config
|
||||
|
||||
def reset(self, kwargs) -> Dict[str, Any]:
|
||||
obs, infos = self.envs.reset()
|
||||
return {'text': None, 'image': obs, 'anchor': None}, infos
|
||||
|
||||
def step(self, text_actions: List[str]):
|
||||
actions, valids = self.projection_f(text_actions)
|
||||
next_obs, rewards, dones, infos = self.envs.step(actions)
|
||||
|
||||
next_observations = {
|
||||
'text': None,
|
||||
'image': next_obs,
|
||||
'anchor': None,
|
||||
}
|
||||
for i, info in enumerate(infos):
|
||||
info['is_action_valid'] = to_numpy(valids[i])
|
||||
|
||||
rewards = to_numpy(rewards)
|
||||
dones = to_numpy(dones)
|
||||
return next_observations, rewards, dones, infos
|
||||
|
||||
def close(self) -> None:
|
||||
self.envs.close()
|
||||
|
||||
def success_evaluator(self, *args, **kwargs) -> Dict[str, np.ndarray]:
|
||||
total_infos = kwargs['total_infos']
|
||||
total_batch_list = kwargs['total_batch_list']
|
||||
batch_size = len(total_batch_list)
|
||||
|
||||
success = defaultdict(list)
|
||||
for bs in range(batch_size):
|
||||
self._process_batch(bs, total_batch_list, total_infos, success)
|
||||
assert len(success['success_rate']) == batch_size
|
||||
return {key: np.array(value) for key, value in success.items()}
|
||||
|
||||
def _process_batch(self, batch_idx, total_batch_list, total_infos, success):
|
||||
for i in reversed(range(len(total_batch_list[batch_idx]))):
|
||||
batch_item = total_batch_list[batch_idx][i]
|
||||
if batch_item['active_masks']:
|
||||
info = total_infos[batch_idx][i]
|
||||
won_value = float(info['won'])
|
||||
success['success_rate'].append(won_value)
|
||||
return
|
||||
139
reflact/envs/alfworld/vendor/env_manager.py
vendored
Normal file
139
reflact/envs/alfworld/vendor/env_manager.py
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/environments/env_manager.py
|
||||
# Trimmed to only include AlfWorldEnvironmentManager and its helpers.
|
||||
|
||||
from typing import List, Dict, Any
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
from reflact.envs.alfworld.vendor.env_base import EnvironmentManagerBase, to_numpy
|
||||
from reflact.envs.alfworld.vendor.alfworld_prompts import (
|
||||
ALFWORLD_TEMPLATE,
|
||||
ALFWORLD_TEMPLATE_NO_HIS,
|
||||
ALFWORLD_TEMPLATE_WITH_MEMORY,
|
||||
)
|
||||
from reflact.envs.alfworld.vendor.memory import SimpleMemory
|
||||
|
||||
|
||||
def parse_gamefile(infos):
|
||||
gamefile = []
|
||||
for info in infos:
|
||||
if 'extra.gamefile' in info:
|
||||
gamefile.append(info['extra.gamefile'])
|
||||
else:
|
||||
gamefile.append(None)
|
||||
return gamefile
|
||||
|
||||
|
||||
def set_gamefile(infos, gamefile):
|
||||
for i in range(len(infos)):
|
||||
if 'extra.gamefile' in infos[i]:
|
||||
infos[i]['extra.gamefile'] = gamefile[i]
|
||||
else:
|
||||
infos[i]['extra.gamefile'] = None
|
||||
return infos
|
||||
|
||||
|
||||
class AlfWorldEnvironmentManager(EnvironmentManagerBase):
|
||||
"""Manages parallel ALFWorld environments with observation templating."""
|
||||
|
||||
def __init__(self, envs, projection_f, config):
|
||||
self.memory = SimpleMemory()
|
||||
self.retrieval_memory = None
|
||||
super().__init__(envs, projection_f, config)
|
||||
|
||||
def reset(self, kwargs):
|
||||
text_obs, image_obs, infos = self.envs.reset()
|
||||
self.gamefile = parse_gamefile(infos)
|
||||
self.memory.reset(batch_size=len(text_obs))
|
||||
self.tasks = []
|
||||
self.pre_text_obs = text_obs
|
||||
self.extract_task(text_obs)
|
||||
|
||||
full_text_obs = self.build_text_obs(text_obs, self.envs.get_admissible_commands, init=True)
|
||||
return {'text': full_text_obs, 'image': image_obs, 'anchor': text_obs}, infos
|
||||
|
||||
def step(self, text_actions: List[str]):
|
||||
actions, valids = self.projection_f(text_actions, self.envs.get_admissible_commands)
|
||||
text_obs, image_obs, rewards, dones, infos = self.envs.step(actions)
|
||||
self.memory.store({'text_obs': self.pre_text_obs, 'action': actions})
|
||||
self.pre_text_obs = text_obs
|
||||
|
||||
full_text_obs = self.build_text_obs(text_obs, self.envs.get_admissible_commands)
|
||||
if infos[0].get("extra.gamefile") is None:
|
||||
infos = set_gamefile(infos, self.gamefile)
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
info['is_action_valid'] = to_numpy(valids[i])
|
||||
|
||||
next_observations = {'text': full_text_obs, 'image': image_obs, 'anchor': text_obs}
|
||||
rewards = to_numpy(rewards)
|
||||
dones = to_numpy(dones)
|
||||
return next_observations, rewards, dones, infos
|
||||
|
||||
def extract_task(self, text_obs: List[str]):
|
||||
for obs in text_obs:
|
||||
task_start = obs.find('Your task is to: ')
|
||||
if task_start != -1:
|
||||
self.tasks.append(obs[task_start + len('Your task is to: '):].strip())
|
||||
else:
|
||||
raise ValueError("Task description not found in text observation.")
|
||||
|
||||
def build_text_obs(self, text_obs: List[str], admissible_actions: List[List[str]], init: bool = False) -> List[str]:
|
||||
postprocess_text_obs = []
|
||||
if not init and self.config.env.history_length > 0:
|
||||
memory_contexts, valid_lens = self.memory.fetch(
|
||||
self.config.env.history_length,
|
||||
obs_key="text_obs",
|
||||
action_key="action",
|
||||
)
|
||||
|
||||
for i in range(len(text_obs)):
|
||||
reformatted_admissible_actions = "\n ".join(
|
||||
f"'{s}'" for s in admissible_actions[i] if s != 'help'
|
||||
)
|
||||
|
||||
if init or self.config.env.history_length <= 0:
|
||||
obs = ALFWORLD_TEMPLATE_NO_HIS.format(
|
||||
current_observation=text_obs[i],
|
||||
admissible_actions=reformatted_admissible_actions,
|
||||
)
|
||||
else:
|
||||
obs = ALFWORLD_TEMPLATE.format(
|
||||
task_description=self.tasks[i],
|
||||
step_count=len(self.memory[i]),
|
||||
history_length=valid_lens[i],
|
||||
action_history=memory_contexts[i],
|
||||
current_step=len(self.memory[i]) + 1,
|
||||
current_observation=text_obs[i],
|
||||
admissible_actions=reformatted_admissible_actions,
|
||||
)
|
||||
postprocess_text_obs.append(obs)
|
||||
return postprocess_text_obs
|
||||
|
||||
def _process_batch(self, batch_idx, total_batch_list, total_infos, success):
|
||||
for i in reversed(range(len(total_batch_list[batch_idx]))):
|
||||
batch_item = total_batch_list[batch_idx][i]
|
||||
if batch_item['active_masks']:
|
||||
info = total_infos[batch_idx][i]
|
||||
won_value = float(info['won'])
|
||||
success['success_rate'].append(won_value)
|
||||
|
||||
gamefile = info.get("extra.gamefile")
|
||||
if gamefile:
|
||||
self._process_gamefile(gamefile, won_value, success)
|
||||
return
|
||||
|
||||
def _process_gamefile(self, gamefile, won_value, success):
|
||||
tasks = [
|
||||
"pick_and_place",
|
||||
"pick_two_obj_and_place",
|
||||
"look_at_obj_in_light",
|
||||
"pick_heat_then_place_in_recep",
|
||||
"pick_cool_then_place_in_recep",
|
||||
"pick_clean_then_place_in_recep",
|
||||
]
|
||||
for task in tasks:
|
||||
if task in gamefile:
|
||||
success[f"{task}_success_rate"].append(won_value)
|
||||
break
|
||||
87
reflact/envs/alfworld/vendor/memory.py
vendored
Normal file
87
reflact/envs/alfworld/vendor/memory.py
vendored
Normal file
@@ -0,0 +1,87 @@
|
||||
# Vendored from SkillRL (Apache-2.0 License)
|
||||
# Original: agent_system/memory/base.py + agent_system/memory/memory.py
|
||||
# Merged into a single file for simplicity.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""Base class for memory management."""
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self, batch_size: int):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def store(self, record: Dict[str, List[Any]]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fetch(self, step: int):
|
||||
pass
|
||||
|
||||
|
||||
class SimpleMemory(BaseMemory):
|
||||
"""Per-environment history buffer for storing observations and actions."""
|
||||
|
||||
def __init__(self):
|
||||
self._data = None
|
||||
self.keys = None
|
||||
self.batch_size = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._data[idx]
|
||||
|
||||
def reset(self, batch_size: int):
|
||||
if self._data is not None:
|
||||
self._data.clear()
|
||||
self._data = [[] for _ in range(batch_size)]
|
||||
self.batch_size = batch_size
|
||||
self.keys = None
|
||||
|
||||
def store(self, record: Dict[str, List[Any]]):
|
||||
if self.keys is None:
|
||||
self.keys = list(record.keys())
|
||||
assert self.keys == list(record.keys())
|
||||
|
||||
for env_idx in range(self.batch_size):
|
||||
self._data[env_idx].append({k: record[k][env_idx] for k in self.keys})
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
history_length: int,
|
||||
obs_key: str = "text_obs",
|
||||
action_key: str = "action",
|
||||
) -> Tuple[List[str], List[int]]:
|
||||
memory_contexts, valid_lengths = [], []
|
||||
|
||||
for env_idx in range(self.batch_size):
|
||||
recent = self._data[env_idx][-history_length:]
|
||||
valid_len = len(recent)
|
||||
start_idx = len(self._data[env_idx]) - valid_len
|
||||
|
||||
lines = []
|
||||
for j, rec in enumerate(recent):
|
||||
step_num = start_idx + j + 1
|
||||
act = rec[action_key]
|
||||
obs = rec[obs_key]
|
||||
lines.append(
|
||||
f"[Observation {step_num}: '{obs}', Action {step_num}: '{act}']"
|
||||
)
|
||||
|
||||
memory_contexts.append("\n".join(lines))
|
||||
valid_lengths.append(valid_len)
|
||||
|
||||
return memory_contexts, valid_lengths
|
||||
1
reflact/envs/babyvision/__init__.py
Normal file
1
reflact/envs/babyvision/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""BabyVision environment package for ReflACT."""
|
||||
267
reflact/envs/babyvision/adapter.py
Normal file
267
reflact/envs/babyvision/adapter.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""BabyVision environment adapter for ReflACT."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from reflact.gradient.deep_probe import generate_deep_probe_instruction
|
||||
from reflact.datasets.base import BatchSpec
|
||||
from reflact.gradient.reflect import run_minibatch_reflect
|
||||
from reflact.envs.base import EnvAdapter
|
||||
from reflact.envs.babyvision.dataloader import BabyVisionDataLoader
|
||||
from reflact.envs.babyvision.rollout import run_batch
|
||||
from reflact.model import get_student_backend
|
||||
|
||||
|
||||
class BabyVisionAdapter(EnvAdapter):
|
||||
"""BabyVision adapter."""
|
||||
|
||||
def build_reference_text(self, item: dict) -> str:
|
||||
cot = str(item.get("cot") or "").strip()
|
||||
if not cot:
|
||||
return ""
|
||||
return f"## Reference CoT\n{cot}"
|
||||
|
||||
def get_reference_metadata(self, item: dict) -> dict:
|
||||
cot = str(item.get("cot") or "").strip()
|
||||
if not cot:
|
||||
return {"fields": [], "preview": ""}
|
||||
return {
|
||||
"fields": ["cot"],
|
||||
"preview": cot[:400],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "ratio",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
max_turns: int = 1,
|
||||
workers: int = 32,
|
||||
analyst_workers: int = 16,
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
image_detail: str = "auto",
|
||||
judge_model: str = "gpt-5.4",
|
||||
judge_max_completion_tokens: int = 256,
|
||||
judge_retries: int = 5,
|
||||
use_deep_reflect: bool = False,
|
||||
deep_reflect_failures: int = 4,
|
||||
deep_reflect_successes: int = 2,
|
||||
) -> None:
|
||||
self.max_turns = max_turns
|
||||
self.workers = workers
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
self.edit_budget = edit_budget
|
||||
self.image_detail = image_detail
|
||||
self.judge_model = judge_model
|
||||
self.judge_max_completion_tokens = judge_max_completion_tokens
|
||||
self.judge_retries = judge_retries
|
||||
self.use_deep_reflect = use_deep_reflect
|
||||
self.deep_reflect_failures = deep_reflect_failures
|
||||
self.deep_reflect_successes = deep_reflect_successes
|
||||
self.dataloader = BabyVisionDataLoader(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
self.dataloader.setup(cfg)
|
||||
|
||||
def get_dataloader(self):
|
||||
return self.dataloader
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
return list(batch.payload or [])
|
||||
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def rollout(
|
||||
self,
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict]:
|
||||
items: list[dict] = env_manager
|
||||
return run_batch(
|
||||
items=items,
|
||||
out_root=out_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
workers=self.workers,
|
||||
image_detail=self.image_detail,
|
||||
judge_model=self.judge_model,
|
||||
judge_max_completion_tokens=self.judge_max_completion_tokens,
|
||||
judge_retries=self.judge_retries,
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
|
||||
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),
|
||||
)
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
if not self.use_deep_reflect:
|
||||
return []
|
||||
|
||||
env_manager = kwargs.get("env_manager")
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
codex_backend = get_student_backend() == "codex_exec"
|
||||
selected_items = self.select_representative_items(
|
||||
results,
|
||||
env_manager if isinstance(env_manager, list) else None,
|
||||
n_failures=self.deep_reflect_failures,
|
||||
n_successes=self.deep_reflect_successes,
|
||||
seed=random_seed,
|
||||
)
|
||||
if not selected_items:
|
||||
return []
|
||||
selected_ids = {str(item["id"]) for item in selected_items}
|
||||
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
|
||||
selected_examples = self.attach_reference_context(selected_results, selected_items)
|
||||
if codex_backend:
|
||||
selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir)
|
||||
selected_metadata = []
|
||||
cot_count = 0
|
||||
for item in selected_items:
|
||||
meta = self.get_reference_metadata(item)
|
||||
if meta["fields"]:
|
||||
cot_count += 1
|
||||
selected_metadata.append({
|
||||
"id": str(item["id"]),
|
||||
"task_type": str(item.get("subtype") or item.get("task_type") or "babyvision"),
|
||||
"reference_fields": meta["fields"],
|
||||
"reference_preview": meta["preview"],
|
||||
})
|
||||
|
||||
deep_dir = os.path.join(out_dir, "deep_reflect")
|
||||
rollout_dir = os.path.join(deep_dir, "rollout")
|
||||
patches_dir = os.path.join(deep_dir, "patches")
|
||||
os.makedirs(deep_dir, exist_ok=True)
|
||||
print(
|
||||
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
|
||||
f"reference_fields=cot({cot_count}/{len(selected_items)})"
|
||||
)
|
||||
probe = generate_deep_probe_instruction(
|
||||
skill_content=skill_content,
|
||||
items=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
)
|
||||
if not probe:
|
||||
return []
|
||||
diagnostic_trace_context_by_id = None
|
||||
if codex_backend:
|
||||
selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target(
|
||||
selected_items=selected_items,
|
||||
selected_examples=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
probe=probe,
|
||||
)
|
||||
probe_record = {
|
||||
**probe,
|
||||
"reference_summary": {
|
||||
"selected_count": len(selected_items),
|
||||
"field_counts": {
|
||||
"cot": cot_count,
|
||||
},
|
||||
},
|
||||
"selected_examples": selected_metadata,
|
||||
}
|
||||
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(probe_record, f, ensure_ascii=False, indent=2)
|
||||
deep_results = run_batch(
|
||||
items=selected_items,
|
||||
out_root=rollout_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
workers=min(self.workers, max(len(selected_items), 1)),
|
||||
image_detail=self.image_detail,
|
||||
judge_model=self.judge_model,
|
||||
judge_max_completion_tokens=self.judge_max_completion_tokens,
|
||||
judge_retries=self.judge_retries,
|
||||
diagnostic_mode=True,
|
||||
diagnostic_instruction=probe["probe_instruction"],
|
||||
diagnostic_trace_context_by_id=diagnostic_trace_context_by_id,
|
||||
)
|
||||
deep_results = self.attach_reference_context(deep_results, selected_items)
|
||||
return run_minibatch_reflect(
|
||||
results=deep_results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=os.path.join(rollout_dir, "predictions"),
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return self.dataloader.get_task_types()
|
||||
214
reflact/envs/babyvision/dataloader.py
Normal file
214
reflact/envs/babyvision/dataloader.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""BabyVision task dataloader."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from reflact.datasets.base import SplitDataLoader
|
||||
|
||||
|
||||
# ── Raw data loading utilities (for preprocessing / standalone eval) ─────
|
||||
|
||||
_CHOICE_LABELS = ["A", "B", "C", "D", "E", "F", "G"]
|
||||
|
||||
|
||||
def _iter_jsonl(path: str) -> list[dict]:
|
||||
items: list[dict] = []
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
items.append(json.loads(line))
|
||||
return items
|
||||
|
||||
|
||||
def _normalize_ans_type(raw: Any, options: list[dict], choice_answer: Any) -> str:
|
||||
text = str(raw or "").strip().lower()
|
||||
if text in {"choice", "multiple_choice", "mcq", "option"}:
|
||||
return "choice"
|
||||
if text in {"blank", "open", "open_ended", "fill_blank", "short_answer"}:
|
||||
return "blank"
|
||||
if options or choice_answer not in (None, "", []):
|
||||
return "choice"
|
||||
return "blank"
|
||||
|
||||
|
||||
def _coerce_options(raw: Any) -> list[dict]:
|
||||
options: list[dict] = []
|
||||
if isinstance(raw, list):
|
||||
for idx, item in enumerate(raw):
|
||||
if isinstance(item, dict):
|
||||
text = str(item.get("text") or item.get("content") or item.get("option") or "").strip()
|
||||
label = str(item.get("label") or _CHOICE_LABELS[idx]).strip()
|
||||
else:
|
||||
text = str(item).strip()
|
||||
label = _CHOICE_LABELS[idx]
|
||||
if text:
|
||||
options.append({"label": label, "text": text})
|
||||
elif isinstance(raw, dict):
|
||||
for idx, (key, value) in enumerate(raw.items()):
|
||||
text = str(value).strip()
|
||||
if text:
|
||||
options.append({"label": str(key).strip() or _CHOICE_LABELS[idx], "text": text})
|
||||
return options
|
||||
|
||||
|
||||
def _normalize_choice_answer(choice_answer: Any, options: list[dict]) -> dict[str, str]:
|
||||
if not options:
|
||||
return {"label": "", "text": ""}
|
||||
|
||||
if isinstance(choice_answer, dict):
|
||||
label = str(choice_answer.get("label") or "").strip().upper()
|
||||
text = str(choice_answer.get("text") or "").strip()
|
||||
for option in options:
|
||||
if label and option["label"].strip().upper() == label:
|
||||
return {"label": option["label"], "text": option["text"]}
|
||||
if text and option["text"] == text:
|
||||
return {"label": option["label"], "text": option["text"]}
|
||||
|
||||
if isinstance(choice_answer, int):
|
||||
idx = choice_answer
|
||||
if 0 <= idx < len(options):
|
||||
return dict(options[idx])
|
||||
if 1 <= idx <= len(options):
|
||||
return dict(options[idx - 1])
|
||||
|
||||
text = str(choice_answer or "").strip()
|
||||
label = text.upper().rstrip(".):")
|
||||
for option in options:
|
||||
if option["label"].strip().upper() == label:
|
||||
return dict(option)
|
||||
if option["text"] == text:
|
||||
return dict(option)
|
||||
|
||||
return {"label": "", "text": ""}
|
||||
|
||||
|
||||
def _coerce_blank_answers(raw: Any) -> list[str]:
|
||||
if isinstance(raw, list):
|
||||
return [str(item).strip() for item in raw if str(item).strip()]
|
||||
if raw is None:
|
||||
return []
|
||||
text = str(raw).strip()
|
||||
return [text] if text else []
|
||||
|
||||
|
||||
def load_items(data_path: str) -> list[dict]:
|
||||
"""Load and normalise BabyVision items from a directory or JSONL file."""
|
||||
if not data_path:
|
||||
raise ValueError("BabyVision requires data_path pointing to a local dataset directory or meta_data.jsonl.")
|
||||
|
||||
if os.path.isdir(data_path):
|
||||
meta_path = os.path.join(data_path, "meta_data.jsonl")
|
||||
image_root = os.path.join(data_path, "images")
|
||||
else:
|
||||
meta_path = data_path
|
||||
image_root = os.path.join(os.path.dirname(data_path), "images")
|
||||
|
||||
if not os.path.exists(meta_path):
|
||||
raise ValueError(
|
||||
"BabyVision expected a meta_data.jsonl file. "
|
||||
f"Could not find: {meta_path}"
|
||||
)
|
||||
|
||||
raw_items = _iter_jsonl(meta_path)
|
||||
items: list[dict] = []
|
||||
for idx, raw in enumerate(raw_items):
|
||||
options = _coerce_options(raw.get("options") or raw.get("choices") or raw.get("choiceOptions"))
|
||||
ans_type = _normalize_ans_type(raw.get("ansType"), options, raw.get("choiceAns"))
|
||||
correct_choice = _normalize_choice_answer(raw.get("choiceAns"), options)
|
||||
blank_answers = _coerce_blank_answers(raw.get("blankAns"))
|
||||
|
||||
image_name = str(
|
||||
raw.get("image")
|
||||
or raw.get("image_path")
|
||||
or raw.get("image_file")
|
||||
or raw.get("img")
|
||||
or ""
|
||||
).strip()
|
||||
if not image_name:
|
||||
continue
|
||||
image_path = image_name if os.path.isabs(image_name) else os.path.join(image_root, image_name)
|
||||
if not os.path.exists(image_path):
|
||||
alt = os.path.join(os.path.dirname(meta_path), image_name)
|
||||
if os.path.exists(alt):
|
||||
image_path = alt
|
||||
else:
|
||||
continue
|
||||
|
||||
task_id = str(raw.get("taskId") or raw.get("id") or idx + 1)
|
||||
task_type = str(raw.get("type") or raw.get("taskType") or "unknown").strip() or "unknown"
|
||||
subtype = str(raw.get("subtype") or raw.get("subType") or task_type).strip() or task_type
|
||||
question = str(raw.get("question") or raw.get("query") or "").strip()
|
||||
if not question:
|
||||
continue
|
||||
|
||||
if ans_type == "choice" and not correct_choice["label"]:
|
||||
continue
|
||||
if ans_type != "choice" and not blank_answers:
|
||||
continue
|
||||
|
||||
items.append({
|
||||
"id": task_id,
|
||||
"task_type": task_type,
|
||||
"subtype": subtype,
|
||||
"question": question,
|
||||
"image_path": os.path.abspath(image_path),
|
||||
"ans_type": ans_type,
|
||||
"choices": options,
|
||||
"correct_choice": correct_choice,
|
||||
"blank_answers": blank_answers,
|
||||
"cot": str(raw.get("coT") or raw.get("cot") or "").strip(),
|
||||
"source_path": os.path.abspath(meta_path),
|
||||
})
|
||||
|
||||
if not items:
|
||||
raise ValueError(f"No valid BabyVision items loaded from {data_path}")
|
||||
return items
|
||||
|
||||
|
||||
# ── Dataloader ───────────────────────────────────────────────────────────
|
||||
|
||||
class BabyVisionDataLoader(SplitDataLoader):
|
||||
"""BabyVision dataloader."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "ratio",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
self._task_types: list[str] = []
|
||||
|
||||
def load_raw_items(self, data_path: str) -> list[dict]:
|
||||
return load_items(data_path)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
all_items = self.train_items + self.val_items + self.test_items
|
||||
task_types = {
|
||||
item.get("subtype") or item.get("task_type") or "unknown"
|
||||
for item in all_items
|
||||
}
|
||||
self._task_types = sorted(task_types)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return list(self._task_types)
|
||||
160
reflact/envs/babyvision/evaluator.py
Normal file
160
reflact/envs/babyvision/evaluator.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""BabyVision evaluation helpers using the official-style LLM judge."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import string
|
||||
|
||||
import regex
|
||||
|
||||
from reflact.model import chat_with_deployment
|
||||
from reflact.prompts import load_prompt
|
||||
|
||||
_EVAL_MODE = "babyvision_judge_v2_official_style"
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
text = str(text).strip().lower()
|
||||
text = "".join(ch for ch in text if ch not in string.punctuation)
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def extract_boxed_answer(text: str | None) -> str | None:
|
||||
"""Extract the final answer using the official BabyVision rule."""
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
pattern = r'\\boxed\{((?:[^{}]|{(?:[^{}]|{.*})*})*)\}'
|
||||
matches = regex.findall(pattern, text)
|
||||
if matches:
|
||||
return matches[-1]
|
||||
|
||||
pattern_alt = r'<\|begin_of_box\|>(.*?)<\|end_of_box\|>'
|
||||
matches_alt = regex.findall(pattern_alt, text)
|
||||
if matches_alt:
|
||||
return matches_alt[-1].strip()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _token_f1(prediction: str, gold: str) -> float:
|
||||
pred_tokens = normalize_text(prediction).split()
|
||||
gold_tokens = normalize_text(gold).split()
|
||||
if not pred_tokens and not gold_tokens:
|
||||
return 1.0
|
||||
if not pred_tokens or not gold_tokens:
|
||||
return 0.0
|
||||
pred_set = {}
|
||||
gold_set = {}
|
||||
for tok in pred_tokens:
|
||||
pred_set[tok] = pred_set.get(tok, 0) + 1
|
||||
for tok in gold_tokens:
|
||||
gold_set[tok] = gold_set.get(tok, 0) + 1
|
||||
common = 0
|
||||
for tok, count in pred_set.items():
|
||||
common += min(count, gold_set.get(tok, 0))
|
||||
if common == 0:
|
||||
return 0.0
|
||||
precision = common / len(pred_tokens)
|
||||
recall = common / len(gold_tokens)
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
|
||||
def _format_choices(choices: list[dict]) -> str:
|
||||
return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices)
|
||||
|
||||
|
||||
def _judge_answer(
|
||||
*,
|
||||
item: dict,
|
||||
prediction_text: str,
|
||||
extracted_answer: str,
|
||||
judge_model: str,
|
||||
max_completion_tokens: int,
|
||||
retries: int,
|
||||
) -> dict:
|
||||
if item["ans_type"] == "choice":
|
||||
ground_truth = str(item["correct_choice"]["label"])
|
||||
else:
|
||||
if len(item["blank_answers"]) == 1:
|
||||
ground_truth = item["blank_answers"][0]
|
||||
else:
|
||||
ground_truth = " | ".join(item["blank_answers"])
|
||||
|
||||
question = str(item["question"])
|
||||
if item["ans_type"] == "choice" and item.get("choices"):
|
||||
question = f"{question}\nChoices:\n{_format_choices(item['choices'])}"
|
||||
|
||||
raw, _ = chat_with_deployment(
|
||||
deployment=judge_model,
|
||||
system="You are a careful and strict evaluator.",
|
||||
user=load_prompt("judge", env="babyvision").format(
|
||||
question=question,
|
||||
groundtruth=ground_truth,
|
||||
modeloutput=extracted_answer,
|
||||
),
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=retries,
|
||||
stage="babyvision_judge",
|
||||
)
|
||||
judge_response_clean = str(raw).strip().lower()
|
||||
if "true" in judge_response_clean:
|
||||
correct = True
|
||||
elif "false" in judge_response_clean:
|
||||
correct = False
|
||||
else:
|
||||
correct = False
|
||||
return {
|
||||
"raw": raw,
|
||||
"correct": correct,
|
||||
"reason": judge_response_clean,
|
||||
"matched_gold": ground_truth if correct else "",
|
||||
}
|
||||
|
||||
|
||||
def evaluate_item(
|
||||
*,
|
||||
item: dict,
|
||||
prediction_text: str,
|
||||
judge_model: str,
|
||||
max_completion_tokens: int = 256,
|
||||
retries: int = 5,
|
||||
) -> dict:
|
||||
answer = extract_boxed_answer(prediction_text)
|
||||
judge = _judge_answer(
|
||||
item=item,
|
||||
prediction_text=prediction_text,
|
||||
extracted_answer=answer,
|
||||
judge_model=judge_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=retries,
|
||||
)
|
||||
hard = 1.0 if judge["correct"] else 0.0
|
||||
|
||||
result = {
|
||||
"evaluation_mode": _EVAL_MODE,
|
||||
"predicted_answer": answer,
|
||||
"em": hard,
|
||||
"f1": hard,
|
||||
"sub_em": hard,
|
||||
"judge_model": judge_model,
|
||||
"judge_raw": judge["raw"],
|
||||
"judge_reason": judge["reason"],
|
||||
"matched_gold": judge["matched_gold"],
|
||||
}
|
||||
|
||||
if item["ans_type"] == "choice":
|
||||
result["predicted_label"] = str(answer or "").strip().upper().rstrip(".):")
|
||||
result["predicted_text"] = ""
|
||||
result["correct_label"] = str(item["correct_choice"].get("label") or "")
|
||||
result["correct_text"] = str(item["correct_choice"].get("text") or "")
|
||||
else:
|
||||
result["gold_answers"] = list(item["blank_answers"])
|
||||
best_f1 = 0.0
|
||||
for gold in item["blank_answers"]:
|
||||
best_f1 = max(best_f1, _token_f1(str(answer or ""), gold))
|
||||
result["string_f1"] = best_f1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def evaluation_mode() -> str:
|
||||
return _EVAL_MODE
|
||||
36
reflact/envs/babyvision/prompts/analyst_error.md
Normal file
36
reflact/envs/babyvision/prompts/analyst_error.md
Normal file
@@ -0,0 +1,36 @@
|
||||
You are an expert failure-analysis agent for child-level visual reasoning tasks.
|
||||
|
||||
You will be given MULTIPLE failed BabyVision trajectories from a minibatch and the current skill document.
|
||||
Each trajectory includes the text prompt, the model answer, and the evaluation result.
|
||||
You do not have direct access to raw pixel content during reflection, so focus on general reasoning,
|
||||
option-selection, and visual-question-answering behaviors that can be improved through prompting.
|
||||
|
||||
## Failure Type Categories
|
||||
- **visual_detail_miss**: the agent likely overlooked a salient visual attribute, relation, count, or object state
|
||||
- **option_mismatch**: the agent selected the wrong option despite relevant evidence likely being present
|
||||
- **instruction_slip**: the agent ignored output format or answered too vaguely
|
||||
- **answer_granularity**: the agent gave an answer that was too broad, too narrow, or mismatched the expected specificity
|
||||
- **other**: none of the above
|
||||
|
||||
## Rules
|
||||
1. Focus on patterns recurring across the minibatch.
|
||||
2. Prefer reusable behaviors for inspecting images and grounding answers in visible evidence.
|
||||
3. Do not memorize dataset-specific answers.
|
||||
4. Only patch gaps not already covered by the current skill.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number>,
|
||||
"failure_summary": [
|
||||
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
|
||||
],
|
||||
"patch": {
|
||||
"reasoning": "<why these edits address the common failures>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
25
reflact/envs/babyvision/prompts/analyst_success.md
Normal file
25
reflact/envs/babyvision/prompts/analyst_success.md
Normal file
@@ -0,0 +1,25 @@
|
||||
You are an expert success-pattern analyst for child-level visual reasoning tasks.
|
||||
|
||||
You will be given MULTIPLE successful BabyVision trajectories from a minibatch and the current skill document.
|
||||
Identify generalizable behavior patterns that help the agent inspect the image carefully and answer at the right level of specificity.
|
||||
|
||||
## Rules
|
||||
- Focus on broadly useful visual QA behaviors.
|
||||
- Prefer patterns about systematic image inspection, comparing options, and concise grounded answers.
|
||||
- Do not add dataset-specific facts.
|
||||
- "edits" may be empty if the skill already captures the useful patterns.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number>,
|
||||
"success_patterns": ["<pattern 1>", "<pattern 2>"],
|
||||
"patch": {
|
||||
"reasoning": "<why these patterns matter>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
25
reflact/envs/babyvision/prompts/deep_probe.md
Normal file
25
reflact/envs/babyvision/prompts/deep_probe.md
Normal file
@@ -0,0 +1,25 @@
|
||||
You are an expert diagnostic-probe designer for BabyVision-style visual reasoning tasks.
|
||||
|
||||
You will be shown representative trajectories, the current student skill, and the student's original prompt context.
|
||||
Design one SMALL diagnostic instruction that exposes the student's intermediate visual judgment without materially changing the original scaffold.
|
||||
|
||||
## Hard Constraints
|
||||
1. Do NOT substantially change the original scaffold.
|
||||
2. Do NOT prescribe a new step-by-step solving method.
|
||||
3. You MAY ask for a short structured list of a few intermediate conclusions, candidate cues, or counted units, as long as it stays close to the original scaffold.
|
||||
4. Do NOT ask for exhaustive listing of all cells, all objects, or a full chain-of-thought.
|
||||
5. Ask only for a short readout that reveals the student's current latent state.
|
||||
6. Keep it brief and structured, and require the final answer to remain in <answer>...</answer>.
|
||||
|
||||
## Good Probe Targets
|
||||
- top answer and runner-up
|
||||
- decisive visual cue
|
||||
- suspicious region or compared objects
|
||||
- counting unit or formatting interpretation
|
||||
- 2-4 short intermediate conclusions that directly support the final answer
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"reasoning": "<why this probe is informative>",
|
||||
"probe_instruction": "<the exact instruction text to append to the student prompt>"
|
||||
}
|
||||
35
reflact/envs/babyvision/prompts/judge.md
Normal file
35
reflact/envs/babyvision/prompts/judge.md
Normal file
@@ -0,0 +1,35 @@
|
||||
You are a careful and strict evaluator. You will be given:
|
||||
|
||||
1. **Question**
|
||||
2. **Ground Truth Answer** (correct answer)
|
||||
3. **Model Output** (answer from another model)
|
||||
|
||||
**Your goal:** Determine if the Model Output **accurately matches** the Ground Truth Answer in meaning.
|
||||
|
||||
* Matching means: the facts, entities, and key details are equivalent, even if phrasing differs.
|
||||
* Not matching means: the Model Output is wrong, incomplete, contains extra incorrect facts, or changes the meaning.
|
||||
|
||||
**Process (internal reasoning):**
|
||||
|
||||
1. Read and understand the Question, Ground Truth Answer, and Model Output.
|
||||
2. Ignore small wording differences, formatting, or synonyms.
|
||||
3. If all factual content matches, conclude `1`. Otherwise, conclude `0`.
|
||||
|
||||
**Important:**
|
||||
|
||||
* Think through your decision step-by-step **internally** before responding.
|
||||
* In your final output, return **only** True or False, with no extra text or explanation.
|
||||
|
||||
**Output format:**
|
||||
|
||||
True
|
||||
|
||||
or
|
||||
|
||||
False
|
||||
|
||||
**Input:**
|
||||
|
||||
Question: {question},
|
||||
Ground Truth Answer: {groundtruth},
|
||||
Model Output: {modeloutput}
|
||||
13
reflact/envs/babyvision/prompts/rollout_system.md
Normal file
13
reflact/envs/babyvision/prompts/rollout_system.md
Normal file
@@ -0,0 +1,13 @@
|
||||
You are an expert visual reasoning agent solving child-level image understanding tasks.
|
||||
|
||||
{skill_section}## Task Format
|
||||
You will receive one image and one question about it.
|
||||
Inspect the image carefully before answering. Ground the answer in visible evidence.
|
||||
|
||||
## Answer Format
|
||||
Think step by step, then provide your final answer in \boxed{{Answer}} format.
|
||||
- For multiple-choice questions, output only the single choice label, such as \boxed{{A}}.
|
||||
- For open questions, output only a short final answer inside \boxed{{...}}.
|
||||
|
||||
Example:
|
||||
\boxed{{B}}
|
||||
4
reflact/envs/babyvision/reflect.py
Normal file
4
reflact/envs/babyvision/reflect.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""BabyVision Reflect stage.
|
||||
|
||||
Prompts are now loaded from .md files by the base adapter.
|
||||
"""
|
||||
467
reflact/envs/babyvision/rollout.py
Normal file
467
reflact/envs/babyvision/rollout.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""BabyVision rollout — multimodal visual QA with image input."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from reflact.envs.babyvision.evaluator import evaluate_item, evaluation_mode, extract_boxed_answer
|
||||
from reflact.model import chat_student_messages, get_student_backend, is_student_exec_backend
|
||||
from reflact.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
|
||||
from reflact.prompts import load_prompt
|
||||
|
||||
def _build_system(skill_content: str) -> str:
|
||||
if skill_content.strip():
|
||||
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
|
||||
else:
|
||||
skill_section = ""
|
||||
return load_prompt("rollout_system", env="babyvision").format(skill_section=skill_section)
|
||||
|
||||
|
||||
def _format_choices(choices: list[dict]) -> str:
|
||||
return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices)
|
||||
|
||||
|
||||
def _build_user_text(
|
||||
item: dict,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> str:
|
||||
parts = []
|
||||
if diagnostic_trace_context.strip():
|
||||
parts.append(
|
||||
"## Previous Codex Trace Snapshot\n"
|
||||
"This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n"
|
||||
f"{diagnostic_trace_context.strip()}"
|
||||
)
|
||||
parts.append(f"## Question\n{item['question']}")
|
||||
if item["ans_type"] == "choice":
|
||||
parts.append(f"## Choices\n{_format_choices(item['choices'])}")
|
||||
parts.append("Answer using the single correct option label in \\boxed{...}.")
|
||||
else:
|
||||
parts.append("Answer with a short phrase in \\boxed{...}.")
|
||||
if diagnostic_mode and diagnostic_instruction.strip():
|
||||
parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _image_to_data_uri(path: str) -> str:
|
||||
mime = mimetypes.guess_type(path)[0] or "image/png"
|
||||
with open(path, "rb") as f:
|
||||
encoded = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:{mime};base64,{encoded}"
|
||||
|
||||
|
||||
def _build_messages(
|
||||
item: dict,
|
||||
skill_content: str,
|
||||
image_detail: str,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> tuple[list[dict], str, str]:
|
||||
system = _build_system(skill_content)
|
||||
user_text = _build_user_text(
|
||||
item,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
image_url = {
|
||||
"url": _image_to_data_uri(item["image_path"]),
|
||||
}
|
||||
if image_detail and image_detail != "auto":
|
||||
image_url["detail"] = image_detail
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_text},
|
||||
{"type": "image_url", "image_url": image_url},
|
||||
],
|
||||
},
|
||||
]
|
||||
return messages, system, user_text
|
||||
|
||||
|
||||
def _build_codex_skill(skill_content: str) -> str:
|
||||
return render_skill_md(
|
||||
skill_content,
|
||||
description="Dynamic ReflACT skill for solving the current BabyVision visual reasoning question.",
|
||||
preamble=(
|
||||
"Use this skill when answering the current visual reasoning question.\n"
|
||||
"Inspect the attached image carefully and return the final answer in \\boxed{...}."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _run_codex_once(
|
||||
*,
|
||||
pred_dir: str,
|
||||
item: dict,
|
||||
skill_content: str,
|
||||
model: str,
|
||||
timeout: int,
|
||||
image_detail: str,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
previous_response: str = "",
|
||||
) -> tuple[str, str, str, str]:
|
||||
user_text = _build_user_text(
|
||||
item,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
task_parts = [user_text]
|
||||
if previous_response:
|
||||
task_parts.append(
|
||||
"## Previous Attempt\n"
|
||||
f"{previous_response}\n\n"
|
||||
"Review the same image and question carefully. If needed, correct the answer."
|
||||
)
|
||||
task_text = "\n\n".join(task_parts)
|
||||
skill_md = _build_codex_skill(skill_content)
|
||||
work_dir = os.path.join(pred_dir, "codex_exec")
|
||||
prepare_workspace(
|
||||
work_dir=work_dir,
|
||||
skill_md=skill_md,
|
||||
task_text=task_text,
|
||||
images=[item["image_path"]],
|
||||
)
|
||||
prompt = (
|
||||
"Use the `reflact-student` skill available in this workspace.\n"
|
||||
"Read `task.md`, inspect the attached image, and answer the question.\n"
|
||||
"Return the final answer in \\boxed{...}."
|
||||
)
|
||||
final_message, raw = run_student_exec(
|
||||
work_dir=work_dir,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
images=[item["image_path"]],
|
||||
)
|
||||
return final_message or raw, raw, skill_md, task_text
|
||||
|
||||
|
||||
def process_one(
|
||||
item: dict,
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
image_detail: str = "auto",
|
||||
judge_model: str = "gpt-5.4",
|
||||
judge_max_completion_tokens: int = 256,
|
||||
judge_retries: int = 5,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> dict:
|
||||
item_id = str(item["id"])
|
||||
result = {
|
||||
"id": item_id,
|
||||
"question": item["question"],
|
||||
"task_type": item.get("subtype") or item.get("task_type") or "babyvision",
|
||||
"task_description": item["question"],
|
||||
"hard": 0,
|
||||
"soft": 0.0,
|
||||
"predicted_answer": "",
|
||||
"predicted_label": "",
|
||||
"predicted_text": "",
|
||||
"response": "",
|
||||
"fail_reason": "",
|
||||
"agent_ok": False,
|
||||
"n_turns": 0,
|
||||
"image_path": item["image_path"],
|
||||
"ans_type": item["ans_type"],
|
||||
"evaluation_mode": evaluation_mode(),
|
||||
"judge_model": judge_model,
|
||||
}
|
||||
if item["ans_type"] == "choice":
|
||||
result["correct_label"] = item["correct_choice"]["label"]
|
||||
result["correct_text"] = item["correct_choice"]["text"]
|
||||
else:
|
||||
result["gold_answers"] = item["blank_answers"]
|
||||
|
||||
try:
|
||||
pred_dir = os.path.join(out_root, "predictions", item_id)
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
|
||||
if is_student_exec_backend():
|
||||
from reflact.model import azure_openai as _llm
|
||||
|
||||
response = ""
|
||||
conversation: list[dict] = [
|
||||
{"role": "user", "content": f"{item['question']}\n\n[image] {os.path.basename(item['image_path'])}"}
|
||||
]
|
||||
system_prompt = ""
|
||||
user_text = ""
|
||||
for turn in range(max_turns):
|
||||
response, raw, system_prompt, user_text = _run_codex_once(
|
||||
pred_dir=pred_dir,
|
||||
item=item,
|
||||
skill_content=skill_content,
|
||||
model=_llm.STUDENT_DEPLOYMENT,
|
||||
timeout=120,
|
||||
image_detail=image_detail,
|
||||
diagnostic_mode=diagnostic_mode if turn == 0 else False,
|
||||
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
|
||||
diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "",
|
||||
previous_response=response if turn > 0 else "",
|
||||
)
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": response})
|
||||
if extract_boxed_answer(response) is not None:
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation) - 1
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system_prompt)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user_text)
|
||||
|
||||
eval_result = evaluate_item(
|
||||
item=item,
|
||||
prediction_text=response,
|
||||
judge_model=judge_model,
|
||||
max_completion_tokens=judge_max_completion_tokens,
|
||||
retries=judge_retries,
|
||||
)
|
||||
result["evaluation_mode"] = eval_result["evaluation_mode"]
|
||||
result["judge_raw"] = eval_result["judge_raw"]
|
||||
result["judge_reason"] = eval_result["judge_reason"]
|
||||
result["matched_gold"] = eval_result["matched_gold"]
|
||||
if item["ans_type"] == "choice":
|
||||
result["predicted_label"] = eval_result["predicted_label"]
|
||||
result["predicted_text"] = eval_result["predicted_text"]
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"judge=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
|
||||
f"but expected '{eval_result['correct_label']}' ({eval_result['judge_reason']})"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted label: {eval_result['predicted_label']!r}\n"
|
||||
f"Predicted text: {eval_result['predicted_text']!r}\n"
|
||||
f"Correct label: {eval_result['correct_label']!r}\n"
|
||||
f"Correct text: {eval_result['correct_text']!r}\n"
|
||||
f"Judge correct: {eval_result['em']}\n"
|
||||
f"Judge reason: {eval_result['judge_reason']}"
|
||||
)
|
||||
else:
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"judge=0: predicted '{eval_result['predicted_answer']}' "
|
||||
f"but expected {item['blank_answers']} ({eval_result['judge_reason']})"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
|
||||
f"Gold answers: {item['blank_answers']!r}\n"
|
||||
f"Judge correct: {eval_result['em']}\n"
|
||||
f"Judge reason: {eval_result['judge_reason']}\n"
|
||||
f"String F1: {eval_result.get('string_f1', 0.0):.4f}"
|
||||
)
|
||||
conversation.append({"role": "system", "content": eval_detail})
|
||||
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(conversation, f, ensure_ascii=False, indent=2)
|
||||
return result
|
||||
|
||||
messages, system_prompt, user_text = _build_messages(
|
||||
item,
|
||||
skill_content,
|
||||
image_detail,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
response = ""
|
||||
conversation: list[dict] = [
|
||||
{"role": "user", "content": f"{user_text}\n\n[image] {os.path.basename(item['image_path'])}"}
|
||||
]
|
||||
|
||||
for turn in range(max_turns):
|
||||
if turn == 0:
|
||||
resp_text, _ = chat_student_messages(
|
||||
messages=messages,
|
||||
max_completion_tokens=768,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
)
|
||||
else:
|
||||
refinement_text = (
|
||||
f"Your previous answer was:\n{response}\n\n"
|
||||
"Review the same image and question carefully. "
|
||||
"If needed, correct your answer. Output the final answer in \\boxed{...}."
|
||||
)
|
||||
refinement_messages = [
|
||||
messages[0],
|
||||
messages[1],
|
||||
{"role": "assistant", "content": response},
|
||||
{"role": "user", "content": refinement_text},
|
||||
]
|
||||
resp_text, _ = chat_student_messages(
|
||||
messages=refinement_messages,
|
||||
max_completion_tokens=512,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
)
|
||||
response = resp_text
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
|
||||
if extract_boxed_answer(resp_text) is not None:
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation) - 1
|
||||
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system_prompt)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user_text)
|
||||
|
||||
eval_result = evaluate_item(
|
||||
item=item,
|
||||
prediction_text=response,
|
||||
judge_model=judge_model,
|
||||
max_completion_tokens=judge_max_completion_tokens,
|
||||
retries=judge_retries,
|
||||
)
|
||||
result["evaluation_mode"] = eval_result["evaluation_mode"]
|
||||
result["judge_raw"] = eval_result["judge_raw"]
|
||||
result["judge_reason"] = eval_result["judge_reason"]
|
||||
result["matched_gold"] = eval_result["matched_gold"]
|
||||
|
||||
if item["ans_type"] == "choice":
|
||||
result["predicted_label"] = eval_result["predicted_label"]
|
||||
result["predicted_text"] = eval_result["predicted_text"]
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"judge=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
|
||||
f"but expected '{eval_result['correct_label']}' ({eval_result['judge_reason']})"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted label: {eval_result['predicted_label']!r}\n"
|
||||
f"Predicted text: {eval_result['predicted_text']!r}\n"
|
||||
f"Correct label: {eval_result['correct_label']!r}\n"
|
||||
f"Correct text: {eval_result['correct_text']!r}\n"
|
||||
f"Judge correct: {eval_result['em']}\n"
|
||||
f"Judge reason: {eval_result['judge_reason']}"
|
||||
)
|
||||
else:
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"judge=0: predicted '{eval_result['predicted_answer']}' "
|
||||
f"but expected {item['blank_answers']} ({eval_result['judge_reason']})"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
|
||||
f"Gold answers: {item['blank_answers']!r}\n"
|
||||
f"Judge correct: {eval_result['em']}\n"
|
||||
f"Judge reason: {eval_result['judge_reason']}\n"
|
||||
f"String F1: {eval_result.get('string_f1', 0.0):.4f}"
|
||||
)
|
||||
|
||||
conversation.append({"role": "system", "content": eval_detail})
|
||||
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(conversation, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
result["fail_reason"] = f"error: {e}"
|
||||
return result
|
||||
|
||||
|
||||
def run_batch(
|
||||
items: list[dict],
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
workers: int = 32,
|
||||
image_detail: str = "auto",
|
||||
judge_model: str = "gpt-5.4",
|
||||
judge_max_completion_tokens: int = 256,
|
||||
judge_retries: int = 5,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context_by_id: dict[str, str] | None = None,
|
||||
) -> list[dict]:
|
||||
results_path = os.path.join(out_root, "results.jsonl")
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
|
||||
expected_eval_mode = evaluation_mode()
|
||||
done_ids: set[str] = set()
|
||||
existing: list[dict] = []
|
||||
rewrite_results = False
|
||||
if os.path.exists(results_path):
|
||||
with open(results_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
try:
|
||||
row = json.loads(line)
|
||||
if row.get("evaluation_mode") != expected_eval_mode:
|
||||
rewrite_results = True
|
||||
continue
|
||||
done_ids.add(str(row["id"]))
|
||||
existing.append(row)
|
||||
except Exception:
|
||||
rewrite_results = True
|
||||
|
||||
pending = [item for item in items if str(item["id"]) not in done_ids]
|
||||
if not pending and not rewrite_results:
|
||||
return existing
|
||||
|
||||
results = list(existing)
|
||||
file_mode = "w" if rewrite_results else "a"
|
||||
with open(results_path, file_mode, encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex:
|
||||
if rewrite_results:
|
||||
for row in existing:
|
||||
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
futs = {
|
||||
ex.submit(
|
||||
process_one,
|
||||
item,
|
||||
out_root,
|
||||
skill_content,
|
||||
max_turns=max_turns,
|
||||
image_detail=image_detail,
|
||||
judge_model=judge_model,
|
||||
judge_max_completion_tokens=judge_max_completion_tokens,
|
||||
judge_retries=judge_retries,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""),
|
||||
): item
|
||||
for item in pending
|
||||
}
|
||||
for fut in as_completed(futs):
|
||||
row = fut.result()
|
||||
results.append(row)
|
||||
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
outf.flush()
|
||||
return results
|
||||
18
reflact/envs/babyvision/skills/initial.md
Normal file
18
reflact/envs/babyvision/skills/initial.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# BabyVision Visual QA Heuristics
|
||||
|
||||
## Image Inspection
|
||||
- First identify the main objects, their attributes, and their spatial relations before answering.
|
||||
- If the question involves counting, compare all relevant instances carefully instead of stopping after the first match.
|
||||
- If the question asks about color, size, position, or action, verify the specific visible evidence for that attribute.
|
||||
|
||||
## Multiple Choice
|
||||
- Compare every option against the visible image evidence before deciding.
|
||||
- Prefer the option that matches the image exactly; reject options that are only partially true or too vague.
|
||||
- When two options are close, check the smallest discriminating visual detail.
|
||||
|
||||
## Open Answers
|
||||
- Answer with the shortest phrase that is fully supported by the image.
|
||||
- Match the expected level of specificity: not broader than the image evidence, not narrower than the question asks.
|
||||
|
||||
## Final Answer
|
||||
- Output only the final answer inside <answer>...</answer>.
|
||||
396
reflact/envs/base.py
Normal file
396
reflact/envs/base.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""ReflACT environment adapter — abstract interface.
|
||||
|
||||
To connect ReflACT to a new environment (benchmark, simulator, etc.),
|
||||
implement a subclass of :class:`EnvAdapter` with environment-specific
|
||||
rollout and reflection logic.
|
||||
|
||||
Example::
|
||||
|
||||
class MyBenchAdapter(EnvAdapter):
|
||||
def build_train_env(self, batch_size, seed, **kw):
|
||||
return MyEnvManager(split="train", n=batch_size, seed=seed)
|
||||
|
||||
def build_eval_env(self, env_num, split, seed, **kw):
|
||||
return MyEnvManager(split=split, n=env_num, seed=seed)
|
||||
|
||||
def rollout(self, env_manager, skill_content, out_dir, **kw):
|
||||
# Run episodes, return [{"id": ..., "hard": 0/1, "soft": 0.0-1.0, ...}]
|
||||
...
|
||||
|
||||
def reflect(self, results, skill_content, out_dir, **kw):
|
||||
# Analyze trajectories, return list of patch dicts
|
||||
...
|
||||
|
||||
def get_task_types(self):
|
||||
return ["task_a", "task_b"]
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import random
|
||||
|
||||
from reflact.datasets.base import BaseDataLoader, BatchSpec
|
||||
from reflact.model.codex_harness import extract_codex_trace_prefix, format_codex_trace_steps, parse_codex_raw
|
||||
from reflact.prompts import load_prompt
|
||||
|
||||
|
||||
class EnvAdapter(ABC):
|
||||
"""Abstract adapter for connecting ReflACT to any environment.
|
||||
|
||||
Subclasses must implement all abstract methods. The ReflACT trainer
|
||||
calls these methods at the appropriate pipeline stages.
|
||||
"""
|
||||
|
||||
# ── Lifecycle hooks ────────────────────────────────────────────────────
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
"""Called once by the trainer before the training loop begins.
|
||||
|
||||
Override to perform one-time initialization that requires the full
|
||||
config (e.g., data loading, split creation). Default is a no-op.
|
||||
"""
|
||||
self._cfg = dict(cfg)
|
||||
|
||||
def get_dataloader(self) -> BaseDataLoader | None:
|
||||
"""Return the task dataloader used by this adapter, if any."""
|
||||
return None
|
||||
|
||||
def requires_ray(self) -> bool:
|
||||
"""Return whether this adapter requires Ray runtime initialization."""
|
||||
return False
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
"""Optional deeper diagnostic reflection pass.
|
||||
|
||||
Default behavior is a no-op. Dataset-backed adapters may override this
|
||||
to re-query the student on a small representative subset of the current
|
||||
batch using minimally-perturbed diagnostic prompts that expose
|
||||
intermediate reasoning state.
|
||||
"""
|
||||
return []
|
||||
|
||||
def build_reference_text(self, item: dict) -> str:
|
||||
"""Return hidden reference material for deep reflection, if any."""
|
||||
return str(item.get("reference_text") or "").strip()
|
||||
|
||||
def get_reference_metadata(self, item: dict) -> dict:
|
||||
"""Return structured metadata about hidden reference material."""
|
||||
reference_text = self.build_reference_text(item)
|
||||
if not reference_text:
|
||||
return {"fields": [], "preview": ""}
|
||||
return {
|
||||
"fields": ["reference_text"],
|
||||
"preview": reference_text[:400],
|
||||
}
|
||||
|
||||
def get_codex_deep_probe_prompt(self) -> str | None:
|
||||
env_name = getattr(self, "_cfg", {}).get("env_name")
|
||||
return load_prompt("deep_probe_codex", env=env_name)
|
||||
|
||||
def attach_codex_probe_context(
|
||||
self,
|
||||
results: list[dict],
|
||||
prediction_dir: str,
|
||||
) -> list[dict]:
|
||||
"""Attach compact Codex step metadata for codex-aware deep reflection."""
|
||||
enriched: list[dict] = []
|
||||
for row in results:
|
||||
merged = dict(row)
|
||||
tid = str(row.get("id"))
|
||||
raw_path = os.path.join(prediction_dir, tid, "codex_raw.txt")
|
||||
if os.path.exists(raw_path):
|
||||
with open(raw_path, encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
parsed = parse_codex_raw(raw)
|
||||
merged["codex_probe_trace_steps"] = format_codex_trace_steps(raw)
|
||||
merged["codex_probe_step_count"] = len(parsed["steps"])
|
||||
enriched.append(merged)
|
||||
return enriched
|
||||
|
||||
def resolve_codex_probe_target(
|
||||
self,
|
||||
*,
|
||||
selected_items: list[dict],
|
||||
selected_examples: list[dict],
|
||||
prediction_dir: str,
|
||||
probe: dict,
|
||||
) -> tuple[list[dict], dict[str, str] | None, dict]:
|
||||
"""Resolve the teacher-selected codex probe target and raw trace prefix."""
|
||||
target_id = str(probe.get("probe_target_id", "")).strip()
|
||||
selected_id_set = {str(item["id"]) for item in selected_items}
|
||||
if target_id not in selected_id_set:
|
||||
target_id = str(selected_items[0]["id"])
|
||||
target_item = next(item for item in selected_items if str(item["id"]) == target_id)
|
||||
target_result = next(
|
||||
(row for row in selected_examples if str(row.get("id")) == target_id),
|
||||
None,
|
||||
)
|
||||
max_probe_step = int((target_result or {}).get("codex_probe_step_count", 0))
|
||||
default_probe_step = max_probe_step - 1 if max_probe_step > 1 else max_probe_step
|
||||
probe_after_step = int(probe.get("probe_after_step", default_probe_step))
|
||||
if max_probe_step > 0:
|
||||
probe_after_step = max(0, min(probe_after_step, max_probe_step))
|
||||
else:
|
||||
probe_after_step = 0
|
||||
raw_path = os.path.join(prediction_dir, target_id, "codex_raw.txt")
|
||||
trace_prefix = ""
|
||||
if os.path.exists(raw_path):
|
||||
with open(raw_path, encoding="utf-8") as f:
|
||||
trace_prefix = extract_codex_trace_prefix(f.read(), after_step=probe_after_step)
|
||||
updated_probe = dict(probe)
|
||||
updated_probe["probe_target_id"] = target_id
|
||||
updated_probe["probe_after_step"] = probe_after_step
|
||||
return [target_item], {target_id: trace_prefix}, updated_probe
|
||||
|
||||
def attach_reference_context(
|
||||
self,
|
||||
results: list[dict],
|
||||
items: list[dict] | None,
|
||||
) -> list[dict]:
|
||||
"""Attach environment-specific hidden reference text to result dicts."""
|
||||
if not results or not items:
|
||||
return list(results)
|
||||
|
||||
item_by_id = {
|
||||
str(item.get("id")): item
|
||||
for item in items
|
||||
if isinstance(item, dict) and item.get("id") is not None
|
||||
}
|
||||
enriched: list[dict] = []
|
||||
for row in results:
|
||||
merged = dict(row)
|
||||
item = item_by_id.get(str(row.get("id")))
|
||||
if item:
|
||||
reference_text = self.build_reference_text(item)
|
||||
if reference_text:
|
||||
merged["reference_text"] = reference_text
|
||||
enriched.append(merged)
|
||||
return enriched
|
||||
|
||||
def select_representative_items(
|
||||
self,
|
||||
results: list[dict],
|
||||
items: list[dict] | None,
|
||||
*,
|
||||
n_failures: int,
|
||||
n_successes: int,
|
||||
seed: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Select a small diverse subset of current-batch items by outcome."""
|
||||
if not items:
|
||||
return []
|
||||
|
||||
item_by_id = {
|
||||
str(item.get("id")): item
|
||||
for item in items
|
||||
if isinstance(item, dict) and item.get("id") is not None
|
||||
}
|
||||
failures = [
|
||||
(result, item_by_id[str(result.get("id"))])
|
||||
for result in results
|
||||
if not result.get("hard") and str(result.get("id")) in item_by_id
|
||||
]
|
||||
successes = [
|
||||
(result, item_by_id[str(result.get("id"))])
|
||||
for result in results
|
||||
if result.get("hard") and str(result.get("id")) in item_by_id
|
||||
]
|
||||
|
||||
rng = random.Random(seed)
|
||||
|
||||
def _pick(pool: list[tuple[dict, dict]], quota: int) -> list[dict]:
|
||||
if quota <= 0 or not pool:
|
||||
return []
|
||||
shuffled = list(pool)
|
||||
rng.shuffle(shuffled)
|
||||
|
||||
picked_ids: set[str] = set()
|
||||
picked: list[dict] = []
|
||||
seen_types: set[str] = set()
|
||||
|
||||
for result, item in shuffled:
|
||||
task_type = str(result.get("task_type") or item.get("task_type") or item.get("subtype") or "unknown")
|
||||
item_id = str(item["id"])
|
||||
if task_type in seen_types or item_id in picked_ids:
|
||||
continue
|
||||
picked.append(item)
|
||||
picked_ids.add(item_id)
|
||||
seen_types.add(task_type)
|
||||
if len(picked) >= quota:
|
||||
return picked
|
||||
|
||||
for _, item in shuffled:
|
||||
item_id = str(item["id"])
|
||||
if item_id in picked_ids:
|
||||
continue
|
||||
picked.append(item)
|
||||
picked_ids.add(item_id)
|
||||
if len(picked) >= quota:
|
||||
break
|
||||
return picked
|
||||
|
||||
selected = _pick(failures, n_failures)
|
||||
selected_ids = {str(item["id"]) for item in selected}
|
||||
selected.extend(
|
||||
item for item in _pick(successes, n_successes)
|
||||
if str(item["id"]) not in selected_ids
|
||||
)
|
||||
return selected
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
"""Build an environment manager or item list from a :class:`BatchSpec`.
|
||||
|
||||
Default behavior preserves the legacy adapter API by routing training
|
||||
batches through :meth:`build_train_env` and evaluation batches through
|
||||
:meth:`build_eval_env`.
|
||||
"""
|
||||
if batch.phase == "train":
|
||||
return self.build_train_env(batch_size=batch.batch_size, seed=batch.seed, **kwargs)
|
||||
return self.build_eval_env(
|
||||
env_num=batch.batch_size,
|
||||
split=batch.split,
|
||||
seed=batch.seed,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
"""Build a training environment manager.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
An environment manager that can be passed to :meth:`rollout`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
"""Build an evaluation environment manager.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
env_num : int
|
||||
Number of evaluation environments.
|
||||
split : str
|
||||
Dataset split (e.g. ``"valid_seen"``, ``"valid_unseen"``).
|
||||
seed : int
|
||||
Random seed for reproducibility.
|
||||
|
||||
Returns
|
||||
-------
|
||||
object
|
||||
An environment manager that can be passed to :meth:`rollout`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def rollout(
|
||||
self,
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict]:
|
||||
"""Run a batch of episodes using the current skill.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict]
|
||||
Each dict conforms to :class:`~reflact.types.RolloutResult`:
|
||||
must have ``"id"`` (str), ``"hard"`` (0/1), ``"soft"``
|
||||
(float 0-1). May include env-specific fields.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
"""Analyze rollout results and produce patches.
|
||||
|
||||
Each returned dict conforms to :class:`~reflact.types.RawPatch`:
|
||||
``"patch"`` (with ``"edits"`` list) + ``"source_type"``
|
||||
(``"failure"`` or ``"success"``).
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict | None]
|
||||
Raw analyst outputs; ``None`` entries are filtered out.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_task_types(self) -> list[str]:
|
||||
"""Return the list of task type names for this environment."""
|
||||
|
||||
# ── Prompt configuration (two-level priority) ────────────────────────
|
||||
#
|
||||
# Priority: env-specific prompt file > generic default prompt file.
|
||||
#
|
||||
# Prompts are loaded from ``.md`` files via ``load_prompt(name, env)``:
|
||||
# 1. ``reflact/envs/<env>/prompts/<name>.md`` (env-specific)
|
||||
# 2. ``reflact/prompts/<name>.md`` (generic fallback)
|
||||
#
|
||||
# Subclasses can still override ``get_*_prompt()`` for full control.
|
||||
|
||||
@property
|
||||
def _env_name(self) -> str:
|
||||
"""Derive the env directory name from this adapter's module path."""
|
||||
# e.g. "reflact.envs.searchqa.adapter" → "searchqa"
|
||||
module = type(self).__module__
|
||||
parts = module.split(".")
|
||||
if len(parts) >= 3 and parts[-3] == "envs":
|
||||
return parts[-2]
|
||||
return ""
|
||||
|
||||
def _load_env_prompt(self, name: str) -> str | None:
|
||||
"""Load a prompt with env-specific override. Returns None if not found."""
|
||||
try:
|
||||
return load_prompt(name, env=self._env_name)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
def get_error_minibatch_prompt(self) -> str | None:
|
||||
update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch")
|
||||
raw_mode = str(update_mode).strip().lower()
|
||||
if raw_mode in {"full_rewrite", "full_rewrite_minibatch", "minibatch_full_rewrite", "skill_rewrite_minibatch"}:
|
||||
prompt = self._load_env_prompt("analyst_error_full_rewrite")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
if raw_mode in {"rewrite", "rewrite_from_suggestions", "suggestions", "rewrite_suggestions"}:
|
||||
prompt = self._load_env_prompt("analyst_error_rewrite")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
return self._load_env_prompt("analyst_error")
|
||||
|
||||
def get_success_minibatch_prompt(self) -> str | None:
|
||||
update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch")
|
||||
raw_mode = str(update_mode).strip().lower()
|
||||
if raw_mode in {"full_rewrite", "full_rewrite_minibatch", "minibatch_full_rewrite", "skill_rewrite_minibatch"}:
|
||||
prompt = self._load_env_prompt("analyst_success_full_rewrite")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
if raw_mode in {"rewrite", "rewrite_from_suggestions", "suggestions", "rewrite_suggestions"}:
|
||||
prompt = self._load_env_prompt("analyst_success_rewrite")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
return self._load_env_prompt("analyst_success")
|
||||
|
||||
def get_deep_probe_prompt(self) -> str | None:
|
||||
return self._load_env_prompt("deep_probe")
|
||||
|
||||
def get_meta_reflect_prompt(self) -> str | None:
|
||||
update_mode = getattr(self, "_cfg", {}).get("skill_update_mode", "patch")
|
||||
if str(update_mode).strip().lower() == "rewrite_from_suggestions":
|
||||
prompt = self._load_env_prompt("meta_reflect_rewrite")
|
||||
if prompt is not None:
|
||||
return prompt
|
||||
return self._load_env_prompt("meta_reflect")
|
||||
114
reflact/envs/deep_reflect.py
Normal file
114
reflact/envs/deep_reflect.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable
|
||||
|
||||
from reflact.gradient.deep_probe import generate_deep_probe_instruction
|
||||
from reflact.gradient.reflect import run_minibatch_reflect
|
||||
|
||||
|
||||
def run_no_reference_deep_reflect(
|
||||
adapter: Any,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
*,
|
||||
env_manager: Any = None,
|
||||
prediction_dir: str | None = None,
|
||||
random_seed: int | None = None,
|
||||
step_buffer_context: str = "",
|
||||
output_requirements: list[str] | None = None,
|
||||
metadata_builder: Callable[[dict], dict] | None = None,
|
||||
) -> list[dict | None]:
|
||||
"""Run teacher-designed diagnostic probing without hidden references."""
|
||||
if not getattr(adapter, "use_deep_reflect", False):
|
||||
return []
|
||||
if not isinstance(env_manager, list):
|
||||
return []
|
||||
|
||||
prediction_dir = prediction_dir or os.path.join(out_dir, "predictions")
|
||||
selected_items = adapter.select_representative_items(
|
||||
results,
|
||||
env_manager,
|
||||
n_failures=getattr(adapter, "deep_reflect_failures", 4),
|
||||
n_successes=getattr(adapter, "deep_reflect_successes", 2),
|
||||
seed=random_seed,
|
||||
)
|
||||
if not selected_items:
|
||||
return []
|
||||
|
||||
selected_ids = {str(item["id"]) for item in selected_items}
|
||||
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
|
||||
if metadata_builder is None:
|
||||
selected_metadata = [
|
||||
{
|
||||
"id": str(item.get("id")),
|
||||
"task_type": str(item.get("task_type") or item.get("topic") or "unknown"),
|
||||
"question_preview": str(item.get("question") or "")[:200],
|
||||
}
|
||||
for item in selected_items
|
||||
]
|
||||
else:
|
||||
selected_metadata = [metadata_builder(item) for item in selected_items]
|
||||
|
||||
deep_dir = os.path.join(out_dir, "deep_reflect")
|
||||
rollout_dir = os.path.join(deep_dir, "rollout")
|
||||
patches_dir = os.path.join(deep_dir, "patches")
|
||||
os.makedirs(deep_dir, exist_ok=True)
|
||||
print(
|
||||
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
|
||||
"mode=no_reference_probe"
|
||||
)
|
||||
|
||||
probe = generate_deep_probe_instruction(
|
||||
skill_content=skill_content,
|
||||
items=selected_results,
|
||||
prediction_dir=prediction_dir,
|
||||
system_prompt=adapter.get_deep_probe_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
output_requirements=output_requirements,
|
||||
)
|
||||
if not probe:
|
||||
return []
|
||||
|
||||
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
**probe,
|
||||
"reference_summary": {
|
||||
"mode": "no_reference_probe",
|
||||
"selected_count": len(selected_items),
|
||||
},
|
||||
"selected_examples": selected_metadata,
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
deep_results = adapter.rollout(
|
||||
selected_items,
|
||||
skill_content,
|
||||
rollout_dir,
|
||||
diagnostic_mode=True,
|
||||
diagnostic_instruction=probe["probe_instruction"],
|
||||
)
|
||||
return run_minibatch_reflect(
|
||||
results=deep_results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=os.path.join(rollout_dir, "predictions"),
|
||||
patches_dir=patches_dir,
|
||||
workers=getattr(adapter, "analyst_workers", 8),
|
||||
failure_only=getattr(adapter, "failure_only", False),
|
||||
minibatch_size=getattr(adapter, "minibatch_size", 8),
|
||||
edit_budget=getattr(adapter, "edit_budget", 4),
|
||||
random_seed=random_seed,
|
||||
error_system=adapter.get_error_minibatch_prompt(),
|
||||
success_system=adapter.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
update_mode=getattr(getattr(adapter, "_cfg", {}), "get", lambda *_: "patch")(
|
||||
"skill_update_mode",
|
||||
"patch",
|
||||
),
|
||||
)
|
||||
1
reflact/envs/docvqa/__init__.py
Normal file
1
reflact/envs/docvqa/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""DocVQA environment package for ReflACT."""
|
||||
153
reflact/envs/docvqa/adapter.py
Normal file
153
reflact/envs/docvqa/adapter.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from reflact.datasets.base import BatchSpec
|
||||
from reflact.envs.base import EnvAdapter
|
||||
from reflact.envs.deep_reflect import run_no_reference_deep_reflect
|
||||
from reflact.envs.docvqa.dataloader import DocVQADataLoader
|
||||
from reflact.envs.docvqa.rollout import run_batch
|
||||
from reflact.gradient.reflect import run_minibatch_reflect
|
||||
|
||||
|
||||
class DocVQAAdapter(EnvAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "split_dir",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
max_turns: int = 1,
|
||||
exec_timeout: int = 120,
|
||||
workers: int = 16,
|
||||
analyst_workers: int = 16,
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
exec_timeout: int = 600,
|
||||
image_detail: str = "auto",
|
||||
use_deep_reflect: bool = False,
|
||||
deep_reflect_failures: int = 4,
|
||||
deep_reflect_successes: int = 2,
|
||||
) -> None:
|
||||
self.max_turns = max_turns
|
||||
self.exec_timeout = exec_timeout
|
||||
self.workers = workers
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
self.edit_budget = edit_budget
|
||||
self.exec_timeout = exec_timeout
|
||||
self.image_detail = image_detail
|
||||
self.use_deep_reflect = use_deep_reflect
|
||||
self.deep_reflect_failures = deep_reflect_failures
|
||||
self.deep_reflect_successes = deep_reflect_successes
|
||||
self.dataloader = DocVQADataLoader(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
self.dataloader.setup(cfg)
|
||||
|
||||
def get_dataloader(self):
|
||||
return self.dataloader
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
return list(batch.payload or [])
|
||||
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def rollout(self, env_manager, skill_content: str, out_dir: str, **kwargs) -> list[dict]:
|
||||
items: list[dict] = env_manager
|
||||
return run_batch(
|
||||
items=items,
|
||||
out_root=out_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
exec_timeout=self.exec_timeout,
|
||||
workers=self.workers,
|
||||
image_detail=self.image_detail,
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
|
||||
task_timeout=self.exec_timeout,
|
||||
)
|
||||
|
||||
def reflect(self, results: list[dict], skill_content: str, out_dir: str, **kwargs) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
return run_no_reference_deep_reflect(
|
||||
self,
|
||||
results,
|
||||
skill_content,
|
||||
out_dir,
|
||||
env_manager=kwargs.get("env_manager"),
|
||||
prediction_dir=kwargs.get("prediction_dir"),
|
||||
random_seed=kwargs.get("random_seed"),
|
||||
step_buffer_context=kwargs.get("step_buffer_context", ""),
|
||||
output_requirements=[
|
||||
"- There is no hidden reference block. Use only the document image prompt, student output, and evaluation result to infer what intermediate state is worth probing.",
|
||||
"- The instruction must explicitly request a short <analysis>...</analysis> block before the final <answer>...</answer>.",
|
||||
"- The readout should focus on visual region, field/table/figure label, OCR text read, candidate answer, and answer-format normalization.",
|
||||
"- Do not ask for exhaustive transcription or a full chain-of-thought.",
|
||||
"- The instruction text should be ready to append directly to the student's prompt.",
|
||||
],
|
||||
metadata_builder=lambda item: {
|
||||
"id": str(item.get("id")),
|
||||
"task_type": str(item.get("task_type") or "docvqa"),
|
||||
"question_preview": str(item.get("question") or "")[:200],
|
||||
"image_path": item.get("image_path", ""),
|
||||
"docId": item.get("docId", ""),
|
||||
"page": item.get("ucsf_document_page_no", ""),
|
||||
},
|
||||
)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
seen: list[str] = []
|
||||
for item in self.dataloader.train_items + self.dataloader.val_items + self.dataloader.test_items:
|
||||
task_type = str(item.get("task_type") or "docvqa")
|
||||
if task_type not in seen:
|
||||
seen.append(task_type)
|
||||
return seen or ["docvqa"]
|
||||
61
reflact/envs/docvqa/dataloader.py
Normal file
61
reflact/envs/docvqa/dataloader.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import csv
|
||||
from pathlib import Path
|
||||
|
||||
from reflact.datasets.base import SplitDataLoader
|
||||
|
||||
|
||||
def _parse_answers(raw: str) -> list[str]:
|
||||
text = str(raw or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
try:
|
||||
parsed = ast.literal_eval(text)
|
||||
except Exception:
|
||||
return [text]
|
||||
if isinstance(parsed, list):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()]
|
||||
return [str(parsed).strip()]
|
||||
|
||||
|
||||
def _extract_document_path(question: str) -> tuple[str, str]:
|
||||
marker = "document_path:"
|
||||
if marker not in question:
|
||||
return question.strip(), ""
|
||||
main, tail = question.split(marker, 1)
|
||||
return main.strip(), tail.strip()
|
||||
|
||||
|
||||
def _normalize_row(row: dict[str, str]) -> dict:
|
||||
question_text, document_path = _extract_document_path(str(row.get("question") or ""))
|
||||
answers = _parse_answers(row.get("answer") or row.get("ground_truth") or "")
|
||||
image_path = str(row.get("image_path") or document_path or "").strip()
|
||||
task_type = str(row.get("topic") or row.get("category") or "docvqa").strip() or "docvqa"
|
||||
return {
|
||||
"id": str(row.get("questionId") or row.get("id") or "").strip(),
|
||||
"question": question_text,
|
||||
"answer": answers[0] if answers else "",
|
||||
"answers": answers,
|
||||
"task_type": task_type,
|
||||
"subtask": task_type,
|
||||
"image_paths": [image_path] if image_path else [],
|
||||
"image_path": image_path,
|
||||
"questionId": str(row.get("questionId") or "").strip(),
|
||||
"docId": str(row.get("docId") or "").strip(),
|
||||
"ucsf_document_id": str(row.get("ucsf_document_id") or "").strip(),
|
||||
"ucsf_document_page_no": str(row.get("ucsf_document_page_no") or "").strip(),
|
||||
"source_split": str(row.get("source_split") or "").strip(),
|
||||
}
|
||||
|
||||
|
||||
class DocVQADataLoader(SplitDataLoader):
|
||||
def load_split_items(self, split_path: str) -> list[dict]:
|
||||
path = Path(split_path)
|
||||
csv_files = sorted(path.glob("*.csv"))
|
||||
if not csv_files:
|
||||
raise FileNotFoundError(f"No .csv file found in {split_path}")
|
||||
with csv_files[0].open(encoding="utf-8", newline="") as f:
|
||||
reader = csv.DictReader(f)
|
||||
return [_normalize_row(row) for row in reader]
|
||||
113
reflact/envs/docvqa/evaluator.py
Normal file
113
reflact/envs/docvqa/evaluator.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
DEFAULT_ANLS_THRESHOLD = 0.5
|
||||
|
||||
|
||||
def _normalize_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
text = str(value).strip().lower()
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def _levenshtein_distance(a: str, b: str) -> int:
|
||||
if a == b:
|
||||
return 0
|
||||
if not a:
|
||||
return len(b)
|
||||
if not b:
|
||||
return len(a)
|
||||
if len(a) > len(b):
|
||||
a, b = b, a
|
||||
previous = list(range(len(b) + 1))
|
||||
for i, char_a in enumerate(a, start=1):
|
||||
current = [i]
|
||||
for j, char_b in enumerate(b, start=1):
|
||||
insert_cost = current[j - 1] + 1
|
||||
delete_cost = previous[j] + 1
|
||||
replace_cost = previous[j - 1] + (char_a != char_b)
|
||||
current.append(min(insert_cost, delete_cost, replace_cost))
|
||||
previous = current
|
||||
return previous[-1]
|
||||
|
||||
|
||||
def _score_single_answer(predicted: Any, target: Any, threshold: float) -> float:
|
||||
predicted_norm = _normalize_text(predicted)
|
||||
target_norm = _normalize_text(target)
|
||||
if not predicted_norm and not target_norm:
|
||||
return 1.0
|
||||
if not predicted_norm or not target_norm:
|
||||
return 0.0
|
||||
distance = _levenshtein_distance(predicted_norm, target_norm)
|
||||
normalized_distance = distance / max(len(predicted_norm), len(target_norm))
|
||||
if normalized_distance >= threshold:
|
||||
return 0.0
|
||||
return 1.0 - normalized_distance
|
||||
|
||||
|
||||
def _extract_answer_strings(raw: Any) -> list[str]:
|
||||
if raw is None:
|
||||
return [""]
|
||||
if isinstance(raw, str):
|
||||
text = raw.strip()
|
||||
if not text:
|
||||
return [""]
|
||||
parsed = None
|
||||
if text[0] in "[{":
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
parsed = ast.literal_eval(text)
|
||||
except (ValueError, SyntaxError):
|
||||
parsed = None
|
||||
if parsed is None:
|
||||
return [text]
|
||||
return _extract_answer_strings(parsed)
|
||||
if isinstance(raw, dict):
|
||||
for key in ("answers", "ground_truth", "answer"):
|
||||
if key in raw:
|
||||
return _extract_answer_strings(raw[key])
|
||||
return [str(raw)]
|
||||
if isinstance(raw, Iterable) and not isinstance(raw, (bytes, bytearray)):
|
||||
answers: list[str] = []
|
||||
for item in raw:
|
||||
if isinstance(item, dict):
|
||||
for key in ("text", "answer", "value"):
|
||||
if key in item:
|
||||
answers.extend(_extract_answer_strings(item[key]))
|
||||
break
|
||||
else:
|
||||
answers.append(str(item))
|
||||
continue
|
||||
answers.append(str(item))
|
||||
return answers or [""]
|
||||
return [str(raw)]
|
||||
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
lower = text.lower()
|
||||
start = lower.rfind("<answer>")
|
||||
end = lower.rfind("</answer>")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return text[start + len("<answer>"):end].strip()
|
||||
lines = [line.strip() for line in text.splitlines() if line.strip()]
|
||||
return lines[-1] if lines else text.strip()
|
||||
|
||||
|
||||
def evaluate(prediction_text: str, gold_answers: Any) -> dict:
|
||||
answer = extract_answer(prediction_text)
|
||||
answers = _extract_answer_strings(gold_answers)
|
||||
score = 0.0
|
||||
for target in answers:
|
||||
score = max(score, _score_single_answer(answer, target, DEFAULT_ANLS_THRESHOLD))
|
||||
return {
|
||||
"anls": score,
|
||||
"predicted_answer": answer,
|
||||
"gold_answers": answers,
|
||||
}
|
||||
35
reflact/envs/docvqa/prompts/analyst_error.md
Normal file
35
reflact/envs/docvqa/prompts/analyst_error.md
Normal file
@@ -0,0 +1,35 @@
|
||||
You are an expert failure-analysis agent for visual document question answering tasks.
|
||||
|
||||
You will be given MULTIPLE failed DocVQA trajectories from a single minibatch and the current skill document. Each trajectory includes the model response and an evaluation result scored with ANLS against one or more acceptable answers.
|
||||
|
||||
Your job is to identify the most important COMMON failure patterns across the batch and propose concise skill edits.
|
||||
|
||||
## Failure Type Categories
|
||||
- evidence_miss: the model overlooked the relevant visible region or line
|
||||
- near_match_confusion: the model selected a nearby but incorrect text span
|
||||
- normalization_error: the answer differed mainly in formatting, spacing, punctuation, or minor text normalization
|
||||
- reading_error: the model misread the document content
|
||||
- other: none of the above
|
||||
|
||||
## Rules
|
||||
- Focus on common, reusable reading and extraction behaviors.
|
||||
- Do not hardcode image-specific answers.
|
||||
- Prefer concise edits that improve evidence selection and exact span extraction.
|
||||
|
||||
Respond ONLY with a valid JSON object (no markdown fences, no extra text):
|
||||
{
|
||||
"batch_size": <number of trajectories analysed>,
|
||||
"failure_summary": [
|
||||
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
|
||||
],
|
||||
"patch": {
|
||||
"reasoning": "<why these edits address the batch's common failures>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown to add at end of skill>"},
|
||||
{"op": "insert_after", "target": "<exact heading/text to insert after>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<exact text to replace>", "content": "<replacement>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
Only include edits that are needed. "edits" can be an empty list if no patch is warranted.
|
||||
24
reflact/envs/docvqa/prompts/analyst_success.md
Normal file
24
reflact/envs/docvqa/prompts/analyst_success.md
Normal file
@@ -0,0 +1,24 @@
|
||||
You are an expert success-pattern analyst for visual document question answering tasks.
|
||||
|
||||
You will be given MULTIPLE successful DocVQA trajectories from a single minibatch and the current skill document. Your job is to identify common visual reading and exact-answer extraction behaviors worth encoding in the skill.
|
||||
|
||||
## Rules
|
||||
- Focus on patterns shared across multiple successful trajectories.
|
||||
- Reinforce reusable behaviors like locating the right region, copying exact spans, and preferring the shortest exact answer over paraphrase.
|
||||
- Only propose patches for patterns not already captured by the current skill.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number of trajectories analysed>,
|
||||
"success_patterns": ["<pattern 1>", "<pattern 2>"],
|
||||
"patch": {
|
||||
"reasoning": "<why these patterns are worth encoding>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
"edits" may be empty if the skill already covers all observed patterns.
|
||||
12
reflact/envs/docvqa/prompts/rollout_system.md
Normal file
12
reflact/envs/docvqa/prompts/rollout_system.md
Normal file
@@ -0,0 +1,12 @@
|
||||
You are an expert visual document question answering agent.
|
||||
|
||||
{skill_section}You will receive a document image and a question about the document.
|
||||
Read the visual evidence carefully and answer concisely.
|
||||
|
||||
Rules:
|
||||
- Ground the answer in the visible document content.
|
||||
- Prefer exact spans, numbers, dates, and names from the document.
|
||||
- Do not invent content that is not visible.
|
||||
- If multiple near-matches exist, choose the one best supported by the document.
|
||||
|
||||
Return the final answer inside <answer>...</answer>.
|
||||
365
reflact/envs/docvqa/rollout.py
Normal file
365
reflact/envs/docvqa/rollout.py
Normal file
@@ -0,0 +1,365 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||
|
||||
from reflact.envs.docvqa.evaluator import evaluate
|
||||
from reflact.model import chat_student_messages, get_student_backend, is_student_exec_backend
|
||||
from reflact.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
|
||||
from reflact.prompts import load_prompt
|
||||
|
||||
|
||||
def _build_system(skill_content: str) -> str:
|
||||
if skill_content.strip():
|
||||
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
|
||||
else:
|
||||
skill_section = ""
|
||||
return load_prompt("rollout_system", env="docvqa").format(skill_section=skill_section)
|
||||
|
||||
|
||||
def _image_to_data_uri(path: str) -> str:
|
||||
import base64
|
||||
import mimetypes
|
||||
|
||||
mime = mimetypes.guess_type(path)[0] or "image/png"
|
||||
with open(path, "rb") as f:
|
||||
encoded = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:{mime};base64,{encoded}"
|
||||
|
||||
|
||||
def _build_messages(
|
||||
item: dict,
|
||||
skill_content: str,
|
||||
image_detail: str,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
) -> tuple[list[dict], str, str]:
|
||||
system = _build_system(skill_content)
|
||||
user_text = item["question"] + "\n\nReturn the final answer inside <answer>...</answer>."
|
||||
if diagnostic_mode and diagnostic_instruction.strip():
|
||||
user_text += f"\n\n## Training Readout\n{diagnostic_instruction.strip()}"
|
||||
image_url = {"url": _image_to_data_uri(item["image_path"])}
|
||||
if image_detail and image_detail != "auto":
|
||||
image_url["detail"] = image_detail
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_text},
|
||||
{"type": "image_url", "image_url": image_url},
|
||||
],
|
||||
},
|
||||
]
|
||||
return messages, system, user_text
|
||||
|
||||
|
||||
def _build_codex_skill(skill_content: str) -> str:
|
||||
return render_skill_md(
|
||||
skill_content,
|
||||
description="Dynamic ReflACT skill for solving the current DocVQA document-image question.",
|
||||
preamble=(
|
||||
"Use this skill when answering the current DocVQA question.\n"
|
||||
"Inspect the attached document image carefully and return the final answer inside <answer>...</answer>."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _run_codex_once(
|
||||
*,
|
||||
pred_dir: str,
|
||||
item: dict,
|
||||
skill_content: str,
|
||||
model: str,
|
||||
timeout: int,
|
||||
image_detail: str,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
previous_response: str = "",
|
||||
) -> tuple[str, str, str, str]:
|
||||
_ = image_detail
|
||||
_messages, _system, user_text = _build_messages(
|
||||
item,
|
||||
skill_content,
|
||||
image_detail,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
)
|
||||
task_parts = [user_text]
|
||||
image_abs = os.path.abspath(item["image_path"])
|
||||
task_parts.append(
|
||||
"## Document Image\n"
|
||||
"The document image is available in this workspace via `ATTACHMENTS.md`.\n"
|
||||
f"Original image path: `{image_abs}`\n"
|
||||
"Open or inspect that image before answering; do not answer from memory."
|
||||
)
|
||||
if previous_response:
|
||||
task_parts.append(
|
||||
"## Previous Attempt\n"
|
||||
f"{previous_response}\n\n"
|
||||
"Review the same document image carefully and correct the answer if needed."
|
||||
)
|
||||
task_text = "\n\n".join(task_parts)
|
||||
skill_md = _build_codex_skill(skill_content)
|
||||
work_dir = os.path.join(pred_dir, "codex_exec")
|
||||
prepare_workspace(
|
||||
work_dir=work_dir,
|
||||
skill_md=skill_md,
|
||||
task_text=task_text,
|
||||
images=[item["image_path"]],
|
||||
)
|
||||
prompt = (
|
||||
"Use the `reflact-student` skill available in this workspace.\n"
|
||||
"Read `task.md`, inspect the attached document image, and answer the DocVQA question.\n"
|
||||
"Return the final answer inside <answer>...</answer>."
|
||||
)
|
||||
final_message, raw = run_student_exec(
|
||||
work_dir=work_dir,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
images=[item["image_path"]],
|
||||
)
|
||||
return final_message or raw, raw, skill_md, task_text
|
||||
|
||||
|
||||
def process_one(
|
||||
item: dict,
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
exec_timeout: int = 120,
|
||||
image_detail: str = "auto",
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
) -> dict:
|
||||
item_id = str(item["id"])
|
||||
result = {
|
||||
"id": item_id,
|
||||
"question": item["question"],
|
||||
"task_type": item.get("subtask") or item.get("task_type") or "docvqa",
|
||||
"task_description": item["question"],
|
||||
"hard": 0,
|
||||
"soft": 0.0,
|
||||
"predicted_answer": "",
|
||||
"response": "",
|
||||
"fail_reason": "",
|
||||
"agent_ok": False,
|
||||
"n_turns": 0,
|
||||
"image_paths": item.get("image_paths", []),
|
||||
"gold_answer": item.get("answers", []),
|
||||
}
|
||||
try:
|
||||
response = ""
|
||||
system_prompt = ""
|
||||
user_text = ""
|
||||
conversation: list[dict] = []
|
||||
if is_student_exec_backend():
|
||||
from reflact.model import azure_openai as _llm
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": item["question"] + "\n\n" + f"[image] {os.path.basename(item['image_path'])}",
|
||||
}
|
||||
]
|
||||
for turn in range(max_turns):
|
||||
response, _raw, system_prompt, user_text = _run_codex_once(
|
||||
pred_dir=os.path.join(out_root, "predictions", item_id),
|
||||
item=item,
|
||||
skill_content=skill_content,
|
||||
model=_llm.STUDENT_DEPLOYMENT,
|
||||
timeout=exec_timeout,
|
||||
image_detail=image_detail,
|
||||
diagnostic_mode=diagnostic_mode if turn == 0 else False,
|
||||
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
|
||||
previous_response=response if turn > 0 else "",
|
||||
)
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": response})
|
||||
if "<answer>" in response.lower():
|
||||
break
|
||||
else:
|
||||
messages, system_prompt, user_text = _build_messages(
|
||||
item,
|
||||
skill_content,
|
||||
image_detail,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
)
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_text + "\n\n" + f"[image] {os.path.basename(item['image_path'])}",
|
||||
}
|
||||
]
|
||||
for turn in range(max_turns):
|
||||
if turn == 0:
|
||||
resp_text, _ = chat_student_messages(
|
||||
messages=messages,
|
||||
max_completion_tokens=768,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
else:
|
||||
refinement_messages = [
|
||||
messages[0],
|
||||
messages[1],
|
||||
{"role": "assistant", "content": response},
|
||||
{"role": "user", "content": "Review the same image carefully and answer again. Keep the final answer inside <answer>...</answer>."},
|
||||
]
|
||||
resp_text, _ = chat_student_messages(
|
||||
messages=refinement_messages,
|
||||
max_completion_tokens=512,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
response = resp_text
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
|
||||
if "<answer>" in resp_text.lower():
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation) - 1
|
||||
|
||||
pred_dir = os.path.join(out_root, "predictions", item_id)
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system_prompt)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user_text)
|
||||
|
||||
eval_result = evaluate(response, item.get("answers", []))
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["anls"] >= 0.999)
|
||||
result["soft"] = eval_result["anls"]
|
||||
if result["soft"] <= 0.0:
|
||||
result["fail_reason"] = f"predicted '{eval_result['predicted_answer']}' but expected one of {item.get('answers', [])}"
|
||||
|
||||
eval_detail = (
|
||||
"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
|
||||
f"Gold answers: {item.get('answers', [])!r}\n"
|
||||
f"ANLS: {eval_result['anls']:.4f}"
|
||||
)
|
||||
conversation.append({"role": "system", "content": eval_detail})
|
||||
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(conversation, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
result["fail_reason"] = f"error: {e}"
|
||||
return result
|
||||
|
||||
|
||||
def run_batch(
|
||||
items: list[dict],
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
exec_timeout: int = 120,
|
||||
workers: int = 16,
|
||||
image_detail: str = "auto",
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
task_timeout: int = 600,
|
||||
) -> list[dict]:
|
||||
task_timeout = max(int(task_timeout), int(exec_timeout) + 60)
|
||||
results_path = os.path.join(out_root, "results.jsonl")
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
|
||||
done_ids: set[str] = set()
|
||||
existing: list[dict] = []
|
||||
if os.path.exists(results_path):
|
||||
with open(results_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except Exception:
|
||||
continue
|
||||
done_ids.add(str(row["id"]))
|
||||
existing.append(row)
|
||||
|
||||
pending = [item for item in items if str(item["id"]) not in done_ids]
|
||||
if not pending:
|
||||
return existing
|
||||
|
||||
def _timeout_result(item: dict) -> dict:
|
||||
return {
|
||||
"id": str(item["id"]),
|
||||
"question": item.get("question", ""),
|
||||
"task_type": item.get("subtask") or item.get("task_type") or "docvqa",
|
||||
"task_description": item.get("question", ""),
|
||||
"hard": 0,
|
||||
"soft": 0.0,
|
||||
"predicted_answer": "",
|
||||
"response": "",
|
||||
"fail_reason": f"task-timeout-{task_timeout}s",
|
||||
"agent_ok": False,
|
||||
"n_turns": 0,
|
||||
"image_paths": item.get("image_paths", []),
|
||||
"gold_answer": item.get("answers", []),
|
||||
"phase": "timeout",
|
||||
}
|
||||
|
||||
def _error_result(item: dict, exc: Exception) -> dict:
|
||||
row = _timeout_result(item)
|
||||
row["phase"] = "error"
|
||||
row["fail_reason"] = f"unexpected: {type(exc).__name__}: {exc}"
|
||||
return row
|
||||
|
||||
started_at: dict[str, float] = {}
|
||||
|
||||
def _run_one(item: dict) -> dict:
|
||||
started_at[str(item["id"])] = time.time()
|
||||
return process_one(
|
||||
item,
|
||||
out_root,
|
||||
skill_content,
|
||||
max_turns=max_turns,
|
||||
exec_timeout=exec_timeout,
|
||||
image_detail=image_detail,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
)
|
||||
|
||||
results = list(existing)
|
||||
with open(results_path, "a", encoding="utf-8") as outf:
|
||||
ex = ThreadPoolExecutor(max_workers=workers)
|
||||
try:
|
||||
futs = {ex.submit(_run_one, item): item for item in pending}
|
||||
pending_futs = set(futs)
|
||||
while pending_futs:
|
||||
done, _ = wait(pending_futs, timeout=5, return_when=FIRST_COMPLETED)
|
||||
now = time.time()
|
||||
timed_out = [
|
||||
fut for fut in pending_futs - done
|
||||
if str(futs[fut]["id"]) in started_at
|
||||
and now - started_at[str(futs[fut]["id"])] >= task_timeout
|
||||
]
|
||||
for fut in done:
|
||||
pending_futs.remove(fut)
|
||||
item = futs[fut]
|
||||
try:
|
||||
res = fut.result()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
res = _error_result(item, exc)
|
||||
results.append(res)
|
||||
outf.write(json.dumps(res, ensure_ascii=False) + "\n")
|
||||
outf.flush()
|
||||
for fut in timed_out:
|
||||
pending_futs.remove(fut)
|
||||
fut.cancel()
|
||||
res = _timeout_result(futs[fut])
|
||||
results.append(res)
|
||||
outf.write(json.dumps(res, ensure_ascii=False) + "\n")
|
||||
outf.flush()
|
||||
finally:
|
||||
ex.shutdown(wait=False, cancel_futures=True)
|
||||
return results
|
||||
11
reflact/envs/docvqa/skills/initial.md
Normal file
11
reflact/envs/docvqa/skills/initial.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# DocVQA Skill
|
||||
|
||||
## Visual Evidence Discipline
|
||||
- Read the document carefully before answering.
|
||||
- Prefer the smallest exact text span that answers the question.
|
||||
- When several nearby strings look similar, choose the one whose surrounding labels or layout best match the question.
|
||||
|
||||
## Exact Answer Discipline
|
||||
- Copy names, numbers, and dates exactly from the document whenever possible.
|
||||
- Prefer direct extraction over paraphrase.
|
||||
- Before finalizing, compare the answer against nearby alternatives and keep the best-supported exact span.
|
||||
1
reflact/envs/livemathematicianbench/__init__.py
Normal file
1
reflact/envs/livemathematicianbench/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""LiveMathematicianBench environment package for ReflACT."""
|
||||
284
reflact/envs/livemathematicianbench/adapter.py
Normal file
284
reflact/envs/livemathematicianbench/adapter.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""LiveMathematicianBench environment adapter for ReflACT."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from reflact.gradient.deep_probe import generate_deep_probe_instruction
|
||||
from reflact.datasets.base import BatchSpec
|
||||
from reflact.gradient.reflect import run_minibatch_reflect
|
||||
from reflact.envs.base import EnvAdapter
|
||||
from reflact.envs.livemathematicianbench.dataloader import LiveMathematicianBenchDataLoader
|
||||
from reflact.envs.livemathematicianbench.rollout import run_batch
|
||||
from reflact.model import get_student_backend
|
||||
|
||||
|
||||
class LiveMathematicianBenchAdapter(EnvAdapter):
|
||||
"""LiveMathematicianBench adapter."""
|
||||
|
||||
def build_reference_text(self, item: dict) -> str:
|
||||
parts: list[str] = []
|
||||
theorem = str(item.get("theorem") or "").strip()
|
||||
sketch = str(item.get("sketch") or "").strip()
|
||||
if theorem:
|
||||
parts.append(f"## Reference Theorem\n{theorem}")
|
||||
if sketch:
|
||||
parts.append(f"## Reference Sketch\n{sketch}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def get_reference_metadata(self, item: dict) -> dict:
|
||||
fields: list[str] = []
|
||||
previews: list[str] = []
|
||||
theorem = str(item.get("theorem") or "").strip()
|
||||
sketch = str(item.get("sketch") or "").strip()
|
||||
if theorem:
|
||||
fields.append("theorem")
|
||||
previews.append(f"[theorem]\n{theorem[:220]}")
|
||||
if sketch:
|
||||
fields.append("sketch")
|
||||
previews.append(f"[sketch]\n{sketch[:220]}")
|
||||
return {
|
||||
"fields": fields,
|
||||
"preview": "\n\n".join(previews)[:500],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "ratio",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
max_turns: int = 1,
|
||||
exec_timeout: int = 300,
|
||||
workers: int = 64,
|
||||
analyst_workers: int = 16,
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
shuffle_choices: bool = True,
|
||||
use_theorem: bool = False,
|
||||
use_sketch: bool = False,
|
||||
exec_timeout: int = 600,
|
||||
use_deep_reflect: bool = False,
|
||||
deep_reflect_failures: int = 4,
|
||||
deep_reflect_successes: int = 2,
|
||||
) -> None:
|
||||
self.max_turns = max_turns
|
||||
self.exec_timeout = exec_timeout
|
||||
self.workers = workers
|
||||
self.exec_timeout = exec_timeout
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
self.edit_budget = edit_budget
|
||||
self.use_theorem = use_theorem
|
||||
self.use_sketch = use_sketch
|
||||
self.use_deep_reflect = use_deep_reflect
|
||||
self.deep_reflect_failures = deep_reflect_failures
|
||||
self.deep_reflect_successes = deep_reflect_successes
|
||||
self.dataloader = LiveMathematicianBenchDataLoader(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
shuffle_choices=shuffle_choices,
|
||||
)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
self.dataloader.setup(cfg)
|
||||
|
||||
def get_dataloader(self):
|
||||
return self.dataloader
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
return list(batch.payload or [])
|
||||
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def rollout(
|
||||
self,
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict]:
|
||||
items: list[dict] = env_manager
|
||||
return run_batch(
|
||||
items=items,
|
||||
out_root=out_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
exec_timeout=self.exec_timeout,
|
||||
workers=self.workers,
|
||||
use_theorem=self.use_theorem,
|
||||
use_sketch=self.use_sketch,
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
|
||||
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),
|
||||
task_timeout=self.exec_timeout,
|
||||
)
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
if not self.use_deep_reflect:
|
||||
return []
|
||||
|
||||
env_manager = kwargs.get("env_manager")
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
codex_backend = get_student_backend() == "codex_exec"
|
||||
selected_items = self.select_representative_items(
|
||||
results,
|
||||
env_manager if isinstance(env_manager, list) else None,
|
||||
n_failures=self.deep_reflect_failures,
|
||||
n_successes=self.deep_reflect_successes,
|
||||
seed=random_seed,
|
||||
)
|
||||
if not selected_items:
|
||||
return []
|
||||
selected_ids = {str(item["id"]) for item in selected_items}
|
||||
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
|
||||
selected_examples = self.attach_reference_context(selected_results, selected_items)
|
||||
if codex_backend:
|
||||
selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir)
|
||||
selected_metadata = []
|
||||
theorem_count = 0
|
||||
sketch_count = 0
|
||||
for item in selected_items:
|
||||
meta = self.get_reference_metadata(item)
|
||||
if "theorem" in meta["fields"]:
|
||||
theorem_count += 1
|
||||
if "sketch" in meta["fields"]:
|
||||
sketch_count += 1
|
||||
selected_metadata.append({
|
||||
"id": str(item["id"]),
|
||||
"task_type": str(item.get("theorem_type", ["math_mcq"])[0] if item.get("theorem_type") else "math_mcq"),
|
||||
"reference_fields": meta["fields"],
|
||||
"reference_preview": meta["preview"],
|
||||
})
|
||||
|
||||
deep_dir = os.path.join(out_dir, "deep_reflect")
|
||||
rollout_dir = os.path.join(deep_dir, "rollout")
|
||||
patches_dir = os.path.join(deep_dir, "patches")
|
||||
os.makedirs(deep_dir, exist_ok=True)
|
||||
print(
|
||||
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
|
||||
f"reference_fields=theorem({theorem_count}/{len(selected_items)}),"
|
||||
f"sketch({sketch_count}/{len(selected_items)})"
|
||||
)
|
||||
probe = generate_deep_probe_instruction(
|
||||
skill_content=skill_content,
|
||||
items=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
)
|
||||
if not probe:
|
||||
return []
|
||||
diagnostic_trace_context_by_id = None
|
||||
if codex_backend:
|
||||
selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target(
|
||||
selected_items=selected_items,
|
||||
selected_examples=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
probe=probe,
|
||||
)
|
||||
probe_record = {
|
||||
**probe,
|
||||
"reference_summary": {
|
||||
"selected_count": len(selected_items),
|
||||
"field_counts": {
|
||||
"theorem": theorem_count,
|
||||
"sketch": sketch_count,
|
||||
},
|
||||
},
|
||||
"selected_examples": selected_metadata,
|
||||
}
|
||||
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(probe_record, f, ensure_ascii=False, indent=2)
|
||||
deep_results = run_batch(
|
||||
items=selected_items,
|
||||
out_root=rollout_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
workers=min(self.workers, max(len(selected_items), 1)),
|
||||
use_theorem=self.use_theorem,
|
||||
use_sketch=self.use_sketch,
|
||||
diagnostic_mode=True,
|
||||
diagnostic_instruction=probe["probe_instruction"],
|
||||
diagnostic_trace_context_by_id=diagnostic_trace_context_by_id,
|
||||
task_timeout=self.exec_timeout,
|
||||
)
|
||||
deep_results = self.attach_reference_context(deep_results, selected_items)
|
||||
return run_minibatch_reflect(
|
||||
results=deep_results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=os.path.join(rollout_dir, "predictions"),
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return self.dataloader.get_task_types()
|
||||
308
reflact/envs/livemathematicianbench/dataloader.py
Normal file
308
reflact/envs/livemathematicianbench/dataloader.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""LiveMathematicianBench task dataloader."""
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
from reflact.datasets.base import BatchSpec, SplitDataLoader
|
||||
|
||||
|
||||
# ── Raw data loading utilities (for preprocessing / standalone eval) ─────
|
||||
|
||||
_CHOICE_LABELS = ["A", "B", "C", "D", "E", "F", "G"]
|
||||
|
||||
|
||||
def _load_json(path: str) -> Any:
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _iter_monthly_files(data_path: str) -> list[str]:
|
||||
if not data_path:
|
||||
return []
|
||||
if os.path.isfile(data_path):
|
||||
return [data_path]
|
||||
if os.path.isdir(data_path):
|
||||
nested = glob.glob(
|
||||
os.path.join(data_path, "**", "qa_*_final.json"),
|
||||
recursive=True,
|
||||
)
|
||||
flat = glob.glob(os.path.join(data_path, "qa_*_final.json"))
|
||||
return sorted(set(nested + flat))
|
||||
return []
|
||||
|
||||
|
||||
def _coerce_choices(raw_choices: Any) -> list[dict]:
|
||||
if isinstance(raw_choices, list):
|
||||
choices: list[dict] = []
|
||||
for idx, item in enumerate(raw_choices):
|
||||
if isinstance(item, dict):
|
||||
label = str(item.get("label") or _CHOICE_LABELS[idx]).strip()
|
||||
text = str(item.get("text") or item.get("content") or "").strip()
|
||||
else:
|
||||
label = _CHOICE_LABELS[idx]
|
||||
text = str(item).strip()
|
||||
if text:
|
||||
choices.append({"label": label, "text": text})
|
||||
return choices
|
||||
|
||||
if isinstance(raw_choices, dict):
|
||||
labels = sorted(raw_choices.keys())
|
||||
return [
|
||||
{"label": str(label).strip(), "text": str(raw_choices[label]).strip()}
|
||||
for label in labels
|
||||
if str(raw_choices[label]).strip()
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def _coerce_theorem_types(raw: Any) -> list[str]:
|
||||
if isinstance(raw, list):
|
||||
return [str(x).strip() for x in raw if str(x).strip()]
|
||||
if raw is None:
|
||||
return []
|
||||
text = str(raw).strip()
|
||||
return [text] if text else []
|
||||
|
||||
|
||||
def _normalize_label(text: str) -> str:
|
||||
return str(text).strip().upper().rstrip(".):")
|
||||
|
||||
|
||||
def _normalize_item(item: dict, row_idx: int, source_path: str) -> dict:
|
||||
mcq = item.get("mcq", {}) if isinstance(item.get("mcq"), dict) else {}
|
||||
question = str(mcq.get("question") or item.get("question") or "").strip()
|
||||
choices = _coerce_choices(mcq.get("choices") or item.get("choices") or [])
|
||||
correct = mcq.get("correct_choice") or item.get("correct_choice") or {}
|
||||
|
||||
if isinstance(correct, dict):
|
||||
correct_label = _normalize_label(correct.get("label", ""))
|
||||
correct_text = str(correct.get("text") or "").strip()
|
||||
else:
|
||||
correct_label = _normalize_label(correct)
|
||||
correct_text = ""
|
||||
|
||||
choice_by_label = {
|
||||
_normalize_label(choice["label"]): choice["text"]
|
||||
for choice in choices
|
||||
}
|
||||
if correct_label and not correct_text:
|
||||
correct_text = choice_by_label.get(correct_label, "")
|
||||
if correct_label and correct_text and correct_label not in choice_by_label:
|
||||
choices.append({"label": correct_label, "text": correct_text})
|
||||
choices.sort(key=lambda choice: _CHOICE_LABELS.index(choice["label"]) if choice["label"] in _CHOICE_LABELS else len(_CHOICE_LABELS))
|
||||
choice_by_label[correct_label] = correct_text
|
||||
|
||||
month = str(item.get("month") or "").strip()
|
||||
item_no = item.get("no", row_idx + 1)
|
||||
item_id = f"{month}:{item_no}" if month else str(item_no)
|
||||
|
||||
return {
|
||||
"id": item_id,
|
||||
"month": month,
|
||||
"no": item_no,
|
||||
"paper_link": str(item.get("paper_link") or "").strip(),
|
||||
"theorem": str(item.get("theorem") or "").strip(),
|
||||
"sketch": str(item.get("sketch") or "").strip(),
|
||||
"theorem_type": _coerce_theorem_types(item.get("theorem_type")),
|
||||
"question": question,
|
||||
"choices": choices,
|
||||
"correct_choice": {
|
||||
"label": correct_label,
|
||||
"text": correct_text,
|
||||
},
|
||||
"source_path": source_path,
|
||||
}
|
||||
|
||||
|
||||
def load_items(data_path: str) -> list[dict]:
|
||||
"""Load and normalise LiveMathematicianBench items from JSON files."""
|
||||
files = _iter_monthly_files(data_path)
|
||||
if not files:
|
||||
raise ValueError(
|
||||
"LiveMathematicianBench requires data_path to be a qa_*_final.json file "
|
||||
"or a directory containing monthly qa_*_final.json files."
|
||||
)
|
||||
|
||||
items: list[dict] = []
|
||||
for path in files:
|
||||
raw = _load_json(path)
|
||||
if not isinstance(raw, list):
|
||||
raise ValueError(f"Expected JSON array in {path}, got {type(raw).__name__}")
|
||||
for row_idx, item in enumerate(raw):
|
||||
norm = _normalize_item(item, row_idx=row_idx, source_path=path)
|
||||
if norm["question"] and norm["choices"] and norm["correct_choice"]["label"]:
|
||||
items.append(norm)
|
||||
if not items:
|
||||
raise ValueError(f"No valid LiveMathematicianBench items loaded from {data_path}")
|
||||
return items
|
||||
|
||||
|
||||
# ── Dataloader ───────────────────────────────────────────────────────────
|
||||
|
||||
class LiveMathematicianBenchDataLoader(SplitDataLoader):
|
||||
"""LiveMathematicianBench dataloader with per-seed choice shuffling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "ratio",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
shuffle_choices: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
self.shuffle_choices = shuffle_choices
|
||||
self._task_types: list[str] = []
|
||||
|
||||
def load_raw_items(self, data_path: str) -> list[dict]:
|
||||
return load_items(data_path)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
all_items = self.train_items + self.val_items + self.test_items
|
||||
task_types: set[str] = set()
|
||||
for item in all_items:
|
||||
for name in item.get("theorem_type", []):
|
||||
if name:
|
||||
task_types.add(name)
|
||||
self._task_types = sorted(task_types)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return list(self._task_types)
|
||||
|
||||
# ── Choice shuffling ─────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _item_shuffle_seed(item_id: str, seed: int) -> int:
|
||||
digest = hashlib.sha256(f"{seed}:{item_id}".encode("utf-8")).hexdigest()
|
||||
return int(digest[:16], 16)
|
||||
|
||||
def _shuffle_item_choices(self, item: dict, seed: int) -> dict:
|
||||
if not self.shuffle_choices:
|
||||
return {
|
||||
**item,
|
||||
"choices": [dict(c) for c in item["choices"]],
|
||||
"correct_choice": dict(item["correct_choice"]),
|
||||
}
|
||||
|
||||
shuffled_choices = [dict(c) for c in item["choices"]]
|
||||
rng = random.Random(self._item_shuffle_seed(str(item["id"]), seed))
|
||||
rng.shuffle(shuffled_choices)
|
||||
|
||||
original_correct = _normalize_label(item["correct_choice"]["label"])
|
||||
remapped_choices: list[dict] = []
|
||||
new_correct_choice = dict(item["correct_choice"])
|
||||
|
||||
for idx, choice in enumerate(shuffled_choices):
|
||||
new_label = _CHOICE_LABELS[idx]
|
||||
old_label = _normalize_label(choice["label"])
|
||||
remapped_choices.append({"label": new_label, "text": choice["text"]})
|
||||
if old_label == original_correct:
|
||||
new_correct_choice = {"label": new_label, "text": choice["text"]}
|
||||
|
||||
transformed = dict(item)
|
||||
transformed["choices"] = remapped_choices
|
||||
transformed["correct_choice"] = new_correct_choice
|
||||
return transformed
|
||||
|
||||
def _materialize_batch(self, items: list[dict], seed: int) -> list[dict]:
|
||||
return [self._shuffle_item_choices(item, seed) for item in items]
|
||||
|
||||
# ── Batch construction (override for choice shuffling) ───────────────
|
||||
|
||||
def plan_train_epoch(
|
||||
self,
|
||||
*,
|
||||
epoch: int,
|
||||
steps_per_epoch: int,
|
||||
accumulation: int,
|
||||
batch_size: int,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> list[BatchSpec]:
|
||||
"""Build a shuffled epoch while preserving per-batch choice shuffling."""
|
||||
epoch_rng = random.Random(seed + epoch * 1000)
|
||||
items = list(self.train_items)
|
||||
epoch_rng.shuffle(items)
|
||||
|
||||
total_batches = steps_per_epoch * accumulation
|
||||
if total_batches <= 0:
|
||||
return []
|
||||
|
||||
batches: list[BatchSpec] = []
|
||||
cursor = 0
|
||||
for batch_idx in range(total_batches):
|
||||
batch_seed = seed + epoch * 1000 + batch_idx + 1
|
||||
batch_items = items[cursor: cursor + batch_size]
|
||||
cursor += len(batch_items)
|
||||
|
||||
if not batch_items and items:
|
||||
refill_rng = random.Random(batch_seed)
|
||||
batch_items = list(items)
|
||||
refill_rng.shuffle(batch_items)
|
||||
batch_items = batch_items[:batch_size]
|
||||
|
||||
batch_items = self._materialize_batch(batch_items, batch_seed)
|
||||
batches.append(
|
||||
BatchSpec(
|
||||
phase="train",
|
||||
split="train",
|
||||
seed=batch_seed,
|
||||
batch_size=len(batch_items),
|
||||
payload=batch_items,
|
||||
)
|
||||
)
|
||||
|
||||
return batches
|
||||
|
||||
def build_train_batch(self, batch_size: int, seed: int, **kwargs) -> BatchSpec:
|
||||
rng = random.Random(seed)
|
||||
items = list(self.train_items)
|
||||
rng.shuffle(items)
|
||||
items = self._materialize_batch(items[:batch_size], seed)
|
||||
return BatchSpec(
|
||||
phase="train",
|
||||
split="train",
|
||||
seed=seed,
|
||||
batch_size=len(items),
|
||||
payload=items,
|
||||
)
|
||||
|
||||
def build_eval_batch(
|
||||
self,
|
||||
env_num: int,
|
||||
split: str,
|
||||
seed: int,
|
||||
**kwargs,
|
||||
) -> BatchSpec:
|
||||
items = self.get_split_items(split)
|
||||
if env_num and env_num < len(items):
|
||||
items = items[:env_num]
|
||||
items = self._materialize_batch(items, seed)
|
||||
return BatchSpec(
|
||||
phase="eval",
|
||||
split=split,
|
||||
seed=seed,
|
||||
batch_size=len(items),
|
||||
payload=items,
|
||||
)
|
||||
62
reflact/envs/livemathematicianbench/evaluator.py
Normal file
62
reflact/envs/livemathematicianbench/evaluator.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""LiveMathematicianBench evaluation helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
matches = re.findall(r"<answer>(.*?)</answer>", text, re.DOTALL | re.IGNORECASE)
|
||||
if matches:
|
||||
return matches[-1].strip()
|
||||
lines = [ln.strip() for ln in text.strip().splitlines() if ln.strip()]
|
||||
if lines:
|
||||
return lines[-1]
|
||||
return text.strip()
|
||||
|
||||
|
||||
def normalize_label(text: str) -> str:
|
||||
return str(text).strip().upper().rstrip(".):")
|
||||
|
||||
|
||||
def parse_choice_label(prediction_text: str, choices: list[dict]) -> str:
|
||||
answer = extract_answer(prediction_text)
|
||||
label = normalize_label(answer)
|
||||
valid_labels = {normalize_label(choice.get("label", "")) for choice in choices}
|
||||
if label in valid_labels:
|
||||
return label
|
||||
|
||||
answer_lower = answer.lower()
|
||||
for choice in choices:
|
||||
choice_label = normalize_label(choice.get("label", ""))
|
||||
choice_text = str(choice.get("text", "")).strip()
|
||||
if choice_text and choice_text.lower() == answer_lower:
|
||||
return choice_label
|
||||
|
||||
first_token = normalize_label(answer.split()[0]) if answer.split() else ""
|
||||
if first_token in valid_labels:
|
||||
return first_token
|
||||
return label
|
||||
|
||||
|
||||
def evaluate(prediction_text: str, correct_choice: dict, choices: list[dict]) -> dict:
|
||||
predicted_label = parse_choice_label(prediction_text, choices)
|
||||
correct_label = normalize_label(correct_choice.get("label", ""))
|
||||
predicted_text = ""
|
||||
correct_text = str(correct_choice.get("text", "")).strip()
|
||||
|
||||
for choice in choices:
|
||||
if normalize_label(choice.get("label", "")) == predicted_label:
|
||||
predicted_text = str(choice.get("text", "")).strip()
|
||||
break
|
||||
|
||||
is_correct = float(predicted_label == correct_label)
|
||||
return {
|
||||
"em": is_correct,
|
||||
"f1": is_correct,
|
||||
"sub_em": is_correct,
|
||||
"predicted_answer": predicted_label or extract_answer(prediction_text),
|
||||
"predicted_label": predicted_label,
|
||||
"predicted_text": predicted_text,
|
||||
"correct_label": correct_label,
|
||||
"correct_text": correct_text,
|
||||
}
|
||||
37
reflact/envs/livemathematicianbench/prompts/analyst_error.md
Normal file
37
reflact/envs/livemathematicianbench/prompts/analyst_error.md
Normal file
@@ -0,0 +1,37 @@
|
||||
You are an expert failure-analysis agent for theorem-grounded mathematical multiple-choice questions.
|
||||
|
||||
You will be given MULTIPLE failed trajectories from a single minibatch and the current skill document.
|
||||
Each trajectory includes the student's response and an evaluation result showing the predicted option
|
||||
versus the correct option.
|
||||
|
||||
Your job is to identify COMMON reasoning failures across the batch and propose concise skill edits.
|
||||
|
||||
## Failure Type Categories
|
||||
- **quantifier_miss**: the agent missed exact quantifiers, scope, or existence/uniqueness conditions
|
||||
- **strength_mismatch**: the agent preferred a weaker or stronger statement than what was proved
|
||||
- **condition_miss**: the agent ignored hypotheses, equality cases, or domain restrictions
|
||||
- **option_confusion**: the agent confused similar answer choices or failed to compare them exactly
|
||||
- **other**: none of the above
|
||||
|
||||
## Rules
|
||||
1. Focus on patterns that recur across the minibatch.
|
||||
2. Prefer edits that improve exact choice discrimination, not theorem-specific memorization.
|
||||
3. Do not hardcode paper-specific content.
|
||||
4. Only patch gaps not already covered by the skill.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number>,
|
||||
"failure_summary": [
|
||||
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
|
||||
],
|
||||
"patch": {
|
||||
"reasoning": "<why these edits address the common failures>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
You are an expert success-pattern analyst for theorem-grounded mathematical multiple-choice questions.
|
||||
|
||||
You will be given MULTIPLE successful trajectories from a minibatch and the current skill document.
|
||||
Identify generalizable behavior patterns that are genuinely helping the agent choose the exact correct option.
|
||||
|
||||
## Rules
|
||||
- Focus on broadly useful reasoning behaviors.
|
||||
- Prefer patterns about exact comparison of options, quantifiers, and equality conditions.
|
||||
- Do not add theorem-specific facts.
|
||||
- "edits" may be empty if the skill already captures the useful patterns.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number>,
|
||||
"success_patterns": ["<pattern 1>", "<pattern 2>"],
|
||||
"patch": {
|
||||
"reasoning": "<why these patterns matter>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
23
reflact/envs/livemathematicianbench/prompts/deep_probe.md
Normal file
23
reflact/envs/livemathematicianbench/prompts/deep_probe.md
Normal file
@@ -0,0 +1,23 @@
|
||||
You are an expert diagnostic-probe designer for theorem-grounded mathematical multiple-choice tasks.
|
||||
|
||||
You will be shown representative trajectories, the current student skill, and the student's original prompt context.
|
||||
Design one SMALL diagnostic instruction that exposes the student's intermediate judgment without materially changing the original scaffold.
|
||||
|
||||
## Hard Constraints
|
||||
1. Do NOT substantially change the original scaffold.
|
||||
2. Do NOT prescribe a new multi-step theorem-solving procedure.
|
||||
3. Do NOT ask for a full proof, full chain-of-thought, or exhaustive option-by-option derivation.
|
||||
4. Ask only for a short readout of the signals already behind the student's current answer.
|
||||
5. Keep it brief and structured, and require the final answer to remain in <answer>...</answer>.
|
||||
|
||||
## Good Probe Targets
|
||||
- top choice and runner-up
|
||||
- decisive constraint
|
||||
- why the runner-up was rejected
|
||||
- strongest-vs-weaker discrimination signal
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"reasoning": "<why this probe is informative>",
|
||||
"probe_instruction": "<the exact instruction text to append to the student prompt>"
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
You are an expert diagnostic-probe designer for theorem-grounded mathematical multiple-choice tasks executed through a Codex trace.
|
||||
|
||||
You will be shown representative trajectories, the current student skill, the student's original prompt context, hidden reference fields, and numbered Codex trace steps.
|
||||
Choose exactly one trajectory and one probe point. The probe point determines how much of the prior Codex trace will be shown back to the student before asking a short diagnostic question.
|
||||
|
||||
## Hard Constraints
|
||||
1. Do NOT reveal or paraphrase the hidden reference directly to the student.
|
||||
2. Do NOT prescribe a new full solving procedure.
|
||||
3. Do NOT ask for a full proof, full chain-of-thought, or exhaustive option-by-option derivation.
|
||||
4. Ask only for a short readout of the signal that should already exist at that point in the student's process.
|
||||
5. The probe instruction must explicitly request a short <analysis>...</analysis> block before the final <answer>...</answer>.
|
||||
6. Select a probe point that is informative about theorem choice, decisive constraint, option elimination, or why a stronger/weaker option should be rejected.
|
||||
|
||||
## Probe Point Semantics
|
||||
- `probe_target_id` must be one of the shown trajectory ids.
|
||||
- `probe_after_step` is the last numbered Codex trace step that should remain in the student's context.
|
||||
- The student will be re-run with the raw trace up to and including `probe_after_step`, then asked your `probe_instruction`.
|
||||
- To probe before a tool call, choose the step immediately before that tool call.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"reasoning": "<why this trajectory and probe point expose the student's intermediate state>",
|
||||
"probe_target_id": "<trajectory id>",
|
||||
"probe_after_step": <integer step number>,
|
||||
"probe_instruction": "<the exact instruction text to append to the student's prompt>"
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
You are an expert mathematical reasoning agent solving multiple-choice questions.
|
||||
|
||||
{skill_section}## Task Format
|
||||
You will receive one mathematics multiple-choice question and its answer choices.
|
||||
Reason carefully about quantifiers, hypotheses, extremal wording, and exact equality conditions.
|
||||
|
||||
## Answer Format
|
||||
Think step by step, then provide your final answer inside <answer>...</answer> tags.
|
||||
Inside the tags, output only the single choice label, such as A or C.
|
||||
|
||||
Example:
|
||||
<answer>B</answer>
|
||||
4
reflact/envs/livemathematicianbench/reflect.py
Normal file
4
reflact/envs/livemathematicianbench/reflect.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""LiveMathematicianBench Reflect stage.
|
||||
|
||||
Prompts are now loaded from .md files by the base adapter.
|
||||
"""
|
||||
401
reflact/envs/livemathematicianbench/rollout.py
Normal file
401
reflact/envs/livemathematicianbench/rollout.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""LiveMathematicianBench rollout — theorem-grounded math MCQ agent."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||
|
||||
from reflact.envs.livemathematicianbench.evaluator import evaluate
|
||||
from reflact.model import chat_student, get_student_backend, is_student_exec_backend
|
||||
from reflact.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
|
||||
from reflact.prompts import load_prompt
|
||||
|
||||
def _build_system(skill_content: str) -> str:
|
||||
if skill_content.strip():
|
||||
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
|
||||
else:
|
||||
skill_section = ""
|
||||
return load_prompt("rollout_system", env="livemathematicianbench").format(skill_section=skill_section)
|
||||
|
||||
|
||||
def _format_choices(choices: list[dict]) -> str:
|
||||
return "\n".join(
|
||||
f"{choice['label']}. {choice['text']}"
|
||||
for choice in choices
|
||||
)
|
||||
|
||||
|
||||
def _build_user(
|
||||
item: dict,
|
||||
*,
|
||||
use_theorem: bool = False,
|
||||
use_sketch: bool = False,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> str:
|
||||
parts = [f"## Question\n{item['question']}", f"## Choices\n{_format_choices(item['choices'])}"]
|
||||
if use_theorem and item.get("theorem"):
|
||||
parts.append(f"## Theorem\n{item['theorem']}")
|
||||
if use_sketch and item.get("sketch"):
|
||||
parts.append(f"## Proof Sketch\n{item['sketch']}")
|
||||
if diagnostic_trace_context.strip():
|
||||
parts.append(
|
||||
"## Previous Codex Trace Snapshot\n"
|
||||
"This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n"
|
||||
f"{diagnostic_trace_context.strip()}"
|
||||
)
|
||||
if diagnostic_mode and diagnostic_instruction.strip():
|
||||
parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _build_codex_skill(skill_content: str) -> str:
|
||||
return render_skill_md(
|
||||
skill_content,
|
||||
description="Dynamic ReflACT skill for solving the current LiveMathematicianBench multiple-choice question.",
|
||||
preamble=(
|
||||
"Use this skill when solving the current math multiple-choice question.\n"
|
||||
"Inspect the option wording carefully and output only the final choice label inside <answer>...</answer>."
|
||||
),
|
||||
)
|
||||
|
||||
def _run_codex_once(
|
||||
*,
|
||||
pred_dir: str,
|
||||
skill_content: str,
|
||||
item: dict,
|
||||
model: str,
|
||||
timeout: int,
|
||||
use_theorem: bool = False,
|
||||
use_sketch: bool = False,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
previous_response: str = "",
|
||||
) -> tuple[str, str, str, str]:
|
||||
user = _build_user(
|
||||
item,
|
||||
use_theorem=use_theorem,
|
||||
use_sketch=use_sketch,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
task_parts = [user]
|
||||
if previous_response:
|
||||
task_parts.append(
|
||||
"## Previous Attempt\n"
|
||||
f"{previous_response}\n\n"
|
||||
"Re-evaluate the exact option wording. If needed, correct it."
|
||||
)
|
||||
task_text = "\n\n".join(task_parts)
|
||||
skill_md = _build_codex_skill(skill_content)
|
||||
work_dir = os.path.join(pred_dir, "codex_exec")
|
||||
prepare_workspace(work_dir=work_dir, skill_md=skill_md, task_text=task_text)
|
||||
prompt = (
|
||||
"Use the `reflact-student` skill available in this workspace.\n"
|
||||
"Read `task.md` and solve the multiple-choice problem.\n"
|
||||
"Output only the final choice label inside <answer>...</answer>."
|
||||
)
|
||||
final_message, raw = run_student_exec(
|
||||
work_dir=work_dir,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
)
|
||||
return final_message or raw, raw, skill_md, task_text
|
||||
|
||||
|
||||
def process_one(
|
||||
item: dict,
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
use_theorem: bool = False,
|
||||
use_sketch: bool = False,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
exec_timeout: int = 300,
|
||||
) -> dict:
|
||||
item_id = str(item["id"])
|
||||
result = {
|
||||
"id": item_id,
|
||||
"question": item["question"],
|
||||
"task_type": item.get("theorem_type", ["math_mcq"])[0] if item.get("theorem_type") else "math_mcq",
|
||||
"hard": 0,
|
||||
"soft": 0.0,
|
||||
"predicted_answer": "",
|
||||
"predicted_label": "",
|
||||
"predicted_text": "",
|
||||
"correct_label": item["correct_choice"]["label"],
|
||||
"correct_text": item["correct_choice"]["text"],
|
||||
"response": "",
|
||||
"fail_reason": "",
|
||||
"agent_ok": False,
|
||||
"n_turns": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
pred_dir = os.path.join(out_root, "predictions", item_id)
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
|
||||
if is_student_exec_backend():
|
||||
from reflact.model import azure_openai as _llm
|
||||
|
||||
conversation: list[dict] = []
|
||||
response = ""
|
||||
system = ""
|
||||
user = ""
|
||||
for turn in range(max_turns):
|
||||
response, raw, system, user = _run_codex_once(
|
||||
pred_dir=pred_dir,
|
||||
skill_content=skill_content,
|
||||
item=item,
|
||||
model=_llm.STUDENT_DEPLOYMENT,
|
||||
timeout=exec_timeout,
|
||||
use_theorem=use_theorem,
|
||||
use_sketch=use_sketch,
|
||||
diagnostic_mode=diagnostic_mode if turn == 0 else False,
|
||||
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
|
||||
diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "",
|
||||
previous_response=response if turn > 0 else "",
|
||||
)
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": response})
|
||||
if "<answer>" in response.lower():
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation)
|
||||
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user)
|
||||
|
||||
eval_result = evaluate(response, item["correct_choice"], item["choices"])
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["predicted_label"] = eval_result["predicted_label"]
|
||||
result["predicted_text"] = eval_result["predicted_text"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"MCQ=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
|
||||
f"but expected '{eval_result['correct_label']}'"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted label: {eval_result['predicted_label']!r}\n"
|
||||
f"Predicted text: {eval_result['predicted_text']!r}\n"
|
||||
f"Correct label: {eval_result['correct_label']!r}\n"
|
||||
f"Correct text: {eval_result['correct_text']!r}\n"
|
||||
f"Exact Match: {eval_result['em']}"
|
||||
)
|
||||
conversation.append({"role": "system", "content": eval_detail})
|
||||
with open(os.path.join(pred_dir, "conversation.json"), "w") as f:
|
||||
json.dump(conversation, f, ensure_ascii=False, indent=2)
|
||||
return result
|
||||
|
||||
system = _build_system(skill_content)
|
||||
user = _build_user(
|
||||
item,
|
||||
use_theorem=use_theorem,
|
||||
use_sketch=use_sketch,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
conversation: list[dict] = []
|
||||
response = ""
|
||||
|
||||
for turn in range(max_turns):
|
||||
if turn == 0:
|
||||
resp_text, _ = chat_student(
|
||||
system=system,
|
||||
user=user,
|
||||
max_completion_tokens=16384,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
else:
|
||||
refinement = (
|
||||
f"Your previous answer was:\n{response}\n\n"
|
||||
"Re-evaluate the exact option wording. If needed, correct it. "
|
||||
"Output only the final choice label inside <answer>...</answer>."
|
||||
)
|
||||
resp_text, _ = chat_student(
|
||||
system=system,
|
||||
user=refinement,
|
||||
max_completion_tokens=16384,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
response = resp_text
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
|
||||
if "<answer>" in resp_text.lower():
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation)
|
||||
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user)
|
||||
|
||||
eval_result = evaluate(response, item["correct_choice"], item["choices"])
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["predicted_label"] = eval_result["predicted_label"]
|
||||
result["predicted_text"] = eval_result["predicted_text"]
|
||||
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"MCQ=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
|
||||
f"but expected '{eval_result['correct_label']}'"
|
||||
)
|
||||
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted label: {eval_result['predicted_label']!r}\n"
|
||||
f"Predicted text: {eval_result['predicted_text']!r}\n"
|
||||
f"Correct label: {eval_result['correct_label']!r}\n"
|
||||
f"Correct text: {eval_result['correct_text']!r}\n"
|
||||
f"Exact Match: {eval_result['em']}"
|
||||
)
|
||||
conversation.append({"role": "system", "content": eval_detail})
|
||||
|
||||
with open(os.path.join(pred_dir, "conversation.json"), "w") as f:
|
||||
json.dump(conversation, f, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
result["fail_reason"] = f"error: {e}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def run_batch(
|
||||
items: list[dict],
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
exec_timeout: int = 300,
|
||||
workers: int = 64,
|
||||
use_theorem: bool = False,
|
||||
use_sketch: bool = False,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context_by_id: dict[str, str] | None = None,
|
||||
task_timeout: int = 600,
|
||||
) -> list[dict]:
|
||||
task_timeout = max(int(task_timeout), int(exec_timeout) + 60)
|
||||
results_path = os.path.join(out_root, "results.jsonl")
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
|
||||
done_ids: set[str] = set()
|
||||
existing: list[dict] = []
|
||||
if os.path.exists(results_path):
|
||||
with open(results_path) as f:
|
||||
for line in f:
|
||||
try:
|
||||
r = json.loads(line)
|
||||
done_ids.add(str(r["id"]))
|
||||
existing.append(r)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
pending = [it for it in items if str(it["id"]) not in done_ids]
|
||||
if not pending:
|
||||
return existing
|
||||
|
||||
results = list(existing)
|
||||
|
||||
started_at: dict[str, float] = {}
|
||||
|
||||
def _run_one(it: dict) -> dict:
|
||||
started_at[str(it["id"])] = time.time()
|
||||
return process_one(
|
||||
it,
|
||||
out_root,
|
||||
skill_content,
|
||||
max_turns=max_turns,
|
||||
exec_timeout=exec_timeout,
|
||||
use_theorem=use_theorem,
|
||||
use_sketch=use_sketch,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(it["id"]), ""),
|
||||
)
|
||||
|
||||
def _timeout_result(it: dict) -> dict:
|
||||
correct = it.get("correct_choice") or {}
|
||||
return {
|
||||
"id": str(it["id"]),
|
||||
"question": it.get("question", ""),
|
||||
"task_type": it.get("theorem_type", ["math_mcq"])[0] if it.get("theorem_type") else "math_mcq",
|
||||
"hard": 0,
|
||||
"soft": 0.0,
|
||||
"predicted_answer": "",
|
||||
"predicted_label": "",
|
||||
"predicted_text": "",
|
||||
"correct_label": correct.get("label", ""),
|
||||
"correct_text": correct.get("text", ""),
|
||||
"response": "",
|
||||
"fail_reason": f"task-timeout-{task_timeout}s",
|
||||
"agent_ok": False,
|
||||
"n_turns": 0,
|
||||
}
|
||||
|
||||
def _error_result(it: dict, exc: Exception) -> dict:
|
||||
res = _timeout_result(it)
|
||||
res["fail_reason"] = f"error: {type(exc).__name__}: {exc}"
|
||||
return res
|
||||
|
||||
with open(results_path, "a") as outf:
|
||||
ex = ThreadPoolExecutor(max_workers=workers)
|
||||
try:
|
||||
futs = {
|
||||
ex.submit(_run_one, it): it
|
||||
for it in pending
|
||||
}
|
||||
pending_futs = set(futs)
|
||||
while pending_futs:
|
||||
done, _ = wait(pending_futs, timeout=5, return_when=FIRST_COMPLETED)
|
||||
now = time.time()
|
||||
timed_out = [
|
||||
fut for fut in pending_futs - done
|
||||
if str(futs[fut]["id"]) in started_at
|
||||
and now - started_at[str(futs[fut]["id"])] >= task_timeout
|
||||
]
|
||||
for fut in done:
|
||||
pending_futs.remove(fut)
|
||||
item = futs[fut]
|
||||
try:
|
||||
res = fut.result()
|
||||
except Exception as e: # noqa: BLE001
|
||||
res = _error_result(item, e)
|
||||
results.append(res)
|
||||
outf.write(json.dumps(res, ensure_ascii=False) + "\n")
|
||||
outf.flush()
|
||||
for fut in timed_out:
|
||||
pending_futs.remove(fut)
|
||||
res = _timeout_result(futs[fut])
|
||||
results.append(res)
|
||||
outf.write(json.dumps(res, ensure_ascii=False) + "\n")
|
||||
outf.flush()
|
||||
finally:
|
||||
ex.shutdown(wait=False, cancel_futures=True)
|
||||
|
||||
return results
|
||||
16
reflact/envs/livemathematicianbench/skills/initial.md
Normal file
16
reflact/envs/livemathematicianbench/skills/initial.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# Live Mathematical MCQ Heuristics
|
||||
|
||||
## Option Comparison
|
||||
- Compare all options before committing. The correct choice is often the strongest statement justified by the question, while nearby distractors are weaker, overstrong, or miss an equality case.
|
||||
- Track exact quantifiers such as "there exists", "for every", "if and only if", and "exactly when".
|
||||
|
||||
## Theorem-Level Precision
|
||||
- Check whether an option weakens the conclusion by dropping a characterization, equality clause, or full equivalence.
|
||||
- Check whether an option overstates the theorem by upgrading regularity, removing scale restrictions, or changing an existential statement into a universal one.
|
||||
|
||||
## Hypotheses
|
||||
- Verify the hypotheses and domain carefully. Distractors often keep the theorem shape but alter the required assumptions.
|
||||
- Pay close attention to equality cases, extremal conditions, and whether a result applies to the full family or only a restricted subfamily.
|
||||
|
||||
## Final Answer
|
||||
- Output the final answer as the single option label only.
|
||||
5
reflact/envs/mathverse/__init__.py
Normal file
5
reflact/envs/mathverse/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""MathVerse environment package."""
|
||||
|
||||
from reflact.envs.mathverse.adapter import MathVerseAdapter
|
||||
|
||||
__all__ = ["MathVerseAdapter"]
|
||||
280
reflact/envs/mathverse/adapter.py
Normal file
280
reflact/envs/mathverse/adapter.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""MathVerse environment adapter for ReflACT."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from reflact.datasets.base import BatchSpec
|
||||
from reflact.envs.base import EnvAdapter
|
||||
from reflact.envs.mathverse.dataloader import MathVerseDataLoader
|
||||
from reflact.envs.mathverse.rollout import run_batch
|
||||
from reflact.gradient.deep_probe import generate_deep_probe_instruction
|
||||
from reflact.gradient.reflect import run_minibatch_reflect
|
||||
from reflact.model import get_student_backend
|
||||
|
||||
|
||||
class MathVerseAdapter(EnvAdapter):
|
||||
"""MathVerse adapter."""
|
||||
|
||||
def build_reference_text(self, item: dict) -> str:
|
||||
if not self.use_text_dominant_reference:
|
||||
return ""
|
||||
question = str(item.get("text_dominant_question") or "").strip()
|
||||
if not question:
|
||||
return ""
|
||||
return f"## Reference Full Question\n{question}"
|
||||
|
||||
def get_reference_metadata(self, item: dict) -> dict:
|
||||
if not self.use_text_dominant_reference:
|
||||
return {"fields": [], "preview": ""}
|
||||
question = str(item.get("text_dominant_question") or "").strip()
|
||||
if not question:
|
||||
return {"fields": [], "preview": ""}
|
||||
return {
|
||||
"fields": ["text_dominant_question"],
|
||||
"preview": question[:400],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_root: str = "",
|
||||
problem_version: str = "Text Lite",
|
||||
use_text_dominant_reference: bool = False,
|
||||
max_turns: int = 1,
|
||||
workers: int = 16,
|
||||
analyst_workers: int = 16,
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
image_detail: str = "auto",
|
||||
judge_model: str = "gpt-5.4",
|
||||
judge_max_completion_tokens: int = 256,
|
||||
judge_retries: int = 5,
|
||||
use_deep_reflect: bool = False,
|
||||
deep_reflect_failures: int = 4,
|
||||
deep_reflect_successes: int = 2,
|
||||
) -> None:
|
||||
self.max_turns = max_turns
|
||||
self.workers = workers
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
self.edit_budget = edit_budget
|
||||
self.image_detail = image_detail
|
||||
self.judge_model = judge_model
|
||||
self.judge_max_completion_tokens = judge_max_completion_tokens
|
||||
self.judge_retries = judge_retries
|
||||
self.problem_version = problem_version
|
||||
self.use_text_dominant_reference = use_text_dominant_reference
|
||||
self.use_deep_reflect = use_deep_reflect
|
||||
self.deep_reflect_failures = deep_reflect_failures
|
||||
self.deep_reflect_successes = deep_reflect_successes
|
||||
self.dataloader = MathVerseDataLoader(
|
||||
split_dir=split_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
data_root=data_root,
|
||||
problem_version=problem_version,
|
||||
)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
self.dataloader.setup(cfg)
|
||||
|
||||
def get_dataloader(self):
|
||||
return self.dataloader
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
return list(batch.payload or [])
|
||||
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def rollout(
|
||||
self,
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict]:
|
||||
items: list[dict] = env_manager
|
||||
return run_batch(
|
||||
items=items,
|
||||
out_root=out_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
workers=self.workers,
|
||||
image_detail=self.image_detail,
|
||||
judge_model=self.judge_model,
|
||||
judge_max_completion_tokens=self.judge_max_completion_tokens,
|
||||
judge_retries=self.judge_retries,
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
|
||||
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),
|
||||
)
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
if not self.use_deep_reflect:
|
||||
return []
|
||||
|
||||
env_manager = kwargs.get("env_manager")
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
selected_items = self.select_representative_items(
|
||||
results,
|
||||
env_manager if isinstance(env_manager, list) else None,
|
||||
n_failures=self.deep_reflect_failures,
|
||||
n_successes=self.deep_reflect_successes,
|
||||
seed=random_seed,
|
||||
)
|
||||
if not selected_items:
|
||||
return []
|
||||
|
||||
selected_ids = {str(item["id"]) for item in selected_items}
|
||||
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
|
||||
selected_examples = self.attach_reference_context(selected_results, selected_items)
|
||||
codex_backend = get_student_backend() == "codex_exec"
|
||||
if codex_backend:
|
||||
selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir)
|
||||
selected_metadata = []
|
||||
ref_count = 0
|
||||
for item in selected_items:
|
||||
meta = self.get_reference_metadata(item)
|
||||
if meta["fields"]:
|
||||
ref_count += 1
|
||||
record = {
|
||||
"id": str(item["id"]),
|
||||
"task_type": str(item.get("task_type") or item.get("question_type") or "mathverse"),
|
||||
"reference_fields": meta["fields"],
|
||||
"reference_preview": meta["preview"],
|
||||
}
|
||||
if codex_backend:
|
||||
record["codex_probe_step_count"] = int(
|
||||
next(
|
||||
(row.get("codex_probe_step_count", 0) for row in selected_examples if str(row.get("id")) == str(item["id"])),
|
||||
0,
|
||||
)
|
||||
)
|
||||
selected_metadata.append(record)
|
||||
|
||||
deep_dir = os.path.join(out_dir, "deep_reflect")
|
||||
rollout_dir = os.path.join(deep_dir, "rollout")
|
||||
patches_dir = os.path.join(deep_dir, "patches")
|
||||
os.makedirs(deep_dir, exist_ok=True)
|
||||
print(
|
||||
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
|
||||
f"reference_fields=text_dominant_question({ref_count}/{len(selected_items)})"
|
||||
)
|
||||
probe = generate_deep_probe_instruction(
|
||||
skill_content=skill_content,
|
||||
items=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
)
|
||||
if not probe:
|
||||
return []
|
||||
|
||||
targeted_items = selected_items
|
||||
diagnostic_trace_context_by_id: dict[str, str] | None = None
|
||||
if codex_backend:
|
||||
targeted_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target(
|
||||
selected_items=selected_items,
|
||||
selected_examples=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
probe=probe,
|
||||
)
|
||||
|
||||
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
**probe,
|
||||
"reference_summary": {
|
||||
"selected_count": len(selected_items),
|
||||
"field_counts": {
|
||||
"text_dominant_question": ref_count,
|
||||
},
|
||||
},
|
||||
"selected_examples": selected_metadata,
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
deep_results = run_batch(
|
||||
items=targeted_items,
|
||||
out_root=rollout_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
workers=min(self.workers, max(len(targeted_items), 1)),
|
||||
image_detail=self.image_detail,
|
||||
judge_model=self.judge_model,
|
||||
judge_max_completion_tokens=self.judge_max_completion_tokens,
|
||||
judge_retries=self.judge_retries,
|
||||
diagnostic_mode=True,
|
||||
diagnostic_instruction=probe["probe_instruction"],
|
||||
diagnostic_trace_context_by_id=diagnostic_trace_context_by_id,
|
||||
)
|
||||
deep_results = self.attach_reference_context(deep_results, targeted_items)
|
||||
return run_minibatch_reflect(
|
||||
results=deep_results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=os.path.join(rollout_dir, "predictions"),
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return self.dataloader.get_task_types()
|
||||
228
reflact/envs/mathverse/dataloader.py
Normal file
228
reflact/envs/mathverse/dataloader.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""MathVerse task dataloader."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from reflact.datasets.base import SplitDataLoader
|
||||
|
||||
|
||||
_CHOICE_LABELS = ["A", "B", "C", "D", "E", "F", "G"]
|
||||
_CHOICE_BLOCK_RE = re.compile(r"\bChoices?\s*:\s*", re.IGNORECASE)
|
||||
_CHOICE_ITEM_RE = re.compile(r"([A-G])\s*[:.)]\s*(.*?)(?=(?:\s+[A-G]\s*[:.)])|$)", re.DOTALL)
|
||||
|
||||
|
||||
def _load_json(path: str) -> Any:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _normalize_space(text: Any) -> str:
|
||||
return re.sub(r"\s+", " ", str(text or "").strip())
|
||||
|
||||
|
||||
def _resolve_image_path(raw_path: str, *, data_root: str, source_path: str) -> str:
|
||||
candidates = []
|
||||
if raw_path:
|
||||
if os.path.isabs(raw_path):
|
||||
candidates.append(raw_path)
|
||||
else:
|
||||
if data_root:
|
||||
candidates.append(os.path.join(data_root, raw_path))
|
||||
candidates.append(os.path.join(data_root, "images", raw_path))
|
||||
candidates.append(os.path.join(os.path.dirname(source_path), raw_path))
|
||||
for candidate in candidates:
|
||||
if candidate and os.path.exists(candidate):
|
||||
return os.path.abspath(candidate)
|
||||
return ""
|
||||
|
||||
|
||||
def _split_question_and_choices(question: str) -> tuple[str, list[dict]]:
|
||||
text = str(question or "").strip()
|
||||
match = _CHOICE_BLOCK_RE.search(text)
|
||||
if not match:
|
||||
return text, []
|
||||
|
||||
stem = text[:match.start()].strip()
|
||||
choice_block = text[match.end():].strip()
|
||||
choices: list[dict] = []
|
||||
for idx, m in enumerate(_CHOICE_ITEM_RE.finditer(choice_block)):
|
||||
label = (m.group(1) or _CHOICE_LABELS[idx]).strip().upper()
|
||||
choice_text = _normalize_space(m.group(2))
|
||||
if choice_text:
|
||||
choices.append({"label": label, "text": choice_text})
|
||||
return stem or text, choices
|
||||
|
||||
|
||||
def _build_text_dominant_map(data_root: str) -> dict[str, str]:
|
||||
if not data_root:
|
||||
return {}
|
||||
candidates = [
|
||||
os.path.join(data_root, "testmini.json"),
|
||||
os.path.join(data_root, "data", "testmini.json"),
|
||||
]
|
||||
source_path = next((path for path in candidates if os.path.exists(path)), "")
|
||||
if not source_path:
|
||||
return {}
|
||||
|
||||
raw = _load_json(source_path)
|
||||
if not isinstance(raw, list):
|
||||
return {}
|
||||
|
||||
mapping: dict[str, str] = {}
|
||||
for item in raw:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if str(item.get("problem_version") or "").strip() != "Text Dominant":
|
||||
continue
|
||||
problem_index = str(item.get("problem_index") or "").strip()
|
||||
question = str(item.get("question") or "").strip()
|
||||
if problem_index and question:
|
||||
mapping[problem_index] = question
|
||||
return mapping
|
||||
|
||||
|
||||
def _normalize_item(
|
||||
item: dict,
|
||||
*,
|
||||
row_idx: int,
|
||||
source_path: str,
|
||||
data_root: str,
|
||||
problem_version: str,
|
||||
text_dominant_map: dict[str, str],
|
||||
) -> dict | None:
|
||||
raw_problem_version = str(item.get("problem_version") or "").strip()
|
||||
if problem_version and raw_problem_version and raw_problem_version != problem_version:
|
||||
return None
|
||||
|
||||
question = str(item.get("question") or "").strip()
|
||||
question_type = str(item.get("question_type") or "").strip()
|
||||
answer = str(item.get("answer") or "").strip()
|
||||
image_rel = str(item.get("image") or "").strip()
|
||||
image_path = _resolve_image_path(image_rel, data_root=data_root, source_path=source_path)
|
||||
if not answer or not image_path:
|
||||
return None
|
||||
|
||||
metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {}
|
||||
subject = str(metadata.get("subject") or "").strip()
|
||||
subfield = str(metadata.get("subfield") or "").strip()
|
||||
source = str(metadata.get("source") or "").strip()
|
||||
|
||||
question_stem, choices = _split_question_and_choices(question)
|
||||
is_choice = question_type == "multi-choice" or bool(choices)
|
||||
|
||||
correct_choice = {"label": "", "text": ""}
|
||||
if is_choice:
|
||||
label = str(answer).strip().upper().rstrip(".):")
|
||||
choice_text = ""
|
||||
for choice in choices:
|
||||
if choice["label"].upper() == label:
|
||||
choice_text = choice["text"]
|
||||
break
|
||||
correct_choice = {"label": label, "text": choice_text}
|
||||
|
||||
problem_index = str(item.get("problem_index") or "").strip()
|
||||
sample_index = str(item.get("sample_index") or row_idx + 1).strip()
|
||||
item_id = problem_index or sample_index
|
||||
task_type = subfield or subject or question_type or "mathverse"
|
||||
|
||||
return {
|
||||
"id": item_id,
|
||||
"sample_index": sample_index,
|
||||
"problem_index": problem_index,
|
||||
"problem_version": raw_problem_version or problem_version,
|
||||
"question": question,
|
||||
"question_stem": question_stem,
|
||||
"question_for_eval": str(item.get("question_for_eval") or question).strip(),
|
||||
"question_type": question_type or ("multi-choice" if is_choice else "free-form"),
|
||||
"is_choice": is_choice,
|
||||
"choices": choices,
|
||||
"correct_choice": correct_choice,
|
||||
"answer": answer,
|
||||
"gold_answers": [answer] if answer else [],
|
||||
"image_rel": image_rel,
|
||||
"image_path": image_path,
|
||||
"query_wo": str(item.get("query_wo") or "").strip(),
|
||||
"query_cot": str(item.get("query_cot") or "").strip(),
|
||||
"metadata": {
|
||||
"split": str(metadata.get("split") or "").strip(),
|
||||
"source": source,
|
||||
"subject": subject,
|
||||
"subfield": subfield,
|
||||
},
|
||||
"task_type": task_type,
|
||||
"source_path": os.path.abspath(source_path),
|
||||
"text_dominant_question": str(
|
||||
item.get("text_dominant_question")
|
||||
or text_dominant_map.get(problem_index, "")
|
||||
).strip(),
|
||||
}
|
||||
|
||||
|
||||
class MathVerseDataLoader(SplitDataLoader):
|
||||
"""MathVerse dataloader."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
data_root: str = "",
|
||||
problem_version: str = "Text Lite",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(split_dir=split_dir, seed=seed, limit=limit)
|
||||
self.data_root = data_root
|
||||
self.problem_version = problem_version
|
||||
self._task_types: list[str] = []
|
||||
self._text_dominant_map = _build_text_dominant_map(data_root)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
if not self.data_root:
|
||||
self.data_root = str(cfg.get("data_root") or "")
|
||||
if not self.problem_version:
|
||||
self.problem_version = str(cfg.get("problem_version") or "Text Lite")
|
||||
self._text_dominant_map = _build_text_dominant_map(self.data_root)
|
||||
super().setup(cfg)
|
||||
all_items = self.train_items + self.val_items + self.test_items
|
||||
task_types = {
|
||||
item.get("task_type") or item.get("question_type") or "mathverse"
|
||||
for item in all_items
|
||||
}
|
||||
self._task_types = sorted(str(x) for x in task_types if str(x).strip())
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return list(self._task_types)
|
||||
|
||||
def load_split_items(self, split_path: str) -> list[dict]:
|
||||
raw_items = super().load_split_items(split_path)
|
||||
source_path = next(
|
||||
(
|
||||
os.path.join(split_path, name)
|
||||
for name in sorted(os.listdir(split_path))
|
||||
if name.endswith(".json")
|
||||
),
|
||||
split_path,
|
||||
)
|
||||
items: list[dict] = []
|
||||
for row_idx, item in enumerate(raw_items):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
norm = _normalize_item(
|
||||
item,
|
||||
row_idx=row_idx,
|
||||
source_path=source_path,
|
||||
data_root=self.data_root,
|
||||
problem_version=self.problem_version,
|
||||
text_dominant_map=self._text_dominant_map,
|
||||
)
|
||||
if norm is not None:
|
||||
items.append(norm)
|
||||
if not items:
|
||||
raise ValueError(
|
||||
f"No valid MathVerse items loaded from {split_path} "
|
||||
f"for problem_version={self.problem_version!r}"
|
||||
)
|
||||
return items
|
||||
180
reflact/envs/mathverse/evaluator.py
Normal file
180
reflact/envs/mathverse/evaluator.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""MathVerse evaluation helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import string
|
||||
|
||||
from reflact.model import chat_with_deployment
|
||||
from reflact.prompts import load_prompt
|
||||
|
||||
|
||||
_EVAL_MODE = "mathverse_choice_or_judge_v1"
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
text = str(text or "").strip().lower()
|
||||
text = text.replace("\\,", " ")
|
||||
text = text.replace("\\ ", " ")
|
||||
text = "".join(ch for ch in text if ch not in string.punctuation)
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def normalize_math_text(text: str) -> str:
|
||||
text = str(text or "").strip()
|
||||
text = text.replace("$", "")
|
||||
text = text.replace("\\mathrm", "")
|
||||
text = text.replace("{", "")
|
||||
text = text.replace("}", "")
|
||||
text = text.replace("~", " ")
|
||||
text = text.replace("\\,", " ")
|
||||
text = text.replace("\\ ", " ")
|
||||
return " ".join(text.split()).lower()
|
||||
|
||||
|
||||
def extract_answer(text: str | None) -> str:
|
||||
raw = str(text or "").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
|
||||
tags = re.findall(r"<answer>\s*(.*?)\s*</answer>", raw, re.IGNORECASE | re.DOTALL)
|
||||
if tags:
|
||||
return tags[-1].strip()
|
||||
|
||||
boxed = re.findall(r"\\boxed\{(.*?)\}", raw, re.IGNORECASE | re.DOTALL)
|
||||
if boxed:
|
||||
return boxed[-1].strip()
|
||||
|
||||
lines = [ln.strip() for ln in raw.splitlines() if ln.strip()]
|
||||
if lines:
|
||||
return lines[-1]
|
||||
return raw
|
||||
|
||||
|
||||
def _judge_answer(
|
||||
*,
|
||||
item: dict,
|
||||
extracted_answer: str,
|
||||
judge_model: str,
|
||||
max_completion_tokens: int,
|
||||
retries: int,
|
||||
) -> dict:
|
||||
question = str(item.get("question_for_eval") or item.get("question") or "").strip()
|
||||
ground_truth = str(item.get("answer") or "").strip()
|
||||
raw, _ = chat_with_deployment(
|
||||
deployment=judge_model,
|
||||
system="You are a careful and strict mathematical answer evaluator.",
|
||||
user=load_prompt("judge", env="mathverse").format(
|
||||
question=question,
|
||||
groundtruth=ground_truth,
|
||||
modeloutput=extracted_answer,
|
||||
),
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=retries,
|
||||
stage="mathverse_judge",
|
||||
)
|
||||
response = str(raw).strip().lower()
|
||||
if "true" in response:
|
||||
correct = True
|
||||
elif "false" in response:
|
||||
correct = False
|
||||
else:
|
||||
correct = False
|
||||
return {
|
||||
"raw": raw,
|
||||
"correct": correct,
|
||||
"reason": response,
|
||||
"matched_gold": ground_truth if correct else "",
|
||||
}
|
||||
|
||||
|
||||
def evaluate_item(
|
||||
*,
|
||||
item: dict,
|
||||
prediction_text: str,
|
||||
judge_model: str,
|
||||
max_completion_tokens: int = 256,
|
||||
retries: int = 5,
|
||||
) -> dict:
|
||||
extracted = extract_answer(prediction_text)
|
||||
|
||||
if item.get("is_choice"):
|
||||
predicted_label = str(extracted).strip().upper().rstrip(".):")
|
||||
correct_label = str(item["correct_choice"].get("label") or "").strip().upper()
|
||||
predicted_text = ""
|
||||
for choice in item.get("choices") or []:
|
||||
if str(choice.get("label") or "").strip().upper() == predicted_label:
|
||||
predicted_text = str(choice.get("text") or "").strip()
|
||||
break
|
||||
hard = 1.0 if predicted_label == correct_label else 0.0
|
||||
return {
|
||||
"evaluation_mode": _EVAL_MODE,
|
||||
"predicted_answer": extracted,
|
||||
"predicted_label": predicted_label,
|
||||
"predicted_text": predicted_text,
|
||||
"correct_label": correct_label,
|
||||
"correct_text": str(item["correct_choice"].get("text") or "").strip(),
|
||||
"em": hard,
|
||||
"f1": hard,
|
||||
"sub_em": hard,
|
||||
"judge_raw": "",
|
||||
"judge_reason": "exact_label_match" if hard else "label_mismatch",
|
||||
"matched_gold": correct_label if hard else "",
|
||||
}
|
||||
|
||||
gold_answer = str(item.get("answer") or "").strip()
|
||||
pred_norm = normalize_math_text(extracted)
|
||||
gold_norm = normalize_math_text(gold_answer)
|
||||
if pred_norm and gold_norm and pred_norm == gold_norm:
|
||||
return {
|
||||
"evaluation_mode": _EVAL_MODE,
|
||||
"predicted_answer": extracted,
|
||||
"em": 1.0,
|
||||
"f1": 1.0,
|
||||
"sub_em": 1.0,
|
||||
"judge_raw": "",
|
||||
"judge_reason": "normalized_exact_match",
|
||||
"matched_gold": gold_answer,
|
||||
"string_f1": 1.0,
|
||||
}
|
||||
|
||||
judge = _judge_answer(
|
||||
item=item,
|
||||
extracted_answer=extracted,
|
||||
judge_model=judge_model,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=retries,
|
||||
)
|
||||
hard = 1.0 if judge["correct"] else 0.0
|
||||
pred_tokens = normalize_text(extracted).split()
|
||||
gold_tokens = normalize_text(gold_answer).split()
|
||||
overlap = 0
|
||||
gold_counts: dict[str, int] = {}
|
||||
for tok in gold_tokens:
|
||||
gold_counts[tok] = gold_counts.get(tok, 0) + 1
|
||||
for tok in pred_tokens:
|
||||
count = gold_counts.get(tok, 0)
|
||||
if count > 0:
|
||||
overlap += 1
|
||||
gold_counts[tok] = count - 1
|
||||
if pred_tokens and gold_tokens and overlap:
|
||||
precision = overlap / len(pred_tokens)
|
||||
recall = overlap / len(gold_tokens)
|
||||
string_f1 = 2 * precision * recall / (precision + recall)
|
||||
else:
|
||||
string_f1 = 0.0
|
||||
|
||||
return {
|
||||
"evaluation_mode": _EVAL_MODE,
|
||||
"predicted_answer": extracted,
|
||||
"em": hard,
|
||||
"f1": hard,
|
||||
"sub_em": hard,
|
||||
"judge_raw": judge["raw"],
|
||||
"judge_reason": judge["reason"],
|
||||
"matched_gold": judge["matched_gold"],
|
||||
"string_f1": string_f1,
|
||||
}
|
||||
|
||||
|
||||
def evaluation_mode() -> str:
|
||||
return _EVAL_MODE
|
||||
37
reflact/envs/mathverse/prompts/analyst_error.md
Normal file
37
reflact/envs/mathverse/prompts/analyst_error.md
Normal file
@@ -0,0 +1,37 @@
|
||||
You are an expert failure-analysis agent for visual mathematical reasoning problems.
|
||||
|
||||
You will be given MULTIPLE failed trajectories from a single minibatch and the current skill document.
|
||||
Each trajectory includes the student's response, the evaluation result, and sometimes a hidden reference
|
||||
containing the fuller Text Dominant version of the same problem.
|
||||
|
||||
Your job is to identify COMMON reasoning failures across the batch and propose concise skill edits.
|
||||
|
||||
## Failure Type Categories
|
||||
- **diagram_underuse**: the agent did not recover key constraints from the image
|
||||
- **constraint_drop**: the agent ignored a condition or relation that should guide the solution
|
||||
- **option_confusion**: the agent failed to discriminate between close answer choices
|
||||
- **format_miss**: the agent solved roughly correctly but returned the wrong final form, unit, or expression
|
||||
- **other**: none of the above
|
||||
|
||||
## Rules
|
||||
1. Focus on patterns that recur across the minibatch.
|
||||
2. Prefer edits that improve visual grounding and exact answer selection.
|
||||
3. Do not hardcode problem-specific formulas or answers.
|
||||
4. If hidden reference text is present, use it only to infer what information the student failed to recover from the Text Lite version.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number>,
|
||||
"failure_summary": [
|
||||
{"failure_type": "<type>", "count": <int>, "description": "<one-line>"}
|
||||
],
|
||||
"patch": {
|
||||
"reasoning": "<why these edits address the common failures>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
26
reflact/envs/mathverse/prompts/analyst_success.md
Normal file
26
reflact/envs/mathverse/prompts/analyst_success.md
Normal file
@@ -0,0 +1,26 @@
|
||||
You are an expert success-pattern analyst for visual mathematical reasoning problems.
|
||||
|
||||
You will be given MULTIPLE successful trajectories from a minibatch and the current skill document.
|
||||
Identify generalizable behavior patterns that genuinely help the agent recover the right constraints
|
||||
from the image and convert them into the exact final answer.
|
||||
|
||||
## Rules
|
||||
- Focus on broadly useful visual-math reasoning behaviors.
|
||||
- Prefer patterns about reading decisive diagram cues, checking hidden assumptions, and matching the final answer format exactly.
|
||||
- Do not add benchmark-specific facts or formulas.
|
||||
- "edits" may be empty if the skill already captures the useful patterns.
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"batch_size": <number>,
|
||||
"success_patterns": ["<pattern 1>", "<pattern 2>"],
|
||||
"patch": {
|
||||
"reasoning": "<why these patterns matter>",
|
||||
"edits": [
|
||||
{"op": "append", "content": "<markdown>"},
|
||||
{"op": "insert_after", "target": "<heading/text>", "content": "<markdown>"},
|
||||
{"op": "replace", "target": "<old text>", "content": "<new text>"},
|
||||
{"op": "delete", "target": "<exact text to remove>"}
|
||||
]
|
||||
}
|
||||
}
|
||||
25
reflact/envs/mathverse/prompts/deep_probe.md
Normal file
25
reflact/envs/mathverse/prompts/deep_probe.md
Normal file
@@ -0,0 +1,25 @@
|
||||
You are an expert diagnostic-probe designer for visual mathematical reasoning tasks.
|
||||
|
||||
You will be shown representative trajectories, the current student skill, and the student's original prompt context.
|
||||
Some trajectories may also include a hidden reference containing the fuller Text Dominant wording of the same problem.
|
||||
Design one SMALL diagnostic instruction that exposes the student's intermediate judgment without materially changing the original scaffold.
|
||||
|
||||
## Hard Constraints
|
||||
1. Do NOT substantially change the original scaffold.
|
||||
2. Do NOT prescribe a new long multi-step solving procedure.
|
||||
3. Do NOT ask for a full proof or full chain-of-thought.
|
||||
4. Ask only for a short readout of the signals already behind the student's current answer.
|
||||
5. Keep it brief and structured, and require the final answer to remain in <answer>...</answer>.
|
||||
6. If hidden reference text is present, use it only to target what visual or textual constraint the student likely missed.
|
||||
|
||||
## Good Probe Targets
|
||||
- decisive diagram cue
|
||||
- top candidate and runner-up
|
||||
- missing relation or quantity
|
||||
- why a near-miss option was rejected
|
||||
|
||||
Respond ONLY with a valid JSON object:
|
||||
{
|
||||
"reasoning": "<why this probe is informative>",
|
||||
"probe_instruction": "<the exact instruction text to append to the student prompt>"
|
||||
}
|
||||
25
reflact/envs/mathverse/prompts/judge.md
Normal file
25
reflact/envs/mathverse/prompts/judge.md
Normal file
@@ -0,0 +1,25 @@
|
||||
You are a careful and strict evaluator for visual math problems.
|
||||
|
||||
You will be given:
|
||||
1. The original question
|
||||
2. The ground-truth answer
|
||||
3. A model output
|
||||
|
||||
Decide whether the model output is mathematically equivalent to the ground-truth answer.
|
||||
|
||||
Rules:
|
||||
- Ignore harmless formatting differences.
|
||||
- Accept mathematically equivalent expressions, equations, and values.
|
||||
- Reject answers that are numerically wrong, symbolically different in meaning, missing required units when the unit changes meaning, or correspond to a different choice.
|
||||
- Do not reward partially correct reasoning if the final answer is wrong.
|
||||
|
||||
Return only:
|
||||
True
|
||||
|
||||
or
|
||||
|
||||
False
|
||||
|
||||
Question: {question}
|
||||
Ground Truth Answer: {groundtruth}
|
||||
Model Output: {modeloutput}
|
||||
11
reflact/envs/mathverse/prompts/rollout_system.md
Normal file
11
reflact/envs/mathverse/prompts/rollout_system.md
Normal file
@@ -0,0 +1,11 @@
|
||||
You are an expert visual mathematical reasoning agent.
|
||||
|
||||
{skill_section}## Task Format
|
||||
You will receive one math problem with an image or diagram.
|
||||
Use the visible diagram as evidence, not just the text.
|
||||
If some information is abbreviated in the text, recover it from the image before answering.
|
||||
|
||||
## Answer Format
|
||||
Think step by step, then provide your final answer inside <answer>...</answer>.
|
||||
- For multiple-choice questions, output only the single option label, such as <answer>B</answer>.
|
||||
- For free-form questions, output only the final mathematical answer, such as <answer>14</answer>.
|
||||
4
reflact/envs/mathverse/reflect.py
Normal file
4
reflact/envs/mathverse/reflect.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""MathVerse Reflect stage.
|
||||
|
||||
Prompts are loaded from .md files by the base adapter.
|
||||
"""
|
||||
415
reflact/envs/mathverse/rollout.py
Normal file
415
reflact/envs/mathverse/rollout.py
Normal file
@@ -0,0 +1,415 @@
|
||||
"""MathVerse rollout — single-image multimodal math reasoning."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from reflact.envs.mathverse.evaluator import evaluate_item, evaluation_mode, extract_answer
|
||||
from reflact.model import chat_student_messages, get_student_backend, is_student_exec_backend
|
||||
from reflact.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
|
||||
from reflact.prompts import load_prompt
|
||||
|
||||
|
||||
def _build_system(skill_content: str) -> str:
|
||||
if skill_content.strip():
|
||||
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
|
||||
else:
|
||||
skill_section = ""
|
||||
return load_prompt("rollout_system", env="mathverse").format(skill_section=skill_section)
|
||||
|
||||
|
||||
def _format_choices(choices: list[dict]) -> str:
|
||||
return "\n".join(f"{choice['label']}. {choice['text']}" for choice in choices)
|
||||
|
||||
|
||||
def _build_user_text(
|
||||
item: dict,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> str:
|
||||
parts = []
|
||||
if diagnostic_trace_context.strip():
|
||||
parts.append(
|
||||
"## Previous Codex Trace Snapshot\n"
|
||||
"This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n"
|
||||
f"{diagnostic_trace_context.strip()}"
|
||||
)
|
||||
question = str(item.get("question_stem") or item.get("question") or "").strip()
|
||||
if question:
|
||||
parts.append(f"## Question\n{question}")
|
||||
else:
|
||||
parts.append("## Question\nRead the full problem statement from the image.")
|
||||
|
||||
if item.get("is_choice"):
|
||||
choices = item.get("choices") or []
|
||||
if choices:
|
||||
parts.append(f"## Choices\n{_format_choices(choices)}")
|
||||
parts.append("Return only the final option label inside <answer>...</answer>.")
|
||||
else:
|
||||
parts.append("Return only the final mathematical answer inside <answer>...</answer>.")
|
||||
|
||||
if diagnostic_mode and diagnostic_instruction.strip():
|
||||
parts.append(f"## Training Readout\n{diagnostic_instruction.strip()}")
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
def _image_to_data_uri(path: str) -> str:
|
||||
mime = mimetypes.guess_type(path)[0] or "image/png"
|
||||
with open(path, "rb") as f:
|
||||
encoded = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:{mime};base64,{encoded}"
|
||||
|
||||
|
||||
def _build_messages(
|
||||
item: dict,
|
||||
skill_content: str,
|
||||
image_detail: str,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> tuple[list[dict], str, str]:
|
||||
system = _build_system(skill_content)
|
||||
user_text = _build_user_text(
|
||||
item,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
image_url = {"url": _image_to_data_uri(item["image_path"])}
|
||||
if image_detail and image_detail != "auto":
|
||||
image_url["detail"] = image_detail
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_text},
|
||||
{"type": "image_url", "image_url": image_url},
|
||||
],
|
||||
},
|
||||
]
|
||||
return messages, system, user_text
|
||||
|
||||
|
||||
def _build_codex_skill(skill_content: str) -> str:
|
||||
return render_skill_md(
|
||||
skill_content,
|
||||
description="Dynamic ReflACT skill for solving the current MathVerse visual math problem.",
|
||||
preamble=(
|
||||
"Use this skill when solving the current MathVerse problem.\n"
|
||||
"Read the image carefully and return the final answer inside <answer>...</answer>."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _run_codex_once(
|
||||
*,
|
||||
pred_dir: str,
|
||||
item: dict,
|
||||
skill_content: str,
|
||||
model: str,
|
||||
timeout: int,
|
||||
image_detail: str,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
previous_response: str = "",
|
||||
) -> tuple[str, str, str, str]:
|
||||
user_text = _build_user_text(
|
||||
item,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
task_parts = [user_text]
|
||||
if previous_response:
|
||||
task_parts.append(
|
||||
"## Previous Attempt\n"
|
||||
f"{previous_response}\n\n"
|
||||
"Re-check the diagram and the mathematical constraints. Correct the final answer if needed."
|
||||
)
|
||||
task_text = "\n\n".join(task_parts)
|
||||
skill_md = _build_codex_skill(skill_content)
|
||||
work_dir = os.path.join(pred_dir, "codex_exec")
|
||||
prepare_workspace(
|
||||
work_dir=work_dir,
|
||||
skill_md=skill_md,
|
||||
task_text=task_text,
|
||||
images=[item["image_path"]],
|
||||
)
|
||||
prompt = (
|
||||
"Use the `reflact-student` skill available in this workspace.\n"
|
||||
"Read `task.md`, inspect the attached image, solve the problem, and return only the final answer inside <answer>...</answer>."
|
||||
)
|
||||
final_message, raw = run_student_exec(
|
||||
work_dir=work_dir,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
images=[item["image_path"]],
|
||||
)
|
||||
return final_message or raw, raw, skill_md, task_text
|
||||
|
||||
|
||||
def process_one(
|
||||
item: dict,
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
image_detail: str = "auto",
|
||||
judge_model: str = "gpt-5.4",
|
||||
judge_max_completion_tokens: int = 256,
|
||||
judge_retries: int = 5,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> dict:
|
||||
item_id = str(item["id"])
|
||||
result = {
|
||||
"id": item_id,
|
||||
"question": item["question"],
|
||||
"task_type": item.get("task_type") or item.get("question_type") or "mathverse",
|
||||
"task_description": item.get("question_stem") or item["question"],
|
||||
"hard": 0,
|
||||
"soft": 0.0,
|
||||
"predicted_answer": "",
|
||||
"predicted_label": "",
|
||||
"predicted_text": "",
|
||||
"response": "",
|
||||
"fail_reason": "",
|
||||
"agent_ok": False,
|
||||
"n_turns": 0,
|
||||
"image_path": item["image_path"],
|
||||
"question_type": item["question_type"],
|
||||
"evaluation_mode": evaluation_mode(),
|
||||
"judge_model": judge_model,
|
||||
}
|
||||
if item.get("is_choice"):
|
||||
result["correct_label"] = item["correct_choice"]["label"]
|
||||
result["correct_text"] = item["correct_choice"]["text"]
|
||||
else:
|
||||
result["gold_answers"] = item.get("gold_answers") or [item["answer"]]
|
||||
|
||||
try:
|
||||
pred_dir = os.path.join(out_root, "predictions", item_id)
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
|
||||
if is_student_exec_backend():
|
||||
from reflact.model import azure_openai as _llm
|
||||
|
||||
response = ""
|
||||
conversation: list[dict] = [
|
||||
{"role": "user", "content": f"{item['question']}\n\n[image] {os.path.basename(item['image_path'])}"}
|
||||
]
|
||||
system_prompt = ""
|
||||
user_text = ""
|
||||
for turn in range(max_turns):
|
||||
response, raw, system_prompt, user_text = _run_codex_once(
|
||||
pred_dir=pred_dir,
|
||||
item=item,
|
||||
skill_content=skill_content,
|
||||
model=_llm.STUDENT_DEPLOYMENT,
|
||||
timeout=120,
|
||||
image_detail=image_detail,
|
||||
diagnostic_mode=diagnostic_mode if turn == 0 else False,
|
||||
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
|
||||
diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "",
|
||||
previous_response=response if turn > 0 else "",
|
||||
)
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": response})
|
||||
if extract_answer(response):
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation) - 1
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system_prompt)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user_text)
|
||||
else:
|
||||
messages, system_prompt, user_text = _build_messages(
|
||||
item,
|
||||
skill_content,
|
||||
image_detail,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
response = ""
|
||||
conversation = [
|
||||
{"role": "user", "content": f"{user_text}\n\n[image] {os.path.basename(item['image_path'])}"}
|
||||
]
|
||||
for turn in range(max_turns):
|
||||
if turn == 0:
|
||||
resp_text, _ = chat_student_messages(
|
||||
messages=messages,
|
||||
max_completion_tokens=1024,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
)
|
||||
else:
|
||||
refinement_text = (
|
||||
f"Your previous answer was:\n{response}\n\n"
|
||||
"Re-check the diagram and the mathematical constraints. "
|
||||
"If needed, correct your answer. Output only the final answer inside <answer>...</answer>."
|
||||
)
|
||||
refinement_messages = [
|
||||
messages[0],
|
||||
messages[1],
|
||||
{"role": "assistant", "content": response},
|
||||
{"role": "user", "content": refinement_text},
|
||||
]
|
||||
resp_text, _ = chat_student_messages(
|
||||
messages=refinement_messages,
|
||||
max_completion_tokens=768,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
)
|
||||
response = resp_text
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
|
||||
if extract_answer(resp_text):
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation) - 1
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system_prompt)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user_text)
|
||||
|
||||
eval_result = evaluate_item(
|
||||
item=item,
|
||||
prediction_text=result["response"],
|
||||
judge_model=judge_model,
|
||||
max_completion_tokens=judge_max_completion_tokens,
|
||||
retries=judge_retries,
|
||||
)
|
||||
result["evaluation_mode"] = eval_result["evaluation_mode"]
|
||||
result["judge_raw"] = eval_result.get("judge_raw", "")
|
||||
result["judge_reason"] = eval_result.get("judge_reason", "")
|
||||
result["matched_gold"] = eval_result.get("matched_gold", "")
|
||||
|
||||
if item.get("is_choice"):
|
||||
result["predicted_label"] = eval_result["predicted_label"]
|
||||
result["predicted_text"] = eval_result["predicted_text"]
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"choice=0: predicted '{eval_result['predicted_label'] or eval_result['predicted_answer']}' "
|
||||
f"but expected '{eval_result['correct_label']}'"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question_for_eval']}\n"
|
||||
f"Predicted label: {eval_result['predicted_label']!r}\n"
|
||||
f"Predicted text: {eval_result['predicted_text']!r}\n"
|
||||
f"Correct label: {eval_result['correct_label']!r}\n"
|
||||
f"Correct text: {eval_result['correct_text']!r}\n"
|
||||
f"Exact Match: {eval_result['em']}"
|
||||
)
|
||||
else:
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"judge=0: predicted '{eval_result['predicted_answer']}' "
|
||||
f"but expected '{item['answer']}' ({eval_result.get('judge_reason', '')})"
|
||||
)
|
||||
eval_detail = (
|
||||
f"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question_for_eval']}\n"
|
||||
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
|
||||
f"Gold answer: {item['answer']!r}\n"
|
||||
f"Judge correct: {eval_result['em']}\n"
|
||||
f"Judge reason: {eval_result.get('judge_reason', '')}\n"
|
||||
f"String F1: {eval_result.get('string_f1', 0.0):.4f}"
|
||||
)
|
||||
|
||||
conversation.append({"role": "system", "content": eval_detail})
|
||||
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(conversation, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
result["fail_reason"] = f"error: {e}"
|
||||
return result
|
||||
|
||||
|
||||
def run_batch(
|
||||
items: list[dict],
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
workers: int = 32,
|
||||
image_detail: str = "auto",
|
||||
judge_model: str = "gpt-5.4",
|
||||
judge_max_completion_tokens: int = 256,
|
||||
judge_retries: int = 5,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context_by_id: dict[str, str] | None = None,
|
||||
) -> list[dict]:
|
||||
results_path = os.path.join(out_root, "results.jsonl")
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
|
||||
expected_eval_mode = evaluation_mode()
|
||||
done_ids: set[str] = set()
|
||||
existing: list[dict] = []
|
||||
rewrite_results = False
|
||||
if os.path.exists(results_path):
|
||||
with open(results_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
try:
|
||||
row = json.loads(line)
|
||||
if row.get("evaluation_mode") != expected_eval_mode:
|
||||
rewrite_results = True
|
||||
continue
|
||||
done_ids.add(str(row["id"]))
|
||||
existing.append(row)
|
||||
except Exception:
|
||||
rewrite_results = True
|
||||
|
||||
pending = [item for item in items if str(item["id"]) not in done_ids]
|
||||
if not pending and not rewrite_results:
|
||||
return existing
|
||||
|
||||
results = list(existing)
|
||||
file_mode = "w" if rewrite_results else "a"
|
||||
with open(results_path, file_mode, encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex:
|
||||
if rewrite_results:
|
||||
for row in existing:
|
||||
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
futs = {
|
||||
ex.submit(
|
||||
process_one,
|
||||
item,
|
||||
out_root,
|
||||
skill_content,
|
||||
max_turns=max_turns,
|
||||
image_detail=image_detail,
|
||||
judge_model=judge_model,
|
||||
judge_max_completion_tokens=judge_max_completion_tokens,
|
||||
judge_retries=judge_retries,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""),
|
||||
): item
|
||||
for item in pending
|
||||
}
|
||||
for fut in as_completed(futs):
|
||||
row = fut.result()
|
||||
results.append(row)
|
||||
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
outf.flush()
|
||||
return results
|
||||
15
reflact/envs/mathverse/skills/initial.md
Normal file
15
reflact/envs/mathverse/skills/initial.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# MathVerse Visual Math Heuristics
|
||||
|
||||
## Diagram First
|
||||
- Read the diagram before locking onto an equation or option.
|
||||
- Recover missing labels, lengths, angles, axes, or object relations from the image when the text is abbreviated.
|
||||
- If the text seems underspecified, assume the image may contain the decisive constraint.
|
||||
|
||||
## Constraint Tracking
|
||||
- Write down the few constraints that actually determine the answer instead of solving from vague intuition.
|
||||
- Prefer geometric or functional relations that are directly supported by the figure.
|
||||
- For multiple-choice questions, compare the final candidate against every option exactly.
|
||||
|
||||
## Final Answer
|
||||
- Use the image and the text consistently.
|
||||
- Return only the final answer inside <answer>...</answer>.
|
||||
2
reflact/envs/mmrb/__init__.py
Normal file
2
reflact/envs/mmrb/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""MMRB environment package."""
|
||||
|
||||
283
reflact/envs/mmrb/adapter.py
Normal file
283
reflact/envs/mmrb/adapter.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""MMRB environment adapter for ReflACT."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from reflact.gradient.deep_probe import generate_deep_probe_instruction
|
||||
from reflact.datasets.base import BatchSpec
|
||||
from reflact.gradient.reflect import run_minibatch_reflect
|
||||
from reflact.envs.base import EnvAdapter
|
||||
from reflact.envs.mmrb.dataloader import MMRBDataLoader
|
||||
from reflact.envs.mmrb.rollout import run_batch
|
||||
from reflact.model import get_student_backend
|
||||
|
||||
|
||||
class MMRBAdapter(EnvAdapter):
|
||||
"""MMRB adapter."""
|
||||
|
||||
def build_reference_text(self, item: dict) -> str:
|
||||
reasoning_steps = item.get("reasoning_steps") or []
|
||||
if not reasoning_steps:
|
||||
return ""
|
||||
|
||||
blocks: list[str] = []
|
||||
for path_idx, path in enumerate(reasoning_steps, 1):
|
||||
if not isinstance(path, list) or not path:
|
||||
continue
|
||||
lines = [f"### Reasoning Path {path_idx}"]
|
||||
for step in path:
|
||||
if not isinstance(step, dict):
|
||||
continue
|
||||
step_no = step.get("reasoning step", "?")
|
||||
step_type = str(step.get("reasoning type") or "").strip()
|
||||
rationale = str(step.get("rationale") or "").strip()
|
||||
if rationale:
|
||||
prefix = f"{step_no}. [{step_type}] " if step_type else f"{step_no}. "
|
||||
lines.append(prefix + rationale)
|
||||
if len(lines) > 1:
|
||||
blocks.append("\n".join(lines))
|
||||
if not blocks:
|
||||
return ""
|
||||
return "## Reference Reasoning Steps\n" + "\n\n".join(blocks[:3])
|
||||
|
||||
def get_reference_metadata(self, item: dict) -> dict:
|
||||
reasoning_steps = item.get("reasoning_steps") or []
|
||||
path_count = 0
|
||||
preview_parts: list[str] = []
|
||||
for path in reasoning_steps:
|
||||
if not isinstance(path, list) or not path:
|
||||
continue
|
||||
path_count += 1
|
||||
first = path[0] if isinstance(path[0], dict) else {}
|
||||
step_type = str(first.get("reasoning type") or "").strip()
|
||||
rationale = str(first.get("rationale") or "").strip()
|
||||
preview_parts.append(f"[path {path_count}] {step_type}: {rationale[:180]}")
|
||||
if not path_count:
|
||||
return {"fields": [], "preview": ""}
|
||||
return {
|
||||
"fields": ["reasoning_steps"],
|
||||
"preview": "\n".join(preview_parts)[:500],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "ratio",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
max_turns: int = 1,
|
||||
workers: int = 16,
|
||||
analyst_workers: int = 16,
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
image_detail: str = "auto",
|
||||
use_deep_reflect: bool = False,
|
||||
deep_reflect_failures: int = 4,
|
||||
deep_reflect_successes: int = 2,
|
||||
) -> None:
|
||||
self.max_turns = max_turns
|
||||
self.workers = workers
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
self.edit_budget = edit_budget
|
||||
self.image_detail = image_detail
|
||||
self.use_deep_reflect = use_deep_reflect
|
||||
self.deep_reflect_failures = deep_reflect_failures
|
||||
self.deep_reflect_successes = deep_reflect_successes
|
||||
self.dataloader = MMRBDataLoader(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
self.dataloader.setup(cfg)
|
||||
|
||||
def get_dataloader(self):
|
||||
return self.dataloader
|
||||
|
||||
def build_env_from_batch(self, batch: BatchSpec, **kwargs):
|
||||
return list(batch.payload or [])
|
||||
|
||||
def build_train_env(self, batch_size: int, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_train_batch(batch_size=batch_size, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def build_eval_env(self, env_num: int, split: str, seed: int, **kwargs):
|
||||
batch = self.dataloader.build_eval_batch(env_num=env_num, split=split, seed=seed, **kwargs)
|
||||
return self.build_env_from_batch(batch, **kwargs)
|
||||
|
||||
def rollout(
|
||||
self,
|
||||
env_manager,
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict]:
|
||||
items: list[dict] = env_manager
|
||||
return run_batch(
|
||||
items=items,
|
||||
out_root=out_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
workers=self.workers,
|
||||
image_detail=self.image_detail,
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
|
||||
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),
|
||||
)
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def deep_reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
if not self.use_deep_reflect:
|
||||
return []
|
||||
|
||||
env_manager = kwargs.get("env_manager")
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
codex_backend = get_student_backend() == "codex_exec"
|
||||
selected_items = self.select_representative_items(
|
||||
results,
|
||||
env_manager if isinstance(env_manager, list) else None,
|
||||
n_failures=self.deep_reflect_failures,
|
||||
n_successes=self.deep_reflect_successes,
|
||||
seed=random_seed,
|
||||
)
|
||||
if not selected_items:
|
||||
return []
|
||||
selected_ids = {str(item["id"]) for item in selected_items}
|
||||
selected_results = [row for row in results if str(row.get("id")) in selected_ids]
|
||||
selected_examples = self.attach_reference_context(selected_results, selected_items)
|
||||
if codex_backend:
|
||||
selected_examples = self.attach_codex_probe_context(selected_examples, prediction_dir)
|
||||
|
||||
reasoning_count = 0
|
||||
selected_metadata = []
|
||||
for item in selected_items:
|
||||
meta = self.get_reference_metadata(item)
|
||||
if meta["fields"]:
|
||||
reasoning_count += 1
|
||||
selected_metadata.append({
|
||||
"id": str(item["id"]),
|
||||
"task_type": str(item.get("subtask") or item.get("task_type") or "mmrb"),
|
||||
"reference_fields": meta["fields"],
|
||||
"reference_preview": meta["preview"],
|
||||
})
|
||||
|
||||
deep_dir = os.path.join(out_dir, "deep_reflect")
|
||||
rollout_dir = os.path.join(deep_dir, "rollout")
|
||||
patches_dir = os.path.join(deep_dir, "patches")
|
||||
os.makedirs(deep_dir, exist_ok=True)
|
||||
print(
|
||||
f" [2b/6 DEEP REFLECT setup] selected={len(selected_items)} "
|
||||
f"reference_fields=reasoning_steps({reasoning_count}/{len(selected_items)})"
|
||||
)
|
||||
probe = generate_deep_probe_instruction(
|
||||
skill_content=skill_content,
|
||||
items=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
system_prompt=self.get_codex_deep_probe_prompt() if codex_backend else self.get_deep_probe_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
)
|
||||
if not probe:
|
||||
return []
|
||||
diagnostic_trace_context_by_id = None
|
||||
if codex_backend:
|
||||
selected_items, diagnostic_trace_context_by_id, probe = self.resolve_codex_probe_target(
|
||||
selected_items=selected_items,
|
||||
selected_examples=selected_examples,
|
||||
prediction_dir=prediction_dir,
|
||||
probe=probe,
|
||||
)
|
||||
probe_record = {
|
||||
**probe,
|
||||
"reference_summary": {
|
||||
"selected_count": len(selected_items),
|
||||
"field_counts": {"reasoning_steps": reasoning_count},
|
||||
},
|
||||
"selected_examples": selected_metadata,
|
||||
}
|
||||
with open(os.path.join(deep_dir, "probe.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(probe_record, f, ensure_ascii=False, indent=2)
|
||||
deep_results = run_batch(
|
||||
items=selected_items,
|
||||
out_root=rollout_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
workers=min(self.workers, max(len(selected_items), 1)),
|
||||
image_detail=self.image_detail,
|
||||
diagnostic_mode=True,
|
||||
diagnostic_instruction=probe["probe_instruction"],
|
||||
diagnostic_trace_context_by_id=diagnostic_trace_context_by_id,
|
||||
)
|
||||
deep_results = self.attach_reference_context(deep_results, selected_items)
|
||||
return run_minibatch_reflect(
|
||||
results=deep_results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=os.path.join(rollout_dir, "predictions"),
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return self.dataloader.get_task_types()
|
||||
146
reflact/envs/mmrb/dataloader.py
Normal file
146
reflact/envs/mmrb/dataloader.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""MMRB task dataloader."""
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from reflact.datasets.base import SplitDataLoader
|
||||
|
||||
|
||||
# ── Raw data loading utilities (for preprocessing / standalone eval) ─────
|
||||
|
||||
def _load_json(path: str) -> Any:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _iter_data_files(data_path: str) -> list[str]:
|
||||
if not data_path:
|
||||
return []
|
||||
if os.path.isfile(data_path):
|
||||
return [data_path]
|
||||
if os.path.isdir(data_path):
|
||||
nested = glob.glob(os.path.join(data_path, "**", "*_human.json"), recursive=True)
|
||||
flat = glob.glob(os.path.join(data_path, "*_human.json"))
|
||||
return sorted(set(nested + flat))
|
||||
return []
|
||||
|
||||
|
||||
def _normalize_space(text: str) -> str:
|
||||
return re.sub(r"\s+", " ", str(text or "").strip())
|
||||
|
||||
|
||||
def _normalize_item(item: dict, row_idx: int, source_path: str) -> dict | None:
|
||||
question = _normalize_space(item.get("question") or "")
|
||||
answer = _normalize_space(item.get("answer") or "")
|
||||
raw_image_paths = item.get("image_paths") or []
|
||||
if not question or not answer or not isinstance(raw_image_paths, list) or not raw_image_paths:
|
||||
return None
|
||||
|
||||
base_dir = os.path.dirname(source_path)
|
||||
image_paths: list[str] = []
|
||||
for raw_path in raw_image_paths:
|
||||
rel = str(raw_path or "").strip()
|
||||
if not rel:
|
||||
continue
|
||||
abs_path = rel if os.path.isabs(rel) else os.path.abspath(os.path.join(base_dir, rel))
|
||||
if os.path.exists(abs_path):
|
||||
image_paths.append(abs_path)
|
||||
if not image_paths:
|
||||
return None
|
||||
|
||||
options_raw = item.get("options") or []
|
||||
options = [_normalize_space(opt) for opt in options_raw if _normalize_space(opt)]
|
||||
source = _normalize_space(item.get("source") or "unknown")
|
||||
subtask = _normalize_space(item.get("subtask") or "unknown")
|
||||
item_index = item.get("index", row_idx)
|
||||
item_id = f"{source}:{subtask}:{item_index}"
|
||||
|
||||
return {
|
||||
"id": item_id,
|
||||
"source": source,
|
||||
"subtask": subtask,
|
||||
"task_type": subtask,
|
||||
"question": question,
|
||||
"answer": answer,
|
||||
"options": options,
|
||||
"is_choice": bool(options),
|
||||
"image_paths": image_paths,
|
||||
"reasoning_steps": item.get("reasoning_steps") or [],
|
||||
"annotation_time": item.get("annotation_time"),
|
||||
"source_path": os.path.abspath(source_path),
|
||||
}
|
||||
|
||||
|
||||
def load_items(data_path: str) -> list[dict]:
|
||||
"""Load and normalise MMRB items from JSON files."""
|
||||
files = _iter_data_files(data_path)
|
||||
if not files:
|
||||
raise ValueError(
|
||||
"MMRB requires data_path to be a *_human.json file or a directory "
|
||||
"containing extracted MMRB subtask folders."
|
||||
)
|
||||
|
||||
items: list[dict] = []
|
||||
for path in files:
|
||||
raw = _load_json(path)
|
||||
if not isinstance(raw, list):
|
||||
raise ValueError(f"Expected JSON array in {path}, got {type(raw).__name__}")
|
||||
for row_idx, item in enumerate(raw):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
norm = _normalize_item(item, row_idx=row_idx, source_path=path)
|
||||
if norm is not None:
|
||||
items.append(norm)
|
||||
|
||||
if not items:
|
||||
raise ValueError(f"No valid MMRB items loaded from {data_path}")
|
||||
return items
|
||||
|
||||
|
||||
# ── Dataloader ───────────────────────────────────────────────────────────
|
||||
|
||||
class MMRBDataLoader(SplitDataLoader):
|
||||
"""MMRB dataloader."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split_dir: str = "",
|
||||
data_path: str = "",
|
||||
split_mode: str = "ratio",
|
||||
split_ratio: str = "2:1:7",
|
||||
split_seed: int = 42,
|
||||
split_output_dir: str = "",
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
split_dir=split_dir,
|
||||
data_path=data_path,
|
||||
split_mode=split_mode,
|
||||
split_ratio=split_ratio,
|
||||
split_seed=split_seed,
|
||||
split_output_dir=split_output_dir,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
)
|
||||
self._task_types: list[str] = []
|
||||
|
||||
def load_raw_items(self, data_path: str) -> list[dict]:
|
||||
return load_items(data_path)
|
||||
|
||||
def setup(self, cfg: dict) -> None:
|
||||
super().setup(cfg)
|
||||
all_items = self.train_items + self.val_items + self.test_items
|
||||
task_types = {
|
||||
item.get("subtask") or item.get("task_type") or "unknown"
|
||||
for item in all_items
|
||||
}
|
||||
self._task_types = sorted(task_types)
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return list(self._task_types)
|
||||
102
reflact/envs/mmrb/evaluator.py
Normal file
102
reflact/envs/mmrb/evaluator.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""MMRB evaluation helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import string
|
||||
|
||||
|
||||
_EVAL_MODE = "mmrb_exact_match_v1"
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
text = str(text or "").strip().lower()
|
||||
text = "".join(ch for ch in text if ch not in string.punctuation)
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def extract_answer(text: str | None) -> str:
|
||||
raw = str(text or "").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
|
||||
answer_tags = re.findall(r"<answer>\s*(.*?)\s*</answer>", raw, re.IGNORECASE | re.DOTALL)
|
||||
if answer_tags:
|
||||
return answer_tags[-1].strip()
|
||||
|
||||
bracket = re.findall(r"Answer\s*\[\s*(.*?)\s*\]", raw, re.IGNORECASE | re.DOTALL)
|
||||
if bracket:
|
||||
return bracket[-1].strip()
|
||||
|
||||
boxed = re.findall(r"\\boxed\{(.*?)\}", raw, re.IGNORECASE | re.DOTALL)
|
||||
if boxed:
|
||||
return boxed[-1].strip()
|
||||
|
||||
single = raw.strip().rstrip(".):")
|
||||
if re.fullmatch(r"[A-Z]", single, re.IGNORECASE):
|
||||
return single.strip()
|
||||
|
||||
patterns = [
|
||||
r"final answer\s*(?:is)?\s*[::]?\s*(.+)",
|
||||
r"the answer is\s*[::]?\s*(.+)",
|
||||
r"answer\s*[::]?\s*(.+)$",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, raw, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).strip().strip("*")
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def evaluate_item(*, item: dict, prediction_text: str) -> dict:
|
||||
predicted_answer = extract_answer(prediction_text)
|
||||
gold_answer = str(item.get("answer") or "").strip()
|
||||
predicted_norm = normalize_text(predicted_answer)
|
||||
gold_norm = normalize_text(gold_answer)
|
||||
|
||||
hard = 0.0
|
||||
matched_gold = ""
|
||||
predicted_label = ""
|
||||
predicted_text = predicted_answer
|
||||
|
||||
if item.get("is_choice"):
|
||||
predicted_label = str(predicted_answer).strip().upper().rstrip(".):")
|
||||
if predicted_label == str(gold_answer).strip().upper():
|
||||
hard = 1.0
|
||||
matched_gold = gold_answer
|
||||
else:
|
||||
for option in item.get("options") or []:
|
||||
label_match = re.match(r"\(?([A-Z])\)", option)
|
||||
if not label_match:
|
||||
continue
|
||||
label = label_match.group(1).upper()
|
||||
option_text = option[label_match.end():].strip(" .:-")
|
||||
if predicted_norm and normalize_text(option_text) == predicted_norm:
|
||||
predicted_label = label
|
||||
predicted_text = option_text
|
||||
break
|
||||
if predicted_label == str(gold_answer).strip().upper():
|
||||
hard = 1.0
|
||||
matched_gold = gold_answer
|
||||
else:
|
||||
if predicted_norm and gold_norm and (
|
||||
predicted_norm == gold_norm or predicted_norm in gold_norm or gold_norm in predicted_norm
|
||||
):
|
||||
hard = 1.0
|
||||
matched_gold = gold_answer
|
||||
|
||||
return {
|
||||
"evaluation_mode": _EVAL_MODE,
|
||||
"predicted_answer": predicted_answer,
|
||||
"predicted_label": predicted_label,
|
||||
"predicted_text": predicted_text,
|
||||
"em": hard,
|
||||
"f1": hard,
|
||||
"sub_em": hard,
|
||||
"matched_gold": matched_gold,
|
||||
}
|
||||
|
||||
|
||||
def evaluation_mode() -> str:
|
||||
return _EVAL_MODE
|
||||
|
||||
10
reflact/envs/mmrb/prompts/rollout_system.md
Normal file
10
reflact/envs/mmrb/prompts/rollout_system.md
Normal file
@@ -0,0 +1,10 @@
|
||||
You are an expert multi-image reasoning agent.
|
||||
|
||||
{skill_section}## Task Format
|
||||
You will receive a question grounded in multiple images.
|
||||
Use the image order exactly as presented in the prompt and compare evidence across images carefully.
|
||||
|
||||
## Answer Format
|
||||
- Put the final answer inside <answer>...</answer>.
|
||||
- For multiple-choice questions, output only the single option letter inside <answer>...</answer>.
|
||||
- For open questions, output only the short final answer inside <answer>...</answer>.
|
||||
439
reflact/envs/mmrb/rollout.py
Normal file
439
reflact/envs/mmrb/rollout.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""MMRB rollout."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from reflact.envs.mmrb.evaluator import evaluate_item, evaluation_mode
|
||||
from reflact.model import chat_student_messages, get_student_backend, is_student_exec_backend
|
||||
from reflact.model.codex_harness import prepare_workspace, render_skill_md, run_student_exec
|
||||
from reflact.prompts import load_prompt
|
||||
|
||||
_IMAGE_REF_RE = re.compile(r"\{image#(\d+)\}", re.IGNORECASE)
|
||||
|
||||
|
||||
def _build_system(skill_content: str) -> str:
|
||||
if skill_content.strip():
|
||||
skill_section = f"## Skill\n{skill_content.strip()}\n\n"
|
||||
else:
|
||||
skill_section = ""
|
||||
return load_prompt("rollout_system", env="mmrb").format(skill_section=skill_section)
|
||||
|
||||
|
||||
def _image_to_data_uri(path: str) -> str:
|
||||
mime = mimetypes.guess_type(path)[0] or "image/png"
|
||||
with open(path, "rb") as f:
|
||||
encoded = base64.b64encode(f.read()).decode("ascii")
|
||||
return f"data:{mime};base64,{encoded}"
|
||||
|
||||
|
||||
def _build_user_content(
|
||||
item: dict,
|
||||
image_detail: str,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> tuple[list[dict], str]:
|
||||
raw_question = str(item["question"])
|
||||
content: list[dict] = []
|
||||
text_parts: list[str] = []
|
||||
used_indices: set[int] = set()
|
||||
cursor = 0
|
||||
|
||||
if diagnostic_trace_context.strip():
|
||||
prefix = (
|
||||
"## Previous Codex Trace Snapshot\n"
|
||||
"This is a partial transcript from an earlier attempt. Use it as your current reasoning context.\n\n"
|
||||
f"{diagnostic_trace_context.strip()}\n\n"
|
||||
)
|
||||
content.append({"type": "text", "text": prefix})
|
||||
text_parts.append(prefix)
|
||||
|
||||
for match in _IMAGE_REF_RE.finditer(raw_question):
|
||||
if match.start() > cursor:
|
||||
chunk = raw_question[cursor:match.start()]
|
||||
if chunk:
|
||||
content.append({"type": "text", "text": chunk})
|
||||
text_parts.append(chunk)
|
||||
|
||||
image_idx = int(match.group(1)) - 1
|
||||
marker = f"[Image #{image_idx + 1}]"
|
||||
text_parts.append(marker)
|
||||
if 0 <= image_idx < len(item["image_paths"]):
|
||||
image_url = {"url": _image_to_data_uri(item["image_paths"][image_idx])}
|
||||
if image_detail and image_detail != "auto":
|
||||
image_url["detail"] = image_detail
|
||||
content.append({"type": "image_url", "image_url": image_url})
|
||||
used_indices.add(image_idx)
|
||||
else:
|
||||
content.append({"type": "text", "text": marker})
|
||||
cursor = match.end()
|
||||
|
||||
if cursor < len(raw_question):
|
||||
tail = raw_question[cursor:]
|
||||
if tail:
|
||||
content.append({"type": "text", "text": tail})
|
||||
text_parts.append(tail)
|
||||
|
||||
for idx, path in enumerate(item["image_paths"]):
|
||||
if idx in used_indices:
|
||||
continue
|
||||
marker = f"\n[Additional Image #{idx + 1}]"
|
||||
text_parts.append(marker)
|
||||
content.append({"type": "text", "text": marker})
|
||||
image_url = {"url": _image_to_data_uri(path)}
|
||||
if image_detail and image_detail != "auto":
|
||||
image_url["detail"] = image_detail
|
||||
content.append({"type": "image_url", "image_url": image_url})
|
||||
|
||||
answer_instruction = (
|
||||
"\n\nAnswer with the single correct option letter inside <answer>...</answer>."
|
||||
if item.get("is_choice")
|
||||
else "\n\nAnswer with the short final answer inside <answer>...</answer>."
|
||||
)
|
||||
content.append({"type": "text", "text": answer_instruction})
|
||||
text_parts.append(answer_instruction)
|
||||
|
||||
if diagnostic_mode and diagnostic_instruction.strip():
|
||||
diag_block = f"\n\n## Training Readout\n{diagnostic_instruction.strip()}"
|
||||
content.append({"type": "text", "text": diag_block})
|
||||
text_parts.append(diag_block)
|
||||
|
||||
return content, "".join(text_parts)
|
||||
|
||||
|
||||
def _build_messages(
|
||||
item: dict,
|
||||
skill_content: str,
|
||||
image_detail: str,
|
||||
*,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
) -> tuple[list[dict], str, str]:
|
||||
system = _build_system(skill_content)
|
||||
user_content, user_text = _build_user_content(
|
||||
item,
|
||||
image_detail,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
return messages, system, user_text
|
||||
|
||||
|
||||
def _build_codex_skill(skill_content: str) -> str:
|
||||
return render_skill_md(
|
||||
skill_content,
|
||||
description="Dynamic ReflACT skill for solving the current MMRB multi-image reasoning question.",
|
||||
preamble=(
|
||||
"Use this skill when solving the current multi-image reasoning task.\n"
|
||||
"Inspect all attached images carefully and return the final answer inside <answer>...</answer>."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _run_codex_once(
|
||||
*,
|
||||
pred_dir: str,
|
||||
item: dict,
|
||||
skill_content: str,
|
||||
model: str,
|
||||
timeout: int,
|
||||
image_detail: str,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
previous_response: str = "",
|
||||
) -> tuple[str, str, str, str]:
|
||||
user_text = _build_user_content(
|
||||
item,
|
||||
image_detail,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)[1]
|
||||
task_parts = [user_text]
|
||||
if previous_response:
|
||||
task_parts.append(
|
||||
"## Previous Attempt\n"
|
||||
f"{previous_response}\n\n"
|
||||
"Review the same images carefully and answer again."
|
||||
)
|
||||
task_text = "\n\n".join(task_parts)
|
||||
skill_md = _build_codex_skill(skill_content)
|
||||
work_dir = os.path.join(pred_dir, "codex_exec")
|
||||
prepare_workspace(
|
||||
work_dir=work_dir,
|
||||
skill_md=skill_md,
|
||||
task_text=task_text,
|
||||
images=item["image_paths"],
|
||||
)
|
||||
prompt = (
|
||||
"Use the `reflact-student` skill available in this workspace.\n"
|
||||
"Read `task.md`, inspect all attached images, and answer the question.\n"
|
||||
"Keep the final answer inside <answer>...</answer>."
|
||||
)
|
||||
final_message, raw = run_student_exec(
|
||||
work_dir=work_dir,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
images=item["image_paths"],
|
||||
)
|
||||
return final_message or raw, raw, skill_md, task_text
|
||||
|
||||
|
||||
def process_one(
|
||||
item: dict,
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
image_detail: str = "auto",
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
) -> dict:
|
||||
item_id = str(item["id"])
|
||||
result = {
|
||||
"id": item_id,
|
||||
"question": item["question"],
|
||||
"task_type": item.get("subtask") or item.get("task_type") or "mmrb",
|
||||
"task_description": item["question"],
|
||||
"hard": 0,
|
||||
"soft": 0.0,
|
||||
"predicted_answer": "",
|
||||
"predicted_label": "",
|
||||
"predicted_text": "",
|
||||
"response": "",
|
||||
"fail_reason": "",
|
||||
"agent_ok": False,
|
||||
"n_turns": 0,
|
||||
"image_paths": item["image_paths"],
|
||||
"gold_answer": item["answer"],
|
||||
"evaluation_mode": evaluation_mode(),
|
||||
}
|
||||
|
||||
try:
|
||||
pred_dir = os.path.join(out_root, "predictions", item_id)
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
|
||||
if is_student_exec_backend():
|
||||
from reflact.model import azure_openai as _llm
|
||||
|
||||
response = ""
|
||||
conversation: list[dict] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": item["question"] + "\n\n" + "\n".join(
|
||||
f"[image] {os.path.basename(path)}" for path in item["image_paths"]
|
||||
),
|
||||
}
|
||||
]
|
||||
system_prompt = ""
|
||||
user_text = ""
|
||||
for turn in range(max_turns):
|
||||
response, raw, system_prompt, user_text = _run_codex_once(
|
||||
pred_dir=pred_dir,
|
||||
item=item,
|
||||
skill_content=skill_content,
|
||||
model=_llm.STUDENT_DEPLOYMENT,
|
||||
timeout=120,
|
||||
image_detail=image_detail,
|
||||
diagnostic_mode=diagnostic_mode if turn == 0 else False,
|
||||
diagnostic_instruction=diagnostic_instruction if turn == 0 else "",
|
||||
diagnostic_trace_context=diagnostic_trace_context if turn == 0 else "",
|
||||
previous_response=response if turn > 0 else "",
|
||||
)
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": response})
|
||||
if "<answer>" in response.lower():
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation) - 1
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system_prompt)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user_text)
|
||||
|
||||
eval_result = evaluate_item(item=item, prediction_text=response)
|
||||
result["evaluation_mode"] = eval_result["evaluation_mode"]
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["predicted_label"] = eval_result["predicted_label"]
|
||||
result["predicted_text"] = eval_result["predicted_text"]
|
||||
result["matched_gold"] = eval_result["matched_gold"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"predicted '{eval_result['predicted_answer']}' but expected '{item['answer']}'"
|
||||
)
|
||||
eval_detail = (
|
||||
"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
|
||||
f"Predicted label: {eval_result['predicted_label']!r}\n"
|
||||
f"Gold answer: {item['answer']!r}\n"
|
||||
f"Correct: {eval_result['em']}\n"
|
||||
)
|
||||
conversation.append({"role": "system", "content": eval_detail})
|
||||
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(conversation, f, ensure_ascii=False, indent=2)
|
||||
return result
|
||||
|
||||
messages, system_prompt, user_text = _build_messages(
|
||||
item,
|
||||
skill_content,
|
||||
image_detail,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
)
|
||||
response = ""
|
||||
conversation: list[dict] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_text + "\n\n" + "\n".join(
|
||||
f"[image] {os.path.basename(path)}" for path in item["image_paths"]
|
||||
),
|
||||
}
|
||||
]
|
||||
|
||||
for turn in range(max_turns):
|
||||
if turn == 0:
|
||||
resp_text, _ = chat_student_messages(
|
||||
messages=messages,
|
||||
max_completion_tokens=768,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
)
|
||||
else:
|
||||
refinement_messages = [
|
||||
messages[0],
|
||||
messages[1],
|
||||
{"role": "assistant", "content": response},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Review the same images carefully and answer again. Keep the final answer inside <answer>...</answer>.",
|
||||
},
|
||||
]
|
||||
resp_text, _ = chat_student_messages(
|
||||
messages=refinement_messages,
|
||||
max_completion_tokens=512,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
)
|
||||
response = resp_text
|
||||
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
|
||||
if "<answer>" in resp_text.lower():
|
||||
break
|
||||
|
||||
result["response"] = response
|
||||
result["agent_ok"] = True
|
||||
result["n_turns"] = len(conversation) - 1
|
||||
|
||||
with open(os.path.join(pred_dir, "student_system_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(system_prompt)
|
||||
with open(os.path.join(pred_dir, "student_user_prompt.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(user_text)
|
||||
|
||||
eval_result = evaluate_item(item=item, prediction_text=response)
|
||||
result["evaluation_mode"] = eval_result["evaluation_mode"]
|
||||
result["predicted_answer"] = eval_result["predicted_answer"]
|
||||
result["predicted_label"] = eval_result["predicted_label"]
|
||||
result["predicted_text"] = eval_result["predicted_text"]
|
||||
result["matched_gold"] = eval_result["matched_gold"]
|
||||
result["hard"] = int(eval_result["em"])
|
||||
result["soft"] = eval_result["f1"]
|
||||
if not result["hard"]:
|
||||
result["fail_reason"] = (
|
||||
f"predicted '{eval_result['predicted_answer']}' but expected '{item['answer']}'"
|
||||
)
|
||||
|
||||
eval_detail = (
|
||||
"[EVALUATION RESULT]\n"
|
||||
f"Question: {item['question']}\n"
|
||||
f"Predicted answer: {eval_result['predicted_answer']!r}\n"
|
||||
f"Predicted label: {eval_result['predicted_label']!r}\n"
|
||||
f"Gold answer: {item['answer']!r}\n"
|
||||
f"Correct: {eval_result['em']}\n"
|
||||
)
|
||||
conversation.append({"role": "system", "content": eval_detail})
|
||||
with open(os.path.join(pred_dir, "conversation.json"), "w", encoding="utf-8") as f:
|
||||
json.dump(conversation, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
result["fail_reason"] = f"error: {e}"
|
||||
return result
|
||||
|
||||
|
||||
def run_batch(
|
||||
items: list[dict],
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
*,
|
||||
max_turns: int = 1,
|
||||
workers: int = 16,
|
||||
image_detail: str = "auto",
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context_by_id: dict[str, str] | None = None,
|
||||
) -> list[dict]:
|
||||
results_path = os.path.join(out_root, "results.jsonl")
|
||||
os.makedirs(out_root, exist_ok=True)
|
||||
|
||||
expected_eval_mode = evaluation_mode()
|
||||
done_ids: set[str] = set()
|
||||
existing: list[dict] = []
|
||||
rewrite_results = False
|
||||
if os.path.exists(results_path):
|
||||
with open(results_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
try:
|
||||
row = json.loads(line)
|
||||
if row.get("evaluation_mode") != expected_eval_mode:
|
||||
rewrite_results = True
|
||||
continue
|
||||
done_ids.add(str(row["id"]))
|
||||
existing.append(row)
|
||||
except Exception:
|
||||
rewrite_results = True
|
||||
|
||||
pending = [item for item in items if str(item["id"]) not in done_ids]
|
||||
if not pending and not rewrite_results:
|
||||
return existing
|
||||
|
||||
results = list(existing)
|
||||
file_mode = "w" if rewrite_results else "a"
|
||||
with open(results_path, file_mode, encoding="utf-8") as outf, ThreadPoolExecutor(max_workers=workers) as ex:
|
||||
if rewrite_results:
|
||||
for row in existing:
|
||||
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
futs = {
|
||||
ex.submit(
|
||||
process_one,
|
||||
item,
|
||||
out_root,
|
||||
skill_content,
|
||||
max_turns=max_turns,
|
||||
image_detail=image_detail,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""),
|
||||
): item
|
||||
for item in pending
|
||||
}
|
||||
for fut in as_completed(futs):
|
||||
row = fut.result()
|
||||
results.append(row)
|
||||
outf.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
outf.flush()
|
||||
return results
|
||||
17
reflact/envs/mmrb/skills/initial.md
Normal file
17
reflact/envs/mmrb/skills/initial.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# MMRB Multi-Image Reasoning Heuristics
|
||||
|
||||
## Cross-Image Alignment
|
||||
- Track the role of each image by its index and compare evidence across all referenced images before deciding.
|
||||
- When the question depends on sequence, correspondence, or retrieval, verify the relation between images instead of judging each image independently.
|
||||
|
||||
## Option Elimination
|
||||
- For multiple-choice tasks, compare all options and reject choices that match only part of the visual evidence.
|
||||
- If options differ by a small visual detail, use the most discriminative cue rather than a coarse scene impression.
|
||||
|
||||
## Open Answers
|
||||
- For open-ended tasks, give the shortest answer that is fully supported by the combined images.
|
||||
- Preserve exact entities, attributes, counts, and directions when the images support them directly.
|
||||
|
||||
## Final Answer
|
||||
- Output only the final answer inside <answer>...</answer>.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user