Make rollout completion tokens configurable

This commit is contained in:
hwq
2026-05-28 09:45:47 +00:00
parent 99212e3956
commit 786d57b5cf
20 changed files with 63 additions and 24 deletions

View File

@@ -81,10 +81,13 @@ class ALFWorldAdapter(EnvAdapter):
analyst_workers: int = 16,
failure_only: bool = False,
minibatch_size: int = 8,
edit_budget: int = 4, ) -> None:
edit_budget: int = 4,
max_completion_tokens: int = 16384,
) -> None:
self.max_steps = max_steps
self.workers = max(int(workers or 1), 1)
self.max_api_workers = max_api_workers
self.max_completion_tokens = int(max_completion_tokens)
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
@@ -349,6 +352,7 @@ class ALFWorldAdapter(EnvAdapter):
max_steps=self.max_steps,
out_root=out_dir,
max_api_workers=self.max_api_workers,
max_completion_tokens=self.max_completion_tokens,
result_ids=getattr(env_manager, "_skillopt_result_ids", None),
)
@@ -411,6 +415,7 @@ class ALFWorldAdapter(EnvAdapter):
max_steps=self.max_steps,
out_root=out_dir,
max_api_workers=min(self.max_api_workers, chunk_size),
max_completion_tokens=self.max_completion_tokens,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
result_ids=chunk_ids,

View File

@@ -134,7 +134,7 @@ def run_alfworld_batch(
out_root: str = "",
max_api_workers: int = 8,
temperature: float = 0.4,
max_completion_tokens: int = 2048,
max_completion_tokens: int = 16384,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
result_ids: list[str] | None = None,

View File

@@ -27,10 +27,13 @@ class DocVQAAdapter(EnvAdapter):
edit_budget: int = 4,
seed: int = 42,
limit: int = 0,
image_detail: str = "auto", ) -> None:
image_detail: str = "auto",
max_completion_tokens: int = 16384,
) -> None:
self.max_turns = max_turns
self.exec_timeout = exec_timeout
self.workers = workers
self.max_completion_tokens = int(max_completion_tokens)
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
@@ -75,6 +78,7 @@ class DocVQAAdapter(EnvAdapter):
exec_timeout=self.exec_timeout,
workers=self.workers,
image_detail=self.image_detail,
max_completion_tokens=self.max_completion_tokens,
diagnostic_mode=kwargs.get("diagnostic_mode", False),
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
task_timeout=self.exec_timeout,

View File

