mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
- Rename teacher -> optimizer, student -> target across all code, configs, docs, prompts - CLI: --teacher_model -> --optimizer_model, --student_model -> --target_model - Remove best_skill files, keep only initial skills - Fix slow update gate (force write into skill) - Fix SLOW_UPDATE marker stripping - Remove deep_reflect and meta_reflect mechanisms - Update .env.example with export prefix and azure_cli docs - Add endpoint empty validation in azure_openai.py Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
307 lines
10 KiB
Python
307 lines
10 KiB
Python
"""Standardized I/O types for the ReflACT pipeline.
|
|
|
|
Shared dataclass definitions for the 6-stage per-step pipeline
|
|
and the 2 epoch-level stages. All types support round-trip
|
|
conversion to/from plain dicts for incremental adoption.
|
|
|
|
Re-exports
|
|
----------
|
|
GateResult, GateAction — from skillopt.evaluation.gate
|
|
BatchSpec — from skillopt.datasets.base
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field, fields as dc_fields
|
|
from typing import Any, Literal
|
|
|
|
from skillopt.evaluation.gate import GateAction, GateResult # noqa: F401
|
|
from skillopt.datasets.base import BatchSpec # noqa: F401
|
|
|
|
|
|
# ── Atomic types ─────────────────────────────────────────────────────────
|
|
|
|
EditOp = Literal["append", "insert_after", "replace", "delete"]
|
|
|
|
|
|
@dataclass
|
|
class Edit:
|
|
"""A single edit operation on a skill document.
|
|
|
|
Used across Reflect → Aggregate → Select → Update → MetaReflect.
|
|
"""
|
|
|
|
op: EditOp
|
|
content: str = ""
|
|
target: str = ""
|
|
support_count: int | None = None
|
|
source_type: Literal["failure", "success"] | None = None
|
|
merge_level: int | None = None
|
|
update_origin: str = ""
|
|
update_target: str = ""
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict) -> Edit:
|
|
return cls(
|
|
op=d.get("op", "append"),
|
|
content=d.get("content", ""),
|
|
target=d.get("target", ""),
|
|
support_count=d.get("support_count"),
|
|
source_type=d.get("source_type"),
|
|
merge_level=d.get("merge_level"),
|
|
update_origin=d.get("update_origin", ""),
|
|
update_target=d.get("update_target", ""),
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
d: dict[str, Any] = {"op": self.op, "content": self.content}
|
|
if self.target:
|
|
d["target"] = self.target
|
|
if self.support_count is not None:
|
|
d["support_count"] = self.support_count
|
|
if self.source_type is not None:
|
|
d["source_type"] = self.source_type
|
|
if self.merge_level is not None:
|
|
d["merge_level"] = self.merge_level
|
|
if self.update_origin:
|
|
d["update_origin"] = self.update_origin
|
|
if self.update_target:
|
|
d["update_target"] = self.update_target
|
|
return d
|
|
|
|
|
|
@dataclass
|
|
class Patch:
|
|
"""A set of edits with reasoning.
|
|
|
|
Output of Aggregate (③), Select (④); input to Update (⑤).
|
|
"""
|
|
|
|
edits: list[Edit] = field(default_factory=list)
|
|
reasoning: str = ""
|
|
ranking_details: dict[str, Any] | None = None
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict) -> Patch:
|
|
edits_raw = d.get("edits", [])
|
|
return cls(
|
|
edits=[Edit.from_dict(e) if isinstance(e, dict) else e for e in edits_raw],
|
|
reasoning=d.get("reasoning", ""),
|
|
ranking_details=d.get("ranking_details"),
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
d: dict[str, Any] = {
|
|
"reasoning": self.reasoning,
|
|
"edits": [e.to_dict() if isinstance(e, Edit) else e for e in self.edits],
|
|
}
|
|
if self.ranking_details is not None:
|
|
d["ranking_details"] = self.ranking_details
|
|
return d
|
|
|
|
|
|
# ── Stage ① ROLLOUT ──────────────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class RolloutResult:
|
|
"""Result of a single episode/task rollout.
|
|
|
|
Universal fields are required; env-specific fields live in ``extras``.
|
|
"""
|
|
|
|
id: str
|
|
hard: int
|
|
soft: float
|
|
n_turns: int = 0
|
|
fail_reason: str = ""
|
|
task_type: str = ""
|
|
task_description: str = ""
|
|
predicted_answer: str = ""
|
|
question: str = ""
|
|
reference_text: str = ""
|
|
target_system_prompt: str = ""
|
|
target_user_prompt: str = ""
|
|
spreadsheet_preview: str = ""
|
|
extras: dict[str, Any] = field(default_factory=dict)
|
|
|
|
_KNOWN_FIELDS: frozenset[str] | None = field(
|
|
default=None, init=False, repr=False, compare=False, # type: ignore[assignment]
|
|
)
|
|
|
|
@classmethod
|
|
def _get_known_fields(cls) -> frozenset[str]:
|
|
if cls._KNOWN_FIELDS is None:
|
|
cls._KNOWN_FIELDS = frozenset(
|
|
f.name for f in dc_fields(cls)
|
|
if f.name != "_KNOWN_FIELDS"
|
|
)
|
|
return cls._KNOWN_FIELDS
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict) -> RolloutResult:
|
|
known = cls._get_known_fields()
|
|
extras = {k: v for k, v in d.items() if k not in known}
|
|
return cls(
|
|
id=str(d.get("id", "")),
|
|
hard=int(d.get("hard", 0)),
|
|
soft=float(d.get("soft", 0.0)),
|
|
n_turns=int(d.get("n_turns", 0)),
|
|
fail_reason=str(d.get("fail_reason", "")),
|
|
task_type=str(d.get("task_type", "")),
|
|
task_description=str(d.get("task_description", "")),
|
|
predicted_answer=str(d.get("predicted_answer", "")),
|
|
question=str(d.get("question", "")),
|
|
reference_text=str(d.get("reference_text", "")),
|
|
target_system_prompt=str(d.get("target_system_prompt", "")),
|
|
target_user_prompt=str(d.get("target_user_prompt", "")),
|
|
spreadsheet_preview=str(d.get("spreadsheet_preview", "")),
|
|
extras=extras,
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
d: dict[str, Any] = {
|
|
"id": self.id,
|
|
"hard": self.hard,
|
|
"soft": self.soft,
|
|
}
|
|
for attr in (
|
|
"n_turns", "fail_reason", "task_type", "task_description",
|
|
"predicted_answer", "question", "reference_text",
|
|
"target_system_prompt", "target_user_prompt",
|
|
"spreadsheet_preview",
|
|
):
|
|
val = getattr(self, attr)
|
|
if val:
|
|
d[attr] = val
|
|
d.update(self.extras)
|
|
return d
|
|
|
|
|
|
# ── Stage ② REFLECT ──────────────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class FailureSummaryEntry:
|
|
"""One entry in the failure summary produced by error analysts."""
|
|
|
|
failure_type: str
|
|
count: int = 0
|
|
description: str = ""
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict) -> FailureSummaryEntry:
|
|
return cls(
|
|
failure_type=d.get("failure_type", ""),
|
|
count=int(d.get("count", 0)),
|
|
description=d.get("description", ""),
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"failure_type": self.failure_type,
|
|
"count": self.count,
|
|
"description": self.description,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class RawPatch:
|
|
"""Analyst output from the Reflect stage — a patch with provenance.
|
|
|
|
Wraps the dict produced by ``run_error_analyst_minibatch``
|
|
and ``run_success_analyst_minibatch``.
|
|
"""
|
|
|
|
patch: Patch
|
|
source_type: Literal["failure", "success"] = "failure"
|
|
batch_size: int = 0
|
|
failure_summary: list[FailureSummaryEntry] = field(default_factory=list)
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict | None) -> RawPatch | None:
|
|
if d is None:
|
|
return None
|
|
inner = d.get("patch", d)
|
|
if not isinstance(inner, dict):
|
|
return None
|
|
patch = Patch.from_dict(inner)
|
|
return cls(
|
|
patch=patch,
|
|
source_type=d.get("source_type", "failure"),
|
|
batch_size=int(d.get("batch_size", 0)),
|
|
failure_summary=[
|
|
FailureSummaryEntry.from_dict(fs)
|
|
for fs in d.get("failure_summary", [])
|
|
],
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
d: dict[str, Any] = {
|
|
"patch": self.patch.to_dict(),
|
|
"source_type": self.source_type,
|
|
"batch_size": self.batch_size,
|
|
}
|
|
if self.failure_summary:
|
|
d["failure_summary"] = [fs.to_dict() for fs in self.failure_summary]
|
|
return d
|
|
|
|
|
|
# ── Epoch-level: SLOW_UPDATE ─────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class SlowUpdateResult:
|
|
"""Output of the epoch-level slow update stage (EMA / regularization)."""
|
|
|
|
reasoning: str = ""
|
|
slow_update_content: str = ""
|
|
action: str = ""
|
|
time_s: float | None = None
|
|
prev_hard: float | None = None
|
|
curr_hard: float | None = None
|
|
selection_hard: float | None = None
|
|
selection_soft: float | None = None
|
|
candidate_hash: str = ""
|
|
update_origin: str = ""
|
|
update_target: str = ""
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict | None) -> SlowUpdateResult | None:
|
|
if d is None:
|
|
return None
|
|
return cls(
|
|
reasoning=d.get("reasoning", ""),
|
|
slow_update_content=d.get("slow_update_content", ""),
|
|
action=d.get("action", ""),
|
|
time_s=d.get("time_s"),
|
|
prev_hard=d.get("prev_hard"),
|
|
curr_hard=d.get("curr_hard"),
|
|
selection_hard=d.get("selection_hard"),
|
|
selection_soft=d.get("selection_soft"),
|
|
candidate_hash=d.get("candidate_hash", ""),
|
|
update_origin=d.get("update_origin", ""),
|
|
update_target=d.get("update_target", ""),
|
|
)
|
|
|
|
def to_dict(self) -> dict:
|
|
d: dict[str, Any] = {
|
|
"reasoning": self.reasoning,
|
|
"slow_update_content": self.slow_update_content,
|
|
}
|
|
if self.action:
|
|
d["action"] = self.action
|
|
if self.time_s is not None:
|
|
d["time_s"] = self.time_s
|
|
if self.prev_hard is not None:
|
|
d["prev_hard"] = self.prev_hard
|
|
if self.curr_hard is not None:
|
|
d["curr_hard"] = self.curr_hard
|
|
if self.selection_hard is not None:
|
|
d["selection_hard"] = self.selection_hard
|
|
if self.selection_soft is not None:
|
|
d["selection_soft"] = self.selection_soft
|
|
if self.candidate_hash:
|
|
d["candidate_hash"] = self.candidate_hash
|
|
if self.update_origin:
|
|
d["update_origin"] = self.update_origin
|
|
if self.update_target:
|
|
d["update_target"] = self.update_target
|
|
return d
|