Make rollout completion tokens configurable

This commit is contained in:
hwq
2026-05-28 09:45:47 +00:00
parent 99212e3956
commit 786d57b5cf
20 changed files with 63 additions and 24 deletions

View File

@@ -79,7 +79,6 @@ env:
name: ""
skill_init: ""
split_mode: ratio # ratio = build deterministic split from data_path; split_dir = use pre-split train/val/test
split_ratio: "2:1:7" # explicit default for dataset-backed benchmarks: train:val:test
split_seed: 42
split_dir: ""
data_path: ""

View File

@@ -19,11 +19,11 @@ env:
name: alfworld
skill_init: skillopt/envs/alfworld/skills/initial.md
split_mode: split_dir
split_ratio: "2:1:7"
split_dir: data/ablation_splits/alfworld/2-1-7_seed42
data_path: ""
split_output_dir: ""
max_steps: 50
max_completion_tokens: 16384
workers: 8
max_api_workers: 8
limit: 0

View File

@@ -18,11 +18,11 @@ env:
name: docvqa
skill_init: skillopt/envs/docvqa/skills/initial.md
split_mode: split_dir
split_ratio: "2:1:7"
split_dir: data/docvqa/splits
data_path: ""
split_output_dir: ""
max_turns: 1
max_completion_tokens: 16384
workers: 16
image_detail: auto
limit: 0

View File

@@ -9,11 +9,11 @@ env:
name: livemathematicianbench
skill_init: skillopt/envs/livemathematicianbench/skills/initial.md
split_mode: split_dir
split_ratio: "2:1:7"
split_dir: data/ablation_splits/livemathematicianbench/2-1-7_seed42
data_path: ""
split_output_dir: ""
max_turns: 1
max_completion_tokens: 16384
exec_timeout: 300
workers: 64
limit: 0

View File

@@ -23,7 +23,7 @@ env:
- data/officeqa_docs_official
workers: 4
max_tool_turns: 24
max_completion_tokens: 10000
max_completion_tokens: 16384
search_mode: offline
max_queries_per_turn: 4
search_api_url: http://apisix.westus2.cloudapp.azure.com/search_tool/search

View File

@@ -23,10 +23,10 @@ env:
name: searchqa
skill_init: skillopt/envs/searchqa/skills/initial.md
split_mode: split_dir
split_ratio: "2:1:7"
split_dir: data/searchqa_split
data_path: ""
split_output_dir: ""
max_turns: 1
max_completion_tokens: 16384
workers: 24
limit: 0

View File

@@ -23,12 +23,12 @@ env:
name: spreadsheetbench
skill_init: skillopt/envs/spreadsheetbench/skills/initial.md
split_mode: split_dir
split_ratio: "2:1:7"
split_dir: data/spreadsheetbench_split
data_path: ""
split_output_dir: ""
data_root: data/spreadsheetbench_verified_400
mode: multi
max_turns: 30
max_completion_tokens: 16384
exec_timeout: 600
workers: 24

View File

@@ -81,10 +81,13 @@ class ALFWorldAdapter(EnvAdapter):
analyst_workers: int = 16,
failure_only: bool = False,
minibatch_size: int = 8,
edit_budget: int = 4, ) -> None:
edit_budget: int = 4,
max_completion_tokens: int = 16384,
) -> None:
self.max_steps = max_steps
self.workers = max(int(workers or 1), 1)
self.max_api_workers = max_api_workers
self.max_completion_tokens = int(max_completion_tokens)
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
@@ -349,6 +352,7 @@ class ALFWorldAdapter(EnvAdapter):
max_steps=self.max_steps,
out_root=out_dir,
max_api_workers=self.max_api_workers,
max_completion_tokens=self.max_completion_tokens,
result_ids=getattr(env_manager, "_skillopt_result_ids", None),
)
@@ -411,6 +415,7 @@ class ALFWorldAdapter(EnvAdapter):
max_steps=self.max_steps,
out_root=out_dir,
max_api_workers=min(self.max_api_workers, chunk_size),
max_completion_tokens=self.max_completion_tokens,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
result_ids=chunk_ids,

View File

@@ -134,7 +134,7 @@ def run_alfworld_batch(
out_root: str = "",
max_api_workers: int = 8,
temperature: float = 0.4,
max_completion_tokens: int = 2048,
max_completion_tokens: int = 16384,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
result_ids: list[str] | None = None,

View File

@@ -27,10 +27,13 @@ class DocVQAAdapter(EnvAdapter):
edit_budget: int = 4,
seed: int = 42,
limit: int = 0,
image_detail: str = "auto", ) -> None:
image_detail: str = "auto",
max_completion_tokens: int = 16384,
) -> None:
self.max_turns = max_turns
self.exec_timeout = exec_timeout
self.workers = workers
self.max_completion_tokens = int(max_completion_tokens)
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
@@ -75,6 +78,7 @@ class DocVQAAdapter(EnvAdapter):
exec_timeout=self.exec_timeout,
workers=self.workers,
image_detail=self.image_detail,
max_completion_tokens=self.max_completion_tokens,
diagnostic_mode=kwargs.get("diagnostic_mode", False),
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
task_timeout=self.exec_timeout,

