Files
Cuzyoung 0dc84162dc feat(optimizer): skill-aware reflection (EmbodiSkill S_app), config-controlled and env-independent
Split failure reflections into SKILL_DEFECT (body edit) vs EXECUTION_LAPSE
(protected appendix note that re-emphasizes an existing rule, never edited
by step-level analysts). Toggle: optimizer.use_skill_aware_reflection
(default false; baseline byte-identical when off).

- optimizer/appendix.py: protected APPENDIX region (inject/extract/append
  with dedup), mirrors the slow_update protected-field pattern
- optimizer/skill_aware.py: analyst prompt augmentation, appendix_notes
  parsing, threshold-gated LLM consolidation, and a process-wide runtime
  switch (configure_skill_aware_reflection) set once by the trainer
- gradient/reflect.py: augment error/success analyst prompts at runtime;
  None-sentinel kwargs resolve from the global switch, so env adapters
  need no per-benchmark wiring (works for all envs, present and future)
- optimizer/skill.py: generalize the protected-region check to
  (slow_update, appendix); edits inside any protected region are skipped
- engine/trainer.py: inject appendix at init, flush per-step
  EXECUTION_LAPSE notes after the gate settles, optional consolidation
- tests: regression suite incl. toggle-off byte-identical guarantee and
  env-independent global-switch resolution (6/6 passing + live smoke)

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
2026-06-10 13:10:08 +00:00

636 lines
25 KiB
Python

