mirror of
https://github.com/datascale-ai/opentalking.git
synced 2026-07-03 15:22:34 +08:00
1176 lines
48 KiB
Python
1176 lines
48 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import csv
|
|
import json
|
|
import math
|
|
import os
|
|
import platform
|
|
import re
|
|
import shutil
|
|
import shlex
|
|
import signal
|
|
import subprocess
|
|
import tarfile
|
|
import time
|
|
import wave
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from urllib import request as urlrequest
|
|
from urllib.request import Request, urlopen
|
|
|
|
import yaml
|
|
|
|
|
|
def run(cmd: list[str], *, cwd: Path | None = None, env: dict[str, str] | None = None) -> subprocess.CompletedProcess[str]:
|
|
return subprocess.run(cmd, cwd=cwd, env=env, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
|
|
|
|
def run_shell(cmd: str, *, cwd: Path | None = None, env: dict[str, str] | None = None) -> subprocess.CompletedProcess[str]:
|
|
return subprocess.run(["bash", "-lc", cmd], cwd=cwd, env=env, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
|
|
|
|
def read_yaml(path: Path) -> dict[str, Any]:
|
|
with path.open("r", encoding="utf-8") as handle:
|
|
data = yaml.safe_load(handle) or {}
|
|
if not isinstance(data, dict):
|
|
raise SystemExit(f"benchmark config must be a mapping: {path}")
|
|
return data
|
|
|
|
|
|
def resolve_path(value: str, repo: Path) -> str:
|
|
path = Path(value).expanduser()
|
|
if path.is_absolute():
|
|
return str(path)
|
|
return str((repo / path).resolve())
|
|
|
|
|
|
E2E_TECHNICAL_ROUTE_BY_MODEL = {
|
|
"wav2lip": "mouth inpainting",
|
|
"musetalk": "mouth inpainting",
|
|
"quicktalk": "mouth inpainting",
|
|
}
|
|
|
|
|
|
def technical_route_for_model(model: str, model_cfg: dict[str, Any], cfg: dict[str, Any]) -> str:
|
|
normalized = model.strip().lower()
|
|
return E2E_TECHNICAL_ROUTE_BY_MODEL.get(normalized) or str(model_cfg.get("technical_route") or cfg.get("technical_route") or "")
|
|
|
|
|
|
def percentile(values: list[float], q: float) -> float | None:
|
|
if not values:
|
|
return None
|
|
items = sorted(values)
|
|
pos = (len(items) - 1) * q
|
|
lo = math.floor(pos)
|
|
hi = math.ceil(pos)
|
|
if lo == hi:
|
|
return items[lo]
|
|
return items[lo] * (hi - pos) + items[hi] * (pos - lo)
|
|
|
|
|
|
def http_json(url: str, *, method: str = "GET", payload: dict[str, Any] | None = None, timeout: float = 30.0) -> dict[str, Any]:
|
|
body = None
|
|
headers = {"Accept": "application/json"}
|
|
if payload is not None:
|
|
body = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
|
headers["Content-Type"] = "application/json"
|
|
req = Request(url, data=body, headers=headers, method=method)
|
|
with urlopen(req, timeout=timeout) as resp:
|
|
return json.loads(resp.read().decode("utf-8"))
|
|
|
|
|
|
def upload_reference(api_base_url: str, avatar_id: str, image_path: Path, timeout: float = 30.0) -> dict[str, Any]:
|
|
boundary = f"----opentalking-benchmark-{int(time.time() * 1000)}"
|
|
image = image_path.read_bytes()
|
|
filename = image_path.name
|
|
parts = [
|
|
f"--{boundary}\r\n".encode(),
|
|
b'Content-Disposition: form-data; name="avatar_id"\r\n\r\n',
|
|
avatar_id.encode("utf-8"),
|
|
b"\r\n",
|
|
f"--{boundary}\r\n".encode(),
|
|
f'Content-Disposition: form-data; name="reference_image"; filename="{filename}"\r\n'.encode(),
|
|
b"Content-Type: image/png\r\n\r\n",
|
|
image,
|
|
b"\r\n",
|
|
f"--{boundary}--\r\n".encode(),
|
|
]
|
|
body = b"".join(parts)
|
|
req = urlrequest.Request(
|
|
api_base_url.rstrip("/") + "/sessions/customize/reference",
|
|
data=body,
|
|
headers={"Content-Type": f"multipart/form-data; boundary={boundary}"},
|
|
method="POST",
|
|
)
|
|
with urlopen(req, timeout=timeout) as resp:
|
|
return json.loads(resp.read().decode("utf-8"))
|
|
|
|
|
|
def wait_for_health(api_base_url: str, timeout: float = 180.0) -> None:
|
|
deadline = time.perf_counter() + timeout
|
|
last: Exception | None = None
|
|
while time.perf_counter() < deadline:
|
|
try:
|
|
http_json(api_base_url.rstrip("/") + "/health", timeout=5.0)
|
|
return
|
|
except Exception as exc:
|
|
last = exc
|
|
time.sleep(1.0)
|
|
raise TimeoutError(f"OpenTalking health not ready: {last}")
|
|
|
|
|
|
def wait_for_model(api_base_url: str, model: str, backend: str, timeout: float = 180.0) -> dict[str, Any]:
|
|
deadline = time.perf_counter() + timeout
|
|
last: Any = None
|
|
while time.perf_counter() < deadline:
|
|
try:
|
|
data = http_json(api_base_url.rstrip("/") + "/models", timeout=10.0)
|
|
last = data
|
|
for item in data.get("statuses", data if isinstance(data, list) else []):
|
|
if item.get("id") == model:
|
|
if item.get("connected") and item.get("backend") == backend and backend != "mock":
|
|
return item
|
|
except Exception as exc:
|
|
last = exc
|
|
time.sleep(1.0)
|
|
raise TimeoutError(f"model {model}/{backend} not connected at /models; last={last}")
|
|
|
|
|
|
def gpu_info(index: int) -> dict[str, Any]:
|
|
cp = run(["nvidia-smi", f"--id={index}", "--query-gpu=name,memory.total,driver_version", "--format=csv,noheader,nounits"])
|
|
if cp.returncode or not cp.stdout.strip():
|
|
return {"index": index, "error": cp.stderr.strip()}
|
|
parts = [p.strip() for p in cp.stdout.strip().splitlines()[0].split(",")]
|
|
return {"index": index, "name": parts[0], "memory_total_gb": round(float(parts[1]) / 1024.0, 3), "driver": parts[2]}
|
|
|
|
|
|
def gpu_device_mem_gb(index: int) -> float | None:
|
|
cp = run(["nvidia-smi", f"--id={index}", "--query-gpu=memory.used", "--format=csv,noheader,nounits"])
|
|
if cp.returncode or not cp.stdout.strip():
|
|
return None
|
|
return round(float(cp.stdout.strip().splitlines()[0]) / 1024.0, 3)
|
|
|
|
|
|
def gpu_process_mem_gb(pids: list[int], gpu_index: int | None = None) -> float | None:
|
|
if not pids:
|
|
return None
|
|
records = gpu_process_mem_records(pids, gpu_index=gpu_index)
|
|
if records is None:
|
|
return None
|
|
return round(sum(float(item["used_memory_mb"]) for item in records) / 1024.0, 3)
|
|
|
|
|
|
def retry_gpu_process_mem_gb(
|
|
pids: list[int],
|
|
*,
|
|
gpu_index: int | None = None,
|
|
attempts: int = 10,
|
|
interval: float = 0.3,
|
|
) -> float | None:
|
|
for _ in range(max(1, attempts)):
|
|
value = gpu_process_mem_gb(pids, gpu_index=gpu_index)
|
|
if value is not None:
|
|
return value
|
|
time.sleep(interval)
|
|
return None
|
|
|
|
|
|
def pids_for_ports(ports: list[int]) -> list[int]:
|
|
pids: set[int] = set()
|
|
for port in ports:
|
|
cp = run_shell(f"ss -ltnp '( sport = :{int(port)} )' 2>/dev/null")
|
|
for pid in re.findall(r"pid=(\d+)", cp.stdout):
|
|
pids.add(int(pid))
|
|
return sorted(pids)
|
|
|
|
|
|
def proc_tree(root_pid: int) -> list[int]:
|
|
pids = {root_pid}
|
|
children: dict[int, list[int]] = {}
|
|
proc_root = Path("/proc")
|
|
if not proc_root.exists():
|
|
return [root_pid]
|
|
for proc in proc_root.iterdir():
|
|
if not proc.name.isdigit():
|
|
continue
|
|
try:
|
|
parts = (proc / "stat").read_text().split()
|
|
children.setdefault(int(parts[3]), []).append(int(parts[0]))
|
|
except Exception:
|
|
continue
|
|
queue = [root_pid]
|
|
while queue:
|
|
pid = queue.pop()
|
|
for child in children.get(pid, []):
|
|
if child not in pids:
|
|
pids.add(child)
|
|
queue.append(child)
|
|
return sorted(pids)
|
|
|
|
|
|
def read_pid_file(path: Path) -> int | None:
|
|
try:
|
|
value = path.read_text(encoding="utf-8").strip()
|
|
if not value:
|
|
return None
|
|
pid = int(value.split()[0])
|
|
except Exception:
|
|
return None
|
|
return pid if Path(f"/proc/{pid}").exists() else None
|
|
|
|
|
|
def resolve_pid_file(path: str, repo: Path) -> Path:
|
|
value = Path(path).expanduser()
|
|
if value.is_absolute():
|
|
return value
|
|
if str(value).startswith("run/"):
|
|
return repo.parent / value
|
|
return (repo / value).resolve()
|
|
|
|
|
|
def default_model_pid_files(model: str, backend: str, repo: Path) -> list[Path]:
|
|
run_dir = repo.parent / "run"
|
|
if backend != "omnirt":
|
|
return []
|
|
mapping = {
|
|
"wav2lip": ["omnirt-wav2lip.pid"],
|
|
"quicktalk": ["omnirt-quicktalk.pid"],
|
|
"musetalk": ["omnirt-musetalk.pid", "omnirt-musetalk-ws.pid"],
|
|
"flashtalk": ["omnirt-flashtalk.pid"],
|
|
}
|
|
return [run_dir / name for name in mapping.get(model, [f"omnirt-{model}.pid"])]
|
|
|
|
|
|
def collect_related_pids(repo: Path, cfg: dict[str, Any], model: str, backend: str, model_cfg: dict[str, Any], ports: list[int]) -> dict[str, Any]:
|
|
run_dir = repo.parent / "run"
|
|
api_port = int(cfg.get("api_port", 8010))
|
|
web_port = int(cfg.get("web_port", 5184))
|
|
pid_files: list[Path] = [run_dir / f"opentalking-api-{api_port}.pid", run_dir / f"opentalking-web-{web_port}.pid"]
|
|
pid_files.extend(default_model_pid_files(model, backend, repo))
|
|
configured = model_cfg.get("pid_files")
|
|
if isinstance(configured, list):
|
|
pid_files.extend(resolve_pid_file(str(item), repo) for item in configured)
|
|
|
|
root_pids: dict[int, str] = {}
|
|
pid_file_map: dict[str, int | None] = {}
|
|
for pid_file in pid_files:
|
|
pid = read_pid_file(pid_file)
|
|
pid_file_map[str(pid_file)] = pid
|
|
if pid is not None:
|
|
root_pids[pid] = str(pid_file)
|
|
|
|
for pid in pids_for_ports(ports):
|
|
root_pids.setdefault(pid, "port-listener")
|
|
|
|
related: set[int] = set()
|
|
for pid in root_pids:
|
|
related.update(proc_tree(pid))
|
|
return {"pids": sorted(related), "root_pids": dict(sorted(root_pids.items())), "pid_files": pid_file_map}
|
|
|
|
|
|
def gpu_process_mem_records(pids: list[int], gpu_index: int | None = None) -> list[dict[str, Any]] | None:
|
|
if not pids:
|
|
return None
|
|
wanted = {int(pid) for pid in pids}
|
|
query = "gpu_uuid,pid,process_name,used_memory"
|
|
cp = run(["nvidia-smi", f"--query-compute-apps={query}", "--format=csv,noheader,nounits"])
|
|
if cp.returncode:
|
|
return None
|
|
uuid_by_index: dict[int, str] = {}
|
|
if gpu_index is not None:
|
|
gpu_cp = run(["nvidia-smi", "--query-gpu=index,uuid", "--format=csv,noheader,nounits"])
|
|
for line in gpu_cp.stdout.splitlines():
|
|
parts = [part.strip() for part in line.split(",")]
|
|
if len(parts) >= 2:
|
|
try:
|
|
uuid_by_index[int(parts[0])] = parts[1]
|
|
except ValueError:
|
|
continue
|
|
target_uuid = uuid_by_index.get(gpu_index) if gpu_index is not None else None
|
|
records: list[dict[str, Any]] = []
|
|
for line in cp.stdout.splitlines():
|
|
parts = [part.strip() for part in line.split(",")]
|
|
if len(parts) < 4:
|
|
continue
|
|
try:
|
|
pid = int(parts[1])
|
|
used_mb = float(parts[3])
|
|
except ValueError:
|
|
continue
|
|
if pid not in wanted:
|
|
continue
|
|
if target_uuid is not None and parts[0] != target_uuid:
|
|
continue
|
|
records.append({
|
|
"gpu_uuid": parts[0],
|
|
"pid": pid,
|
|
"process_name": parts[2],
|
|
"used_memory_mb": used_mb,
|
|
"used_memory_gb": round(used_mb / 1024.0, 3),
|
|
})
|
|
return records
|
|
|
|
|
|
def rss_gb_for_pids(pids: list[int]) -> float | None:
|
|
total = 0
|
|
seen = False
|
|
for pid in pids:
|
|
status = Path(f"/proc/{pid}/status")
|
|
if not status.exists():
|
|
continue
|
|
seen = True
|
|
try:
|
|
for line in status.read_text().splitlines():
|
|
if line.startswith("VmRSS:"):
|
|
total += int(line.split()[1])
|
|
break
|
|
except Exception:
|
|
continue
|
|
return round(total / 1024.0 / 1024.0, 3) if seen else None
|
|
|
|
|
|
class ResourceSampler:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
gpu_index: int,
|
|
pids: list[int] | None = None,
|
|
pid_provider: Any | None = None,
|
|
interval: float = 0.2,
|
|
) -> None:
|
|
self.gpu_index = gpu_index
|
|
self.pids = pids or []
|
|
self.pid_provider = pid_provider
|
|
self.interval = interval
|
|
self.max_device_vram_gb: float | None = None
|
|
self.max_process_vram_gb: float | None = None
|
|
self.max_cpu_gb: float | None = None
|
|
self.max_process_vram_pids: list[int] = []
|
|
self.latest_pids: list[int] = list(self.pids)
|
|
self._stop = asyncio.Event()
|
|
|
|
def current_pids(self) -> list[int]:
|
|
if self.pid_provider is None:
|
|
return list(self.pids)
|
|
try:
|
|
pids = list(self.pid_provider())
|
|
except Exception:
|
|
pids = list(self.pids)
|
|
if pids:
|
|
self.latest_pids = sorted({int(pid) for pid in pids})
|
|
return self.latest_pids
|
|
return list(self.latest_pids)
|
|
|
|
async def sample(self) -> None:
|
|
while not self._stop.is_set():
|
|
pids = self.current_pids()
|
|
dev = gpu_device_mem_gb(self.gpu_index)
|
|
proc = gpu_process_mem_gb(pids, gpu_index=self.gpu_index)
|
|
cpu = rss_gb_for_pids(pids)
|
|
if dev is not None:
|
|
self.max_device_vram_gb = dev if self.max_device_vram_gb is None else max(self.max_device_vram_gb, dev)
|
|
if proc is not None and (self.max_process_vram_gb is None or proc > self.max_process_vram_gb):
|
|
self.max_process_vram_gb = proc
|
|
self.max_process_vram_pids = list(pids)
|
|
if cpu is not None:
|
|
self.max_cpu_gb = cpu if self.max_cpu_gb is None else max(self.max_cpu_gb, cpu)
|
|
try:
|
|
await asyncio.wait_for(self._stop.wait(), timeout=self.interval)
|
|
except asyncio.TimeoutError:
|
|
pass
|
|
|
|
def stop(self) -> None:
|
|
self._stop.set()
|
|
|
|
|
|
async def sse_events(api_base_url: str, session_id: str, sink: list[dict[str, Any]], stop: asyncio.Event) -> None:
|
|
import httpx
|
|
|
|
url = api_base_url.rstrip("/") + f"/sessions/{session_id}/events"
|
|
async with httpx.AsyncClient(timeout=None) as client:
|
|
async with client.stream("GET", url) as resp:
|
|
resp.raise_for_status()
|
|
event = "message"
|
|
data_lines: list[str] = []
|
|
async for line in resp.aiter_lines():
|
|
if stop.is_set():
|
|
return
|
|
if line == "":
|
|
if data_lines:
|
|
try:
|
|
sink.append({"event": event, "data": json.loads("\n".join(data_lines)), "received_unix": time.time()})
|
|
except Exception:
|
|
sink.append({"event": event, "data": "\n".join(data_lines), "received_unix": time.time()})
|
|
event = "message"
|
|
data_lines = []
|
|
continue
|
|
if line.startswith("event:"):
|
|
event = line.split(":", 1)[1].strip()
|
|
elif line.startswith("data:"):
|
|
data_lines.append(line.split(":", 1)[1].strip())
|
|
|
|
|
|
async def setup_webrtc(api_base_url: str, session_id: str, first_video: asyncio.Event, first_video_time: dict[str, float], video_frames: dict[str, Any]) -> tuple[Any, dict[str, Any]]:
|
|
from aiortc import RTCPeerConnection, RTCSessionDescription
|
|
from aiortc.contrib.media import MediaRelay
|
|
|
|
pc = RTCPeerConnection()
|
|
relay = MediaRelay()
|
|
pc.addTransceiver("video", direction="recvonly")
|
|
pc.addTransceiver("audio", direction="recvonly")
|
|
track_state: dict[str, Any] = {
|
|
"relay": relay,
|
|
"source_tracks": {},
|
|
"record_tracks": {},
|
|
"audio_ready": asyncio.Event(),
|
|
"video_ready": asyncio.Event(),
|
|
}
|
|
|
|
@pc.on("track")
|
|
def on_track(track: Any) -> None:
|
|
track_state["source_tracks"][track.kind] = track
|
|
track_state["record_tracks"][track.kind] = relay.subscribe(track, buffered=False)
|
|
if track.kind == "audio" and not track_state["audio_ready"].is_set():
|
|
track_state["audio_ready"].set()
|
|
if track.kind == "video" and not track_state["video_ready"].is_set():
|
|
track_state["video_ready"].set()
|
|
if track.kind != "video":
|
|
return
|
|
observer_track = relay.subscribe(track)
|
|
|
|
async def recv_loop() -> None:
|
|
while True:
|
|
try:
|
|
frame = await observer_track.recv()
|
|
except Exception:
|
|
return
|
|
video_frames["count"] = video_frames.get("count", 0) + 1
|
|
capture = video_frames.get("capture")
|
|
frames = video_frames.setdefault("frames", [])
|
|
if capture and len(frames) < int(video_frames.get("max_frames", 250)):
|
|
try:
|
|
frames.append(frame.to_ndarray(format="bgr24"))
|
|
except Exception:
|
|
pass
|
|
if "first_frame" not in video_frames:
|
|
try:
|
|
video_frames["first_frame"] = frame.to_ndarray(format="bgr24")
|
|
except Exception:
|
|
pass
|
|
if not first_video.is_set():
|
|
first_video_time["unix"] = time.time()
|
|
first_video.set()
|
|
|
|
asyncio.create_task(recv_loop())
|
|
|
|
offer = await pc.createOffer()
|
|
await pc.setLocalDescription(offer)
|
|
answer = http_json(
|
|
api_base_url.rstrip("/") + f"/sessions/{session_id}/webrtc/offer",
|
|
method="POST",
|
|
payload={"sdp": pc.localDescription.sdp, "type": pc.localDescription.type},
|
|
timeout=30.0,
|
|
)
|
|
await pc.setRemoteDescription(RTCSessionDescription(sdp=answer["sdp"], type=answer["type"]))
|
|
return pc, track_state
|
|
|
|
|
|
def write_video_sample(path: Path, frames: list[Any], fps: float = 25.0) -> bool:
|
|
if not frames:
|
|
return False
|
|
try:
|
|
import cv2
|
|
import numpy as np
|
|
|
|
frame = np.asarray(frames[0])
|
|
height, width = frame.shape[:2]
|
|
writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), float(fps), (width, height))
|
|
if not writer.isOpened():
|
|
return False
|
|
for item in frames:
|
|
writer.write(np.asarray(item))
|
|
writer.release()
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def mux_audio_into_video(video_path: Path, audio_path: Path) -> bool:
|
|
if not video_path.exists() or not audio_path.exists() or audio_path.stat().st_size == 0:
|
|
return False
|
|
tmp_path = video_path.with_name(video_path.stem + ".with_audio.mp4")
|
|
cp = run([
|
|
"ffmpeg", "-y",
|
|
"-i", str(video_path),
|
|
"-i", str(audio_path),
|
|
"-map", "0:v:0",
|
|
"-map", "1:a:0",
|
|
"-c:v", "copy",
|
|
"-c:a", "aac",
|
|
"-shortest",
|
|
str(tmp_path),
|
|
])
|
|
if cp.returncode:
|
|
return False
|
|
tmp_path.replace(video_path)
|
|
return True
|
|
|
|
|
|
def media_file_has_streams(path: Path) -> bool:
|
|
if not path.exists() or path.stat().st_size == 0:
|
|
return False
|
|
cp = run([
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-show_entries",
|
|
"stream=codec_type",
|
|
"-of",
|
|
"csv=p=0",
|
|
str(path),
|
|
])
|
|
if cp.returncode:
|
|
return False
|
|
streams = {line.strip() for line in cp.stdout.splitlines() if line.strip()}
|
|
return "audio" in streams and "video" in streams
|
|
|
|
|
|
def probe_media_streams(path: Path) -> dict[str, dict[str, float]]:
|
|
cp = run([
|
|
"ffprobe",
|
|
"-v",
|
|
"error",
|
|
"-show_entries",
|
|
"stream=codec_type,start_time,duration",
|
|
"-show_entries",
|
|
"format=duration",
|
|
"-of",
|
|
"json",
|
|
str(path),
|
|
])
|
|
if cp.returncode:
|
|
raise RuntimeError(f"ffprobe failed for {path}: {cp.stderr.strip()}")
|
|
data = json.loads(cp.stdout or "{}")
|
|
format_duration = float((data.get("format") or {}).get("duration") or 0.0)
|
|
streams: dict[str, dict[str, float]] = {}
|
|
for stream in data.get("streams") or []:
|
|
kind = stream.get("codec_type")
|
|
if kind not in {"audio", "video"} or kind in streams:
|
|
continue
|
|
start = float(stream.get("start_time") or 0.0)
|
|
duration = float(stream.get("duration") or format_duration or 0.0)
|
|
streams[kind] = {"start": start, "duration": duration}
|
|
return streams
|
|
|
|
|
|
def normalize_webrtc_sample(path: Path) -> bool:
|
|
if not media_file_has_streams(path):
|
|
return False
|
|
streams = probe_media_streams(path)
|
|
audio = streams.get("audio")
|
|
video = streams.get("video")
|
|
if not audio or not video:
|
|
return False
|
|
common_start = max(audio["start"], video["start"])
|
|
common_end = min(audio["start"] + audio["duration"], video["start"] + video["duration"])
|
|
duration = common_end - common_start
|
|
if duration <= 0.25:
|
|
return False
|
|
|
|
raw_path = path.with_name(path.stem + ".raw" + path.suffix)
|
|
tmp_path = path.with_name(path.stem + ".normalized" + path.suffix)
|
|
path.replace(raw_path)
|
|
video_offset = max(0.0, common_start - video["start"])
|
|
audio_offset = max(0.0, common_start - audio["start"])
|
|
cp = run([
|
|
"ffmpeg",
|
|
"-y",
|
|
"-i",
|
|
str(raw_path),
|
|
"-map",
|
|
"0:v:0",
|
|
"-map",
|
|
"0:a:0",
|
|
"-filter:v",
|
|
f"trim=start={video_offset:.6f}:duration={duration:.6f},setpts=PTS-STARTPTS",
|
|
"-filter:a",
|
|
f"atrim=start={audio_offset:.6f}:duration={duration:.6f},asetpts=PTS-STARTPTS",
|
|
"-c:v",
|
|
"libx264",
|
|
"-pix_fmt",
|
|
"yuv420p",
|
|
"-preset",
|
|
"veryfast",
|
|
"-c:a",
|
|
"aac",
|
|
"-shortest",
|
|
"-movflags",
|
|
"+faststart",
|
|
str(tmp_path),
|
|
])
|
|
if cp.returncode:
|
|
raw_path.replace(path)
|
|
tmp_path.unlink(missing_ok=True)
|
|
return False
|
|
tmp_path.replace(path)
|
|
return media_file_has_streams(path)
|
|
|
|
|
|
def prepare_benchmark_avatar(repo: Path, base_avatar_id: str, *, model: str, timestamp: str) -> tuple[str, Path]:
|
|
avatars_root = repo / "examples" / "avatars"
|
|
base_dir = (avatars_root / base_avatar_id).resolve()
|
|
try:
|
|
base_dir.relative_to(avatars_root.resolve())
|
|
except ValueError as exc:
|
|
raise RuntimeError(f"invalid base avatar id: {base_avatar_id}") from exc
|
|
if not base_dir.is_dir() or not (base_dir / "manifest.json").is_file():
|
|
raise RuntimeError(f"base avatar not found: {base_avatar_id}")
|
|
safe_model = re.sub(r"[^A-Za-z0-9_-]+", "-", model).strip("-") or "model"
|
|
avatar_id = f"benchmark-{timestamp}-{safe_model}"
|
|
target_dir = avatars_root / avatar_id
|
|
if target_dir.exists():
|
|
shutil.rmtree(target_dir)
|
|
shutil.copytree(base_dir, target_dir, ignore=shutil.ignore_patterns("reference_custom.*"))
|
|
manifest_path = target_dir / "manifest.json"
|
|
raw = json.loads(manifest_path.read_text(encoding="utf-8"))
|
|
metadata = dict(raw.get("metadata") or {})
|
|
metadata["base_avatar_id"] = raw.get("id") or base_avatar_id
|
|
raw["id"] = avatar_id
|
|
raw["name"] = f"Benchmark {base_avatar_id} {safe_model}"
|
|
raw["metadata"] = metadata
|
|
manifest_path.write_text(json.dumps(raw, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
|
return avatar_id, target_dir
|
|
|
|
|
|
def load_audio_duration(path: Path, target_path: Path, seconds: float) -> float:
|
|
cp = run(["ffmpeg", "-y", "-i", str(path), "-t", str(seconds), "-ac", "1", "-ar", "16000", "-f", "wav", str(target_path)])
|
|
if cp.returncode:
|
|
raise RuntimeError(cp.stderr.strip())
|
|
with wave.open(str(target_path), "rb") as handle:
|
|
return round(handle.getnframes() / float(handle.getframerate()), 3)
|
|
|
|
|
|
def omnirt_port_from_config(model_cfg: dict[str, Any], cfg: dict[str, Any]) -> int:
|
|
url = str(model_cfg.get("omnirt") or cfg.get("omnirt") or "")
|
|
if not url:
|
|
raise RuntimeError("backend=omnirt requires an omnirt URL in model config")
|
|
try:
|
|
return int(url.rstrip("/").rsplit(":", 1)[1])
|
|
except Exception as exc:
|
|
raise RuntimeError(f"cannot parse OmniRT port from URL: {url}") from exc
|
|
|
|
|
|
def stop_pid(pid: int, timeout: float = 20.0) -> None:
|
|
targets = list(reversed(proc_tree(pid)))
|
|
for item in targets:
|
|
try:
|
|
os.kill(item, signal.SIGTERM)
|
|
except ProcessLookupError:
|
|
pass
|
|
except PermissionError:
|
|
pass
|
|
deadline = time.perf_counter() + timeout
|
|
while time.perf_counter() < deadline:
|
|
if not any(Path(f"/proc/{item}").exists() for item in targets):
|
|
return
|
|
time.sleep(0.2)
|
|
for item in targets:
|
|
try:
|
|
os.kill(item, signal.SIGKILL)
|
|
except ProcessLookupError:
|
|
pass
|
|
except PermissionError:
|
|
pass
|
|
|
|
|
|
def stop_pid_files(pid_files: list[Path]) -> None:
|
|
for pid_file in pid_files:
|
|
pid = read_pid_file(pid_file)
|
|
if pid is not None:
|
|
stop_pid(pid)
|
|
try:
|
|
pid_file.unlink()
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
|
|
def omnirt_extra_ports(model: str, model_cfg: dict[str, Any], port: int) -> list[int]:
|
|
ports = [port]
|
|
if model == "musetalk":
|
|
ports.append(int(model_cfg.get("musetalk_port", 8766)))
|
|
return ports
|
|
|
|
|
|
def stop_omnirt_model(repo: Path, cfg: dict[str, Any], model: str, backend: str, model_cfg: dict[str, Any]) -> None:
|
|
if backend != "omnirt":
|
|
return
|
|
port = omnirt_port_from_config(model_cfg, cfg)
|
|
pid_files = default_model_pid_files(model, backend, repo)
|
|
configured = model_cfg.get("pid_files")
|
|
if isinstance(configured, list):
|
|
pid_files.extend(resolve_pid_file(str(item), repo) for item in configured)
|
|
stop_pid_files(pid_files)
|
|
for pid in pids_for_ports(omnirt_extra_ports(model, model_cfg, port)):
|
|
stop_pid(pid)
|
|
|
|
|
|
def start_omnirt_model(repo: Path, cfg: dict[str, Any], model: str, backend: str, model_cfg: dict[str, Any], gpu_index: int, out_dir: Path) -> None:
|
|
if backend != "omnirt":
|
|
return
|
|
port = omnirt_port_from_config(model_cfg, cfg)
|
|
env = os.environ.copy()
|
|
home = str(Path.home())
|
|
repo_parent = str(repo.parent)
|
|
env["PATH"] = f"{home}/.local/bin:/usr/local/bin:/usr/bin:/bin:" + env.get("PATH", "")
|
|
env["DIGITAL_HUMAN_HOME"] = repo_parent
|
|
env["OMNIRT_PORT"] = str(port)
|
|
env["OMNIRT_HOST"] = "0.0.0.0"
|
|
env["CUDA_VISIBLE_DEVICES"] = str(gpu_index)
|
|
env["TMPDIR"] = str(repo.parent / "tmp")
|
|
env["PIP_CACHE_DIR"] = env.get("PIP_CACHE_DIR", f"{home}/.cache/pip")
|
|
env["UV_CACHE_DIR"] = env.get("UV_CACHE_DIR", f"{home}/.cache/uv")
|
|
env.setdefault("UV_DEFAULT_INDEX", "https://pypi.tuna.tsinghua.edu.cn/simple")
|
|
env.setdefault("PIP_INDEX_URL", "https://pypi.tuna.tsinghua.edu.cn/simple")
|
|
Path(env["TMPDIR"]).mkdir(parents=True, exist_ok=True)
|
|
|
|
if model == "wav2lip":
|
|
env["OMNIRT_MODEL_ROOT"] = str(repo.parent / "models")
|
|
cmd = ["bash", "scripts/quickstart/start_omnirt_wav2lip.sh", "--device", "cuda", "--port", str(port), "--skip-install"]
|
|
elif model == "musetalk":
|
|
env["OMNIRT_MODEL_ROOT"] = str(repo.parent / "models")
|
|
musetalk_port = str(model_cfg.get("musetalk_port", 8766))
|
|
env["OMNIRT_MUSETALK_PORT"] = musetalk_port
|
|
cmd = ["bash", "scripts/quickstart/start_omnirt_musetalk.sh", "--device", "cuda", "--port", str(port), "--musetalk-port", musetalk_port, "--skip-install"]
|
|
elif model == "quicktalk":
|
|
quicktalk_root = repo / "models" / "quicktalk" / "checkpoints"
|
|
env["OMNIRT_MODEL_ROOT"] = str(repo / "models")
|
|
env["OMNIRT_QUICKTALK_MODEL_ROOT"] = str(quicktalk_root)
|
|
env["OMNIRT_QUICKTALK_CHECKPOINT"] = str(quicktalk_root / "quicktalk.pth")
|
|
cmd = ["bash", "scripts/quickstart/start_omnirt_quicktalk.sh", "--device", "cuda:0", "--port", str(port), "--skip-install"]
|
|
else:
|
|
raise RuntimeError(f"unsupported benchmark-managed OmniRT model: {model}")
|
|
|
|
original_env = repo / "scripts" / "quickstart" / "env"
|
|
benchmark_env = out_dir / "logs" / "benchmark-quickstart.env"
|
|
override_keys = [
|
|
"DIGITAL_HUMAN_HOME",
|
|
"OMNIRT_MODEL_ROOT",
|
|
"OMNIRT_PORT",
|
|
"OMNIRT_HOST",
|
|
"CUDA_VISIBLE_DEVICES",
|
|
"TMPDIR",
|
|
"PIP_CACHE_DIR",
|
|
"UV_CACHE_DIR",
|
|
"UV_DEFAULT_INDEX",
|
|
"PIP_INDEX_URL",
|
|
"OMNIRT_MUSETALK_PORT",
|
|
"OMNIRT_QUICKTALK_MODEL_ROOT",
|
|
"OMNIRT_QUICKTALK_CHECKPOINT",
|
|
]
|
|
lines = ["# Generated by benchmark_opentalking_e2e.py; do not edit."]
|
|
if original_env.exists():
|
|
lines.append(f"source {shlex.quote(str(original_env))}")
|
|
for key in override_keys:
|
|
if key in env:
|
|
lines.append(f"export {key}={shlex.quote(str(env[key]))}")
|
|
benchmark_env.parent.mkdir(parents=True, exist_ok=True)
|
|
benchmark_env.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
|
env["OPENTALKING_QUICKSTART_ENV"] = str(benchmark_env)
|
|
|
|
log_path = out_dir / "logs" / "start-omnirt.log"
|
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with log_path.open("ab") as handle:
|
|
cp = subprocess.run(cmd, cwd=repo, env=env, stdout=handle, stderr=subprocess.STDOUT)
|
|
if cp.returncode:
|
|
raise RuntimeError(f"failed to start OmniRT {model}; see {log_path}")
|
|
|
|
|
|
def start_opentalking(repo: Path, cfg: dict[str, Any], model: str, backend: str, model_cfg: dict[str, Any], out_dir: Path) -> subprocess.Popen[bytes]:
|
|
env = os.environ.copy()
|
|
env["OPENTALKING_BENCHMARK_TIMING"] = "1"
|
|
env["OPENTALKING_TTS_PROVIDER"] = str(cfg.get("tts_provider") or "edge")
|
|
if cfg.get("tts_voice"):
|
|
env["OPENTALKING_TTS_VOICE"] = str(cfg["tts_voice"])
|
|
if cfg.get("tts_model"):
|
|
env["OPENTALKING_TTS_MODEL"] = str(cfg["tts_model"])
|
|
cmd = [
|
|
"bash",
|
|
"scripts/start_unified.sh",
|
|
"--backend",
|
|
backend,
|
|
"--model",
|
|
model,
|
|
"--api-port",
|
|
str(cfg.get("api_port", 8010)),
|
|
"--web-port",
|
|
str(cfg.get("web_port", 5184)),
|
|
"--host",
|
|
str(cfg.get("host", "0.0.0.0")),
|
|
]
|
|
if backend == "omnirt":
|
|
omnirt = str(model_cfg.get("omnirt") or cfg.get("omnirt") or "")
|
|
if not omnirt:
|
|
raise RuntimeError("backend=omnirt requires model omnirt URL")
|
|
cmd.extend(["--omnirt", omnirt])
|
|
log_path = out_dir / "logs" / "start-opentalking.log"
|
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
|
handle = log_path.open("ab")
|
|
return subprocess.Popen(cmd, cwd=repo, env=env, stdout=handle, stderr=subprocess.STDOUT, start_new_session=True)
|
|
|
|
|
|
def stop_opentalking(repo: Path, cfg: dict[str, Any]) -> None:
|
|
api_port = int(cfg.get("api_port", 8010))
|
|
web_port = int(cfg.get("web_port", 5184))
|
|
script = f"""
|
|
set +e
|
|
repo_root={str(repo)!r}
|
|
home_dir="$(cd "$repo_root/.." && pwd)"
|
|
run_dir="$home_dir/run"
|
|
stop_pid_file() {{
|
|
pid_file="$1"
|
|
if [ -f "$pid_file" ]; then
|
|
pid="$(cat "$pid_file" 2>/dev/null)"
|
|
if [ -n "$pid" ] && kill -0 "$pid" >/dev/null 2>&1; then
|
|
kill "$pid" >/dev/null 2>&1 || true
|
|
for _ in $(seq 1 20); do
|
|
kill -0 "$pid" >/dev/null 2>&1 || break
|
|
sleep 0.5
|
|
done
|
|
kill -0 "$pid" >/dev/null 2>&1 && kill -9 "$pid" >/dev/null 2>&1 || true
|
|
fi
|
|
rm -f "$pid_file"
|
|
fi
|
|
}}
|
|
stop_pid_file "$run_dir/opentalking-api-{api_port}.pid"
|
|
stop_pid_file "$run_dir/opentalking-web-{web_port}.pid"
|
|
for pid in $(pgrep -f "$repo_root/.venv/bin/.*opentalking-unified" || true); do
|
|
[ "$pid" = "$$" ] && continue
|
|
if tr '\\0' '\\n' < "/proc/$pid/environ" 2>/dev/null | grep -qx "OPENTALKING_UNIFIED_PORT={api_port}"; then
|
|
kill "$pid" >/dev/null 2>&1 || true
|
|
fi
|
|
done
|
|
for pid in $(pgrep -f "$repo_root/apps/web/node_modules/.bin/vite .*--port {web_port}" || true); do
|
|
[ "$pid" = "$$" ] && continue
|
|
kill "$pid" >/dev/null 2>&1 || true
|
|
done
|
|
"""
|
|
run_shell(script, cwd=repo)
|
|
|
|
|
|
def write_csv(path: Path, row: dict[str, Any]) -> None:
|
|
keys = [
|
|
"测试日期", "测试人", "模型", "技术路线", "backend", "硬件", "OS", "驱动环境", "commit",
|
|
"输入类型", "输出分辨率", "输出 FPS", "chunk size", "冷启动时间", "预热时间", "TTFA",
|
|
"TTFV", "首轮总延迟", "稳态 FPS", "RTF", "idle 显存", "推理峰值显存",
|
|
]
|
|
with path.open("w", encoding="utf-8", newline="") as handle:
|
|
writer = csv.DictWriter(handle, fieldnames=keys)
|
|
writer.writeheader()
|
|
writer.writerow({key: row.get(key, "not_measured") for key in keys})
|
|
|
|
|
|
async def run_once(args: argparse.Namespace) -> None:
|
|
repo = Path(args.repo_root).resolve()
|
|
cfg = read_yaml(Path(args.config).resolve() if Path(args.config).is_absolute() else repo / args.config)
|
|
model = args.model or str(cfg.get("model"))
|
|
backend = args.backend or str(cfg.get("backend"))
|
|
if not args.tester and not cfg.get("tester"):
|
|
raise SystemExit("--tester or tester in config is required")
|
|
tester = args.tester or str(cfg["tester"])
|
|
models = cfg.get("models") if isinstance(cfg.get("models"), dict) else {}
|
|
model_cfg = dict(models.get(model, {}))
|
|
gpu_index = int(args.gpu_index if args.gpu_index is not None else model_cfg.get("gpu", cfg.get("gpu_index", 0)))
|
|
model_cfg.setdefault("backend", backend)
|
|
backend = str(model_cfg.get("backend") or backend)
|
|
api_base_url = str(args.api_base_url or cfg.get("api_base_url") or f"http://127.0.0.1:{cfg.get('api_port', 8010)}")
|
|
hardware = gpu_info(gpu_index)
|
|
label = re.sub(r"[^A-Za-z0-9]+", "_", str(hardware.get("name", "unknown_gpu"))).strip("_")
|
|
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
out_dir = Path(args.out_dir).resolve() if args.out_dir else repo / "outputs" / "benchmarks" / "opentalking-e2e" / f"{timestamp}-{label}-{model}-{backend}"
|
|
for sub in ("logs", "raw", "samples"):
|
|
(out_dir / sub).mkdir(parents=True, exist_ok=True)
|
|
|
|
input_cfg = cfg.get("input") if isinstance(cfg.get("input"), dict) else {}
|
|
audio_src = Path(resolve_path(str(input_cfg.get("audio_path", "configs/benchmark/input/ttsmaker-file.mp3")), repo))
|
|
ref_src = Path(resolve_path(str(input_cfg.get("ref_image", "configs/benchmark/input/reference.png")), repo))
|
|
shutil.copy2(ref_src, out_dir / "input_reference.png")
|
|
audio_duration = load_audio_duration(audio_src, out_dir / "input_audio.wav", float(input_cfg.get("audio_duration_seconds", 7.0)))
|
|
|
|
stop_opentalking(repo, cfg)
|
|
omnirt_started_by_benchmark = False
|
|
if backend == "omnirt" and not args.reuse_omnirt:
|
|
stop_omnirt_model(repo, cfg, model, backend, model_cfg)
|
|
cold_start0 = time.perf_counter()
|
|
if backend == "omnirt" and not args.reuse_omnirt:
|
|
start_omnirt_model(repo, cfg, model, backend, model_cfg, gpu_index, out_dir)
|
|
omnirt_started_by_benchmark = True
|
|
process = start_opentalking(repo, cfg, model, backend, model_cfg, out_dir)
|
|
try:
|
|
wait_for_health(api_base_url, timeout=float(args.timeout))
|
|
model_status = wait_for_model(api_base_url, model, backend, timeout=float(args.timeout))
|
|
cold_start = time.perf_counter() - cold_start0
|
|
ports = [int(cfg.get("api_port", 8010))]
|
|
if backend == "omnirt" and model_cfg.get("omnirt"):
|
|
try:
|
|
ports.append(int(str(model_cfg["omnirt"]).rstrip("/").rsplit(":", 1)[1]))
|
|
except Exception:
|
|
pass
|
|
pid_info = collect_related_pids(repo, cfg, model, backend, model_cfg, ports)
|
|
pids = pid_info["pids"]
|
|
|
|
def current_related_pids() -> list[int]:
|
|
return collect_related_pids(repo, cfg, model, backend, model_cfg, ports)["pids"]
|
|
|
|
service_ready_device_vram = gpu_device_mem_gb(gpu_index)
|
|
service_ready_proc_vram = retry_gpu_process_mem_gb(pids, gpu_index=gpu_index)
|
|
|
|
|
|
base_avatar_id = str(args.avatar_id or cfg.get("avatar_id", "office-woman"))
|
|
avatar_id, benchmark_avatar_dir = prepare_benchmark_avatar(repo, base_avatar_id, model=model, timestamp=timestamp)
|
|
try:
|
|
upload_reference(api_base_url, avatar_id, ref_src)
|
|
except Exception as exc:
|
|
raise RuntimeError(f"failed to upload benchmark reference image: {exc}") from exc
|
|
create_payload = {
|
|
"avatar_id": avatar_id,
|
|
"model": model,
|
|
"tts_provider": cfg.get("tts_provider"),
|
|
"tts_voice": cfg.get("tts_voice"),
|
|
}
|
|
session = http_json(api_base_url.rstrip("/") + "/sessions", method="POST", payload={k: v for k, v in create_payload.items() if v}, timeout=60.0)
|
|
session_id = session["session_id"]
|
|
events: list[dict[str, Any]] = []
|
|
stop_sse = asyncio.Event()
|
|
sse_task = asyncio.create_task(sse_events(api_base_url, session_id, events, stop_sse))
|
|
await asyncio.sleep(0.5)
|
|
first_video = asyncio.Event()
|
|
first_video_time: dict[str, float] = {}
|
|
video_frames: dict[str, Any] = {"capture": False, "max_frames": 250, "frames": []}
|
|
pc, track_state = await setup_webrtc(api_base_url, session_id, first_video, first_video_time, video_frames)
|
|
|
|
warm_start = time.perf_counter()
|
|
speak_common = {
|
|
"tts_provider": cfg.get("tts_provider"),
|
|
"voice": cfg.get("tts_voice"),
|
|
"tts_model": cfg.get("tts_model") or None,
|
|
}
|
|
warm_payload = {"text": str(cfg.get("warmup_prompt", "你好,这是预热测试。")), **speak_common}
|
|
http_json(api_base_url.rstrip("/") + f"/sessions/{session_id}/speak", method="POST", payload={k: v for k, v in warm_payload.items() if v}, timeout=30.0)
|
|
warm_deadline = time.perf_counter() + float(args.timeout)
|
|
while time.perf_counter() < warm_deadline and not any(e["event"] == "speech.timing" for e in events):
|
|
await asyncio.sleep(0.2)
|
|
warmup_seconds = time.perf_counter() - warm_start
|
|
post_warmup_pids = current_related_pids()
|
|
idle_device_vram = gpu_device_mem_gb(gpu_index)
|
|
idle_proc_vram = retry_gpu_process_mem_gb(post_warmup_pids, gpu_index=gpu_index)
|
|
idle_vram_measured = idle_proc_vram is not None
|
|
events.clear()
|
|
first_video.clear()
|
|
first_video_time.clear()
|
|
video_frames["count"] = 0
|
|
video_frames["frames"] = []
|
|
video_frames.pop("first_frame", None)
|
|
video_frames["capture"] = True
|
|
|
|
try:
|
|
await asyncio.wait_for(
|
|
asyncio.gather(
|
|
track_state["video_ready"].wait(),
|
|
track_state["audio_ready"].wait(),
|
|
),
|
|
timeout=10.0,
|
|
)
|
|
except Exception:
|
|
pass
|
|
record_tracks = track_state["record_tracks"]
|
|
video_track = record_tracks.get("video")
|
|
audio_track = record_tracks.get("audio")
|
|
|
|
sampler = ResourceSampler(gpu_index=gpu_index, pids=post_warmup_pids, pid_provider=current_related_pids)
|
|
sampler_task = asyncio.create_task(sampler.sample())
|
|
request_unix = time.time()
|
|
speak_payload = {"text": str(cfg.get("prompt", "OpenTalking benchmark fixed input")), **speak_common}
|
|
http_json(api_base_url.rstrip("/") + f"/sessions/{session_id}/speak", method="POST", payload={k: v for k, v in speak_payload.items() if v}, timeout=30.0)
|
|
try:
|
|
await asyncio.wait_for(first_video.wait(), timeout=float(args.timeout))
|
|
except asyncio.TimeoutError as exc:
|
|
raise TimeoutError("first WebRTC video frame was not received") from exc
|
|
|
|
sample_path: Path | None = None
|
|
sample_written = False
|
|
video_mock_path = out_dir / "samples" / "video_output_mocked.txt"
|
|
video_mock_path.write_text(
|
|
"\n".join(
|
|
[
|
|
"E2E sample video output is intentionally mocked/disabled.",
|
|
"The benchmark still observes the OpenTalking WebRTC video track for first-frame and timing metrics.",
|
|
"A recorded MP4 is not equivalent to the browser's real-time WebRTC playback because container muxing preserves track timestamp offsets.",
|
|
]
|
|
)
|
|
+ "\n",
|
|
encoding="utf-8",
|
|
)
|
|
|
|
deadline = time.perf_counter() + float(args.timeout)
|
|
timing: dict[str, Any] | None = None
|
|
while time.perf_counter() < deadline:
|
|
for event in events:
|
|
if event["event"] == "speech.timing":
|
|
timing = dict(event["data"])
|
|
elif event["event"] == "speech.ended" and timing and first_video.is_set():
|
|
deadline = min(deadline, time.perf_counter() + 1.0)
|
|
if timing and first_video.is_set():
|
|
if len(video_frames.get("frames", [])) >= min(50, int(video_frames.get("max_frames", 250))):
|
|
if any(e["event"] == "speech.ended" for e in events):
|
|
break
|
|
await asyncio.sleep(0.2)
|
|
video_frames["capture"] = False
|
|
sampler.stop()
|
|
await sampler_task
|
|
if timing is None:
|
|
raise TimeoutError("speech.timing was not received")
|
|
if idle_proc_vram is None or sampler.max_process_vram_gb is None:
|
|
raise RuntimeError(
|
|
"related-process GPU memory was not measured; "
|
|
f"idle={idle_proc_vram!r}, peak={sampler.max_process_vram_gb!r}, pids={pids!r}"
|
|
)
|
|
peak_process_vram = max(idle_proc_vram, sampler.max_process_vram_gb)
|
|
|
|
# The WebRTC track can deliver an idle/reference frame before speech starts.
|
|
# Use the server-side speech media milestone for the required first-response
|
|
# metric, and keep the aiortc frame observation as transport evidence.
|
|
e2e_first = timing.get("e2e_first_response_ms")
|
|
chunk_lat = [float(v) for v in timing.get("chunk_latency_ms", [])]
|
|
actual_width = timing.get("output_width") or model_cfg.get("width", "not_measured")
|
|
actual_height = timing.get("output_height") or model_cfg.get("height", "not_measured")
|
|
actual_fps = timing.get("output_fps") or model_cfg.get("fps", model_status.get("fps", "not_measured"))
|
|
if timing.get("chunk_samples") and timing.get("sample_rate"):
|
|
actual_chunk = f"{round(float(timing['chunk_samples']) / float(timing['sample_rate']) * 1000)}ms"
|
|
else:
|
|
actual_chunk = model_cfg.get("chunk_size", "not_measured")
|
|
result = {
|
|
"测试日期": cfg.get("test_date") or datetime.now().strftime("%Y-%m-%d"),
|
|
"测试人": tester,
|
|
"模型": model,
|
|
"技术路线": technical_route_for_model(model, model_cfg, cfg),
|
|
"backend": backend,
|
|
"硬件": hardware.get("name"),
|
|
"OS": platform.platform(),
|
|
"驱动环境": f"driver {hardware.get('driver')} / torch unknown",
|
|
"commit": f"{run(['git', 'rev-parse', 'HEAD'], cwd=repo).stdout.strip()} + {run(['git', '-C', str(repo.parent / 'omnirt'), 'rev-parse', 'HEAD']).stdout.strip()}",
|
|
"输入类型": cfg.get("input_type", "audio+image"),
|
|
"输出分辨率": f"{actual_width}x{actual_height}",
|
|
"输出 FPS": actual_fps,
|
|
"chunk size": actual_chunk,
|
|
"冷启动时间": round(cold_start, 3),
|
|
"预热时间": round(warmup_seconds, 3),
|
|
"TTFA": round(float(timing.get("ttfa_ms") or 0.0), 3),
|
|
"TTFV": round(float(timing.get("ttfv_ms") or 0.0), 3),
|
|
"首轮总延迟": round(float(e2e_first or 0.0), 3),
|
|
"稳态 FPS": round(float(timing.get("steady_fps") or 0.0), 3),
|
|
"RTF": round(float(timing.get("rtf") or 0.0), 4),
|
|
"idle 显存": idle_proc_vram,
|
|
"推理峰值显存": peak_process_vram,
|
|
}
|
|
raw = {
|
|
"result": result,
|
|
"timing": timing,
|
|
"events": events,
|
|
"model_status": model_status,
|
|
"resource": {
|
|
"gpu_index": gpu_index,
|
|
"vram_result_scope": "related_process_only",
|
|
"vram_definition": {
|
|
"idle_vram_gb": "OpenTalking + OmniRT related PID GPU memory on the target GPU after warmup and before the measured speak request",
|
|
"peak_inference_vram_gb": "Peak GPU memory of the same related PID set on the target GPU during the measured speak request",
|
|
},
|
|
"vram_measurement_status": {
|
|
"idle_process_measured": idle_vram_measured,
|
|
"peak_process_measured": sampler.max_process_vram_gb is not None,
|
|
"device_values_are_diagnostic_only": True,
|
|
},
|
|
"service_ready_device_vram_gb_diagnostic": service_ready_device_vram,
|
|
"service_ready_process_vram_gb": service_ready_proc_vram,
|
|
"post_warmup_idle_device_vram_gb_diagnostic": idle_device_vram,
|
|
"idle_process_vram_gb": idle_proc_vram,
|
|
"peak_device_vram_gb_diagnostic": sampler.max_device_vram_gb,
|
|
"peak_process_vram_gb": peak_process_vram,
|
|
"speak_sample_peak_process_vram_gb": sampler.max_process_vram_gb,
|
|
"peak_cpu_gb": sampler.max_cpu_gb,
|
|
"pids": pids,
|
|
"post_warmup_pids": post_warmup_pids,
|
|
"latest_pids": sampler.latest_pids,
|
|
"peak_process_vram_pids": sampler.max_process_vram_pids,
|
|
"root_pids": pid_info["root_pids"],
|
|
"pid_files": pid_info["pid_files"],
|
|
"ports": ports,
|
|
"nvidia_smi_process_records": gpu_process_mem_records(sampler.latest_pids or pids, gpu_index=gpu_index) or [],
|
|
"nvidia_smi_snapshot": run(["nvidia-smi"]).stdout,
|
|
},
|
|
"avatar": {"base_avatar_id": base_avatar_id, "benchmark_avatar_id": avatar_id, "benchmark_avatar_dir": str(benchmark_avatar_dir)},
|
|
"input": {"audio": str(audio_src), "reference": str(ref_src), "duration_seconds": audio_duration},
|
|
"output_video_frames_observed": video_frames.get("count", 0),
|
|
"output_sample_path": "",
|
|
"output_video_mocked": True,
|
|
"output_video_mock_note": str(video_mock_path),
|
|
}
|
|
(out_dir / "result.json").write_text(json.dumps(raw, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
|
write_csv(out_dir / "result.csv", result)
|
|
md = ["# OpenTalking E2E Benchmark", ""]
|
|
for key, value in result.items():
|
|
md.append(f"- {key}: {value}")
|
|
md += [
|
|
f"- chunk latency p50: {round(percentile(chunk_lat, 0.50) or 0.0, 3)}ms",
|
|
f"- chunk latency p95: {round(percentile(chunk_lat, 0.95) or 0.0, 3)}ms",
|
|
f"- 日志路径: {out_dir / 'logs'}",
|
|
f"- 输出样例路径: mocked ({video_mock_path})",
|
|
"",
|
|
]
|
|
report = out_dir / f"{label}_opentalking_e2e_benchmark_result.md"
|
|
report.write_text("\n".join(md), encoding="utf-8")
|
|
archive = out_dir / f"{label}_opentalking_e2e_artifacts.tar.gz"
|
|
with tarfile.open(archive, "w:gz") as tar:
|
|
for path in sorted(out_dir.rglob("*")):
|
|
if path != archive and path.is_file():
|
|
tar.add(path, arcname=path.relative_to(out_dir))
|
|
await pc.close()
|
|
stop_sse.set()
|
|
sse_task.cancel()
|
|
print(json.dumps({"output_dir": str(out_dir), "result": result, "raw": str(out_dir / "result.json")}, ensure_ascii=False, indent=2))
|
|
finally:
|
|
if "benchmark_avatar_dir" in locals():
|
|
shutil.rmtree(benchmark_avatar_dir, ignore_errors=True)
|
|
stop_opentalking(repo, cfg)
|
|
if process.poll() is None:
|
|
try:
|
|
os.killpg(process.pid, signal.SIGTERM)
|
|
except Exception:
|
|
process.terminate()
|
|
if omnirt_started_by_benchmark and not args.keep_omnirt:
|
|
stop_omnirt_model(repo, cfg, model, backend, model_cfg)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", default="configs/benchmark/opentalking-e2e.yaml")
|
|
parser.add_argument("--repo-root", default=".")
|
|
parser.add_argument("--out-dir", default="")
|
|
parser.add_argument("--api-base-url", default="")
|
|
parser.add_argument("--backend", default="")
|
|
parser.add_argument("--model", default="")
|
|
parser.add_argument("--avatar-id", default="")
|
|
parser.add_argument("--tester", default="")
|
|
parser.add_argument("--gpu-index", type=int, default=None)
|
|
parser.add_argument("--timeout", type=float, default=240.0)
|
|
parser.add_argument("--reuse-omnirt", action="store_true", help="reuse an already-running OmniRT service; cold start then excludes OmniRT startup")
|
|
parser.add_argument("--keep-omnirt", action="store_true", help="keep benchmark-started OmniRT service running after the test")
|
|
args = parser.parse_args()
|
|
asyncio.run(run_once(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|