diff --git a/skillopt_webui/app.py b/skillopt_webui/app.py index ef0c68f..e4978c5 100644 --- a/skillopt_webui/app.py +++ b/skillopt_webui/app.py @@ -9,15 +9,19 @@ import glob import json import os import signal +import socket import subprocess import sys import threading -import time from pathlib import Path +from urllib.parse import urlparse import gradio as gr import yaml +from skillopt.config import flatten_config +from skillopt.config import load_config as load_merged_config + PROJECT_ROOT = Path(__file__).resolve().parent.parent @@ -42,6 +46,131 @@ def config_to_display(cfg: dict) -> str: return yaml.dump(cfg, default_flow_style=False, sort_keys=False) +def _can_connect_to_url(url: str, timeout: float = 0.5) -> bool: + parsed = urlparse(url) + host = parsed.hostname + if not host: + return False + port = parsed.port or (443 if parsed.scheme == "https" else 80) + try: + with socket.create_connection((host, port), timeout=timeout): + return True + except OSError: + return False + + +def _load_env_file(path: Path, env: dict[str, str]) -> None: + for line in path.read_text().splitlines(): + line = line.strip() + if line.startswith("export "): + line = line[len("export "):].strip() + if line and not line.startswith("#") and "=" in line: + key, value = line.split("=", 1) + env[key.strip()] = value.strip().strip("\"'") + + +def build_training_env() -> dict[str, str]: + """Build the environment shared by preflight and the training subprocess.""" + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + + dot_env = PROJECT_ROOT / ".env" + if dot_env.is_file(): + _load_env_file(dot_env, env) + + secrets_dir = PROJECT_ROOT / ".secrets" + if secrets_dir.is_dir(): + for env_file in sorted(secrets_dir.glob("*.env")): + _load_env_file(env_file, env) + + # Propagate OPTIMIZER_* to base AZURE_OPENAI_* when base is missing, + # so target/default endpoints inherit from optimizer config. + for suffix in ( + "ENDPOINT", "API_VERSION", "AUTH_MODE", "MANAGED_IDENTITY_CLIENT_ID", + "AD_SCOPE", "API_KEY", + ): + base_key = f"AZURE_OPENAI_{suffix}" + optimizer_key = f"OPTIMIZER_AZURE_OPENAI_{suffix}" + if not env.get(base_key) and env.get(optimizer_key): + env[base_key] = env[optimizer_key] + return env + + +def validate_training_config( + config_path: str, + overrides: dict, + env: dict[str, str] | None = None, +) -> str | None: + """Return an actionable preflight error, or None when training can start.""" + env = env or os.environ + cfg_options = [ + f"{key}={value}" for key, value in overrides.items() + if value is not None and value != "" + ] + try: + cfg = flatten_config(load_merged_config(str(PROJECT_ROOT / config_path), cfg_options)) + except Exception as exc: + return f"❌ Invalid config: {exc}" + + shared_endpoint = ( + cfg.get("azure_openai_endpoint") + or cfg.get("azure_endpoint") + or env.get("AZURE_OPENAI_ENDPOINT") + ) + missing_openai_roles = [] + for role in ("optimizer", "target"): + if cfg.get(f"{role}_backend") != "openai_chat": + continue + role_endpoint = ( + cfg.get(f"{role}_azure_openai_endpoint") + or env.get(f"{role.upper()}_AZURE_OPENAI_ENDPOINT") + or shared_endpoint + ) + if not role_endpoint: + missing_openai_roles.append(role) + if missing_openai_roles: + configured_backend = cfg.get("model_backend") + detail = "" + if configured_backend in {"qwen", "qwen_chat"}: + detail = ( + "\nNote: model.backend is qwen, but explicit optimizer_backend/" + "target_backend values are still openai_chat." + ) + return ( + "❌ Model backend is not ready: missing Azure/OpenAI-compatible endpoint " + f"for {', '.join(missing_openai_roles)}.\n" + "Set model.azure_openai_endpoint (or AZURE_OPENAI_ENDPOINT), or change " + "the role backends to the backend you intend to use." + f"{detail}" + ) + + qwen_failures = [] + qwen_shared = ( + cfg.get("qwen_chat_base_url") + or env.get("QWEN_CHAT_BASE_URL") + or "http://localhost:8000/v1" + ) + for role in ("optimizer", "target"): + if cfg.get(f"{role}_backend") != "qwen_chat": + continue + base_url = ( + cfg.get(f"{role}_qwen_chat_base_url") + or env.get(f"{role.upper()}_QWEN_CHAT_BASE_URL") + or qwen_shared + ) + if not _can_connect_to_url(str(base_url)): + qwen_failures.append(f"{role}={base_url}") + if qwen_failures: + return ( + "❌ Model backend is not ready: cannot connect to qwen_chat endpoint " + f"for {', '.join(qwen_failures)}.\n" + "Start your OpenAI-compatible Qwen/vLLM server, or set " + "model.qwen_chat_base_url / OPTIMIZER_QWEN_CHAT_BASE_URL / " + "TARGET_QWEN_CHAT_BASE_URL to the correct URL." + ) + return None + + # ─── Training process management ──────────────────────────────────────────── class TrainingManager: @@ -63,6 +192,11 @@ class TrainingManager: if self.running: return "⚠️ Training already running. Stop it first." + env = build_training_env() + preflight_error = validate_training_config(config_path, overrides, env) + if preflight_error: + return preflight_error + cmd = [ sys.executable, "scripts/train.py", "--config", config_path, @@ -75,30 +209,6 @@ class TrainingManager: cmd.append("--cfg-options") cmd.extend(cfg_options) - env = os.environ.copy() - env["PYTHONUNBUFFERED"] = "1" - # Auto-load API credentials from .secrets/*.env - secrets_dir = PROJECT_ROOT / ".secrets" - if secrets_dir.is_dir(): - for env_file in sorted(secrets_dir.glob("*.env")): - for line in env_file.read_text().splitlines(): - line = line.strip() - if line and not line.startswith("#") and "=" in line: - k, v = line.split("=", 1) - env[k] = v - # Propagate OPTIMIZER_* to base AZURE_OPENAI_* when base is missing, - # so target/default endpoints inherit from optimizer config. - _propagate = [ - ("ENDPOINT", ""), ("API_VERSION", ""), ("AUTH_MODE", ""), - ("MANAGED_IDENTITY_CLIENT_ID", ""), ("AD_SCOPE", ""), - ("API_KEY", ""), - ] - for suffix, _ in _propagate: - base_key = f"AZURE_OPENAI_{suffix}" - optimizer_key = f"OPTIMIZER_AZURE_OPENAI_{suffix}" - if not env.get(base_key) and env.get(optimizer_key): - env[base_key] = env[optimizer_key] - try: proc = subprocess.Popen( cmd, diff --git a/tests/test_webui_env_preflight.py b/tests/test_webui_env_preflight.py new file mode 100644 index 0000000..5b84d86 --- /dev/null +++ b/tests/test_webui_env_preflight.py @@ -0,0 +1,89 @@ +import pytest +import yaml + +pytest.importorskip("gradio") + +from skillopt_webui import app as webui_app + + +def _write_config(tmp_path, model): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump({ + "model": model, + "env": {"name": "searchqa"}, + }), + encoding="utf-8", + ) + return str(config_path) + + +def test_build_training_env_loads_project_dotenv(tmp_path, monkeypatch): + monkeypatch.setattr(webui_app, "PROJECT_ROOT", tmp_path) + (tmp_path / ".env").write_text( + "\n".join([ + "export QWEN_CHAT_BASE_URL=http://qwen.example/v1", + "QWEN_CHAT_MODEL=test-model", + "QWEN_CHAT_API_KEY='secret-value'", + ]), + encoding="utf-8", + ) + + env = webui_app.build_training_env() + + assert env["QWEN_CHAT_BASE_URL"] == "http://qwen.example/v1" + assert env["QWEN_CHAT_MODEL"] == "test-model" + assert env["QWEN_CHAT_API_KEY"] == "secret-value" + + +def test_preflight_reports_missing_openai_chat_endpoint(tmp_path, monkeypatch): + monkeypatch.delenv("AZURE_OPENAI_ENDPOINT", raising=False) + monkeypatch.delenv("OPTIMIZER_AZURE_OPENAI_ENDPOINT", raising=False) + monkeypatch.delenv("TARGET_AZURE_OPENAI_ENDPOINT", raising=False) + config_path = _write_config( + tmp_path, + { + "backend": "qwen", + "optimizer_backend": "openai_chat", + "target_backend": "openai_chat", + }, + ) + + error = webui_app.validate_training_config(config_path, {}) + + assert "missing Azure/OpenAI-compatible endpoint for optimizer, target" in error + assert "model.backend is qwen" in error + + +def test_preflight_reports_unreachable_qwen_endpoint(tmp_path, monkeypatch): + monkeypatch.setattr(webui_app, "_can_connect_to_url", lambda _url: False) + config_path = _write_config( + tmp_path, + { + "backend": "qwen", + "optimizer_backend": "qwen_chat", + "target_backend": "qwen_chat", + "qwen_chat_base_url": "http://127.0.0.1:9/v1", + }, + ) + + error = webui_app.validate_training_config(config_path, {}) + + assert "cannot connect to qwen_chat endpoint" in error + assert "127.0.0.1:9" in error + + +def test_preflight_accepts_reachable_qwen_endpoint(tmp_path, monkeypatch): + seen_urls = [] + monkeypatch.setattr(webui_app, "_can_connect_to_url", lambda url: seen_urls.append(url) or True) + config_path = _write_config( + tmp_path, + { + "optimizer_backend": "qwen_chat", + "target_backend": "qwen_chat", + "qwen_chat_base_url": "http://qwen.example/v1", + }, + ) + + assert webui_app.validate_training_config(config_path, {}) is None + assert seen_urls == ["http://qwen.example/v1", "http://qwen.example/v1"]