This commit is contained in:
hwq
2026-05-30 15:01:34 +00:00
parent 4f3a9bc055
commit 1f75d022a5
7 changed files with 226 additions and 58 deletions

View File

@@ -11,7 +11,6 @@ import json
import os
import re
import sys
import time
import concurrent.futures
import numpy as np
@@ -206,7 +205,6 @@ def run_alfworld_batch(
# Call API in parallel
actions = ["None"] * env_num
action_timeout = 180
def call_api(idx):
try:
@@ -216,7 +214,7 @@ def run_alfworld_batch(
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=120,
timeout=None,
)
response = (response or "").strip()
if not response:
@@ -230,7 +228,6 @@ def run_alfworld_batch(
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_api_workers)
try:
futures = {executor.submit(call_api, i): i for i in active_indices}
started_at = {future: time.time() for future in futures}
pending_futs = set(futures)
while pending_futs:
done, _ = concurrent.futures.wait(
@@ -238,11 +235,6 @@ def run_alfworld_batch(
timeout=5,
return_when=concurrent.futures.FIRST_COMPLETED,
)
now = time.time()
timed_out = [
future for future in pending_futs - done
if now - started_at[future] >= action_timeout
]
for future in done:
pending_futs.remove(future)
try:
@@ -251,10 +243,6 @@ def run_alfworld_batch(
idx = futures[future]
response = "<think>error</think><action>look</action>"
actions[idx] = response
for future in timed_out:
pending_futs.remove(future)
idx = futures[future]
actions[idx] = "<think>api timeout</think><action>look</action>"
finally:
executor.shutdown(wait=False, cancel_futures=True)

View File

@@ -119,7 +119,7 @@ def process_one(
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
exec_timeout: int = 300,
exec_timeout: int | None = 300,
max_completion_tokens: int = 16384,
) -> dict:
item_id = str(item["id"])
@@ -143,6 +143,7 @@ def process_one(
try:
pred_dir = os.path.join(out_root, "predictions", item_id)
os.makedirs(pred_dir, exist_ok=True)
llm_timeout = int(exec_timeout) if exec_timeout and int(exec_timeout) > 0 else None
if is_target_exec_backend():
from skillopt.model import azure_openai as _llm
@@ -157,7 +158,7 @@ def process_one(
skill_content=skill_content,
item=item,
model=_llm.TARGET_DEPLOYMENT,
timeout=exec_timeout,
timeout=llm_timeout,
use_theorem=use_theorem,
use_sketch=use_sketch,
diagnostic_mode=diagnostic_mode if turn == 0 else False,
@@ -223,7 +224,7 @@ def process_one(
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=exec_timeout,
timeout=llm_timeout,
)
else:
refinement = (
@@ -237,7 +238,7 @@ def process_one(
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=exec_timeout,
timeout=llm_timeout,
)
response = resp_text
conversation.append({"type": "message", "turn": turn + 1, "content": resp_text})
@@ -292,7 +293,7 @@ def run_batch(
skill_content: str,
*,
max_turns: int = 1,
exec_timeout: int = 300,
exec_timeout: int | None = 300,
workers: int = 64,
max_completion_tokens: int = 16384,
use_theorem: bool = False,
@@ -300,9 +301,14 @@ def run_batch(
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context_by_id: dict[str, str] | None = None,
task_timeout: int = 600,
task_timeout: int | None = 600,
) -> list[dict]:
task_timeout = max(int(task_timeout), int(exec_timeout) + 60)
exec_timeout_value = int(exec_timeout) if exec_timeout and int(exec_timeout) > 0 else 0
task_timeout_value = int(task_timeout) if task_timeout and int(task_timeout) > 0 else 0
if exec_timeout_value <= 0 or task_timeout_value <= 0:
task_timeout = None
else:
task_timeout = max(task_timeout_value, exec_timeout_value + 60)
results_path = os.path.join(out_root, "results.jsonl")
os.makedirs(out_root, exist_ok=True)
@@ -385,6 +391,7 @@ def run_batch(
now = time.time()
timed_out = [
fut for fut in pending_futs - done
if task_timeout is not None
if str(futs[fut]["id"]) in started_at
and now - started_at[str(futs[fut]["id"])] >= task_timeout
]

View File

@@ -6,14 +6,11 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from skillopt.envs.officeqa.evaluator import evaluate
from skillopt.envs.officeqa.tool_runtime import (
build_oracle_parsed_pages_context,
custom_search,
resolve_candidate_files,
resolve_docs_roots,
run_tool,
)
try:
from skillopt.envs.sealqa.tool_runtime import custom_search
except ImportError:
custom_search = None # type: ignore[assignment]
from skillopt.model import chat_target_messages, get_target_backend, is_target_exec_backend
from skillopt.model.codex_harness import prepare_workspace, render_skill_md, run_target_exec
from skillopt.prompts import load_prompt

View File

@@ -5,16 +5,31 @@ import html
import json
import os
import re
import socket
import time
from functools import lru_cache
from html.parser import HTMLParser
from pathlib import Path
from urllib.error import HTTPError, URLError
from urllib.parse import parse_qs, urlparse
from urllib.request import Request, urlopen
_MAX_READ_CHARS = 4000
_MAX_GREP_MATCHES = 20
_MAX_GLOB_MATCHES = 50
_MAX_ORACLE_PAGE_CHARS = 24000
_MAX_ORACLE_CONTEXT_CHARS = 80000
DEFAULT_USER_AGENT = (
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/135.0 Safari/537.36"
)
DEFAULT_CUSTOM_SEARCH_URL = "http://apisix.westus2.cloudapp.azure.com/search_tool/search"
DEFAULT_CUSTOM_SEARCH_AUTH_ENV = "OFFICEQA_CUSTOM_SEARCH_AUTH"
DEFAULT_CUSTOM_SEARCH_PROVIDER = "duckduckgo"
DEFAULT_CUSTOM_SEARCH_MAX_RESULTS = 4
DEFAULT_CUSTOM_SEARCH_TIMEOUT = 20
DEFAULT_CUSTOM_SEARCH_MAX_RETRIES = 4
DEFAULT_CUSTOM_SEARCH_INITIAL_BACKOFF_SECONDS = 1.0
def _normalize_data_dirs(data_dirs: list[str] | tuple[str, ...] | str | None, project_root: Path) -> list[str]:
@@ -352,6 +367,141 @@ def build_oracle_parsed_pages_context(
)
def _extract_search_items(payload: object) -> list[dict]:
if isinstance(payload, list):
return [item for item in payload if isinstance(item, dict)]
if not isinstance(payload, dict):
return []
candidate_keys = (
"results",
"items",
"data",
"organic",
"organic_results",
"search_results",
"webPages",
"value",
)
for key in candidate_keys:
value = payload.get(key)
if isinstance(value, list):
return [item for item in value if isinstance(item, dict)]
if isinstance(value, dict):
nested = _extract_search_items(value)
if nested:
return nested
return []
def _normalize_search_item(item: dict, index: int) -> str:
title = str(
item.get("title")
or item.get("name")
or item.get("headline")
or item.get("source")
or f"Result {index}"
).strip()
url = str(
item.get("url")
or item.get("link")
or item.get("href")
or item.get("display_url")
or ""
).strip()
snippet = str(
item.get("snippet")
or item.get("description")
or item.get("body")
or item.get("text")
or item.get("content")
or ""
).strip()
lines = [f"[{index}] {title}"]
if url:
lines.append(f"URL: {url}")
if snippet:
lines.append(f"Snippet: {snippet}")
return "\n".join(lines)
def _format_search_payload(query: str, payload: object) -> str:
items = _extract_search_items(payload)
header = f"Query: {query}"
if not items:
body = json.dumps(payload, ensure_ascii=False) if payload else "[no results]"
return f"{header}\n{body}"
rendered = [_normalize_search_item(item, index) for index, item in enumerate(items, start=1)]
return f"{header}\n\n" + "\n\n".join(rendered)
def _is_retryable_search_http_error(status_code: int) -> bool:
return status_code in {408, 429} or status_code >= 500
def custom_search(
query: str,
*,
api_url: str = DEFAULT_CUSTOM_SEARCH_URL,
auth_token: str | None = None,
auth_env: str = DEFAULT_CUSTOM_SEARCH_AUTH_ENV,
provider: str = DEFAULT_CUSTOM_SEARCH_PROVIDER,
max_num_results: int = DEFAULT_CUSTOM_SEARCH_MAX_RESULTS,
timeout: int = DEFAULT_CUSTOM_SEARCH_TIMEOUT,
max_retries: int = DEFAULT_CUSTOM_SEARCH_MAX_RETRIES,
initial_backoff_seconds: float = DEFAULT_CUSTOM_SEARCH_INITIAL_BACKOFF_SECONDS,
) -> str:
query = str(query or "").strip()
if not query:
raise ValueError("custom_search query must be non-empty")
token = str(auth_token or os.environ.get(auth_env, "")).strip()
if not token:
raise ValueError(f"custom_search auth token missing; set {auth_env}")
payload = json.dumps(
{
"query": query,
"max_num_results": int(max_num_results),
"provider": provider,
},
ensure_ascii=False,
).encode("utf-8")
req = Request(
api_url,
data=payload,
headers={
"Authorization": token,
"Content-Type": "application/json",
"User-Agent": DEFAULT_USER_AGENT,
},
method="POST",
)
attempts = max(1, int(max_retries) + 1)
last_error: RuntimeError | None = None
for attempt in range(1, attempts + 1):
try:
with urlopen(req, timeout=timeout) as response:
raw_body = response.read().decode("utf-8", errors="ignore")
break
except HTTPError as exc:
detail = exc.read().decode("utf-8", errors="ignore")
last_error = RuntimeError(f"custom_search HTTP {exc.code}: {detail[:1000]}")
if attempt >= attempts or not _is_retryable_search_http_error(exc.code):
raise last_error from exc
except (URLError, TimeoutError, socket.timeout) as exc:
last_error = RuntimeError(f"custom_search connection error: {exc}")
if attempt >= attempts:
raise last_error from exc
backoff_seconds = max(0.0, float(initial_backoff_seconds)) * (2 ** (attempt - 1))
if backoff_seconds > 0:
time.sleep(backoff_seconds)
else:
raise last_error or RuntimeError("custom_search failed without a captured error")
try:
parsed = json.loads(raw_body)
except json.JSONDecodeError:
return f"Query: {query}\n\n{raw_body.strip() or '[empty response]'}"
return _format_search_payload(query, parsed)
def run_tool(name: str, arguments: dict, *, allowed_roots: list[str], allowed_files: list[str]) -> tuple[str, str]:
if name == "glob":
pattern = str(arguments.get("pattern") or "*")

View File

@@ -188,7 +188,7 @@ def _build_user(
# ── LLM call with retry ────────────────────────────────────────────────────
def _llm_call_with_retry(call_fn, *, retries: int = 5, timeout: int = 120):
def _llm_call_with_retry(call_fn, *, retries: int = 5, timeout: int | None = 120):
"""Wrap an LLM API call with retry and per-call timeout."""
last_err = None
for attempt in range(retries):
@@ -335,7 +335,7 @@ def _chat_call(
deployment: str,
messages: list[dict],
max_output_tokens: int,
llm_timeout: int = 120,
llm_timeout: int | None = 120,
) -> str:
"""Single LLM call, no tools. Returns raw text."""
reasoning_effort = get_reasoning_effort()
@@ -402,8 +402,8 @@ def run_single(
answer_position: str = "",
skill_content: str = "",
max_output_tokens: int = 16384,
llm_timeout: int = 120,
task_timeout: int = 300,
llm_timeout: int | None = 120,
task_timeout: int | None = 300,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
@@ -416,8 +416,9 @@ def run_single(
Returns ``{"code": str, "raw": str, "n_turns": 1}``.
"""
no_task_timeout = task_timeout is None or task_timeout <= 0
if is_target_exec_backend():
deadline = time.time() + task_timeout
deadline = None if no_task_timeout else time.time() + task_timeout
deployment = _get_deployment()
work_dir, skill_md, task_md, prompt = _prepare_codex_workspace(
instruction=instruction,
@@ -430,8 +431,11 @@ def run_single(
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
)
remaining = max(10, int(deadline - time.time()))
effective_timeout = min(task_timeout, remaining)
if deadline is None:
effective_timeout = 10**9
else:
remaining = max(10, int(deadline - time.time()))
effective_timeout = min(task_timeout, remaining)
final_message, raw = _run_exec_backend(
work_dir=work_dir,
prompt=prompt,
@@ -453,7 +457,7 @@ def run_single(
"target_user_prompt": f"{prompt}\n\n## Task File\n\n{task_md}",
}
deadline = time.time() + task_timeout
deadline = None if no_task_timeout else time.time() + task_timeout
client = get_target_client()
deployment = _get_deployment()
system = _build_system(skill_content)
@@ -472,8 +476,11 @@ def run_single(
{"role": "user", "content": user},
]
remaining = max(10, int(deadline - time.time()))
effective_timeout = min(llm_timeout, remaining)
if deadline is None:
effective_timeout = None
else:
remaining = max(10, int(deadline - time.time()))
effective_timeout = min(llm_timeout or remaining, remaining)
raw = _chat_call(client, deployment, messages, max_output_tokens, llm_timeout=effective_timeout)
time.sleep(3) # Rate-limit cooldown after successful LLM call
code = extract_code(raw)
@@ -497,8 +504,8 @@ def run_multi(
skill_content: str = "",
max_turns: int = 5,
max_output_tokens: int = 16384,
llm_timeout: int = 120,
task_timeout: int = 600,
llm_timeout: int | None = 120,
task_timeout: int | None = 600,
gold_path: str = "",
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
@@ -520,8 +527,9 @@ def run_multi(
Returns ``{"code": str, "raw": str, "n_turns": int, "conversation": [...]}``.
"""
no_task_timeout = task_timeout is None or task_timeout <= 0
if is_target_exec_backend():
deadline = time.time() + task_timeout
deadline = None if no_task_timeout else time.time() + task_timeout
deployment = _get_deployment()
work_dir, skill_md, task_md, initial_prompt = _prepare_codex_workspace(
instruction=instruction,
@@ -549,11 +557,13 @@ def run_multi(
solution_path = os.path.join(work_dir, "solution.py")
for turn in range(max_turns):
remaining = deadline - time.time()
if remaining <= 10:
break
effective_timeout = max(10, int(remaining))
if deadline is None:
effective_timeout = 10**9
else:
remaining = deadline - time.time()
if remaining <= 10:
break
effective_timeout = max(10, int(remaining))
final_message, raw = _run_exec_backend(
work_dir=work_dir,
prompt=prompt,
@@ -577,7 +587,12 @@ def run_multi(
"Write a complete `solution.py` that reads `INPUT_PATH` and saves `OUTPUT_PATH`."
)
else:
ok, err = run_generated_code(code, input_xlsx, output_path)
ok, err = run_generated_code(
code,
input_xlsx,
output_path,
timeout=None if no_task_timeout else 120,
)
if ok:
if gold_path and answer_position:
from skillopt.envs.spreadsheetbench.rollout import _auto_verify_output
@@ -617,7 +632,7 @@ def run_multi(
"target_user_prompt": f"{initial_prompt}\n\n## Task File\n\n{task_md}",
}
deadline = time.time() + task_timeout
deadline = None if no_task_timeout else time.time() + task_timeout
client = get_target_client()
deployment = _get_deployment()
system = _build_system(skill_content)
@@ -640,12 +655,14 @@ def run_multi(
raw = ""
for turn in range(max_turns):
remaining = deadline - time.time()
if remaining <= 10:
# Not enough time for another round
break
effective_timeout = min(llm_timeout, int(remaining))
if deadline is None:
effective_timeout = None
else:
remaining = deadline - time.time()
if remaining <= 10:
# Not enough time for another round
break
effective_timeout = min(llm_timeout or int(remaining), int(remaining))
raw = _chat_call(client, deployment, messages, max_output_tokens, llm_timeout=effective_timeout)
time.sleep(3) # Rate-limit cooldown after successful LLM call
code = extract_code(raw)
@@ -663,7 +680,12 @@ def run_multi(
continue
# Execute the code
ok, err = run_generated_code(code, input_xlsx, output_path)
ok, err = run_generated_code(
code,
input_xlsx,
output_path,
timeout=None if no_task_timeout else 120,
)
if ok:
# Execution succeeded — check correctness if gold_path available
if gold_path and answer_position:

View File

@@ -34,7 +34,7 @@ def _strip_path_assignments(code: str) -> str:
return _PATH_ASSIGN_RE.sub("", code)
def run_generated_code(code: str, input_path: str, output_path: str, timeout: int = 120) -> tuple[bool, str]:
def run_generated_code(code: str, input_path: str, output_path: str, timeout: int | None = 120) -> tuple[bool, str]:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
cleaned = _strip_path_assignments(code)
indented = textwrap.indent(cleaned, " ")
@@ -51,7 +51,7 @@ def run_generated_code(code: str, input_path: str, output_path: str, timeout: in
[sys.executable, tmp],
capture_output=True,
text=True,
timeout=timeout,
timeout=timeout if timeout and timeout > 0 else None,
)
if proc.returncode != 0:
return False, (proc.stdout + "\n" + proc.stderr).strip()

View File

@@ -547,6 +547,7 @@ def process_one_codegen(
mode: str = "single",
max_turns: int = 5,
max_completion_tokens: int = 16384,
task_timeout: int = 600,
use_eval_feedback: bool = False,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
@@ -659,6 +660,7 @@ def process_one_codegen(
skill_content=skill_content,
max_turns=max_turns,
max_output_tokens=max_completion_tokens,
task_timeout=task_timeout,
gold_path=first_gold if use_eval_feedback else "",
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
@@ -673,6 +675,7 @@ def process_one_codegen(
answer_position=answer_position_eval,
skill_content=skill_content,
max_output_tokens=max_completion_tokens,
task_timeout=task_timeout,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
@@ -810,10 +813,10 @@ def run_spreadsheet_batch_codegen(
Args:
mode: "single" or "multi".
task_timeout: Hard per-task timeout in seconds at the future level.
0 = auto (single: 300s, multi: 600s).
0 or negative disables the per-task timeout.
"""
if task_timeout <= 0:
task_timeout = 300 if mode == "single" else 600
no_task_timeout = task_timeout <= 0
task_timeout_label = "none" if no_task_timeout else f"{task_timeout}s"
os.makedirs(out_root, exist_ok=True)
@@ -833,7 +836,7 @@ def run_spreadsheet_batch_codegen(
pending = [it for it in items if str(it["id"]) not in done_ids]
print(
f" [spreadsheet codegen-{mode}] total={len(items)} done={len(done_ids)} "
f"pending={len(pending)} workers={max_api_workers} task_timeout={task_timeout}s"
f"pending={len(pending)} workers={max_api_workers} task_timeout={task_timeout_label}"
)
if not pending:
@@ -854,6 +857,7 @@ def run_spreadsheet_batch_codegen(
mode,
max_turns,
max_completion_tokens,
task_timeout,
use_eval_feedback,
diagnostic_mode,
diagnostic_instruction,
@@ -903,7 +907,7 @@ def run_spreadsheet_batch_codegen(
while pending_futs:
done, _ = wait(pending_futs, timeout=5, return_when=FIRST_COMPLETED)
now = time.time()
timed_out = [
timed_out = [] if no_task_timeout else [
fut for fut in pending_futs - done
if str(futs[fut]["id"]) in started_at
and now - started_at[str(futs[fut]["id"])] >= task_timeout