feat(moa): per-preset fanout cadence — user_turn runs advisors once per user turn (#57591)

New preset key 'fanout': 'per_iteration' (default, unchanged behavior)
re-runs the reference fan-out whenever the advisory view changes — every
tool iteration. 'user_turn' runs the advisors ONCE per user turn and lets
the aggregator act alone for the rest of the tool loop — the original MoA
shape (upfront multi-model synthesis, then a single acting model), and the
obvious lever on MoA's wall/cost multiplier (advisor generation dominates
per-turn latency).

Implementation reuses the existing turn-scoped reference cache: in
user_turn mode the cache signature hashes only the prefix up to the LAST
user message, so mid-turn advisory-view growth doesn't change the key and
iteration 2+ is a cache HIT (advice reused, zero advisor spend, no
re-trace). A new user message changes the prefix and re-triggers the
fan-out. Unknown fanout values normalize to per_iteration.
This commit is contained in:
Teknium
2026-07-03 01:02:44 -07:00
committed by GitHub
parent 6eb39c2bbe
commit 9e044cf795
2 changed files with 38 additions and 3 deletions

View File

@@ -773,13 +773,33 @@ class MoAChatCompletions:
reference_outputs: list[tuple[str, str, Any]] = []
ref_messages = _reference_messages(messages)
# Fan-out cadence. "per_iteration" (default): advisors re-run whenever
# the advisory view changes — i.e. every tool iteration, since the
# view grows with each tool result. "user_turn": advisors run ONCE per
# user turn; subsequent tool iterations reuse that turn's advice and
# the aggregator acts alone (the original MoA shape: synthesize at the
# start, then let the acting model work). Implemented by hashing only
# the prefix up to the LAST USER message so mid-turn growth doesn't
# change the signature — iteration 2+ becomes a cache HIT.
fanout_mode = str(preset.get("fanout") or "per_iteration").strip().lower()
sig_messages = ref_messages
if fanout_mode == "user_turn":
last_user_idx = None
for _i in range(len(ref_messages) - 1, -1, -1):
if ref_messages[_i].get("role") == "user":
last_user_idx = _i
break
if last_user_idx is not None:
sig_messages = ref_messages[: last_user_idx + 1]
# Turn-scoped cache: only run + display references when the advisory
# view changed (i.e. a new user turn). Within one turn the agent loop
# calls create() once per tool iteration with the same advisory view;
# reuse the cached outputs and skip both the re-run and the re-emit.
# calls create() once per tool iteration; in user_turn mode the
# signature is stable across those iterations (prefix hash above), so
# the fan-out runs once per user turn and iterations reuse the advice.
_sig = hashlib.sha256(
"\u0000".join(
f"{m.get('role')}:{m.get('content')}" for m in ref_messages
f"{m.get('role')}:{m.get('content')}" for m in sig_messages
).encode("utf-8", "replace")
).hexdigest()
_cache_key = (self.preset_name, _sig, tuple(_slot_label(s) for s in reference_models))

View File

@@ -67,6 +67,12 @@ def _coerce_int_or_none(value: Any) -> int | None:
return n if n > 0 else None
def _coerce_fanout(value: Any) -> str:
"""Normalize the fan-out cadence; unknown values fall back to default."""
mode = str(value or "").strip().lower()
return mode if mode in {"per_iteration", "user_turn"} else "per_iteration"
def _clean_slot(slot: Any) -> dict[str, str] | None:
if not isinstance(slot, dict):
return None
@@ -94,6 +100,7 @@ def _default_preset() -> dict[str, Any]:
"aggregator_temperature": None,
"max_tokens": 4096,
"reference_max_tokens": None,
"fanout": "per_iteration",
"enabled": True,
}
@@ -131,6 +138,13 @@ def _normalize_preset(raw: Any) -> dict[str, Any]:
# judgement, so capping roughly halves per-turn wall time. Does NOT cap
# the acting aggregator (its output is the user-visible answer).
"reference_max_tokens": _coerce_int_or_none(raw.get("reference_max_tokens")),
# When the reference fan-out runs. "per_iteration" (default) re-runs
# the advisors whenever the advisory view changes — i.e. every tool
# iteration, so advice tracks live task state. "user_turn" runs the
# advisors ONCE per user turn (the original MoA shape): the
# aggregator gets their upfront plan-level advice, then acts alone
# for the rest of the tool loop.
"fanout": _coerce_fanout(raw.get("fanout")),
}
@@ -177,6 +191,7 @@ def normalize_moa_config(raw: Any) -> dict[str, Any]:
"aggregator_temperature": active["aggregator_temperature"],
"max_tokens": active["max_tokens"],
"reference_max_tokens": active.get("reference_max_tokens"),
"fanout": active.get("fanout", "per_iteration"),
"enabled": active["enabled"],
}