Files
Cuzyoung 4a1b984d87 refactor: rename teacher/student to optimizer/target, remove best skills, fix slow update
- Rename teacher -> optimizer, student -> target across all code, configs, docs, prompts
- CLI: --teacher_model -> --optimizer_model, --student_model -> --target_model
- Remove best_skill files, keep only initial skills
- Fix slow update gate (force write into skill)
- Fix SLOW_UPDATE marker stripping
- Remove deep_reflect and meta_reflect mechanisms
- Update .env.example with export prefix and azure_cli docs
- Add endpoint empty validation in azure_openai.py

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-24 19:15:10 +00:00

307 lines
10 KiB
Python

"""Standardized I/O types for the ReflACT pipeline.
Shared dataclass definitions for the 6-stage per-step pipeline
and the 2 epoch-level stages. All types support round-trip
conversion to/from plain dicts for incremental adoption.
Re-exports
----------
GateResult, GateAction — from skillopt.evaluation.gate
BatchSpec — from skillopt.datasets.base
"""
from __future__ import annotations
from dataclasses import dataclass, field, fields as dc_fields
from typing import Any, Literal
from skillopt.evaluation.gate import GateAction, GateResult # noqa: F401
from skillopt.datasets.base import BatchSpec # noqa: F401
# ── Atomic types ─────────────────────────────────────────────────────────
EditOp = Literal["append", "insert_after", "replace", "delete"]
@dataclass
class Edit:
"""A single edit operation on a skill document.
Used across Reflect → Aggregate → Select → Update → MetaReflect.
"""
op: EditOp
content: str = ""
target: str = ""
support_count: int | None = None
source_type: Literal["failure", "success"] | None = None
merge_level: int | None = None
update_origin: str = ""
update_target: str = ""
@classmethod
def from_dict(cls, d: dict) -> Edit:
return cls(
op=d.get("op", "append"),
content=d.get("content", ""),
target=d.get("target", ""),
support_count=d.get("support_count"),
source_type=d.get("source_type"),
merge_level=d.get("merge_level"),
update_origin=d.get("update_origin", ""),
update_target=d.get("update_target", ""),
)
def to_dict(self) -> dict:
d: dict[str, Any] = {"op": self.op, "content": self.content}
if self.target:
d["target"] = self.target
if self.support_count is not None:
d["support_count"] = self.support_count
if self.source_type is not None:
d["source_type"] = self.source_type
if self.merge_level is not None:
d["merge_level"] = self.merge_level
if self.update_origin:
d["update_origin"] = self.update_origin
if self.update_target:
d["update_target"] = self.update_target
return d
@dataclass
class Patch:
"""A set of edits with reasoning.
Output of Aggregate (③), Select (④); input to Update (⑤).
"""
edits: list[Edit] = field(default_factory=list)
reasoning: str = ""
ranking_details: dict[str, Any] | None = None
@classmethod
def from_dict(cls, d: dict) -> Patch:
edits_raw = d.get("edits", [])
return cls(
edits=[Edit.from_dict(e) if isinstance(e, dict) else e for e in edits_raw],
reasoning=d.get("reasoning", ""),
ranking_details=d.get("ranking_details"),
)
def to_dict(self) -> dict:
d: dict[str, Any] = {
"reasoning": self.reasoning,
"edits": [e.to_dict() if isinstance(e, Edit) else e for e in self.edits],
}
if self.ranking_details is not None:
d["ranking_details"] = self.ranking_details
return d
# ── Stage ① ROLLOUT ──────────────────────────────────────────────────────
@dataclass
class RolloutResult:
"""Result of a single episode/task rollout.
Universal fields are required; env-specific fields live in ``extras``.
"""
id: str
hard: int
soft: float
n_turns: int = 0
fail_reason: str = ""
task_type: str = ""
task_description: str = ""
predicted_answer: str = ""
question: str = ""
reference_text: str = ""
target_system_prompt: str = ""
target_user_prompt: str = ""
spreadsheet_preview: str = ""
extras: dict[str, Any] = field(default_factory=dict)
_KNOWN_FIELDS: frozenset[str] | None = field(
default=None, init=False, repr=False, compare=False, # type: ignore[assignment]
)
@classmethod
def _get_known_fields(cls) -> frozenset[str]:
if cls._KNOWN_FIELDS is None:
cls._KNOWN_FIELDS = frozenset(
f.name for f in dc_fields(cls)
if f.name != "_KNOWN_FIELDS"
)
return cls._KNOWN_FIELDS
@classmethod
def from_dict(cls, d: dict) -> RolloutResult:
known = cls._get_known_fields()
extras = {k: v for k, v in d.items() if k not in known}
return cls(
id=str(d.get("id", "")),
hard=int(d.get("hard", 0)),
soft=float(d.get("soft", 0.0)),
n_turns=int(d.get("n_turns", 0)),
fail_reason=str(d.get("fail_reason", "")),
task_type=str(d.get("task_type", "")),
task_description=str(d.get("task_description", "")),
predicted_answer=str(d.get("predicted_answer", "")),
question=str(d.get("question", "")),
reference_text=str(d.get("reference_text", "")),
target_system_prompt=str(d.get("target_system_prompt", "")),
target_user_prompt=str(d.get("target_user_prompt", "")),
spreadsheet_preview=str(d.get("spreadsheet_preview", "")),
extras=extras,
)
def to_dict(self) -> dict:
d: dict[str, Any] = {
"id": self.id,
"hard": self.hard,
"soft": self.soft,
}
for attr in (
"n_turns", "fail_reason", "task_type", "task_description",
"predicted_answer", "question", "reference_text",
"target_system_prompt", "target_user_prompt",
"spreadsheet_preview",
):
val = getattr(self, attr)
if val:
d[attr] = val
d.update(self.extras)
return d
# ── Stage ② REFLECT ──────────────────────────────────────────────────────
@dataclass
class FailureSummaryEntry:
"""One entry in the failure summary produced by error analysts."""
failure_type: str
count: int = 0
description: str = ""
@classmethod
def from_dict(cls, d: dict) -> FailureSummaryEntry:
return cls(
failure_type=d.get("failure_type", ""),
count=int(d.get("count", 0)),
description=d.get("description", ""),
)
def to_dict(self) -> dict:
return {
"failure_type": self.failure_type,
"count": self.count,
"description": self.description,
}
@dataclass
class RawPatch:
"""Analyst output from the Reflect stage — a patch with provenance.
Wraps the dict produced by ``run_error_analyst_minibatch``
and ``run_success_analyst_minibatch``.
"""
patch: Patch
source_type: Literal["failure", "success"] = "failure"
batch_size: int = 0
failure_summary: list[FailureSummaryEntry] = field(default_factory=list)
@classmethod
def from_dict(cls, d: dict | None) -> RawPatch | None:
if d is None:
return None
inner = d.get("patch", d)
if not isinstance(inner, dict):
return None
patch = Patch.from_dict(inner)
return cls(
patch=patch,
source_type=d.get("source_type", "failure"),
batch_size=int(d.get("batch_size", 0)),
failure_summary=[
FailureSummaryEntry.from_dict(fs)
for fs in d.get("failure_summary", [])
],
)
def to_dict(self) -> dict:
d: dict[str, Any] = {
"patch": self.patch.to_dict(),
"source_type": self.source_type,
"batch_size": self.batch_size,
}
if self.failure_summary:
d["failure_summary"] = [fs.to_dict() for fs in self.failure_summary]
return d
# ── Epoch-level: SLOW_UPDATE ─────────────────────────────────────────────
@dataclass
class SlowUpdateResult:
"""Output of the epoch-level slow update stage (EMA / regularization)."""
reasoning: str = ""
slow_update_content: str = ""
action: str = ""
time_s: float | None = None
prev_hard: float | None = None
curr_hard: float | None = None
selection_hard: float | None = None
selection_soft: float | None = None
candidate_hash: str = ""
update_origin: str = ""
update_target: str = ""
@classmethod
def from_dict(cls, d: dict | None) -> SlowUpdateResult | None:
if d is None:
return None
return cls(
reasoning=d.get("reasoning", ""),
slow_update_content=d.get("slow_update_content", ""),
action=d.get("action", ""),
time_s=d.get("time_s"),
prev_hard=d.get("prev_hard"),
curr_hard=d.get("curr_hard"),
selection_hard=d.get("selection_hard"),
selection_soft=d.get("selection_soft"),
candidate_hash=d.get("candidate_hash", ""),
update_origin=d.get("update_origin", ""),
update_target=d.get("update_target", ""),
)
def to_dict(self) -> dict:
d: dict[str, Any] = {
"reasoning": self.reasoning,
"slow_update_content": self.slow_update_content,
}
if self.action:
d["action"] = self.action
if self.time_s is not None:
d["time_s"] = self.time_s
if self.prev_hard is not None:
d["prev_hard"] = self.prev_hard
if self.curr_hard is not None:
d["curr_hard"] = self.curr_hard
if self.selection_hard is not None:
d["selection_hard"] = self.selection_hard
if self.selection_soft is not None:
d["selection_soft"] = self.selection_soft
if self.candidate_hash:
d["candidate_hash"] = self.candidate_hash
if self.update_origin:
d["update_origin"] = self.update_origin
if self.update_target:
d["update_target"] = self.update_target
return d