View File

@@ -134,6 +134,7 @@ def process_one(
max_turns: int = 1,
exec_timeout: int = 120,
image_detail: str = "auto",
max_completion_tokens: int = 16384,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
) -> dict:
@@ -200,7 +201,7 @@ def process_one(
if turn == 0:
resp_text, _ = chat_target_messages(
messages=messages,
max_completion_tokens=768,
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=exec_timeout,
@@ -214,7 +215,7 @@ def process_one(
]
resp_text, _ = chat_target_messages(
messages=refinement_messages,
max_completion_tokens=512,
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=exec_timeout,
@@ -266,6 +267,7 @@ def run_batch(
exec_timeout: int = 120,
workers: int = 16,
image_detail: str = "auto",
max_completion_tokens: int = 16384,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
task_timeout: int = 600,
@@ -325,6 +327,7 @@ def run_batch(
max_turns=max_turns,
exec_timeout=exec_timeout,
image_detail=image_detail,
max_completion_tokens=max_completion_tokens,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
)

View File

@@ -60,10 +60,13 @@ class LiveMathematicianBenchAdapter(EnvAdapter):
limit: int = 0,
shuffle_choices: bool = True,
use_theorem: bool = False,
use_sketch: bool = False, ) -> None:
use_sketch: bool = False,
max_completion_tokens: int = 16384,
) -> None:
self.max_turns = max_turns
self.exec_timeout = exec_timeout
self.workers = workers
self.max_completion_tokens = int(max_completion_tokens)
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
@@ -115,6 +118,7 @@ class LiveMathematicianBenchAdapter(EnvAdapter):
max_turns=self.max_turns,
exec_timeout=self.exec_timeout,
workers=self.workers,
max_completion_tokens=self.max_completion_tokens,
use_theorem=self.use_theorem,
use_sketch=self.use_sketch,
diagnostic_mode=kwargs.get("diagnostic_mode", False),

View File

@@ -120,6 +120,7 @@ def process_one(
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
exec_timeout: int = 300,
max_completion_tokens: int = 16384,
) -> dict:
item_id = str(item["id"])
result = {
@@ -219,7 +220,7 @@ def process_one(
resp_text, _ = chat_target(
system=system,
user=user,
max_completion_tokens=16384,
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=exec_timeout,
@@ -233,7 +234,7 @@ def process_one(
resp_text, _ = chat_target(
system=system,
user=refinement,
max_completion_tokens=16384,
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=exec_timeout,
@@ -293,6 +294,7 @@ def run_batch(
max_turns: int = 1,
exec_timeout: int = 300,
workers: int = 64,
max_completion_tokens: int = 16384,
use_theorem: bool = False,
use_sketch: bool = False,
diagnostic_mode: bool = False,
@@ -338,6 +340,7 @@ def run_batch(
skill_content,
max_turns=max_turns,
exec_timeout=exec_timeout,
max_completion_tokens=max_completion_tokens,
use_theorem=use_theorem,
use_sketch=use_sketch,
diagnostic_mode=diagnostic_mode,

View File

@@ -26,7 +26,7 @@ class OfficeQAAdapter(EnvAdapter):
seed: int = 42,
limit: int = 0,
max_tool_turns: int = 12,
max_completion_tokens: int = 64000,
max_completion_tokens: int = 16384,
search_mode: str = "offline",
max_queries_per_turn: int = 4,
search_api_url: str = os.environ.get("OFFICEQA_SEARCH_API_URL", "http://localhost:8080/search_tool/search"),

View File

@@ -516,7 +516,7 @@ def process_one(
skill_content: str,
*,
max_tool_turns: int = 12,
max_completion_tokens: int = 64000,
max_completion_tokens: int = 16384,
search_mode: str = _DEFAULT_SEARCH_MODE,
max_queries_per_turn: int = 4,
search_api_url: str = "",
@@ -652,7 +652,7 @@ def process_one(
for turn in range(1, max_tool_turns + 1):
message, _ = chat_target_messages(
messages=messages,
max_completion_tokens=768,
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
tools=_TOOL_SCHEMAS,
@@ -725,7 +725,7 @@ def run_batch(
*,
workers: int = 8,
max_tool_turns: int = 12,
max_completion_tokens: int = 64000,
max_completion_tokens: int = 16384,
search_mode: str = _DEFAULT_SEARCH_MODE,
max_queries_per_turn: int = 4,
search_api_url: str = "",

View File

@@ -31,10 +31,13 @@ class SearchQAAdapter(EnvAdapter):
minibatch_size: int = 8,
edit_budget: int = 4,
seed: int = 42,
limit: int = 0, ) -> None:
limit: int = 0,
max_completion_tokens: int = 16384,
) -> None:
self.max_turns = max_turns
self.exec_timeout = exec_timeout
self.workers = workers
self.max_completion_tokens = int(max_completion_tokens)
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
@@ -84,6 +87,7 @@ class SearchQAAdapter(EnvAdapter):
max_turns=self.max_turns,
exec_timeout=self.exec_timeout,
workers=self.workers,
max_completion_tokens=self.max_completion_tokens,
diagnostic_mode=kwargs.get("diagnostic_mode", False),
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),

View File

@@ -148,6 +148,7 @@ def process_one(
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
exec_timeout: int = 120,
max_completion_tokens: int = 16384,
) -> dict:
"""Process a single QA item: run agent + evaluate.
@@ -268,7 +269,7 @@ def process_one(
if turn == 0:
resp_text, _ = chat_target(
system=system, user=user,
max_completion_tokens=512,
max_completion_tokens=max_completion_tokens,
retries=5, stage="rollout",
timeout=exec_timeout,
)
@@ -281,7 +282,7 @@ def process_one(
)
resp_text, _ = chat_target(
system=system, user=refinement,
max_completion_tokens=512,
max_completion_tokens=max_completion_tokens,
retries=5, stage="rollout",
timeout=exec_timeout,
)
@@ -352,6 +353,7 @@ def run_batch(
max_turns: int = 1,
exec_timeout: int = 120,
workers: int = 64,
max_completion_tokens: int = 16384,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context_by_id: dict[str, str] | None = None,
@@ -423,6 +425,7 @@ def run_batch(
diagnostic_instruction,
(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""),
exec_timeout,
max_completion_tokens,
)
with open(results_path, "a") as outf:

View File

@@ -44,12 +44,15 @@ class SpreadsheetBenchAdapter(EnvAdapter):
failure_only: bool = False,
minibatch_size: int = 8,
edit_budget: int = 4,
seed: int = 42, ) -> None:
seed: int = 42,
max_completion_tokens: int = 16384,
) -> None:
self.data_root = data_root
self.mode = mode # "single", "multi", or "react"
self.max_turns = max_turns
self.exec_timeout = exec_timeout
self.workers = workers
self.max_completion_tokens = int(max_completion_tokens)
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
@@ -124,6 +127,7 @@ class SpreadsheetBenchAdapter(EnvAdapter):
skill_content=skill_content,
mode=self.mode,
max_turns=self.max_turns,
max_completion_tokens=self.max_completion_tokens,
max_api_workers=self.workers,
task_timeout=self.exec_timeout,
use_eval_feedback=kwargs.get("use_eval_feedback", False),
@@ -138,6 +142,7 @@ class SpreadsheetBenchAdapter(EnvAdapter):
out_root=out_dir,
skill_content=skill_content,
max_turns=self.max_turns,
max_completion_tokens=self.max_completion_tokens,
max_api_workers=self.workers,
task_timeout=max(600, int(self.exec_timeout) + 60),
diagnostic_mode=kwargs.get("diagnostic_mode", False),

View File

@@ -365,7 +365,7 @@ def run_react(
answer_position: str = "",
skill_content: str = "",
max_turns: int = 30,
max_output_tokens: int = 4096,
max_output_tokens: int = 16384,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",

View File

@@ -174,6 +174,7 @@ def process_one(
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
max_completion_tokens: int = 16384,
) -> dict:
"""Run the ReAct agent on a single SpreadsheetBench task.
@@ -283,6 +284,7 @@ def process_one(
answer_position=answer_position_eval,
skill_content=skill_content,
max_turns=max_turns,
max_output_tokens=max_completion_tokens,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
@@ -409,6 +411,7 @@ def run_spreadsheet_batch(
out_root: str,
skill_content: str,
max_turns: int = 30,
max_completion_tokens: int = 16384,
max_api_workers: int = 64,
task_timeout: int = 600,
diagnostic_mode: bool = False,
@@ -479,6 +482,7 @@ def run_spreadsheet_batch(
diagnostic_mode,
diagnostic_instruction,
(diagnostic_trace_context_by_id or {}).get(str(it["id"]), ""),
max_completion_tokens,
)
ex = ThreadPoolExecutor(max_workers=max_api_workers)
@@ -542,6 +546,7 @@ def process_one_codegen(
skill_content: str,
mode: str = "single",
max_turns: int = 5,
max_completion_tokens: int = 16384,
use_eval_feedback: bool = False,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
@@ -653,6 +658,7 @@ def process_one_codegen(
answer_position=answer_position_eval,
skill_content=skill_content,
max_turns=max_turns,
max_output_tokens=max_completion_tokens,
gold_path=first_gold if use_eval_feedback else "",
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
@@ -666,6 +672,7 @@ def process_one_codegen(
instruction_type=instruction_type,
answer_position=answer_position_eval,
skill_content=skill_content,
max_output_tokens=max_completion_tokens,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
@@ -790,6 +797,7 @@ def run_spreadsheet_batch_codegen(
skill_content: str,
mode: str = "single",
max_turns: int = 5,
max_completion_tokens: int = 16384,
max_api_workers: int = 32,
task_timeout: int = 0,
use_eval_feedback: bool = False,
@@ -845,6 +853,7 @@ def run_spreadsheet_batch_codegen(
skill_content,
mode,
max_turns,
max_completion_tokens,
use_eval_feedback,
diagnostic_mode,
diagnostic_instruction,