mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-05 23:30:35 +08:00
Merge pull request #63 from summerview1997/codex/webui-env-backend-preflight
Add WebUI env loading and backend preflight
This commit is contained in:
@@ -9,15 +9,19 @@ import glob
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
|
||||
from skillopt.config import flatten_config
|
||||
from skillopt.config import load_config as load_merged_config
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
@@ -42,6 +46,131 @@ def config_to_display(cfg: dict) -> str:
|
||||
return yaml.dump(cfg, default_flow_style=False, sort_keys=False)
|
||||
|
||||
|
||||
def _can_connect_to_url(url: str, timeout: float = 0.5) -> bool:
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
return False
|
||||
port = parsed.port or (443 if parsed.scheme == "https" else 80)
|
||||
try:
|
||||
with socket.create_connection((host, port), timeout=timeout):
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _load_env_file(path: Path, env: dict[str, str]) -> None:
|
||||
for line in path.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("export "):
|
||||
line = line[len("export "):].strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
env[key.strip()] = value.strip().strip("\"'")
|
||||
|
||||
|
||||
def build_training_env() -> dict[str, str]:
|
||||
"""Build the environment shared by preflight and the training subprocess."""
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
|
||||
dot_env = PROJECT_ROOT / ".env"
|
||||
if dot_env.is_file():
|
||||
_load_env_file(dot_env, env)
|
||||
|
||||
secrets_dir = PROJECT_ROOT / ".secrets"
|
||||
if secrets_dir.is_dir():
|
||||
for env_file in sorted(secrets_dir.glob("*.env")):
|
||||
_load_env_file(env_file, env)
|
||||
|
||||
# Propagate OPTIMIZER_* to base AZURE_OPENAI_* when base is missing,
|
||||
# so target/default endpoints inherit from optimizer config.
|
||||
for suffix in (
|
||||
"ENDPOINT", "API_VERSION", "AUTH_MODE", "MANAGED_IDENTITY_CLIENT_ID",
|
||||
"AD_SCOPE", "API_KEY",
|
||||
):
|
||||
base_key = f"AZURE_OPENAI_{suffix}"
|
||||
optimizer_key = f"OPTIMIZER_AZURE_OPENAI_{suffix}"
|
||||
if not env.get(base_key) and env.get(optimizer_key):
|
||||
env[base_key] = env[optimizer_key]
|
||||
return env
|
||||
|
||||
|
||||
def validate_training_config(
|
||||
config_path: str,
|
||||
overrides: dict,
|
||||
env: dict[str, str] | None = None,
|
||||
) -> str | None:
|
||||
"""Return an actionable preflight error, or None when training can start."""
|
||||
env = env or os.environ
|
||||
cfg_options = [
|
||||
f"{key}={value}" for key, value in overrides.items()
|
||||
if value is not None and value != ""
|
||||
]
|
||||
try:
|
||||
cfg = flatten_config(load_merged_config(str(PROJECT_ROOT / config_path), cfg_options))
|
||||
except Exception as exc:
|
||||
return f"❌ Invalid config: {exc}"
|
||||
|
||||
shared_endpoint = (
|
||||
cfg.get("azure_openai_endpoint")
|
||||
or cfg.get("azure_endpoint")
|
||||
or env.get("AZURE_OPENAI_ENDPOINT")
|
||||
)
|
||||
missing_openai_roles = []
|
||||
for role in ("optimizer", "target"):
|
||||
if cfg.get(f"{role}_backend") != "openai_chat":
|
||||
continue
|
||||
role_endpoint = (
|
||||
cfg.get(f"{role}_azure_openai_endpoint")
|
||||
or env.get(f"{role.upper()}_AZURE_OPENAI_ENDPOINT")
|
||||
or shared_endpoint
|
||||
)
|
||||
if not role_endpoint:
|
||||
missing_openai_roles.append(role)
|
||||
if missing_openai_roles:
|
||||
configured_backend = cfg.get("model_backend")
|
||||
detail = ""
|
||||
if configured_backend in {"qwen", "qwen_chat"}:
|
||||
detail = (
|
||||
"\nNote: model.backend is qwen, but explicit optimizer_backend/"
|
||||
"target_backend values are still openai_chat."
|
||||
)
|
||||
return (
|
||||
"❌ Model backend is not ready: missing Azure/OpenAI-compatible endpoint "
|
||||
f"for {', '.join(missing_openai_roles)}.\n"
|
||||
"Set model.azure_openai_endpoint (or AZURE_OPENAI_ENDPOINT), or change "
|
||||
"the role backends to the backend you intend to use."
|
||||
f"{detail}"
|
||||
)
|
||||
|
||||
qwen_failures = []
|
||||
qwen_shared = (
|
||||
cfg.get("qwen_chat_base_url")
|
||||
or env.get("QWEN_CHAT_BASE_URL")
|
||||
or "http://localhost:8000/v1"
|
||||
)
|
||||
for role in ("optimizer", "target"):
|
||||
if cfg.get(f"{role}_backend") != "qwen_chat":
|
||||
continue
|
||||
base_url = (
|
||||
cfg.get(f"{role}_qwen_chat_base_url")
|
||||
or env.get(f"{role.upper()}_QWEN_CHAT_BASE_URL")
|
||||
or qwen_shared
|
||||
)
|
||||
if not _can_connect_to_url(str(base_url)):
|
||||
qwen_failures.append(f"{role}={base_url}")
|
||||
if qwen_failures:
|
||||
return (
|
||||
"❌ Model backend is not ready: cannot connect to qwen_chat endpoint "
|
||||
f"for {', '.join(qwen_failures)}.\n"
|
||||
"Start your OpenAI-compatible Qwen/vLLM server, or set "
|
||||
"model.qwen_chat_base_url / OPTIMIZER_QWEN_CHAT_BASE_URL / "
|
||||
"TARGET_QWEN_CHAT_BASE_URL to the correct URL."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# ─── Training process management ────────────────────────────────────────────
|
||||
|
||||
class TrainingManager:
|
||||
@@ -63,6 +192,11 @@ class TrainingManager:
|
||||
if self.running:
|
||||
return "⚠️ Training already running. Stop it first."
|
||||
|
||||
env = build_training_env()
|
||||
preflight_error = validate_training_config(config_path, overrides, env)
|
||||
if preflight_error:
|
||||
return preflight_error
|
||||
|
||||
cmd = [
|
||||
sys.executable, "scripts/train.py",
|
||||
"--config", config_path,
|
||||
@@ -75,30 +209,6 @@ class TrainingManager:
|
||||
cmd.append("--cfg-options")
|
||||
cmd.extend(cfg_options)
|
||||
|
||||
env = os.environ.copy()
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
# Auto-load API credentials from .secrets/*.env
|
||||
secrets_dir = PROJECT_ROOT / ".secrets"
|
||||
if secrets_dir.is_dir():
|
||||
for env_file in sorted(secrets_dir.glob("*.env")):
|
||||
for line in env_file.read_text().splitlines():
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, v = line.split("=", 1)
|
||||
env[k] = v
|
||||
# Propagate OPTIMIZER_* to base AZURE_OPENAI_* when base is missing,
|
||||
# so target/default endpoints inherit from optimizer config.
|
||||
_propagate = [
|
||||
("ENDPOINT", ""), ("API_VERSION", ""), ("AUTH_MODE", ""),
|
||||
("MANAGED_IDENTITY_CLIENT_ID", ""), ("AD_SCOPE", ""),
|
||||
("API_KEY", ""),
|
||||
]
|
||||
for suffix, _ in _propagate:
|
||||
base_key = f"AZURE_OPENAI_{suffix}"
|
||||
optimizer_key = f"OPTIMIZER_AZURE_OPENAI_{suffix}"
|
||||
if not env.get(base_key) and env.get(optimizer_key):
|
||||
env[base_key] = env[optimizer_key]
|
||||
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
|
||||
89
tests/test_webui_env_preflight.py
Normal file
89
tests/test_webui_env_preflight.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
pytest.importorskip("gradio")
|
||||
|
||||
from skillopt_webui import app as webui_app
|
||||
|
||||
|
||||
def _write_config(tmp_path, model):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(
|
||||
yaml.safe_dump({
|
||||
"model": model,
|
||||
"env": {"name": "searchqa"},
|
||||
}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return str(config_path)
|
||||
|
||||
|
||||
def test_build_training_env_loads_project_dotenv(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(webui_app, "PROJECT_ROOT", tmp_path)
|
||||
(tmp_path / ".env").write_text(
|
||||
"\n".join([
|
||||
"export QWEN_CHAT_BASE_URL=http://qwen.example/v1",
|
||||
"QWEN_CHAT_MODEL=test-model",
|
||||
"QWEN_CHAT_API_KEY='secret-value'",
|
||||
]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
env = webui_app.build_training_env()
|
||||
|
||||
assert env["QWEN_CHAT_BASE_URL"] == "http://qwen.example/v1"
|
||||
assert env["QWEN_CHAT_MODEL"] == "test-model"
|
||||
assert env["QWEN_CHAT_API_KEY"] == "secret-value"
|
||||
|
||||
|
||||
def test_preflight_reports_missing_openai_chat_endpoint(tmp_path, monkeypatch):
|
||||
monkeypatch.delenv("AZURE_OPENAI_ENDPOINT", raising=False)
|
||||
monkeypatch.delenv("OPTIMIZER_AZURE_OPENAI_ENDPOINT", raising=False)
|
||||
monkeypatch.delenv("TARGET_AZURE_OPENAI_ENDPOINT", raising=False)
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
{
|
||||
"backend": "qwen",
|
||||
"optimizer_backend": "openai_chat",
|
||||
"target_backend": "openai_chat",
|
||||
},
|
||||
)
|
||||
|
||||
error = webui_app.validate_training_config(config_path, {})
|
||||
|
||||
assert "missing Azure/OpenAI-compatible endpoint for optimizer, target" in error
|
||||
assert "model.backend is qwen" in error
|
||||
|
||||
|
||||
def test_preflight_reports_unreachable_qwen_endpoint(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(webui_app, "_can_connect_to_url", lambda _url: False)
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
{
|
||||
"backend": "qwen",
|
||||
"optimizer_backend": "qwen_chat",
|
||||
"target_backend": "qwen_chat",
|
||||
"qwen_chat_base_url": "http://127.0.0.1:9/v1",
|
||||
},
|
||||
)
|
||||
|
||||
error = webui_app.validate_training_config(config_path, {})
|
||||
|
||||
assert "cannot connect to qwen_chat endpoint" in error
|
||||
assert "127.0.0.1:9" in error
|
||||
|
||||
|
||||
def test_preflight_accepts_reachable_qwen_endpoint(tmp_path, monkeypatch):
|
||||
seen_urls = []
|
||||
monkeypatch.setattr(webui_app, "_can_connect_to_url", lambda url: seen_urls.append(url) or True)
|
||||
config_path = _write_config(
|
||||
tmp_path,
|
||||
{
|
||||
"optimizer_backend": "qwen_chat",
|
||||
"target_backend": "qwen_chat",
|
||||
"qwen_chat_base_url": "http://qwen.example/v1",
|
||||
},
|
||||
)
|
||||
|
||||
assert webui_app.validate_training_config(config_path, {}) is None
|
||||
assert seen_urls == ["http://qwen.example/v1", "http://qwen.example/v1"]
|
||||
Reference in New Issue
Block a user