Files
microsoft-SkillOpt/skillopt/model/common.py

230 lines
7.4 KiB
Python

"""Shared model utilities for ReflACT backends."""
from __future__ import annotations
import json
import threading
from dataclasses import dataclass, field
from typing import Any
_RESPONSES_API_MODELS = {
"gpt-5.3-codex",
"gpt-5.1-codex",
"gpt-5.2-codex",
"gpt-5-codex",
"codex-mini",
"gpt-5.4-pro",
}
_BACKEND_DEFAULT_MODELS = {
"azure_openai": "gpt-4o",
"openai_chat": "gpt-4o",
"codex": "gpt-4o",
"codex_exec": "gpt-4o",
"claude": "claude-sonnet-4-6",
"claude_chat": "claude-sonnet-4-6",
"claude_code_exec": "claude-sonnet-4-6",
"qwen_chat": "Qwen/Qwen3.5-4B",
"minimax_chat": "MiniMax-M2.7",
}
_BACKEND_ALIASES = {
"azure": "azure_openai",
"azure_openai": "azure_openai",
"azure-openai": "azure_openai",
"openai_chat": "openai_chat",
"openai": "codex",
"codex": "codex",
"codex_exec": "codex_exec",
"claude": "claude_chat",
"claude_chat": "claude_chat",
"claude_code_exec": "claude_code_exec",
"anthropic": "claude_chat",
"qwen": "qwen_chat",
"qwen_chat": "qwen_chat",
"minimax": "minimax_chat",
"minimax_chat": "minimax_chat",
}
def normalize_backend_name(name: str | None) -> str:
normalized = str(name or "").strip().lower()
return _BACKEND_ALIASES.get(normalized, normalized or "azure_openai")
def default_model_for_backend(backend: str | None) -> str:
return _BACKEND_DEFAULT_MODELS.get(
normalize_backend_name(backend),
_BACKEND_DEFAULT_MODELS["azure_openai"],
)
def needs_responses_api(model: str) -> bool:
normalized = str(model or "").strip().lower()
return any(
normalized == prefix or normalized.startswith(prefix + "-")
for prefix in _RESPONSES_API_MODELS
)
class TokenTracker:
def __init__(self) -> None:
self._lock = threading.Lock()
self._data: dict[str, dict[str, int]] = {}
def record(self, stage: str, prompt_tokens: int, completion_tokens: int) -> None:
with self._lock:
if stage not in self._data:
self._data[stage] = {
"calls": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
}
entry = self._data[stage]
entry["calls"] += 1
entry["prompt_tokens"] += prompt_tokens
entry["completion_tokens"] += completion_tokens
def summary(self) -> dict[str, dict[str, int]]:
with self._lock:
out: dict[str, dict[str, int]] = {}
total_prompt = total_completion = total_calls = 0
for stage, entry in sorted(self._data.items()):
prompt_tokens = entry["prompt_tokens"]
completion_tokens = entry["completion_tokens"]
out[stage] = {
"calls": entry["calls"],
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
total_prompt += prompt_tokens
total_completion += completion_tokens
total_calls += entry["calls"]
out["_total"] = {
"calls": total_calls,
"prompt_tokens": total_prompt,
"completion_tokens": total_completion,
"total_tokens": total_prompt + total_completion,
}
return out
def reset(self) -> None:
with self._lock:
self._data.clear()
tracker = TokenTracker()
@dataclass
class CompatToolFunction:
name: str
arguments: str
def model_dump(self, mode: str = "json") -> dict[str, str]:
del mode
return {
"name": self.name,
"arguments": self.arguments,
}
@dataclass
class CompatToolCall:
id: str
function: CompatToolFunction
type: str = "function"
def model_dump(self, mode: str = "json") -> dict[str, Any]:
del mode
return {
"id": self.id,
"type": self.type,
"function": self.function.model_dump(),
}
@dataclass
class CompatAssistantMessage:
content: str
tool_calls: list[CompatToolCall] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
def model_dump(self, mode: str = "json") -> dict[str, Any]:
del mode
data: dict[str, Any] = {"role": "assistant", "content": self.content}
if self.tool_calls:
data["tool_calls"] = [tool_call.model_dump() for tool_call in self.tool_calls]
return data
def usage_from_openai_usage(usage: Any) -> dict[str, int]:
if not usage:
return {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
total_tokens = getattr(usage, "total_tokens", 0) or (prompt_tokens + completion_tokens)
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
def usage_from_responses_usage(usage: Any) -> dict[str, int]:
if not usage:
return {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
prompt_tokens = getattr(usage, "input_tokens", 0) or 0
completion_tokens = getattr(usage, "output_tokens", 0) or 0
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
def compat_message_from_chat_message(message: Any) -> CompatAssistantMessage:
content = getattr(message, "content", "") or ""
tool_calls = []
for tool_call in getattr(message, "tool_calls", None) or []:
function = getattr(tool_call, "function", None)
tool_calls.append(
CompatToolCall(
id=getattr(tool_call, "id", "") or "",
function=CompatToolFunction(
name=getattr(function, "name", "") or "",
arguments=getattr(function, "arguments", "") or "{}",
),
)
)
return CompatAssistantMessage(content=content, tool_calls=tool_calls)
def compat_message_from_responses_output(output: list[Any]) -> CompatAssistantMessage:
text_parts: list[str] = []
tool_calls: list[CompatToolCall] = []
for item in output:
item_type = getattr(item, "type", "") or ""
if item_type == "function_call":
raw_arguments = getattr(item, "arguments", None)
if raw_arguments is None:
raw_arguments = json.dumps(getattr(item, "input", {}) or {})
tool_calls.append(
CompatToolCall(
id=getattr(item, "call_id", "") or getattr(item, "id", "") or "",
function=CompatToolFunction(
name=getattr(item, "name", "") or "",
arguments=str(raw_arguments or "{}"),
),
)
)
continue
if item_type != "message":
continue
for part in getattr(item, "content", []) or []:
part_type = getattr(part, "type", "") or ""
if part_type in {"output_text", "text"}:
text_parts.append(getattr(part, "text", "") or "")
return CompatAssistantMessage(content="".join(text_parts), tool_calls=tool_calls)