@@ -134,6 +134,7 @@ def process_one(
max_turns: int = 1,
exec_timeout: int = 120,
image_detail: str = "auto",
max_completion_tokens: int = 16384,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
) -> dict:
@@ -200,7 +201,7 @@ def process_one(
if turn == 0:
resp_text, _ = chat_target_messages(
messages=messages,
max_completion_tokens=768,
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=exec_timeout,
@@ -214,7 +215,7 @@ def process_one(
]
resp_text, _ = chat_target_messages(
messages=refinement_messages,
max_completion_tokens=512,
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=exec_timeout,
@@ -266,6 +267,7 @@ def run_batch(
exec_timeout: int = 120,
workers: int = 16,
image_detail: str = "auto",
max_completion_tokens: int = 16384,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
task_timeout: int = 600,
@@ -325,6 +327,7 @@ def run_batch(
max_turns=max_turns,
exec_timeout=exec_timeout,
image_detail=image_detail,
max_completion_tokens=max_completion_tokens,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
)

View File

@@ -60,10 +60,13 @@ class LiveMathematicianBenchAdapter(EnvAdapter):
limit: int = 0,
shuffle_choices: bool = True,
use_theorem: bool = False,
use_sketch: bool = False, ) -> None:
use_sketch: bool = False,
max_completion_tokens: int = 16384,
) -> None:
self.max_turns = max_turns
self.exec_timeout = exec_timeout
self.workers = workers
self.max_completion_tokens = int(max_completion_tokens)
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
@@ -115,6 +118,7 @@ class LiveMathematicianBenchAdapter(EnvAdapter):
max_turns=self.max_turns,
exec_timeout=self.exec_timeout,
workers=self.workers,
max_completion_tokens=self.max_completion_tokens,
use_theorem=self.use_theorem,
use_sketch=self.use_sketch,
diagnostic_mode=kwargs.get("diagnostic_mode", False),

View File

@@ -120,6 +120,7 @@ def process_one(
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
exec_timeout: int = 300,
max_completion_tokens: int = 16384,
) -> dict:
item_id = str(item["id"])
result = {
@@ -219,7 +220,7 @@ def process_one(
resp_text, _ = chat_target(
system=system,
user=user,
max_completion_tokens=16384,
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=exec_timeout,
@@ -233,7 +234,7 @@ def process_one(
resp_text, _ = chat_target(
system=system,
user=refinement,
max_completion_tokens=16384,
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
timeout=exec_timeout,
@@ -293,6 +294,7 @@ def run_batch(
max_turns: int = 1,
exec_timeout: int = 300,
workers: int = 64,
max_completion_tokens: int = 16384,
use_theorem: bool = False,
use_sketch: bool = False,
diagnostic_mode: bool = False,
@@ -338,6 +340,7 @@ def run_batch(
skill_content,
max_turns=max_turns,
exec_timeout=exec_timeout,
max_completion_tokens=max_completion_tokens,
use_theorem=use_theorem,
use_sketch=use_sketch,
diagnostic_mode=diagnostic_mode,

View File

@@ -26,7 +26,7 @@ class OfficeQAAdapter(EnvAdapter):
seed: int = 42,
limit: int = 0,
max_tool_turns: int = 12,
max_completion_tokens: int = 64000,
max_completion_tokens: int = 16384,
search_mode: str = "offline",
max_queries_per_turn: int = 4,
search_api_url: str = os.environ.get("OFFICEQA_SEARCH_API_URL", "http://localhost:8080/search_tool/search"),

View File

@@ -516,7 +516,7 @@ def process_one(
skill_content: str,
*,
max_tool_turns: int = 12,
max_completion_tokens: int = 64000,
max_completion_tokens: int = 16384,
search_mode: str = _DEFAULT_SEARCH_MODE,
max_queries_per_turn: int = 4,
search_api_url: str = "",
@@ -652,7 +652,7 @@ def process_one(
for turn in range(1, max_tool_turns + 1):
message, _ = chat_target_messages(
messages=messages,
max_completion_tokens=768,
max_completion_tokens=max_completion_tokens,
retries=5,
stage="rollout",
tools=_TOOL_SCHEMAS,
@@ -725,7 +725,7 @@ def run_batch(
*,
workers: int = 8,
max_tool_turns: int = 12,
max_completion_tokens: int = 64000,
max_completion_tokens: int = 16384,
search_mode: str = _DEFAULT_SEARCH_MODE,
max_queries_per_turn: int = 4,
search_api_url: str = "",

View File

@@ -31,10 +31,13 @@ class SearchQAAdapter(EnvAdapter):
minibatch_size: int = 8,
edit_budget: int = 4,
seed: int = 42,
limit: int = 0, ) -> None:
limit: int = 0,
max_completion_tokens: int = 16384,
) -> None:
self.max_turns = max_turns
self.exec_timeout = exec_timeout
self.workers = workers
self.max_completion_tokens = int(max_completion_tokens)
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
@@ -84,6 +87,7 @@ class SearchQAAdapter(EnvAdapter):
max_turns=self.max_turns,
exec_timeout=self.exec_timeout,
workers=self.workers,
max_completion_tokens=self.max_completion_tokens,
diagnostic_mode=kwargs.get("diagnostic_mode", False),
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),

View File

@@ -148,6 +148,7 @@ def process_one(
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
exec_timeout: int = 120,
max_completion_tokens: int = 16384,
) -> dict:
"""Process a single QA item: run agent + evaluate.
@@ -268,7 +269,7 @@ def process_one(
if turn == 0:
resp_text, _ = chat_target(
system=system, user=user,
max_completion_tokens=512,
max_completion_tokens=max_completion_tokens,
retries=5, stage="rollout",
timeout=exec_timeout,
)
@@ -281,7 +282,7 @@ def process_one(
)
resp_text, _ = chat_target(
system=system, user=refinement,
max_completion_tokens=512,
max_completion_tokens=max_completion_tokens,
retries=5, stage="rollout",
timeout=exec_timeout,
)
@@ -352,6 +353,7 @@ def run_batch(
max_turns: int = 1,
exec_timeout: int = 120,
workers: int = 64,
max_completion_tokens: int = 16384,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context_by_id: dict[str, str] | None = None,
@@ -423,6 +425,7 @@ def run_batch(
diagnostic_instruction,
(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""),
exec_timeout,
max_completion_tokens,
)
with open(results_path, "a") as outf:

View File

@@ -44,12 +44,15 @@ class SpreadsheetBenchAdapter(EnvAdapter):
failure_only: bool = False,
minibatch_size: int = 8,
edit_budget: int = 4,
seed: int = 42, ) -> None:
seed: int = 42,
max_completion_tokens: int = 16384,
) -> None:
self.data_root = data_root
self.mode = mode # "single", "multi", or "react"
self.max_turns = max_turns
self.exec_timeout = exec_timeout
self.workers = workers
self.max_completion_tokens = int(max_completion_tokens)
self.analyst_workers = analyst_workers
self.failure_only = failure_only
self.minibatch_size = minibatch_size
@@ -124,6 +127,7 @@ class SpreadsheetBenchAdapter(EnvAdapter):
skill_content=skill_content,
mode=self.mode,
max_turns=self.max_turns,
max_completion_tokens=self.max_completion_tokens,
max_api_workers=self.workers,
task_timeout=self.exec_timeout,
use_eval_feedback=kwargs.get("use_eval_feedback", False),
@@ -138,6 +142,7 @@ class SpreadsheetBenchAdapter(EnvAdapter):
out_root=out_dir,
skill_content=skill_content,
max_turns=self.max_turns,
max_completion_tokens=self.max_completion_tokens,
max_api_workers=self.workers,
task_timeout=max(600, int(self.exec_timeout) + 60),
diagnostic_mode=kwargs.get("diagnostic_mode", False),

View File

@@ -365,7 +365,7 @@ def run_react(
answer_position: str = "",
skill_content: str = "",
max_turns: int = 30,
max_output_tokens: int = 4096,
max_output_tokens: int = 16384,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",

View File

@@ -174,6 +174,7 @@ def process_one(
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
diagnostic_trace_context: str = "",
max_completion_tokens: int = 16384,
) -> dict:
"""Run the ReAct agent on a single SpreadsheetBench task.
@@ -283,6 +284,7 @@ def process_one(
answer_position=answer_position_eval,
skill_content=skill_content,
max_turns=max_turns,
max_output_tokens=max_completion_tokens,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
@@ -409,6 +411,7 @@ def run_spreadsheet_batch(
out_root: str,
skill_content: str,
max_turns: int = 30,
max_completion_tokens: int = 16384,
max_api_workers: int = 64,
task_timeout: int = 600,
diagnostic_mode: bool = False,
@@ -479,6 +482,7 @@ def run_spreadsheet_batch(
diagnostic_mode,
diagnostic_instruction,
(diagnostic_trace_context_by_id or {}).get(str(it["id"]), ""),
max_completion_tokens,
)
ex = ThreadPoolExecutor(max_workers=max_api_workers)
@@ -542,6 +546,7 @@ def process_one_codegen(
skill_content: str,
mode: str = "single",
max_turns: int = 5,
max_completion_tokens: int = 16384,
use_eval_feedback: bool = False,
diagnostic_mode: bool = False,
diagnostic_instruction: str = "",
@@ -653,6 +658,7 @@ def process_one_codegen(
answer_position=answer_position_eval,
skill_content=skill_content,
max_turns=max_turns,
max_output_tokens=max_completion_tokens,
gold_path=first_gold if use_eval_feedback else "",
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
@@ -666,6 +672,7 @@ def process_one_codegen(
instruction_type=instruction_type,
answer_position=answer_position_eval,
skill_content=skill_content,
max_output_tokens=max_completion_tokens,
diagnostic_mode=diagnostic_mode,
diagnostic_instruction=diagnostic_instruction,
diagnostic_trace_context=diagnostic_trace_context,
@@ -790,6 +797,7 @@ def run_spreadsheet_batch_codegen(
skill_content: str,
mode: str = "single",
max_turns: int = 5,
max_completion_tokens: int = 16384,
max_api_workers: int = 32,
task_timeout: int = 0,
use_eval_feedback: bool = False,
@@ -845,6 +853,7 @@ def run_spreadsheet_batch_codegen(
skill_content,
mode,
max_turns,
max_completion_tokens,
use_eval_feedback,
diagnostic_mode,
diagnostic_instruction,