mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
219 lines
7.1 KiB
Python
219 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Watch an ablation run root and rerun final failures.
|
|
|
|
This watcher is intended to run in tmux next to scripts/run_ablation_matrix.py.
|
|
It writes STATUS.md on every poll and starts a direct rerun for any run that
|
|
the launcher marks as final [FAIL].
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import time
|
|
from pathlib import Path
|
|
|
|
from run_ablation_matrix import PROJECT_ROOT, build_matrix, command_for
|
|
|
|
|
|
RUN_RE = re.compile(r"\[(START|DONE|FAIL|RETRY)\]\s+([^\s]+)")
|
|
ERROR_RE = re.compile(
|
|
r"Traceback|RuntimeError|AuthenticationError|PermissionDenied|"
|
|
r"DeploymentNotFound|LLM call failed|LLM message call failed|"
|
|
r"BadRequestError|RateLimitError",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
def read_text(path: Path) -> str:
|
|
try:
|
|
return path.read_text(encoding="utf-8", errors="replace")
|
|
except FileNotFoundError:
|
|
return ""
|
|
|
|
|
|
def parse_launcher(path: Path) -> dict[str, list[str]]:
|
|
events = {"START": [], "DONE": [], "FAIL": [], "RETRY": []}
|
|
for line in read_text(path).splitlines():
|
|
match = RUN_RE.search(line)
|
|
if match:
|
|
events[match.group(1)].append(match.group(2))
|
|
return events
|
|
|
|
|
|
def active_run_ids(run_root: Path) -> list[str]:
|
|
try:
|
|
raw = subprocess.check_output(["pgrep", "-af", "scripts/train.py"], text=True)
|
|
except subprocess.CalledProcessError:
|
|
return []
|
|
active: list[str] = []
|
|
pattern = re.compile(re.escape(str(run_root)) + r"/([^\s]+)")
|
|
for line in raw.splitlines():
|
|
for match in pattern.finditer(line):
|
|
active.append(match.group(1))
|
|
return sorted(set(active))
|
|
|
|
|
|
def scan_errors(logs_dir: Path) -> dict[str, str]:
|
|
errors: dict[str, str] = {}
|
|
for log_path in sorted(logs_dir.glob("*.log")):
|
|
text = read_text(log_path)
|
|
match = ERROR_RE.search(text)
|
|
if match:
|
|
run_id = log_path.name.split(".watchrerun", 1)[0].removesuffix(".log")
|
|
errors[run_id] = match.group(0)
|
|
return errors
|
|
|
|
|
|
def load_state(path: Path) -> dict:
|
|
try:
|
|
return json.loads(path.read_text(encoding="utf-8"))
|
|
except Exception:
|
|
return {"reruns": {}}
|
|
|
|
|
|
def save_state(path: Path, state: dict) -> None:
|
|
tmp = path.with_suffix(".tmp")
|
|
tmp.write_text(json.dumps(state, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
tmp.replace(path)
|
|
|
|
|
|
def write_status(
|
|
run_root: Path,
|
|
total: int,
|
|
events: dict[str, list[str]],
|
|
active: list[str],
|
|
completed: list[str],
|
|
pending: list[str],
|
|
errors: dict[str, str],
|
|
reruns: dict[str, int],
|
|
) -> None:
|
|
now = time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime())
|
|
failed = sorted(set(events["FAIL"]))
|
|
retrying = sorted(set(events["RETRY"]))
|
|
lines = [
|
|
"# Ablation Status",
|
|
"",
|
|
f"Updated: {now}",
|
|
f"Run root: `{run_root}`",
|
|
"",
|
|
"| Metric | Count |",
|
|
"| --- | ---: |",
|
|
f"| Total planned | {total} |",
|
|
f"| Completed summaries | {len(completed)} |",
|
|
f"| Active train processes | {len(active)} |",
|
|
f"| Pending/not summarized | {len(pending)} |",
|
|
f"| Launcher final fails | {len(failed)} |",
|
|
f"| Launcher retries | {len(retrying)} |",
|
|
f"| Logs with error patterns | {len(errors)} |",
|
|
"",
|
|
"## Active",
|
|
"",
|
|
*(f"- `{run_id}`" for run_id in active),
|
|
"",
|
|
"## Final Fails",
|
|
"",
|
|
*(f"- `{run_id}` watcher_reruns={reruns.get(run_id, 0)}" for run_id in failed),
|
|
"",
|
|
"## Error Patterns",
|
|
"",
|
|
*(f"- `{run_id}`: `{err}`" for run_id, err in sorted(errors.items())),
|
|
"",
|
|
"## Recent Launcher",
|
|
"",
|
|
"```text",
|
|
"\n".join(read_text(run_root / "launcher.log").splitlines()[-30:]),
|
|
"```",
|
|
]
|
|
(run_root / "STATUS.md").write_text("\n".join(lines) + "\n", encoding="utf-8")
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
parser.add_argument("--run-root", required=True)
|
|
parser.add_argument("--interval", type=int, default=60)
|
|
parser.add_argument("--watcher-retries", type=int, default=1)
|
|
parser.add_argument("--groups", nargs="+", default=["all"])
|
|
parser.add_argument("--bench", nargs="+", default=["searchqa", "spreadsheetbench"])
|
|
args = parser.parse_args()
|
|
|
|
run_root = Path(args.run_root).resolve()
|
|
logs_dir = run_root / "logs"
|
|
logs_dir.mkdir(parents=True, exist_ok=True)
|
|
state_path = run_root / "watcher_state.json"
|
|
|
|
groups = set(args.groups)
|
|
if "all" in groups:
|
|
groups = {"default", "split", "mbs", "lr", "sched", "slown", "mod", "smodel"}
|
|
experiments = {
|
|
exp.run_id: exp
|
|
for exp in build_matrix(groups, args.bench, run_root, include_duplicate_defaults=False)
|
|
}
|
|
|
|
active_reruns: dict[str, subprocess.Popen] = {}
|
|
while True:
|
|
state = load_state(state_path)
|
|
reruns = state.setdefault("reruns", {})
|
|
events = parse_launcher(run_root / "launcher.log")
|
|
active = active_run_ids(run_root)
|
|
completed = sorted(
|
|
run_id for run_id in experiments
|
|
if (run_root / run_id / "summary.json").exists()
|
|
)
|
|
pending = sorted(set(experiments) - set(completed))
|
|
errors = scan_errors(logs_dir)
|
|
|
|
# Reap watcher-started reruns.
|
|
for run_id, proc in list(active_reruns.items()):
|
|
rc = proc.poll()
|
|
if rc is None:
|
|
continue
|
|
active_reruns.pop(run_id, None)
|
|
with open(logs_dir / f"{run_id}.watcher.log", "a", encoding="utf-8") as f:
|
|
f.write(f"\n[WATCHER_DONE] rc={rc} time={time.time()}\n")
|
|
|
|
for run_id in sorted(set(events["FAIL"])):
|
|
if run_id not in experiments:
|
|
continue
|
|
if (run_root / run_id / "summary.json").exists():
|
|
continue
|
|
if run_id in active or run_id in active_reruns:
|
|
continue
|
|
count = int(reruns.get(run_id, 0))
|
|
if count >= args.watcher_retries:
|
|
continue
|
|
reruns[run_id] = count + 1
|
|
save_state(state_path, state)
|
|
log_path = logs_dir / f"{run_id}.watchrerun{count + 1}.log"
|
|
with open(log_path, "w", encoding="utf-8") as log_f:
|
|
log_f.write(f"[WATCHER_START] run_id={run_id} attempt={count + 1}\n")
|
|
log_f.flush()
|
|
proc = subprocess.Popen(
|
|
command_for(experiments[run_id]),
|
|
cwd=PROJECT_ROOT,
|
|
stdout=log_f,
|
|
stderr=subprocess.STDOUT,
|
|
text=True,
|
|
close_fds=True,
|
|
)
|
|
active_reruns[run_id] = proc
|
|
|
|
save_state(state_path, state)
|
|
write_status(
|
|
run_root=run_root,
|
|
total=len(experiments),
|
|
events=events,
|
|
active=active,
|
|
completed=completed,
|
|
pending=pending,
|
|
errors=errors,
|
|
reruns={k: int(v) for k, v in reruns.items()},
|
|
)
|
|
time.sleep(max(5, int(args.interval)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|