From 786d57b5cfb77a3837ac4b6d84d8127bf8b2945f Mon Sep 17 00:00:00 2001 From: hwq Date: Thu, 28 May 2026 09:45:47 +0000 Subject: [PATCH] Make rollout completion tokens configurable --- configs/_base_/default.yaml | 1 - configs/alfworld/default.yaml | 2 +- configs/docvqa/default.yaml | 2 +- configs/livemathematicianbench/default.yaml | 2 +- configs/officeqa/default.yaml | 2 +- configs/searchqa/default.yaml | 2 +- configs/spreadsheetbench/default.yaml | 2 +- skillopt/envs/alfworld/adapter.py | 7 ++++++- skillopt/envs/alfworld/rollout.py | 2 +- skillopt/envs/docvqa/adapter.py | 6 +++++- skillopt/envs/docvqa/rollout.py | 7 +++++-- skillopt/envs/livemathematicianbench/adapter.py | 6 +++++- skillopt/envs/livemathematicianbench/rollout.py | 7 +++++-- skillopt/envs/officeqa/adapter.py | 2 +- skillopt/envs/officeqa/rollout.py | 6 +++--- skillopt/envs/searchqa/adapter.py | 6 +++++- skillopt/envs/searchqa/rollout.py | 7 +++++-- skillopt/envs/spreadsheetbench/adapter.py | 7 ++++++- skillopt/envs/spreadsheetbench/react_agent.py | 2 +- skillopt/envs/spreadsheetbench/rollout.py | 9 +++++++++ 20 files changed, 63 insertions(+), 24 deletions(-) diff --git a/configs/_base_/default.yaml b/configs/_base_/default.yaml index a454b72..9ec2270 100644 --- a/configs/_base_/default.yaml +++ b/configs/_base_/default.yaml @@ -79,7 +79,6 @@ env: name: "" skill_init: "" split_mode: ratio # ratio = build deterministic split from data_path; split_dir = use pre-split train/val/test - split_ratio: "2:1:7" # explicit default for dataset-backed benchmarks: train:val:test split_seed: 42 split_dir: "" data_path: "" diff --git a/configs/alfworld/default.yaml b/configs/alfworld/default.yaml index d769224..48ce6f0 100644 --- a/configs/alfworld/default.yaml +++ b/configs/alfworld/default.yaml @@ -19,11 +19,11 @@ env: name: alfworld skill_init: skillopt/envs/alfworld/skills/initial.md split_mode: split_dir - split_ratio: "2:1:7" split_dir: data/ablation_splits/alfworld/2-1-7_seed42 data_path: "" split_output_dir: "" max_steps: 50 + max_completion_tokens: 16384 workers: 8 max_api_workers: 8 limit: 0 diff --git a/configs/docvqa/default.yaml b/configs/docvqa/default.yaml index 51bf88f..c3e8ce0 100644 --- a/configs/docvqa/default.yaml +++ b/configs/docvqa/default.yaml @@ -18,11 +18,11 @@ env: name: docvqa skill_init: skillopt/envs/docvqa/skills/initial.md split_mode: split_dir - split_ratio: "2:1:7" split_dir: data/docvqa/splits data_path: "" split_output_dir: "" max_turns: 1 + max_completion_tokens: 16384 workers: 16 image_detail: auto limit: 0 diff --git a/configs/livemathematicianbench/default.yaml b/configs/livemathematicianbench/default.yaml index c337094..465a331 100644 --- a/configs/livemathematicianbench/default.yaml +++ b/configs/livemathematicianbench/default.yaml @@ -9,11 +9,11 @@ env: name: livemathematicianbench skill_init: skillopt/envs/livemathematicianbench/skills/initial.md split_mode: split_dir - split_ratio: "2:1:7" split_dir: data/ablation_splits/livemathematicianbench/2-1-7_seed42 data_path: "" split_output_dir: "" max_turns: 1 + max_completion_tokens: 16384 exec_timeout: 300 workers: 64 limit: 0 diff --git a/configs/officeqa/default.yaml b/configs/officeqa/default.yaml index 992391f..7b72f1a 100644 --- a/configs/officeqa/default.yaml +++ b/configs/officeqa/default.yaml @@ -23,7 +23,7 @@ env: - data/officeqa_docs_official workers: 4 max_tool_turns: 24 - max_completion_tokens: 10000 + max_completion_tokens: 16384 search_mode: offline max_queries_per_turn: 4 search_api_url: http://apisix.westus2.cloudapp.azure.com/search_tool/search diff --git a/configs/searchqa/default.yaml b/configs/searchqa/default.yaml index bd75a7b..a1177ab 100644 --- a/configs/searchqa/default.yaml +++ b/configs/searchqa/default.yaml @@ -23,10 +23,10 @@ env: name: searchqa skill_init: skillopt/envs/searchqa/skills/initial.md split_mode: split_dir - split_ratio: "2:1:7" split_dir: data/searchqa_split data_path: "" split_output_dir: "" max_turns: 1 + max_completion_tokens: 16384 workers: 24 limit: 0 diff --git a/configs/spreadsheetbench/default.yaml b/configs/spreadsheetbench/default.yaml index 13e919f..e93c3a3 100644 --- a/configs/spreadsheetbench/default.yaml +++ b/configs/spreadsheetbench/default.yaml @@ -23,12 +23,12 @@ env: name: spreadsheetbench skill_init: skillopt/envs/spreadsheetbench/skills/initial.md split_mode: split_dir - split_ratio: "2:1:7" split_dir: data/spreadsheetbench_split data_path: "" split_output_dir: "" data_root: data/spreadsheetbench_verified_400 mode: multi max_turns: 30 + max_completion_tokens: 16384 exec_timeout: 600 workers: 24 diff --git a/skillopt/envs/alfworld/adapter.py b/skillopt/envs/alfworld/adapter.py index 41bf73b..e689169 100644 --- a/skillopt/envs/alfworld/adapter.py +++ b/skillopt/envs/alfworld/adapter.py @@ -81,10 +81,13 @@ class ALFWorldAdapter(EnvAdapter): analyst_workers: int = 16, failure_only: bool = False, minibatch_size: int = 8, - edit_budget: int = 4, ) -> None: + edit_budget: int = 4, + max_completion_tokens: int = 16384, + ) -> None: self.max_steps = max_steps self.workers = max(int(workers or 1), 1) self.max_api_workers = max_api_workers + self.max_completion_tokens = int(max_completion_tokens) self.analyst_workers = analyst_workers self.failure_only = failure_only self.minibatch_size = minibatch_size @@ -349,6 +352,7 @@ class ALFWorldAdapter(EnvAdapter): max_steps=self.max_steps, out_root=out_dir, max_api_workers=self.max_api_workers, + max_completion_tokens=self.max_completion_tokens, result_ids=getattr(env_manager, "_skillopt_result_ids", None), ) @@ -411,6 +415,7 @@ class ALFWorldAdapter(EnvAdapter): max_steps=self.max_steps, out_root=out_dir, max_api_workers=min(self.max_api_workers, chunk_size), + max_completion_tokens=self.max_completion_tokens, diagnostic_mode=diagnostic_mode, diagnostic_instruction=diagnostic_instruction, result_ids=chunk_ids, diff --git a/skillopt/envs/alfworld/rollout.py b/skillopt/envs/alfworld/rollout.py index afd84dd..ac6f7f3 100644 --- a/skillopt/envs/alfworld/rollout.py +++ b/skillopt/envs/alfworld/rollout.py @@ -134,7 +134,7 @@ def run_alfworld_batch( out_root: str = "", max_api_workers: int = 8, temperature: float = 0.4, - max_completion_tokens: int = 2048, + max_completion_tokens: int = 16384, diagnostic_mode: bool = False, diagnostic_instruction: str = "", result_ids: list[str] | None = None, diff --git a/skillopt/envs/docvqa/adapter.py b/skillopt/envs/docvqa/adapter.py index 5c95a0b..9184906 100644 --- a/skillopt/envs/docvqa/adapter.py +++ b/skillopt/envs/docvqa/adapter.py @@ -27,10 +27,13 @@ class DocVQAAdapter(EnvAdapter): edit_budget: int = 4, seed: int = 42, limit: int = 0, - image_detail: str = "auto", ) -> None: + image_detail: str = "auto", + max_completion_tokens: int = 16384, + ) -> None: self.max_turns = max_turns self.exec_timeout = exec_timeout self.workers = workers + self.max_completion_tokens = int(max_completion_tokens) self.analyst_workers = analyst_workers self.failure_only = failure_only self.minibatch_size = minibatch_size @@ -75,6 +78,7 @@ class DocVQAAdapter(EnvAdapter): exec_timeout=self.exec_timeout, workers=self.workers, image_detail=self.image_detail, + max_completion_tokens=self.max_completion_tokens, diagnostic_mode=kwargs.get("diagnostic_mode", False), diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), task_timeout=self.exec_timeout, diff --git a/skillopt/envs/docvqa/rollout.py b/skillopt/envs/docvqa/rollout.py index 14d4bb0..6396163 100644 --- a/skillopt/envs/docvqa/rollout.py +++ b/skillopt/envs/docvqa/rollout.py @@ -134,6 +134,7 @@ def process_one( max_turns: int = 1, exec_timeout: int = 120, image_detail: str = "auto", + max_completion_tokens: int = 16384, diagnostic_mode: bool = False, diagnostic_instruction: str = "", ) -> dict: @@ -200,7 +201,7 @@ def process_one( if turn == 0: resp_text, _ = chat_target_messages( messages=messages, - max_completion_tokens=768, + max_completion_tokens=max_completion_tokens, retries=5, stage="rollout", timeout=exec_timeout, @@ -214,7 +215,7 @@ def process_one( ] resp_text, _ = chat_target_messages( messages=refinement_messages, - max_completion_tokens=512, + max_completion_tokens=max_completion_tokens, retries=5, stage="rollout", timeout=exec_timeout, @@ -266,6 +267,7 @@ def run_batch( exec_timeout: int = 120, workers: int = 16, image_detail: str = "auto", + max_completion_tokens: int = 16384, diagnostic_mode: bool = False, diagnostic_instruction: str = "", task_timeout: int = 600, @@ -325,6 +327,7 @@ def run_batch( max_turns=max_turns, exec_timeout=exec_timeout, image_detail=image_detail, + max_completion_tokens=max_completion_tokens, diagnostic_mode=diagnostic_mode, diagnostic_instruction=diagnostic_instruction, ) diff --git a/skillopt/envs/livemathematicianbench/adapter.py b/skillopt/envs/livemathematicianbench/adapter.py index b98090c..554b067 100644 --- a/skillopt/envs/livemathematicianbench/adapter.py +++ b/skillopt/envs/livemathematicianbench/adapter.py @@ -60,10 +60,13 @@ class LiveMathematicianBenchAdapter(EnvAdapter): limit: int = 0, shuffle_choices: bool = True, use_theorem: bool = False, - use_sketch: bool = False, ) -> None: + use_sketch: bool = False, + max_completion_tokens: int = 16384, + ) -> None: self.max_turns = max_turns self.exec_timeout = exec_timeout self.workers = workers + self.max_completion_tokens = int(max_completion_tokens) self.analyst_workers = analyst_workers self.failure_only = failure_only self.minibatch_size = minibatch_size @@ -115,6 +118,7 @@ class LiveMathematicianBenchAdapter(EnvAdapter): max_turns=self.max_turns, exec_timeout=self.exec_timeout, workers=self.workers, + max_completion_tokens=self.max_completion_tokens, use_theorem=self.use_theorem, use_sketch=self.use_sketch, diagnostic_mode=kwargs.get("diagnostic_mode", False), diff --git a/skillopt/envs/livemathematicianbench/rollout.py b/skillopt/envs/livemathematicianbench/rollout.py index de4f3dc..a217648 100644 --- a/skillopt/envs/livemathematicianbench/rollout.py +++ b/skillopt/envs/livemathematicianbench/rollout.py @@ -120,6 +120,7 @@ def process_one( diagnostic_instruction: str = "", diagnostic_trace_context: str = "", exec_timeout: int = 300, + max_completion_tokens: int = 16384, ) -> dict: item_id = str(item["id"]) result = { @@ -219,7 +220,7 @@ def process_one( resp_text, _ = chat_target( system=system, user=user, - max_completion_tokens=16384, + max_completion_tokens=max_completion_tokens, retries=5, stage="rollout", timeout=exec_timeout, @@ -233,7 +234,7 @@ def process_one( resp_text, _ = chat_target( system=system, user=refinement, - max_completion_tokens=16384, + max_completion_tokens=max_completion_tokens, retries=5, stage="rollout", timeout=exec_timeout, @@ -293,6 +294,7 @@ def run_batch( max_turns: int = 1, exec_timeout: int = 300, workers: int = 64, + max_completion_tokens: int = 16384, use_theorem: bool = False, use_sketch: bool = False, diagnostic_mode: bool = False, @@ -338,6 +340,7 @@ def run_batch( skill_content, max_turns=max_turns, exec_timeout=exec_timeout, + max_completion_tokens=max_completion_tokens, use_theorem=use_theorem, use_sketch=use_sketch, diagnostic_mode=diagnostic_mode, diff --git a/skillopt/envs/officeqa/adapter.py b/skillopt/envs/officeqa/adapter.py index b504309..ba2e6f1 100644 --- a/skillopt/envs/officeqa/adapter.py +++ b/skillopt/envs/officeqa/adapter.py @@ -26,7 +26,7 @@ class OfficeQAAdapter(EnvAdapter): seed: int = 42, limit: int = 0, max_tool_turns: int = 12, - max_completion_tokens: int = 64000, + max_completion_tokens: int = 16384, search_mode: str = "offline", max_queries_per_turn: int = 4, search_api_url: str = os.environ.get("OFFICEQA_SEARCH_API_URL", "http://localhost:8080/search_tool/search"), diff --git a/skillopt/envs/officeqa/rollout.py b/skillopt/envs/officeqa/rollout.py index 7d3a37a..871281b 100644 --- a/skillopt/envs/officeqa/rollout.py +++ b/skillopt/envs/officeqa/rollout.py @@ -516,7 +516,7 @@ def process_one( skill_content: str, *, max_tool_turns: int = 12, - max_completion_tokens: int = 64000, + max_completion_tokens: int = 16384, search_mode: str = _DEFAULT_SEARCH_MODE, max_queries_per_turn: int = 4, search_api_url: str = "", @@ -652,7 +652,7 @@ def process_one( for turn in range(1, max_tool_turns + 1): message, _ = chat_target_messages( messages=messages, - max_completion_tokens=768, + max_completion_tokens=max_completion_tokens, retries=5, stage="rollout", tools=_TOOL_SCHEMAS, @@ -725,7 +725,7 @@ def run_batch( *, workers: int = 8, max_tool_turns: int = 12, - max_completion_tokens: int = 64000, + max_completion_tokens: int = 16384, search_mode: str = _DEFAULT_SEARCH_MODE, max_queries_per_turn: int = 4, search_api_url: str = "", diff --git a/skillopt/envs/searchqa/adapter.py b/skillopt/envs/searchqa/adapter.py index 15afbd0..2253ebe 100644 --- a/skillopt/envs/searchqa/adapter.py +++ b/skillopt/envs/searchqa/adapter.py @@ -31,10 +31,13 @@ class SearchQAAdapter(EnvAdapter): minibatch_size: int = 8, edit_budget: int = 4, seed: int = 42, - limit: int = 0, ) -> None: + limit: int = 0, + max_completion_tokens: int = 16384, + ) -> None: self.max_turns = max_turns self.exec_timeout = exec_timeout self.workers = workers + self.max_completion_tokens = int(max_completion_tokens) self.analyst_workers = analyst_workers self.failure_only = failure_only self.minibatch_size = minibatch_size @@ -84,6 +87,7 @@ class SearchQAAdapter(EnvAdapter): max_turns=self.max_turns, exec_timeout=self.exec_timeout, workers=self.workers, + max_completion_tokens=self.max_completion_tokens, diagnostic_mode=kwargs.get("diagnostic_mode", False), diagnostic_instruction=kwargs.get("diagnostic_instruction", ""), diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"), diff --git a/skillopt/envs/searchqa/rollout.py b/skillopt/envs/searchqa/rollout.py index b94f671..ab7215d 100644 --- a/skillopt/envs/searchqa/rollout.py +++ b/skillopt/envs/searchqa/rollout.py @@ -148,6 +148,7 @@ def process_one( diagnostic_instruction: str = "", diagnostic_trace_context: str = "", exec_timeout: int = 120, + max_completion_tokens: int = 16384, ) -> dict: """Process a single QA item: run agent + evaluate. @@ -268,7 +269,7 @@ def process_one( if turn == 0: resp_text, _ = chat_target( system=system, user=user, - max_completion_tokens=512, + max_completion_tokens=max_completion_tokens, retries=5, stage="rollout", timeout=exec_timeout, ) @@ -281,7 +282,7 @@ def process_one( ) resp_text, _ = chat_target( system=system, user=refinement, - max_completion_tokens=512, + max_completion_tokens=max_completion_tokens, retries=5, stage="rollout", timeout=exec_timeout, ) @@ -352,6 +353,7 @@ def run_batch( max_turns: int = 1, exec_timeout: int = 120, workers: int = 64, + max_completion_tokens: int = 16384, diagnostic_mode: bool = False, diagnostic_instruction: str = "", diagnostic_trace_context_by_id: dict[str, str] | None = None, @@ -423,6 +425,7 @@ def run_batch( diagnostic_instruction, (diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""), exec_timeout, + max_completion_tokens, ) with open(results_path, "a") as outf: diff --git a/skillopt/envs/spreadsheetbench/adapter.py b/skillopt/envs/spreadsheetbench/adapter.py index c43ae98..5b2b678 100644 --- a/skillopt/envs/spreadsheetbench/adapter.py +++ b/skillopt/envs/spreadsheetbench/adapter.py @@ -44,12 +44,15 @@ class SpreadsheetBenchAdapter(EnvAdapter): failure_only: bool = False, minibatch_size: int = 8, edit_budget: int = 4, - seed: int = 42, ) -> None: + seed: int = 42, + max_completion_tokens: int = 16384, + ) -> None: self.data_root = data_root self.mode = mode # "single", "multi", or "react" self.max_turns = max_turns self.exec_timeout = exec_timeout self.workers = workers + self.max_completion_tokens = int(max_completion_tokens) self.analyst_workers = analyst_workers self.failure_only = failure_only self.minibatch_size = minibatch_size @@ -124,6 +127,7 @@ class SpreadsheetBenchAdapter(EnvAdapter): skill_content=skill_content, mode=self.mode, max_turns=self.max_turns, + max_completion_tokens=self.max_completion_tokens, max_api_workers=self.workers, task_timeout=self.exec_timeout, use_eval_feedback=kwargs.get("use_eval_feedback", False), @@ -138,6 +142,7 @@ class SpreadsheetBenchAdapter(EnvAdapter): out_root=out_dir, skill_content=skill_content, max_turns=self.max_turns, + max_completion_tokens=self.max_completion_tokens, max_api_workers=self.workers, task_timeout=max(600, int(self.exec_timeout) + 60), diagnostic_mode=kwargs.get("diagnostic_mode", False), diff --git a/skillopt/envs/spreadsheetbench/react_agent.py b/skillopt/envs/spreadsheetbench/react_agent.py index ff296a8..2e72953 100644 --- a/skillopt/envs/spreadsheetbench/react_agent.py +++ b/skillopt/envs/spreadsheetbench/react_agent.py @@ -365,7 +365,7 @@ def run_react( answer_position: str = "", skill_content: str = "", max_turns: int = 30, - max_output_tokens: int = 4096, + max_output_tokens: int = 16384, diagnostic_mode: bool = False, diagnostic_instruction: str = "", diagnostic_trace_context: str = "", diff --git a/skillopt/envs/spreadsheetbench/rollout.py b/skillopt/envs/spreadsheetbench/rollout.py index d9c35d6..d33594e 100644 --- a/skillopt/envs/spreadsheetbench/rollout.py +++ b/skillopt/envs/spreadsheetbench/rollout.py @@ -174,6 +174,7 @@ def process_one( diagnostic_mode: bool = False, diagnostic_instruction: str = "", diagnostic_trace_context: str = "", + max_completion_tokens: int = 16384, ) -> dict: """Run the ReAct agent on a single SpreadsheetBench task. @@ -283,6 +284,7 @@ def process_one( answer_position=answer_position_eval, skill_content=skill_content, max_turns=max_turns, + max_output_tokens=max_completion_tokens, diagnostic_mode=diagnostic_mode, diagnostic_instruction=diagnostic_instruction, diagnostic_trace_context=diagnostic_trace_context, @@ -409,6 +411,7 @@ def run_spreadsheet_batch( out_root: str, skill_content: str, max_turns: int = 30, + max_completion_tokens: int = 16384, max_api_workers: int = 64, task_timeout: int = 600, diagnostic_mode: bool = False, @@ -479,6 +482,7 @@ def run_spreadsheet_batch( diagnostic_mode, diagnostic_instruction, (diagnostic_trace_context_by_id or {}).get(str(it["id"]), ""), + max_completion_tokens, ) ex = ThreadPoolExecutor(max_workers=max_api_workers) @@ -542,6 +546,7 @@ def process_one_codegen( skill_content: str, mode: str = "single", max_turns: int = 5, + max_completion_tokens: int = 16384, use_eval_feedback: bool = False, diagnostic_mode: bool = False, diagnostic_instruction: str = "", @@ -653,6 +658,7 @@ def process_one_codegen( answer_position=answer_position_eval, skill_content=skill_content, max_turns=max_turns, + max_output_tokens=max_completion_tokens, gold_path=first_gold if use_eval_feedback else "", diagnostic_mode=diagnostic_mode, diagnostic_instruction=diagnostic_instruction, @@ -666,6 +672,7 @@ def process_one_codegen( instruction_type=instruction_type, answer_position=answer_position_eval, skill_content=skill_content, + max_output_tokens=max_completion_tokens, diagnostic_mode=diagnostic_mode, diagnostic_instruction=diagnostic_instruction, diagnostic_trace_context=diagnostic_trace_context, @@ -790,6 +797,7 @@ def run_spreadsheet_batch_codegen( skill_content: str, mode: str = "single", max_turns: int = 5, + max_completion_tokens: int = 16384, max_api_workers: int = 32, task_timeout: int = 0, use_eval_feedback: bool = False, @@ -845,6 +853,7 @@ def run_spreadsheet_batch_codegen( skill_content, mode, max_turns, + max_completion_tokens, use_eval_feedback, diagnostic_mode, diagnostic_instruction,