mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
refactor: make EnvAdapter.reflect a shared default (fixes dropped reflect kwargs)
All six adapters duplicated an identical reflect() that delegates to run_minibatch_reflect. The copies had drifted: OfficeQA/DocVQA silently dropped meta_skill_context and ALFWorld dropped update_mode, so those analysts ran without inputs every other benchmark receives (active under the default use_meta_skill: true). Move the delegation into EnvAdapter.reflect as one default that forwards all kwargs uniformly, and delete the six overrides. reflect is no longer abstract — adapters inherit it and override only for custom logic. Net -225 lines. Behavior change: OfficeQA/DocVQA/ALFWorld reflect now receive the kwargs they previously dropped; the three already-correct benchmarks are unaffected. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -5,8 +5,8 @@ This directory provides scaffold files for adding a new benchmark to SkillOpt.
|
||||
## Files
|
||||
|
||||
- `env_template.py` — Environment adapter template (subclasses
|
||||
`EnvAdapter`; implements the 5 abstract methods so the file is
|
||||
instantiable out of the box).
|
||||
`EnvAdapter`; implements the 4 abstract methods so the file is
|
||||
instantiable out of the box — `reflect` is inherited).
|
||||
- `loader_template.py` — Data loader template (subclasses
|
||||
`SplitDataLoader`; implements `load_split_items` for `.json`/`.jsonl`).
|
||||
- `config_template.yaml` — Config file template.
|
||||
@@ -28,8 +28,8 @@ This directory provides scaffold files for adding a new benchmark to SkillOpt.
|
||||
`TemplateBenchmarkLoader → YourBenchmarkLoader`)
|
||||
and fix the cross-import in `adapter.py`.
|
||||
3. **Implement the TODO blocks** inside `adapter.py:rollout` and the
|
||||
`_normalize_item` helper in `dataloader.py`. If you want real reflection,
|
||||
uncomment the `run_minibatch_reflect` block in `adapter.py:reflect`.
|
||||
`_normalize_item` helper in `dataloader.py`. (`reflect` is inherited from
|
||||
`EnvAdapter`; override it only for custom reflection logic.)
|
||||
4. **Register** the adapter — add a `try / except ImportError` block in
|
||||
`scripts/train.py`'s `_register_builtins()` mapping the registry key
|
||||
to your `YourBenchmarkAdapter` class. There is no
|
||||
|
||||
@@ -14,13 +14,9 @@ For a fully worked example see ``skillopt/envs/officeqa/``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from skillopt.datasets.base import BatchSpec
|
||||
from skillopt.envs.base import EnvAdapter
|
||||
from skillopt.envs._template.loader_template import TemplateBenchmarkLoader
|
||||
# When you wire in real reflection, also import:
|
||||
# from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
|
||||
|
||||
class TemplateBenchmarkEnv(EnvAdapter):
|
||||
@@ -131,53 +127,12 @@ class TemplateBenchmarkEnv(EnvAdapter):
|
||||
)
|
||||
return results
|
||||
|
||||
# ── Reflect: turn rollout results into patch dicts ─────────────────
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
"""
|
||||
Turn rollouts into a list of raw patch dicts (or None to drop).
|
||||
|
||||
Each non-None dict MUST have:
|
||||
- "patch": {"edits": [...]} a Patch.to_dict() payload
|
||||
- "source_type": "failure" | "success"
|
||||
|
||||
Most benchmarks delegate to
|
||||
:func:`skillopt.gradient.reflect.run_minibatch_reflect` which
|
||||
will call the optimizer model with the
|
||||
``analyst_error_*`` / ``analyst_success_*`` prompts. To enable it,
|
||||
uncomment the import above and call:
|
||||
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=kwargs.get(
|
||||
"prediction_dir", os.path.join(out_dir, "predictions")
|
||||
),
|
||||
patches_dir=kwargs.get(
|
||||
"patches_dir", os.path.join(out_dir, "patches")
|
||||
),
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=kwargs.get("random_seed"),
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=kwargs.get("step_buffer_context", ""),
|
||||
update_mode=getattr(self, "_cfg", {}).get(
|
||||
"skill_update_mode", "patch"
|
||||
),
|
||||
)
|
||||
"""
|
||||
# Template default: produce no patches (no-op trainer step).
|
||||
return [None for _ in results]
|
||||
# ── Reflect (inherited) ─────────────────────────────────────────────
|
||||
#
|
||||
# ``reflect`` is inherited from ``EnvAdapter``: the default delegates to
|
||||
# ``skillopt.gradient.reflect.run_minibatch_reflect`` using your
|
||||
# ``analyst_error_*`` / ``analyst_success_*`` prompts. You do NOT need to
|
||||
# implement it — override only if your benchmark needs custom reflection.
|
||||
|
||||
# ── Stratification hint ────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ from skillopt.envs.alfworld.rollout import (
|
||||
run_alfworld_batch,
|
||||
TASKS,
|
||||
)
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
from skillopt.utils import compute_score
|
||||
|
||||
|
||||
@@ -425,35 +424,5 @@ class ALFWorldAdapter(EnvAdapter):
|
||||
all_results.extend(chunk_results)
|
||||
return all_results
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
)
|
||||
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return list(TASKS)
|
||||
|
||||
@@ -231,7 +231,6 @@ class EnvAdapter(ABC):
|
||||
(float 0-1). May include env-specific fields.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
@@ -241,15 +240,36 @@ class EnvAdapter(ABC):
|
||||
) -> list[dict | None]:
|
||||
"""Analyze rollout results and produce patches.
|
||||
|
||||
Default implementation: delegate to the shared minibatch reflect
|
||||
stage. Every built-in benchmark uses this unchanged — override only
|
||||
if your environment needs custom reflection logic.
|
||||
|
||||
Each returned dict conforms to :class:`~skillopt.types.RawPatch`:
|
||||
``"patch"`` (with ``"edits"`` list) + ``"source_type"``
|
||||
(``"failure"`` or ``"success"``).
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict | None]
|
||||
Raw analyst outputs; ``None`` entries are filtered out.
|
||||
(``"failure"`` or ``"success"``); ``None`` entries are filtered out.
|
||||
"""
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=kwargs.get(
|
||||
"prediction_dir", os.path.join(out_dir, "predictions")
|
||||
),
|
||||
patches_dir=kwargs.get(
|
||||
"patches_dir", os.path.join(out_dir, "patches")
|
||||
),
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=kwargs.get("random_seed"),
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=kwargs.get("step_buffer_context", ""),
|
||||
meta_skill_context=kwargs.get("meta_skill_context", ""),
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_task_types(self) -> list[str]:
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from skillopt.datasets.base import BatchSpec
|
||||
from skillopt.envs.base import EnvAdapter
|
||||
from skillopt.envs.docvqa.dataloader import DocVQADataLoader
|
||||
from skillopt.envs.docvqa.rollout import run_batch
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
|
||||
|
||||
class DocVQAAdapter(EnvAdapter):
|
||||
@@ -84,28 +81,6 @@ class DocVQAAdapter(EnvAdapter):
|
||||
task_timeout=self.exec_timeout,
|
||||
)
|
||||
|
||||
def reflect(self, results: list[dict], skill_content: str, out_dir: str, **kwargs) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
seen: list[str] = []
|
||||
for item in self.dataloader.train_items + self.dataloader.val_items + self.dataloader.test_items:
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from skillopt.datasets.base import BatchSpec
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
from skillopt.envs.base import EnvAdapter
|
||||
from skillopt.envs.livemathematicianbench.dataloader import LiveMathematicianBenchDataLoader
|
||||
from skillopt.envs.livemathematicianbench.rollout import run_batch
|
||||
@@ -127,36 +125,5 @@ class LiveMathematicianBenchAdapter(EnvAdapter):
|
||||
task_timeout=self.exec_timeout,
|
||||
)
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return self.dataloader.get_task_types()
|
||||
|
||||
@@ -6,7 +6,6 @@ from skillopt.datasets.base import BatchSpec
|
||||
from skillopt.envs.base import EnvAdapter
|
||||
from skillopt.envs.officeqa.dataloader import OfficeQADataLoader
|
||||
from skillopt.envs.officeqa.rollout import run_batch
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
|
||||
|
||||
class OfficeQAAdapter(EnvAdapter):
|
||||
@@ -104,28 +103,6 @@ class OfficeQAAdapter(EnvAdapter):
|
||||
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
|
||||
)
|
||||
|
||||
def reflect(self, results: list[dict], skill_content: str, out_dir: str, **kwargs) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
seen: list[str] = []
|
||||
for item in self.dataloader.train_items + self.dataloader.val_items + self.dataloader.test_items:
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from skillopt.datasets.base import BatchSpec
|
||||
from skillopt.envs.base import EnvAdapter
|
||||
from skillopt.envs.searchqa.dataloader import SearchQADataLoader
|
||||
from skillopt.envs.searchqa.rollout import run_batch
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
from skillopt.model import get_target_backend
|
||||
|
||||
|
||||
@@ -94,36 +92,5 @@ class SearchQAAdapter(EnvAdapter):
|
||||
task_timeout=self.exec_timeout,
|
||||
)
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return ["qa"]
|
||||
|
||||
@@ -16,7 +16,6 @@ from skillopt.envs.spreadsheetbench.rollout import (
|
||||
run_spreadsheet_batch,
|
||||
run_spreadsheet_batch_codegen,
|
||||
)
|
||||
from skillopt.gradient.reflect import run_minibatch_reflect
|
||||
from skillopt.model import get_target_backend, is_target_exec_backend
|
||||
|
||||
|
||||
@@ -156,37 +155,5 @@ class SpreadsheetBenchAdapter(EnvAdapter):
|
||||
|
||||
return results
|
||||
|
||||
def reflect(
|
||||
self,
|
||||
results: list[dict],
|
||||
skill_content: str,
|
||||
out_dir: str,
|
||||
**kwargs,
|
||||
) -> list[dict | None]:
|
||||
"""Analyze rollout results and produce patches (minibatch mode)."""
|
||||
prediction_dir = kwargs.get("prediction_dir", os.path.join(out_dir, "predictions"))
|
||||
patches_dir = kwargs.get("patches_dir", os.path.join(out_dir, "patches"))
|
||||
random_seed = kwargs.get("random_seed")
|
||||
step_buffer_context = kwargs.get("step_buffer_context", "")
|
||||
meta_skill_context = kwargs.get("meta_skill_context", "")
|
||||
|
||||
return run_minibatch_reflect(
|
||||
results=results,
|
||||
skill_content=skill_content,
|
||||
prediction_dir=prediction_dir,
|
||||
patches_dir=patches_dir,
|
||||
workers=self.analyst_workers,
|
||||
failure_only=self.failure_only,
|
||||
minibatch_size=self.minibatch_size,
|
||||
edit_budget=self.edit_budget,
|
||||
random_seed=random_seed,
|
||||
error_system=self.get_error_minibatch_prompt(),
|
||||
success_system=self.get_success_minibatch_prompt(),
|
||||
step_buffer_context=step_buffer_context,
|
||||
meta_skill_context=meta_skill_context,
|
||||
update_mode=getattr(self, "_cfg", {}).get("skill_update_mode", "patch"),
|
||||
)
|
||||
|
||||
|
||||
def get_task_types(self) -> list[str]:
|
||||
return list(TASK_TYPES)
|
||||
|
||||
Reference in New Issue
Block a user