mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
Make rollout completion tokens configurable
This commit is contained in:
@@ -81,10 +81,13 @@ class ALFWorldAdapter(EnvAdapter):
|
||||
analyst_workers: int = 16,
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4, ) -> None:
|
||||
edit_budget: int = 4,
|
||||
max_completion_tokens: int = 16384,
|
||||
) -> None:
|
||||
self.max_steps = max_steps
|
||||
self.workers = max(int(workers or 1), 1)
|
||||
self.max_api_workers = max_api_workers
|
||||
self.max_completion_tokens = int(max_completion_tokens)
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
@@ -349,6 +352,7 @@ class ALFWorldAdapter(EnvAdapter):
|
||||
max_steps=self.max_steps,
|
||||
out_root=out_dir,
|
||||
max_api_workers=self.max_api_workers,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
result_ids=getattr(env_manager, "_skillopt_result_ids", None),
|
||||
)
|
||||
|
||||
@@ -411,6 +415,7 @@ class ALFWorldAdapter(EnvAdapter):
|
||||
max_steps=self.max_steps,
|
||||
out_root=out_dir,
|
||||
max_api_workers=min(self.max_api_workers, chunk_size),
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
result_ids=chunk_ids,
|
||||
|
||||
@@ -134,7 +134,7 @@ def run_alfworld_batch(
|
||||
out_root: str = "",
|
||||
max_api_workers: int = 8,
|
||||
temperature: float = 0.4,
|
||||
max_completion_tokens: int = 2048,
|
||||
max_completion_tokens: int = 16384,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
result_ids: list[str] | None = None,
|
||||
|
||||
@@ -27,10 +27,13 @@ class DocVQAAdapter(EnvAdapter):
|
||||
edit_budget: int = 4,
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
image_detail: str = "auto", ) -> None:
|
||||
image_detail: str = "auto",
|
||||
max_completion_tokens: int = 16384,
|
||||
) -> None:
|
||||
self.max_turns = max_turns
|
||||
self.exec_timeout = exec_timeout
|
||||
self.workers = workers
|
||||
self.max_completion_tokens = int(max_completion_tokens)
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
@@ -75,6 +78,7 @@ class DocVQAAdapter(EnvAdapter):
|
||||
exec_timeout=self.exec_timeout,
|
||||
workers=self.workers,
|
||||
image_detail=self.image_detail,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
|
||||
task_timeout=self.exec_timeout,
|
||||
|
||||
@@ -134,6 +134,7 @@ def process_one(
|
||||
max_turns: int = 1,
|
||||
exec_timeout: int = 120,
|
||||
image_detail: str = "auto",
|
||||
max_completion_tokens: int = 16384,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
) -> dict:
|
||||
@@ -200,7 +201,7 @@ def process_one(
|
||||
if turn == 0:
|
||||
resp_text, _ = chat_target_messages(
|
||||
messages=messages,
|
||||
max_completion_tokens=768,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
timeout=exec_timeout,
|
||||
@@ -214,7 +215,7 @@ def process_one(
|
||||
]
|
||||
resp_text, _ = chat_target_messages(
|
||||
messages=refinement_messages,
|
||||
max_completion_tokens=512,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
timeout=exec_timeout,
|
||||
@@ -266,6 +267,7 @@ def run_batch(
|
||||
exec_timeout: int = 120,
|
||||
workers: int = 16,
|
||||
image_detail: str = "auto",
|
||||
max_completion_tokens: int = 16384,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
task_timeout: int = 600,
|
||||
@@ -325,6 +327,7 @@ def run_batch(
|
||||
max_turns=max_turns,
|
||||
exec_timeout=exec_timeout,
|
||||
image_detail=image_detail,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
)
|
||||
|
||||
@@ -60,10 +60,13 @@ class LiveMathematicianBenchAdapter(EnvAdapter):
|
||||
limit: int = 0,
|
||||
shuffle_choices: bool = True,
|
||||
use_theorem: bool = False,
|
||||
use_sketch: bool = False, ) -> None:
|
||||
use_sketch: bool = False,
|
||||
max_completion_tokens: int = 16384,
|
||||
) -> None:
|
||||
self.max_turns = max_turns
|
||||
self.exec_timeout = exec_timeout
|
||||
self.workers = workers
|
||||
self.max_completion_tokens = int(max_completion_tokens)
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
@@ -115,6 +118,7 @@ class LiveMathematicianBenchAdapter(EnvAdapter):
|
||||
max_turns=self.max_turns,
|
||||
exec_timeout=self.exec_timeout,
|
||||
workers=self.workers,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
use_theorem=self.use_theorem,
|
||||
use_sketch=self.use_sketch,
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
|
||||
@@ -120,6 +120,7 @@ def process_one(
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
exec_timeout: int = 300,
|
||||
max_completion_tokens: int = 16384,
|
||||
) -> dict:
|
||||
item_id = str(item["id"])
|
||||
result = {
|
||||
@@ -219,7 +220,7 @@ def process_one(
|
||||
resp_text, _ = chat_target(
|
||||
system=system,
|
||||
user=user,
|
||||
max_completion_tokens=16384,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
timeout=exec_timeout,
|
||||
@@ -233,7 +234,7 @@ def process_one(
|
||||
resp_text, _ = chat_target(
|
||||
system=system,
|
||||
user=refinement,
|
||||
max_completion_tokens=16384,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
timeout=exec_timeout,
|
||||
@@ -293,6 +294,7 @@ def run_batch(
|
||||
max_turns: int = 1,
|
||||
exec_timeout: int = 300,
|
||||
workers: int = 64,
|
||||
max_completion_tokens: int = 16384,
|
||||
use_theorem: bool = False,
|
||||
use_sketch: bool = False,
|
||||
diagnostic_mode: bool = False,
|
||||
@@ -338,6 +340,7 @@ def run_batch(
|
||||
skill_content,
|
||||
max_turns=max_turns,
|
||||
exec_timeout=exec_timeout,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
use_theorem=use_theorem,
|
||||
use_sketch=use_sketch,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
|
||||
@@ -26,7 +26,7 @@ class OfficeQAAdapter(EnvAdapter):
|
||||
seed: int = 42,
|
||||
limit: int = 0,
|
||||
max_tool_turns: int = 12,
|
||||
max_completion_tokens: int = 64000,
|
||||
max_completion_tokens: int = 16384,
|
||||
search_mode: str = "offline",
|
||||
max_queries_per_turn: int = 4,
|
||||
search_api_url: str = os.environ.get("OFFICEQA_SEARCH_API_URL", "http://localhost:8080/search_tool/search"),
|
||||
|
||||
@@ -516,7 +516,7 @@ def process_one(
|
||||
skill_content: str,
|
||||
*,
|
||||
max_tool_turns: int = 12,
|
||||
max_completion_tokens: int = 64000,
|
||||
max_completion_tokens: int = 16384,
|
||||
search_mode: str = _DEFAULT_SEARCH_MODE,
|
||||
max_queries_per_turn: int = 4,
|
||||
search_api_url: str = "",
|
||||
@@ -652,7 +652,7 @@ def process_one(
|
||||
for turn in range(1, max_tool_turns + 1):
|
||||
message, _ = chat_target_messages(
|
||||
messages=messages,
|
||||
max_completion_tokens=768,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=5,
|
||||
stage="rollout",
|
||||
tools=_TOOL_SCHEMAS,
|
||||
@@ -725,7 +725,7 @@ def run_batch(
|
||||
*,
|
||||
workers: int = 8,
|
||||
max_tool_turns: int = 12,
|
||||
max_completion_tokens: int = 64000,
|
||||
max_completion_tokens: int = 16384,
|
||||
search_mode: str = _DEFAULT_SEARCH_MODE,
|
||||
max_queries_per_turn: int = 4,
|
||||
search_api_url: str = "",
|
||||
|
||||
@@ -31,10 +31,13 @@ class SearchQAAdapter(EnvAdapter):
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
seed: int = 42,
|
||||
limit: int = 0, ) -> None:
|
||||
limit: int = 0,
|
||||
max_completion_tokens: int = 16384,
|
||||
) -> None:
|
||||
self.max_turns = max_turns
|
||||
self.exec_timeout = exec_timeout
|
||||
self.workers = workers
|
||||
self.max_completion_tokens = int(max_completion_tokens)
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
@@ -84,6 +87,7 @@ class SearchQAAdapter(EnvAdapter):
|
||||
max_turns=self.max_turns,
|
||||
exec_timeout=self.exec_timeout,
|
||||
workers=self.workers,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
diagnostic_instruction=kwargs.get("diagnostic_instruction", ""),
|
||||
diagnostic_trace_context_by_id=kwargs.get("diagnostic_trace_context_by_id"),
|
||||
|
||||
@@ -148,6 +148,7 @@ def process_one(
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
exec_timeout: int = 120,
|
||||
max_completion_tokens: int = 16384,
|
||||
) -> dict:
|
||||
"""Process a single QA item: run agent + evaluate.
|
||||
|
||||
@@ -268,7 +269,7 @@ def process_one(
|
||||
if turn == 0:
|
||||
resp_text, _ = chat_target(
|
||||
system=system, user=user,
|
||||
max_completion_tokens=512,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=5, stage="rollout",
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
@@ -281,7 +282,7 @@ def process_one(
|
||||
)
|
||||
resp_text, _ = chat_target(
|
||||
system=system, user=refinement,
|
||||
max_completion_tokens=512,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
retries=5, stage="rollout",
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
@@ -352,6 +353,7 @@ def run_batch(
|
||||
max_turns: int = 1,
|
||||
exec_timeout: int = 120,
|
||||
workers: int = 64,
|
||||
max_completion_tokens: int = 16384,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context_by_id: dict[str, str] | None = None,
|
||||
@@ -423,6 +425,7 @@ def run_batch(
|
||||
diagnostic_instruction,
|
||||
(diagnostic_trace_context_by_id or {}).get(str(item["id"]), ""),
|
||||
exec_timeout,
|
||||
max_completion_tokens,
|
||||
)
|
||||
|
||||
with open(results_path, "a") as outf:
|
||||
|
||||
@@ -44,12 +44,15 @@ class SpreadsheetBenchAdapter(EnvAdapter):
|
||||
failure_only: bool = False,
|
||||
minibatch_size: int = 8,
|
||||
edit_budget: int = 4,
|
||||
seed: int = 42, ) -> None:
|
||||
seed: int = 42,
|
||||
max_completion_tokens: int = 16384,
|
||||
) -> None:
|
||||
self.data_root = data_root
|
||||
self.mode = mode # "single", "multi", or "react"
|
||||
self.max_turns = max_turns
|
||||
self.exec_timeout = exec_timeout
|
||||
self.workers = workers
|
||||
self.max_completion_tokens = int(max_completion_tokens)
|
||||
self.analyst_workers = analyst_workers
|
||||
self.failure_only = failure_only
|
||||
self.minibatch_size = minibatch_size
|
||||
@@ -124,6 +127,7 @@ class SpreadsheetBenchAdapter(EnvAdapter):
|
||||
skill_content=skill_content,
|
||||
mode=self.mode,
|
||||
max_turns=self.max_turns,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
max_api_workers=self.workers,
|
||||
task_timeout=self.exec_timeout,
|
||||
use_eval_feedback=kwargs.get("use_eval_feedback", False),
|
||||
@@ -138,6 +142,7 @@ class SpreadsheetBenchAdapter(EnvAdapter):
|
||||
out_root=out_dir,
|
||||
skill_content=skill_content,
|
||||
max_turns=self.max_turns,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
max_api_workers=self.workers,
|
||||
task_timeout=max(600, int(self.exec_timeout) + 60),
|
||||
diagnostic_mode=kwargs.get("diagnostic_mode", False),
|
||||
|
||||
@@ -365,7 +365,7 @@ def run_react(
|
||||
answer_position: str = "",
|
||||
skill_content: str = "",
|
||||
max_turns: int = 30,
|
||||
max_output_tokens: int = 4096,
|
||||
max_output_tokens: int = 16384,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
|
||||
@@ -174,6 +174,7 @@ def process_one(
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
diagnostic_trace_context: str = "",
|
||||
max_completion_tokens: int = 16384,
|
||||
) -> dict:
|
||||
"""Run the ReAct agent on a single SpreadsheetBench task.
|
||||
|
||||
@@ -283,6 +284,7 @@ def process_one(
|
||||
answer_position=answer_position_eval,
|
||||
skill_content=skill_content,
|
||||
max_turns=max_turns,
|
||||
max_output_tokens=max_completion_tokens,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
@@ -409,6 +411,7 @@ def run_spreadsheet_batch(
|
||||
out_root: str,
|
||||
skill_content: str,
|
||||
max_turns: int = 30,
|
||||
max_completion_tokens: int = 16384,
|
||||
max_api_workers: int = 64,
|
||||
task_timeout: int = 600,
|
||||
diagnostic_mode: bool = False,
|
||||
@@ -479,6 +482,7 @@ def run_spreadsheet_batch(
|
||||
diagnostic_mode,
|
||||
diagnostic_instruction,
|
||||
(diagnostic_trace_context_by_id or {}).get(str(it["id"]), ""),
|
||||
max_completion_tokens,
|
||||
)
|
||||
|
||||
ex = ThreadPoolExecutor(max_workers=max_api_workers)
|
||||
@@ -542,6 +546,7 @@ def process_one_codegen(
|
||||
skill_content: str,
|
||||
mode: str = "single",
|
||||
max_turns: int = 5,
|
||||
max_completion_tokens: int = 16384,
|
||||
use_eval_feedback: bool = False,
|
||||
diagnostic_mode: bool = False,
|
||||
diagnostic_instruction: str = "",
|
||||
@@ -653,6 +658,7 @@ def process_one_codegen(
|
||||
answer_position=answer_position_eval,
|
||||
skill_content=skill_content,
|
||||
max_turns=max_turns,
|
||||
max_output_tokens=max_completion_tokens,
|
||||
gold_path=first_gold if use_eval_feedback else "",
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
@@ -666,6 +672,7 @@ def process_one_codegen(
|
||||
instruction_type=instruction_type,
|
||||
answer_position=answer_position_eval,
|
||||
skill_content=skill_content,
|
||||
max_output_tokens=max_completion_tokens,
|
||||
diagnostic_mode=diagnostic_mode,
|
||||
diagnostic_instruction=diagnostic_instruction,
|
||||
diagnostic_trace_context=diagnostic_trace_context,
|
||||
@@ -790,6 +797,7 @@ def run_spreadsheet_batch_codegen(
|
||||
skill_content: str,
|
||||
mode: str = "single",
|
||||
max_turns: int = 5,
|
||||
max_completion_tokens: int = 16384,
|
||||
max_api_workers: int = 32,
|
||||
task_timeout: int = 0,
|
||||
use_eval_feedback: bool = False,
|
||||
@@ -845,6 +853,7 @@ def run_spreadsheet_batch_codegen(
|
||||
skill_content,
|
||||
mode,
|
||||
max_turns,
|
||||
max_completion_tokens,
|
||||
use_eval_feedback,
|
||||
diagnostic_mode,
|
||||
diagnostic_instruction,
|
||||
|
||||
Reference in New Issue
Block a user