Fail fast on systemic SearchQA rollout failures

This commit is contained in:
summerview1997
2026-06-16 09:20:57 +08:00
parent 46b3207b96
commit da799620ba

View File

@@ -13,20 +13,31 @@ from __future__ import annotations
import json
import os
import time
import traceback
from collections import Counter
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
from skillopt.model import chat_target, get_target_backend, is_target_exec_backend
from skillopt.envs.searchqa.evaluator import evaluate
from skillopt.model import chat_target, is_target_exec_backend
from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_target_exec
from skillopt.prompts import load_prompt
from skillopt.envs.searchqa.evaluator import evaluate
# ── Prompt templates ─────────────────────────────────────────────────────────
_MAX_CONTEXT_CHARS = 6000
def _raise_on_systemic_failure(results: list[dict]) -> None:
"""Abort when all rollout rows failed before any agent response."""
if not results or not all(row.get("agent_ok") is False for row in results):
return
reasons = Counter(str(row.get("fail_reason") or "unknown error") for row in results)
common_reason, count = reasons.most_common(1)[0]
raise RuntimeError(
f"SearchQA rollout failed for all {len(results)} items before an agent "
f"response ({count}x): {common_reason}"
)
def _truncate_context(context: str, max_chars: int = _MAX_CONTEXT_CHARS) -> str:
"""Truncate context at [DOC] boundaries to stay within budget."""
if len(context) <= max_chars:
@@ -379,6 +390,7 @@ def run_batch(
pending = [it for it in items if str(it["id"]) not in done_ids]
if not pending:
_raise_on_systemic_failure(existing)
return existing
total = len(existing) + len(pending)
@@ -478,4 +490,5 @@ def run_batch(
finally:
ex.shutdown(wait=False, cancel_futures=True)
_raise_on_systemic_failure(results)
return results