"""ReflACT core Reflect engine -- minibatch trajectory analysis.
Provides environment-agnostic minibatch trajectory analysis: instead of
analyzing each trajectory independently, trajectories are grouped into
minibatches of size M and analyzed together -- analogous to minibatch SGD
vs per-sample SGD in neural network training.
Two-level prompt priority system:
1. **Custom prompt** (adapter returns non-None) -- used as-is.
2. **Generic default prompt** (adapter returns None) -- built-in defaults
that work for any environment without configuration.
Public API
----------
- :func:`fmt_trajectory` -- format one conversation into text
- :func:`fmt_minibatch_trajectories` -- format multiple trajectories for batch analysis
- :func:`run_error_analyst_minibatch` -- one optimizer call for a group of failures
- :func:`run_success_analyst_minibatch` -- one optimizer call for a group of successes
- :func:`run_minibatch_reflect` -- full reflect stage dispatcher
"""
from __future__ import annotations
import json
import os
import random
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from skillopt.model import chat_optimizer
from skillopt.optimizer.meta_skill import format_meta_skill_context
from skillopt.optimizer.skill_aware import (
augment_error_prompt,
augment_success_prompt,
extract_appendix_notes,
get_skill_aware_appendix_source,
is_skill_aware_enabled,
)
from skillopt.optimizer.update_modes import (
get_payload_items,
is_full_rewrite_minibatch_mode,
normalize_update_mode,
payload_key,
payload_label,
truncate_payload,
)
from skillopt.prompts import load_prompt
from skillopt.utils import extract_json
# ── Trajectory formatting ────────────────────────────────────────────────────
def _clip_text(value, limit: int | None = None) -> str:
"""Render optional trajectory fields. Truncation is disabled: the optimizer
is given the full content so it can see exactly what the agent saw/did.
``limit`` is accepted for backward compatibility but ignored.
"""
if value is None:
return ""
return str(value)
def fmt_trajectory(
conversation: list[dict],
max_chars: int | None = None,
) -> str:
"""Format a conversation list into analyst-readable text.
Accepts two common formats:
1. Tool-call records: ``{"type": "tool_call", "cmd": ..., "obs": ...}``
2. Step records: ``{"step": N, "action": ..., "env_feedback": ..., "reasoning": ...}``
Any other dict is rendered via its ``"content"`` key.
"""
lines: list[str] = []
for item in conversation:
if not isinstance(item, dict):
lines.append(f"[agent] {_clip_text(item)}")
continue
if item.get("type") == "tool_call":
cmd = _clip_text(item.get("cmd"))
obs = _clip_text(item.get("obs"))
lines.append(f"[action] {cmd}")
lines.append(f"[obs] {obs}")
elif "action" in item and "env_feedback" in item:
step = item.get("step", "?")
reasoning = _clip_text(item.get("reasoning"))
action = _clip_text(item.get("action"))
feedback = _clip_text(item.get("env_feedback"))
if reasoning:
lines.append(f"[step {step} think] {reasoning}")
lines.append(f"[step {step} action] {action}")
lines.append(f"[step {step} obs] {feedback}")
elif item.get("role") == "system":
# Post-execution verification / enrichment info
msg = _clip_text(item.get("content"))
lines.append(f"[verification] {msg}")
else:
msg = _clip_text(item.get("content"))
role = item.get("role", "agent")
lines.append(f"[{role}] {msg}")
return "\n".join(lines)
# ── Minibatch trajectory formatting ──────────────────────────────────────────
def fmt_minibatch_trajectories(
items: list[dict],
prediction_dir: str,
) -> str:
"""Format multiple trajectories for minibatch analyst consumption.
Each item is a rollout result dict with ``"id"``, ``"task_description"``,
``"task_type"``, ``"fail_reason"``, etc. Reads ``conversation.json``
for each and formats them together with trajectory headers.
If available, includes the spreadsheet preview and target system prompt
so the analyst can see what the agent saw.
Parameters
----------
items : list[dict]
Rollout result dicts belonging to one minibatch.
prediction_dir : str
Path to ``predictions/`` directory containing per-task
``<task_id>/conversation.json`` files.
Returns
-------
str
Formatted text with all trajectories separated by ``---``.
"""
parts: list[str] = []
for idx, item in enumerate(items, 1):
tid = str(item["id"])
conv_path = os.path.join(prediction_dir, tid, "conversation.json")
if not os.path.exists(conv_path):
continue
with open(conv_path) as f:
conversation = json.load(f)
if not conversation:
continue
traj_text = fmt_trajectory(conversation)
header = (
f"### Trajectory {idx} (id={tid})\n"
f"Task: {item.get('task_description', item.get('instruction', ''))}\n"
f"Task type: {item.get('task_type', item.get('instruction_type', ''))}\n"
)
fail_reason = item.get("fail_reason", "")
if fail_reason:
header += f"Failure reason: {fail_reason}\n"
header += f"Steps: {item.get('n_turns', '?')}\n"
reference_text = str(item.get("reference_text") or "").strip()
if reference_text:
header += (
f"\n#### Hidden Reference\n"
f"{reference_text}\n"
)
# ── Append target context (what the agent saw) ──────────────
target_prompt = item.get("target_system_prompt", "")
if not target_prompt:
prompt_path = os.path.join(prediction_dir, tid, "target_system_prompt.txt")
if os.path.exists(prompt_path):
with open(prompt_path) as f:
target_prompt = f.read()
if target_prompt:
header += (
f"\n#### Target System Prompt\n"
f"{target_prompt}\n"
)
user_prompt = item.get("target_user_prompt", "")
if not user_prompt:
user_prompt_path = os.path.join(prediction_dir, tid, "target_user_prompt.txt")
if os.path.exists(user_prompt_path):
with open(user_prompt_path) as f:
user_prompt = f.read()
if user_prompt:
header += (
f"\n#### Target User Prompt\n"
f"{user_prompt}\n"
)
if os.environ.get("REFLACT_CODEX_TRACE_TO_OPTIMIZER", "0") == "1":
codex_trace_summary = item.get("codex_trace_summary", "")
if not codex_trace_summary:
codex_trace_summary_path = os.path.join(prediction_dir, tid, "codex_trace_summary.txt")
if os.path.exists(codex_trace_summary_path):
with open(codex_trace_summary_path) as f:
codex_trace_summary = f.read()
if codex_trace_summary:
header += (
f"\n#### Codex Trace Summary\n"
f"{codex_trace_summary}\n"
)
codex_probe_trace_steps = str(item.get("codex_probe_trace_steps") or "").strip()
if codex_probe_trace_steps:
header += (
f"\n#### Codex Trace Steps\n"
f"{codex_probe_trace_steps}\n"
)
preview = item.get("spreadsheet_preview", "")
if not preview:
preview_path = os.path.join(prediction_dir, tid, "spreadsheet_preview.txt")
if os.path.exists(preview_path):
with open(preview_path) as f:
preview = f.read()
if preview:
header += (
f"\n#### Spreadsheet Preview\n"
f"{preview}\n"
)
parts.append(header + "\n" + traj_text)
return "\n\n---\n\n".join(parts)
# ── Prompt resolution ───────────────────────────────────────────────────────
def _resolve_prompt(custom: str | None, default_name: str, update_mode: str = "patch") -> str:
"""Return *custom* if provided (non-None), otherwise load from file."""
if custom is not None:
return custom
mode = normalize_update_mode(update_mode)
actual_name = default_name
if is_full_rewrite_minibatch_mode(mode):
full_name = f"{default_name}_full_rewrite"
try:
return load_prompt(full_name)
except FileNotFoundError:
actual_name = default_name
elif mode == "rewrite_from_suggestions":
rewrite_name = f"{default_name}_rewrite"
try:
return load_prompt(rewrite_name)
except FileNotFoundError:
actual_name = default_name
return load_prompt(actual_name)
# ── Minibatch analysts ──────────────────────────────────────────────────────
def run_error_analyst_minibatch(
skill_content: str,
items: list[dict],
prediction_dir: str,
edit_budget: int = 4,
*,
system_prompt: str | None = None,
rejection_context: str = "",
trajectory_memory_context: str = "",
step_buffer_context: str = "",
meta_skill_context: str = "",
update_mode: str = "patch",
skill_aware_reflection: bool = False,
) -> dict | None:
"""Analyze a minibatch of failed trajectories in one optimizer call.
Parameters
----------
skill_content : str
Current skill document text.
items : list[dict]
Rollout result dicts (all should have ``hard=0``).
prediction_dir : str
Path to ``predictions/`` directory.
edit_budget : int
Maximum number of edits (L).
system_prompt : str | None
Custom system prompt. ``None`` = use generic default.
rejection_context : str
*Deprecated* — use ``step_buffer_context``.
trajectory_memory_context : str
*Deprecated* — use ``step_buffer_context``.
step_buffer_context : str
Unified summary of previous steps (failure patterns + rejected edits).
Returns
-------
dict | None
Patch dict with ``source_type="failure"``, or ``None`` on error.
"""
mode = normalize_update_mode(update_mode)
actual_system = _resolve_prompt(system_prompt, "analyst_error", mode)
# Skill-aware reflection: augment the resolved prompt at runtime so both
# env-specific and generic analyst prompts get the defect/lapse instruction.
# When the toggle is off this is a no-op (prompt byte-identical to baseline).
if skill_aware_reflection and not is_full_rewrite_minibatch_mode(mode):
actual_system = augment_error_prompt(actual_system)
trajectories_text = fmt_minibatch_trajectories(items, prediction_dir)
if not trajectories_text.strip():
return None
user = (
f"## Current Skill\n{skill_content}\n\n"
)
if is_full_rewrite_minibatch_mode(mode):
user += (
f"## Update Format\n"
f"Produce one complete replacement skill candidate for this minibatch. "
f"Do not output edits, patches, or revise suggestions.\n\n"
)
else:
user += (
f"## {payload_label(mode, title=True)} Budget\n"
f"Produce at most L={edit_budget} {payload_label(mode)}.\n\n"
)
# Unified step buffer context (preferred)
ctx = step_buffer_context or rejection_context or ""
if trajectory_memory_context:
ctx = f"{ctx}\n{trajectory_memory_context}" if ctx else trajectory_memory_context
if ctx.strip():
user += f"## Previous Steps in This Epoch\n{ctx}\n\n"
optimizer_ctx = format_meta_skill_context(meta_skill_context)
if optimizer_ctx:
user += optimizer_ctx + "\n\n"
user += f"## Failed Trajectories ({len(items)} total)\n{trajectories_text}"
try:
response, _ = chat_optimizer(
system=actual_system, user=user,
max_completion_tokens=64000 if is_full_rewrite_minibatch_mode(mode) else 16384,
retries=3,
stage="analyst",
)
result = extract_json(response)
if not result:
return None
notes = extract_appendix_notes(result) if skill_aware_reflection else []
if "patch" in result:
result["source_type"] = "failure"
if not is_full_rewrite_minibatch_mode(mode):
truncate_payload(result["patch"], edit_budget, mode)
if skill_aware_reflection:
result["appendix_notes"] = notes
return result
# Skill-aware: a batch may legitimately yield ONLY execution-lapse notes
# (no body edit). Return a no-op patch so the notes still reach the
# trainer via all_raw_patches; empty edits are dropped from the body
# pipeline by _normalise_patches, so body behavior is unchanged.
if skill_aware_reflection and notes:
return {
"source_type": "failure",
"patch": {"reasoning": "execution-lapse only", "edits": []},
"appendix_notes": notes,
}
except Exception: # noqa: BLE001
traceback.print_exc()
return None
def run_success_analyst_minibatch(
skill_content: str,
items: list[dict],
prediction_dir: str,
edit_budget: int = 4,
*,
system_prompt: str | None = None,
trajectory_memory_context: str = "",
step_buffer_context: str = "",
meta_skill_context: str = "",
update_mode: str = "patch",
skill_aware_reflection: bool = False,
emit_appendix_notes: bool = True,
) -> dict | None:
"""Analyze a minibatch of successful trajectories in one optimizer call.
Parameters
----------
system_prompt : str | None
Custom system prompt. ``None`` = use generic default.
trajectory_memory_context : str
*Deprecated* — use ``step_buffer_context``.
step_buffer_context : str
Unified summary of previous steps (failure patterns + rejected edits).
Returns
-------
dict | None
Patch dict with ``source_type="success"``, or ``None`` on error.
"""
mode = normalize_update_mode(update_mode)
actual_system = _resolve_prompt(system_prompt, "analyst_success", mode)
# Only augment + parse appendix notes on the success side when allowed.
# failure_only mode (paper-faithful S_app) suppresses success-side notes.
sa_emit = skill_aware_reflection and emit_appendix_notes
if sa_emit and not is_full_rewrite_minibatch_mode(mode):
actual_system = augment_success_prompt(actual_system)
trajectories_text = fmt_minibatch_trajectories(items, prediction_dir)
if not trajectories_text.strip():
return None
user = (
f"## Current Skill\n{skill_content}\n\n"
)
if is_full_rewrite_minibatch_mode(mode):
user += (
f"## Update Format\n"
f"Produce one complete replacement skill candidate for this minibatch. "
f"Do not output edits, patches, or revise suggestions.\n\n"
)
else:
user += (
f"## {payload_label(mode, title=True)} Budget\n"
f"Produce at most L={edit_budget} {payload_label(mode)}.\n\n"
)
ctx = step_buffer_context or trajectory_memory_context or ""
if ctx.strip():
user += f"## Previous Steps in This Epoch\n{ctx}\n\n"
optimizer_ctx = format_meta_skill_context(meta_skill_context)
if optimizer_ctx:
user += optimizer_ctx + "\n\n"
user += f"## Successful Trajectories ({len(items)} total)\n{trajectories_text}"
try:
response, _ = chat_optimizer(
system=actual_system, user=user,
max_completion_tokens=64000 if is_full_rewrite_minibatch_mode(mode) else 16384,
retries=3,
stage="analyst",
)
result = extract_json(response)
if result and "patch" in result:
result["source_type"] = "success"
if not is_full_rewrite_minibatch_mode(mode):
truncate_payload(result["patch"], edit_budget, mode)
if sa_emit:
result["appendix_notes"] = extract_appendix_notes(result)
return result
except Exception: # noqa: BLE001
traceback.print_exc()
return None
# ── Minibatch reflect dispatcher ────────────────────────────────────────────
def _split_minibatches(items: list, batch_size: int) -> list[list]:
"""Split items into minibatches of at most *batch_size*."""
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
def _shuffle_for_minibatch(items: list, seed: int | None) -> list:
"""Return items in minibatch order.
Uses a deterministic shuffle when a seed is provided so resume runs keep
the same minibatch composition. Falls back to input order when no seed is
available.
"""
ordered = list(items)
if seed is None:
return ordered
random.Random(seed).shuffle(ordered)
return ordered
def run_minibatch_reflect(
results: list[dict],
skill_content: str,
prediction_dir: str,
patches_dir: str,
workers: int,
failure_only: bool,
minibatch_size: int = 8,
edit_budget: int = 4,
random_seed: int | None = None,
*,
error_system: str | None = None,
success_system: str | None = None,
rejection_context: str = "",
trajectory_memory_context: str = "",
step_buffer_context: str = "",
meta_skill_context: str = "",
update_mode: str = "patch",
skill_aware_reflection: bool | None = None,
skill_aware_appendix_source: str | None = None,
) -> list[dict | None]:
"""Full minibatch reflect stage: group → parallel optimizer calls → patches.
Separates failure and success trajectories, splits each into minibatches
of size M, runs all minibatches in parallel, and saves patch files.
Parameters
----------
results : list[dict]
Rollout result dicts; see :class:`~skillopt.types.RolloutResult`.
skill_content : str
Current skill document.
prediction_dir : str
Path to ``predictions/`` with ``conversation.json`` files.
patches_dir : str
Path to save per-minibatch patch JSON files.
workers : int
Max parallel optimizer calls.
failure_only : bool
If True, skip success trajectories.
minibatch_size : int
Trajectories per group (M).
edit_budget : int
Max edits per minibatch (L).
random_seed : int | None
Optional seed used to shuffle trajectories before minibatch splitting.
error_system, success_system : str | None
Optional custom prompts. ``None`` = use generic defaults.
Returns
-------
list[dict | None]
Patch dicts (with ``source_type`` "failure" or "success").
"""
# Resolve the skill-aware toggle: explicit kwargs win; otherwise fall back
# to the process-wide config switch set by the trainer, so the feature is
# env-independent and adapters need no per-benchmark wiring.
if skill_aware_reflection is None:
skill_aware_reflection = is_skill_aware_enabled()
if skill_aware_appendix_source is None:
skill_aware_appendix_source = get_skill_aware_appendix_source()
os.makedirs(patches_dir, exist_ok=True)
# Separate failure / success
failures = [r for r in results if not r.get("hard") or float(r.get("hard", 0)) < 1e-9]
successes = [r for r in results if r.get("hard")] if not failure_only else []
failures = _shuffle_for_minibatch(failures, random_seed)
successes = _shuffle_for_minibatch(successes, None if random_seed is None else random_seed + 1)
# Split into minibatches
fail_batches = _split_minibatches(failures, minibatch_size)
succ_batches = _split_minibatches(successes, minibatch_size)
n_fail_batches = len(fail_batches)
n_succ_batches = len(succ_batches)
print(
f" [2/6 REFLECT minibatch] "
f"failure={len(failures)}{n_fail_batches} groups "
f"success={len(successes)}{n_succ_batches} groups "
f"(M={minibatch_size}, L={edit_budget}, workers={workers})"
)
raw_patches: list[dict | None] = []
# Resume support: check for already-done minibatch patches
pending_fail: list[tuple[int, list[dict]]] = []
for idx, batch in enumerate(fail_batches):
path = os.path.join(patches_dir, f"minibatch_fail_{idx:03d}.json")
if os.path.exists(path):
with open(path) as f:
raw_patches.append(json.load(f))
else:
pending_fail.append((idx, batch))
pending_succ: list[tuple[int, list[dict]]] = []
for idx, batch in enumerate(succ_batches):
path = os.path.join(patches_dir, f"minibatch_succ_{idx:03d}.json")
if os.path.exists(path):
with open(path) as f:
raw_patches.append(json.load(f))
else:
pending_succ.append((idx, batch))
# ── Worker functions ──────────────────────────────────────────────────
def _do_fail(idx: int, batch: list[dict]) -> tuple[str, dict | None]:
patch = run_error_analyst_minibatch(
skill_content, batch, prediction_dir,
edit_budget=edit_budget,
system_prompt=error_system,
step_buffer_context=step_buffer_context,
# backward compat fallback
rejection_context=rejection_context,
trajectory_memory_context=trajectory_memory_context,
meta_skill_context=meta_skill_context,
update_mode=update_mode,
skill_aware_reflection=skill_aware_reflection,
)
return f"minibatch_fail_{idx:03d}", patch
def _do_succ(idx: int, batch: list[dict]) -> tuple[str, dict | None]:
patch = run_success_analyst_minibatch(
skill_content, batch, prediction_dir,
edit_budget=edit_budget,
system_prompt=success_system,
step_buffer_context=step_buffer_context,
trajectory_memory_context=trajectory_memory_context,
meta_skill_context=meta_skill_context,
update_mode=update_mode,
skill_aware_reflection=skill_aware_reflection,
emit_appendix_notes=(skill_aware_appendix_source != "failure_only"),
)
return f"minibatch_succ_{idx:03d}", patch
# Run all pending minibatches in parallel
all_pending = (
[("fail", idx, batch) for idx, batch in pending_fail]
+ [("succ", idx, batch) for idx, batch in pending_succ]
)
with ThreadPoolExecutor(max_workers=workers) as ex:
futs = {}
for kind, idx, batch in all_pending:
if kind == "fail":
futs[ex.submit(_do_fail, idx, batch)] = (kind, idx, len(batch))
else:
futs[ex.submit(_do_succ, idx, batch)] = (kind, idx, len(batch))
for i, fut in enumerate(as_completed(futs), 1):
kind, idx, batch_len = futs[fut]
tag, patch = fut.result()
if patch:
path = os.path.join(patches_dir, f"{tag}.json")
with open(path, "w") as f:
json.dump(patch, f, ensure_ascii=False, indent=2)
raw_patches.append(patch)
n_edits = len(get_payload_items(patch.get("patch", {}) if patch else {}, update_mode))
print(
f" [analyst] {i}/{len(all_pending)} {tag} "
f"({batch_len} trajs) → {n_edits} {payload_label(update_mode)}"
)
return raw_patches