feat: offline runtime providers (CosyVoice TRT/spk sidecar, SenseVoice, shared voice assets) (#124)

平台中立的本地运行时增强,Linux/Windows 本地与一键部署包通用:
- 共享音色库 voice_assets + 跨 provider 复用
- CosyVoice 本地 sidecar: TRT/autocast_fp16 加速, zero-shot spk 预存, 独立 venv
- SenseVoice 本地 ASR: model_dir 绝对路径修复 (避免被 funasr 当 repo_id)
- DashScope STT key 兜底, QuickTalk CUDA OOM 回退 CPU, Mem0 不可用回退 SQLite
- 前端: 解除人设/驱动模型耦合, 音色展示去技术前缀, mock TTS provider
- 5 个内置 system 音色资产

不含 Windows 部署包(bundle/)、行尾归一、平台专属脚本。
This commit is contained in:
zyairehhh
2026-06-26 19:08:22 +08:00
committed by GitHub
parent c112faccac
commit faad141b36
45 changed files with 1726 additions and 553 deletions

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import tempfile
from dataclasses import asdict
from pathlib import Path
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
from pydantic import BaseModel
@@ -162,27 +163,32 @@ async def _add_uploaded_document(
filename = file.filename or "document.txt"
mime_type = file.content_type or "application/octet-stream"
total = 0
with tempfile.NamedTemporaryFile(prefix="opentalking-kb-", delete=True) as tmp:
while True:
chunk = await file.read(1024 * 1024)
if not chunk:
break
total += len(chunk)
if total > MAX_DOCUMENT_BYTES:
raise HTTPException(status_code=413, detail="document is larger than 20MB")
tmp.write(chunk)
tmp.flush()
tmp_path: Path | None = None
try:
with tempfile.NamedTemporaryFile(prefix="opentalking-kb-", delete=False) as tmp:
tmp_path = Path(tmp.name)
while True:
chunk = await file.read(1024 * 1024)
if not chunk:
break
total += len(chunk)
if total > MAX_DOCUMENT_BYTES:
raise HTTPException(status_code=413, detail="document is larger than 20MB")
tmp.write(chunk)
try:
doc = await store.add_document(
kb_id=kb_id,
filename=filename,
mime_type=mime_type,
source_path=tmp.name,
source_path=tmp_path,
)
except DuplicateKnowledgeDocumentError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
finally:
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
return KnowledgeDocumentResponse(**asdict(doc))
@@ -194,26 +200,31 @@ async def _add_uploaded_file(
filename = file.filename or "document.txt"
mime_type = file.content_type or "application/octet-stream"
total = 0
with tempfile.NamedTemporaryFile(prefix="opentalking-kb-file-", delete=True) as tmp:
while True:
chunk = await file.read(1024 * 1024)
if not chunk:
break
total += len(chunk)
if total > MAX_DOCUMENT_BYTES:
raise HTTPException(status_code=413, detail="document is larger than 20MB")
tmp.write(chunk)
tmp.flush()
tmp_path: Path | None = None
try:
with tempfile.NamedTemporaryFile(prefix="opentalking-kb-file-", delete=False) as tmp:
tmp_path = Path(tmp.name)
while True:
chunk = await file.read(1024 * 1024)
if not chunk:
break
total += len(chunk)
if total > MAX_DOCUMENT_BYTES:
raise HTTPException(status_code=413, detail="document is larger than 20MB")
tmp.write(chunk)
try:
doc = await store.add_file(
filename=filename,
mime_type=mime_type,
source_path=tmp.name,
source_path=tmp_path,
)
except DuplicateKnowledgeDocumentError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
finally:
if tmp_path is not None:
tmp_path.unlink(missing_ok=True)
return KnowledgeDocumentResponse(**asdict(doc))

View File

@@ -706,12 +706,33 @@ def _prewarm_local_backend(
settings=settings,
overwrite=overwrite,
)
return _prewarm_local_adapter(
model,
avatar_dir,
settings,
prepared_cache=prepared_cache,
)
try:
return _prewarm_local_adapter(
model,
avatar_dir,
settings,
prepared_cache=prepared_cache,
)
except RuntimeError as exc:
if model == "quicktalk" and "out of memory" in str(exc).lower():
cache = {
"model": model,
"status": "ready",
"source_mode": "local",
"frames": prepared_cache.frames if prepared_cache is not None else None,
"detail": "local adapter prepared QuickTalk assets but warmup ran out of memory",
"prepared_status": prepared_cache.status if prepared_cache is not None else None,
}
runtime = {
"type": "local_prewarm_result",
"backend": "local",
"model": model,
"warmed": False,
"elapsed_ms": 0.0,
"message": str(exc),
}
return cache, runtime
raise
def _quicktalk_runtime_payload(
@@ -1063,11 +1084,12 @@ async def prewarm_avatar(avatar_id: str, request: Request) -> dict[str, Any]:
)
except Exception as exc: # noqa: BLE001
raise HTTPException(status_code=500, detail=f"failed to prewarm local {model}: {exc}") from exc
runtime_status = "failed" if not bool(runtime.get("warmed", True)) else "ready"
return {
"avatar_id": avatar_id,
"model": model,
"status": "ready",
"runtime_status": "ready",
"runtime_status": runtime_status,
"cache": cache_response,
"runtime": runtime,
}

View File

@@ -50,6 +50,7 @@ def _runtime_status_payload(request: Request) -> dict[str, Any]:
tts_provider_list = [tts_provider, *tts_provider_list]
tts_status_providers = [*tts_provider_list]
for provider in (
"mock",
"local_cosyvoice",
"indextts",
"dashscope",
@@ -63,9 +64,11 @@ def _runtime_status_payload(request: Request) -> dict[str, Any]:
tts_status_providers.append(provider)
tts_provider_map = {provider: tts_provider_config(provider) for provider in tts_status_providers}
tts_effective = tts_provider_map.get(tts_provider, tts)
llm_key = os.environ.get("OPENTALKING_LLM_API_KEY", "").strip() or str(
getattr(settings, "llm_api_key", "") or ""
).strip()
llm_key = (
os.environ.get("OPENTALKING_LLM_API_KEY", "").strip()
or os.environ.get("DASHSCOPE_API_KEY", "").strip()
or str(getattr(settings, "llm_api_key", "") or "").strip()
)
ignored_legacy_env = [name for name in _IGNORED_LEGACY_ENV if os.environ.get(name)]
quicktalk_backend = os.environ.get("OPENTALKING_QUICKTALK_BACKEND", "").strip() or str(
getattr(settings, "quicktalk_backend", "") or ""

View File

@@ -15,6 +15,7 @@ from opentalking.persona.wechat_import import WeChatImportJobRegistry
from opentalking.providers.memory.decision_agent import MemoryDecisionAgent
from opentalking.providers.memory.factory import build_memory_provider
from opentalking.providers.memory.runtime import MemoryRuntime, normalize_memory_scope
from opentalking.providers.memory.sqlite_provider import SQLiteMemoryProvider
router = APIRouter(prefix="/memory", tags=["memory"])
@@ -38,16 +39,6 @@ class MemoryImportRequest(BaseModel):
source: str | None = None
class WeChatSpeakerSelectionRequest(BaseModel):
target_speaker_id: str
class WeChatImportCommitRequest(BaseModel):
persona_id: str
persona_name: str | None = None
description: str | None = None
def _profile(value: str | None) -> str:
return (value or get_settings().memory_default_profile_id or "default").strip() or "default"
@@ -76,6 +67,28 @@ def _wechat_registry(request: Request) -> WeChatImportJobRegistry:
return registry
def _fallback_memory_provider() -> SQLiteMemoryProvider:
settings = get_settings()
return SQLiteMemoryProvider(settings.memory_sqlite_path)
async def _memory_provider():
try:
return build_memory_provider()
except Exception: # noqa: BLE001
return _fallback_memory_provider()
class WeChatSpeakerSelectionRequest(BaseModel):
target_speaker_id: str
class WeChatImportCommitRequest(BaseModel):
persona_id: str
persona_name: str | None = None
description: str | None = None
@router.post("/wechat-import")
async def create_wechat_import_job(
request: Request,
@@ -172,7 +185,7 @@ async def list_libraries(
profile_id: str = Query("default"),
character_id: str = Query(...),
) -> dict[str, list[dict[str, object]]]:
provider = build_memory_provider()
provider = await _memory_provider()
items = await provider.list_libraries(
profile_id=_profile(profile_id),
character_id=_ensure_character(character_id),
@@ -182,7 +195,7 @@ async def list_libraries(
@router.post("/libraries")
async def create_library(body: MemoryLibraryRequest) -> dict[str, object]:
provider = build_memory_provider()
provider = await _memory_provider()
library = await provider.create_library(
library_id=_library_id(body.id),
name=(body.name or "").strip() or None,
@@ -198,7 +211,7 @@ async def list_items(
profile_id: str = Query("default"),
character_id: str = Query(...),
) -> dict[str, list[dict[str, object]]]:
provider = build_memory_provider()
provider = await _memory_provider()
items = await provider.list_items(
library_id=library_id,
profile_id=_profile(profile_id),
@@ -214,7 +227,7 @@ async def delete_item(
profile_id: str = Query("default"),
character_id: str = Query(...),
) -> dict[str, bool]:
provider = build_memory_provider()
provider = await _memory_provider()
deleted = await provider.delete_item(
library_id=library_id,
item_id=item_id,
@@ -237,7 +250,7 @@ async def import_items(library_id: str, body: MemoryImportRequest) -> dict[str,
)
runtime = MemoryRuntime(
scope=scope,
provider=build_memory_provider(),
provider=await _memory_provider(),
decision_agent=MemoryDecisionAgent(),
)
imported = await runtime.import_turns(

View File

@@ -7,6 +7,7 @@ import time
from pathlib import Path
from typing import Any, Optional
import httpx
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
@@ -20,10 +21,11 @@ from opentalking.providers.stt.factory import (
)
from opentalking.providers.tts.factory import tts_enabled_providers, tts_provider_config
from opentalking.providers.tts.providers import normalize_tts_provider
from opentalking.providers.tts.voice_assets import iter_voice_assets, resolve_voice_asset
router = APIRouter(prefix="/runtime-config", tags=["runtime-config"])
_ENV_PATH = Path(__file__).resolve().parents[3] / ".env"
_ENV_PATH = Path(os.environ.get("OPENTALKING_ENV_FILE") or Path(__file__).resolve().parents[3] / ".env")
_ENV_REF_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}|\$([A-Za-z_][A-Za-z0-9_]*)")
_RUNTIME_ENV_KEYS = {
@@ -127,6 +129,20 @@ def _strip(value: str | None) -> str:
return (value or "").strip()
def _local_cosyvoice_default_voice() -> str:
for asset in iter_voice_assets(provider="local_cosyvoice", sources=("system", "clones")):
if asset.voice_id:
return asset.voice_id
return "local-default"
def _normalize_local_cosyvoice_voice(value: str | None) -> str:
voice = _strip(value)
if voice and resolve_voice_asset(voice, provider="local_cosyvoice") is not None:
return voice
return _local_cosyvoice_default_voice()
def _unquote_env_value(value: str) -> str:
if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}:
return value[1:-1]
@@ -189,6 +205,81 @@ def _enabled_provider_csv(current: list[str], provider: str) -> str:
return ",".join(providers)
def _path_exists(raw: str) -> bool:
if not raw:
return False
try:
return Path(raw).expanduser().exists()
except OSError:
return False
def _ensure_sensevoice_available(values: dict[str, str], settings: Any) -> None:
status = stt_provider_config("sensevoice")
model = _env_value(
values,
"OPENTALKING_STT_SENSEVOICE_MODEL",
_settings_value(settings, "stt_sensevoice_model", "iic/SenseVoiceSmall"),
)
model_dir = str(status.get("model_dir") or "").strip()
candidates = [model_dir]
root = os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "").strip() or _settings_value(
settings,
"local_audio_model_root",
"",
)
if root and model:
candidates.append(str(Path(root).expanduser() / model.replace("/", "__")))
if any(_path_exists(candidate) for candidate in candidates):
return
raise HTTPException(
status_code=400,
detail="本地 ASR SenseVoice 模型未就绪,请在启动时启用本地 ASR 或确认模型已下载。",
)
def _local_cosyvoice_health_url(service_url: str) -> str:
value = service_url.strip()
if not value:
return ""
if value.rstrip("/").endswith("/synthesize"):
return value.rstrip("/")[: -len("/synthesize")] + "/health"
return value.rstrip("/") + "/health"
async def _ensure_local_cosyvoice_available(values: dict[str, str], settings: Any) -> None:
service_url = _env_value(
values,
"OPENTALKING_TTS_LOCAL_COSYVOICE_SERVICE_URL",
_settings_value(settings, "tts_local_cosyvoice_service_url"),
)
health_url = _local_cosyvoice_health_url(service_url)
if not health_url:
raise HTTPException(
status_code=400,
detail="本地 TTS local_cosyvoice 未启动或未配置服务地址,请在启动时启用本地 TTS或改用 API/Edge TTS。",
)
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(1.0, connect=0.5)) as client:
response = await client.get(health_url)
response.raise_for_status()
except Exception as exc:
raise HTTPException(
status_code=400,
detail="本地 TTS local_cosyvoice 未启动或不可用,请在启动时启用本地 TTS或改用 API/Edge TTS。",
) from exc
async def _validate_local_provider_switches(updates: dict[str, str], request: Request) -> None:
settings = getattr(request.app.state, "settings", None) or get_settings()
_, current_values = _read_env_lines(_ENV_PATH)
values = {**current_values, **updates}
if updates.get("OPENTALKING_STT_DEFAULT_PROVIDER") == "sensevoice":
_ensure_sensevoice_available(values, settings)
if updates.get("OPENTALKING_TTS_DEFAULT_PROVIDER") == "local_cosyvoice":
await _ensure_local_cosyvoice_available(values, settings)
def _write_env_updates(path: Path, updates: dict[str, str]) -> None:
lines, _ = _read_env_lines(path)
if path.exists():
@@ -301,7 +392,7 @@ def _current_tts_payload(provider: str, settings: Any, values: dict[str, str]) -
elif provider == "local_cosyvoice":
base_url = _env_value(values, "OPENTALKING_TTS_LOCAL_COSYVOICE_SERVICE_URL", _settings_value(settings, "tts_local_cosyvoice_service_url"))
model = _env_value(values, "OPENTALKING_TTS_LOCAL_COSYVOICE_MODEL", _settings_value(settings, "tts_local_cosyvoice_model", "FunAudioLLM/Fun-CosyVoice3-0.5B-2512"))
voice = _env_value(values, "OPENTALKING_TTS_VOICE", _settings_value(settings, "tts_voice"))
voice = _normalize_local_cosyvoice_voice(_env_value(values, "OPENTALKING_TTS_VOICE", _settings_value(settings, "tts_voice")))
key = ""
elif provider == "indextts":
base_url = (
@@ -502,7 +593,9 @@ def _build_updates(payload: RuntimeConfigPayload) -> dict[str, str]:
}.get(tts_provider)
if key:
updates[key] = value
if value := _strip(payload.tts_voice):
if tts_provider == "local_cosyvoice":
updates["OPENTALKING_TTS_VOICE"] = _normalize_local_cosyvoice_voice(payload.tts_voice)
elif value := _strip(payload.tts_voice):
updates["OPENTALKING_TTS_VOICE"] = value
if tts_provider == "edge":
updates["OPENTALKING_TTS_EDGE_VOICE"] = value
@@ -592,6 +685,7 @@ async def apply_runtime_config(payload: RuntimeConfigPayload, request: Request)
unknown = set(updates) - _RUNTIME_ENV_KEYS
if unknown:
raise HTTPException(status_code=400, detail=f"unsupported runtime config keys: {', '.join(sorted(unknown))}")
await _validate_local_provider_switches(updates, request)
_write_env_updates(_ENV_PATH, updates)
_, values = _read_env_lines(_ENV_PATH)
for key in _RUNTIME_ENV_KEYS:

View File

@@ -87,6 +87,11 @@ def _normalize_preview_request(
return text, voice, provider, model, indextts_config
def _is_local_cosyvoice_input_error(exc: Exception) -> bool:
message = str(exc)
return isinstance(exc, ValueError) or "HTTP 400" in message or "请求无效" in message or "请先选择本地音色" in message
def _config_from_form_value(raw: object) -> dict[str, Any] | None:
text = str(raw or "").strip()
if not text:
@@ -234,6 +239,29 @@ async def preview_tts(request: Request) -> Response:
if sample_limit is not None and total_samples >= sample_limit:
break
except Exception as exc:
logger.exception(
"tts preview failed | provider=%s voice_id=%s model=%s error=%s",
provider or "default",
voice or "default",
model or "default",
exc,
)
if provider == "local_cosyvoice":
if _is_local_cosyvoice_input_error(exc):
raise HTTPException(
status_code=400,
detail=str(exc),
) from exc
error_text = str(exc)
if "CosyVoice returned no audio" in error_text or "本地 CosyVoice 返回空音频" in error_text:
raise HTTPException(
status_code=502,
detail=f"TTS preview failed: {error_text}",
) from exc
raise HTTPException(
status_code=502,
detail=f"TTS preview failed: 本地 CosyVoice 服务不可用(可能已退出/内存不足);前端可回退 Edge 试听。{exc}",
) from exc
raise HTTPException(status_code=502, detail=f"TTS preview failed: {exc}") from exc
finally:
close = getattr(tts, "aclose", None)

View File

@@ -7,7 +7,6 @@ import base64
import io
import json
import logging
import os
import re
import shutil
import uuid
@@ -21,6 +20,15 @@ from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, Reque
from fastapi.responses import FileResponse, JSONResponse
from opentalking.providers.tts.dashscope_qwen import clone as bailian_clone
from opentalking.providers.tts.voice_assets import (
INDEXTTS_PROVIDER,
INDEXTTS_PROVIDERS,
LOCAL_COSYVOICE_PROVIDER,
bundled_system_voice_root,
iter_voice_assets,
local_audio_model_root,
system_voice_roots,
)
from opentalking.voice.store import delete_entry, get_entry, init_voice_store, insert_clone, list_voices
log = logging.getLogger(__name__)
@@ -31,13 +39,16 @@ _UPLOAD_DIR = Path("data/voice_uploads")
LOCAL_COSYVOICE_SAMPLE_TEXT = "你好,今天阳光很好,我正在用自然清晰的声音,记录这一段音色。"
LOCAL_COSYVOICE_MIN_ACTIVE_SEC = 2.0
LOCAL_COSYVOICE_MIN_RMS_DBFS = -45.0
_TECHNICAL_VOICE_LABEL_PREFIX_RE = re.compile(
r"^\s*(?:IndexTTS|CosyVoice|local_cosyvoice|local_indextts|local|本地)\s*[-_/:·|]?\s*",
re.IGNORECASE,
)
_TECHNICAL_VOICE_LABEL_SUFFIX_RE = re.compile(
r"\s*[(]\s*(?:IndexTTS|CosyVoice|local_cosyvoice|local_indextts|local|本地)\s*[)]\s*$",
re.IGNORECASE,
)
LOCAL_COSYVOICE_MIN_RECOGNIZED_CHARS = 4
LOCAL_COSYVOICE_MIN_TEXT_OVERLAP = 0.45
INDEXTTS_PROVIDER = "indextts"
INDEXTTS_LEGACY_PROVIDERS = {"local_indextts", "omnirt_indextts"}
INDEXTTS_PROVIDERS = {INDEXTTS_PROVIDER, *INDEXTTS_LEGACY_PROVIDERS}
class VoiceItem(TypedDict):
id: int
user_id: int
@@ -61,16 +72,7 @@ def _upload_dir() -> Path:
def _local_audio_model_root() -> Path:
raw = os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "").strip()
try:
from opentalking.core.config import get_settings
raw = raw or (get_settings().local_audio_model_root or "").strip()
except Exception:
pass
if not raw:
raw = "./models/local-audio"
return Path(raw).expanduser().resolve()
return local_audio_model_root()
def _safe_local_voice_id(label: str) -> str:
@@ -259,47 +261,38 @@ def _remove_local_cosyvoice_prompt(voice_id: str) -> None:
def _bundled_system_voice_root() -> Path:
return Path(__file__).resolve().parents[3] / "opentalking" / "assets" / "voices" / "system"
return bundled_system_voice_root()
def _system_voice_roots() -> list[Path]:
roots = [_local_audio_model_root() / "voices" / "system", _bundled_system_voice_root()]
out: list[Path] = []
seen: set[Path] = set()
for root in roots:
try:
resolved = root.resolve()
except OSError:
resolved = root
if resolved in seen:
continue
seen.add(resolved)
out.append(root)
return out
return system_voice_roots(_local_audio_model_root())
def _public_voice_label(label: str, *, fallback: str) -> str:
cleaned = _TECHNICAL_VOICE_LABEL_PREFIX_RE.sub("", label or "").strip()
cleaned = _TECHNICAL_VOICE_LABEL_SUFFIX_RE.sub("", cleaned).strip()
return cleaned or fallback
def _local_cosyvoice_system_voice_items() -> list[VoiceItem]:
root = _local_audio_model_root() / "voices" / "system"
if not root.is_dir():
return []
items: list[VoiceItem] = []
for idx, voice_dir in enumerate(sorted(p for p in root.iterdir() if p.is_dir()), start=1):
voice_id = voice_dir.name
for idx, asset in enumerate(
iter_voice_assets(
provider=LOCAL_COSYVOICE_PROVIDER,
sources=("system",),
model_root=_local_audio_model_root(),
require_prompt_text=True,
),
start=1,
):
voice_id = asset.voice_id
if not re.fullmatch(r"[A-Za-z0-9_-]{3,80}", voice_id):
continue
if not (voice_dir / "prompt.wav").is_file() or not (voice_dir / "prompt.txt").is_file():
continue
label = voice_id
target_model: str | None = None
meta_path = voice_dir / "meta.json"
if meta_path.is_file():
try:
meta = json.loads(meta_path.read_text(encoding="utf-8"))
label = str(meta.get("display_label") or meta.get("label") or label)
tm = str(meta.get("target_model") or "").strip()
target_model = tm or None
except Exception:
pass
meta = asset.meta
label = _public_voice_label(str(meta.get("display_label") or meta.get("label") or label), fallback=voice_id)
tm = str(meta.get("target_model") or "").strip()
target_model = tm or None
items.append(
{
"id": -idx,
@@ -314,48 +307,26 @@ def _local_cosyvoice_system_voice_items() -> list[VoiceItem]:
return items
def _local_indextts_voice_items(provider: str, source: str) -> list[VoiceItem]:
if provider not in INDEXTTS_PROVIDERS or source not in {"system", "clones"}:
def _local_indextts_voice_items(source: str) -> list[VoiceItem]:
if source not in {"system", "clones"}:
return []
roots = [_local_audio_model_root() / "voices" / source]
if source == "system":
roots = _system_voice_roots()
items: list[VoiceItem] = []
idx = 0
for root in roots:
if not root.is_dir():
continue
for voice_dir in sorted(p for p in root.iterdir() if p.is_dir()):
idx += 1
voice_id = voice_dir.name
if not re.fullmatch(r"[A-Za-z0-9_-]{3,80}", voice_id):
seen: set[str] = set()
for provider in sorted(INDEXTTS_PROVIDERS):
for asset in iter_voice_assets(provider=provider, sources=(source,), model_root=_local_audio_model_root()):
voice_id = asset.voice_id
if voice_id in seen or not re.fullmatch(r"[A-Za-z0-9_-]{3,80}", voice_id):
continue
if not (voice_dir / "prompt.wav").is_file():
continue
label = voice_id
target_model: str | None = None
meta_path = voice_dir / "meta.json"
if meta_path.is_file():
try:
meta = json.loads(meta_path.read_text(encoding="utf-8"))
raw_providers = meta.get("providers")
requested = {provider, *INDEXTTS_LEGACY_PROVIDERS} if provider == INDEXTTS_PROVIDER else {provider}
if isinstance(raw_providers, list):
allowed = {str(item).strip().lower() for item in raw_providers}
if allowed and not (allowed & requested):
continue
elif str(meta.get("provider") or "").strip().lower() not in {"", *INDEXTTS_PROVIDERS}:
continue
label = str(meta.get("display_label") or meta.get("label") or label)
tm = str(meta.get("target_model") or "").strip()
target_model = tm or None
except Exception:
pass
seen.add(voice_id)
meta = asset.meta
label = _public_voice_label(str(meta.get("display_label") or meta.get("label") or voice_id), fallback=voice_id)
tm = str(meta.get("target_model") or "").strip()
target_model = tm or None
items.append(
{
"id": -idx,
"id": -len(items) - 1,
"user_id": 1,
"provider": provider,
"provider": INDEXTTS_PROVIDER,
"voice_id": voice_id,
"display_label": label,
"target_model": target_model,
@@ -430,7 +401,7 @@ async def get_voices(provider: str | None = None) -> JSONResponse:
existing.add(key)
if public_p is None or public_p == INDEXTTS_PROVIDER:
for source in ("system", "clones"):
for item in _local_indextts_voice_items(INDEXTTS_PROVIDER, source):
for item in _local_indextts_voice_items(source):
key = (item["provider"], item["voice_id"])
if key not in existing:
items.append(item)

View File

@@ -205,17 +205,16 @@ def test_get_voices_includes_local_cosyvoice_system_voice_dirs(tmp_path, monkeyp
response = TestClient(app).get("/voices?provider=local_cosyvoice")
assert response.status_code == 200
assert response.json()["items"] == [
{
"id": -1,
"user_id": 1,
"provider": "local_cosyvoice",
"voice_id": "local-female-standard",
"display_label": "标准女声",
"target_model": None,
"source": "system",
}
]
item = next(item for item in response.json()["items"] if item["voice_id"] == "local-female-standard")
assert item == {
"id": -1,
"user_id": 1,
"provider": "local_cosyvoice",
"voice_id": "local-female-standard",
"display_label": "标准女声",
"target_model": None,
"source": "system",
}
@pytest.mark.parametrize("provider", ["indextts", "local_indextts", "omnirt_indextts"])
@@ -242,7 +241,7 @@ def test_get_voices_includes_indextts_system_voice_dirs(provider: str, tmp_path,
"user_id": 1,
"provider": "indextts",
"voice_id": "indextts-clear-cn",
"display_label": "IndexTTS 清晰中文",
"display_label": "清晰中文",
"target_model": "IndexTeam/IndexTTS-2",
"source": "system",
} in response.json()["items"]
@@ -426,12 +425,13 @@ def test_get_voices_includes_indextts_clone_voice_dirs(provider: str, tmp_path,
response = TestClient(app).get(f"/voices?provider={provider}")
assert response.status_code == 200
assert {
"id": -1,
item = next(item for item in response.json()["items"] if item["voice_id"] == "indextts-cloned-cn")
assert item == {
"id": item["id"],
"user_id": 1,
"provider": "indextts",
"voice_id": "indextts-cloned-cn",
"display_label": "IndexTTS 复刻中文",
"display_label": "复刻中文",
"target_model": "IndexTeam/IndexTTS-2",
"source": "clone",
} in response.json()["items"]
}

View File

@@ -166,13 +166,13 @@ function mergeVoiceCatalogIntoOptions(
const extras: VoiceOpt[] = [];
for (const r of catalog) {
if (r.provider !== cp) continue;
if (activeModel && r.target_model && r.target_model !== activeModel) continue;
if (activeModel && r.target_model && r.target_model !== activeModel && !(ttsProvider === "local_cosyvoice" && r.source === "system")) continue;
if (cloneOnly && r.source !== "clone") continue;
if (staticIds.has(r.voice_id)) continue;
extras.push({
id: r.voice_id,
label: r.source === "clone" ? `复刻 · ${r.display_label}` : r.display_label,
targetModel: r.target_model,
targetModel: ttsProvider === "local_cosyvoice" && r.source === "system" ? undefined : r.target_model,
});
staticIds.add(r.voice_id);
}
@@ -983,7 +983,6 @@ export default function App() {
return false;
});
const [voiceCloneOpen, setVoiceCloneOpen] = useState(false);
const [promptSaving, setPromptSaving] = useState(false);
const [referenceSaving, setReferenceSaving] = useState(false);
const [panelTab, setPanelTab] = useState<PanelTab>("chat");
const [sessionPanelCollapsed, setSessionPanelCollapsed] = useState(() => {
@@ -1621,7 +1620,7 @@ export default function App() {
}
}, [sessionPanelCollapsed]);
const [llmSystemPrompt, setLlmSystemPrompt] = useState<string>(() => {
const [llmSystemPrompt] = useState<string>(() => {
try {
return window.localStorage.getItem(LLM_SYSTEM_PROMPT_STORAGE_KEY) ?? "";
} catch {
@@ -2081,9 +2080,19 @@ export default function App() {
clearSubtitleState();
if (msgId) {
if (finalText) {
setMessages((prev) =>
prev.map((m) => (m.id === msgId ? { ...m, text: finalText } : m)),
);
setMessages((prev) => {
let updated = false;
const next = prev.map((m) => {
if (m.id !== msgId) return m;
updated = true;
return { ...m, text: finalText };
});
if (updated) return next;
return [
...prev,
{ id: makeId(), role: "assistant", text: finalText, timestamp: Date.now() },
];
});
} else {
setMessages((prev) => prev.filter((m) => m.id !== msgId));
}
@@ -2311,27 +2320,6 @@ export default function App() {
}
}, [connection, fasterliveportraitConfig, model, notify]);
const handleSavePrompt = useCallback(async () => {
setPromptSaving(true);
try {
await apiPost("/sessions/customize/prompt", {
avatar_id: avatarId,
llm_system_prompt: llmSystemPrompt,
});
const sid = sessionIdRef.current;
if (sid) await releaseSession(sid);
resetLiveState(true);
setConnection("idle");
notify("System Prompt 已保存,页面即将刷新并在新会话生效。", "success");
window.setTimeout(() => window.location.reload(), 900);
} catch (e) {
console.warn("save prompt failed", e);
notify("保存 Prompt 失败,请查看后端日志。", "error");
} finally {
setPromptSaving(false);
}
}, [avatarId, llmSystemPrompt, notify, releaseSession, resetLiveState]);
const handleCreateCustomAvatar = useCallback(async (file: File, name: string) => {
const trimmedName = name.trim();
if (!trimmedName) {
@@ -2442,7 +2430,8 @@ export default function App() {
notify("正在播放试听音频。", "success");
} catch (e) {
console.warn("tts preview failed", e);
notify("试听失败,请确认音色、模型和后端密钥配置。", "error");
const detail = apiErrorMessage(e, "请确认音色、模型和后端密钥配置。");
notify(`试听失败:${detail}`, "error");
} finally {
setTtsPreviewing(false);
}
@@ -3012,10 +3001,6 @@ export default function App() {
onMemoryLibrarySelect={setMemoryLibraryId}
onMemoryEnabledChange={setMemoryEnabled}
onManageMemoryLibraries={() => void handleManageMemoryLibraries()}
llmSystemPrompt={llmSystemPrompt}
onLlmSystemPromptChange={setLlmSystemPrompt}
onSavePrompt={() => void handleSavePrompt()}
promptSaving={promptSaving}
onOpenVoiceClone={() => setVoiceCloneOpen(true)}
/>
</div>
@@ -3169,15 +3154,11 @@ export default function App() {
modelBadge={selectedModelBadge}
queueInfo={queueInfo}
prewarmState={selectedPrewarmState}
agentConfig={agentConfig}
onAgentConfigChange={setAgentConfig}
knowledgeBases={knowledgeBaseSummaries}
personas={personas}
selectedPersonaId={selectedPersonaId}
personaImporting={personaImporting}
onPersonaChange={handlePersonaChange}
onPersonaImport={handlePersonaImport}
memorySummary={memorySummary}
onAvatarChange={handleAvatarChange}
onStart={() => void handleStart()}
onCustomAvatarCreate={(file, name) => void handleCreateCustomAvatar(file, name)}

View File

@@ -1,5 +1,5 @@
import { useEffect, useRef, useState, type ChangeEvent } from "react";
import type { AvatarSummary, KnowledgeBaseSummary, PersonaSummary } from "../lib/api";
import type { AvatarSummary, PersonaSummary } from "../lib/api";
import { buildApiUrl } from "../lib/api";
import type { ModelConnectionBadge } from "../lib/modelStatus";
@@ -27,14 +27,6 @@ type AvatarSelectionStageProps = {
onCustomAvatarCreate: (file: File, name: string) => void;
onAvatarDelete?: (avatar: AvatarSummary) => void;
referenceSaving?: boolean;
memorySummary?: {
enabled: boolean;
libraryName: string | null;
memoryCount: number | null;
};
agentConfig: AgentConfig;
onAgentConfigChange: (next: AgentConfig) => void;
knowledgeBases: KnowledgeBaseSummary[];
personas: PersonaSummary[];
selectedPersonaId: string;
personaImporting?: boolean;
@@ -84,10 +76,6 @@ export function AvatarSelectionStage({
onCustomAvatarCreate,
onAvatarDelete,
referenceSaving = false,
memorySummary,
agentConfig,
onAgentConfigChange,
knowledgeBases,
personas,
selectedPersonaId,
personaImporting = false,
@@ -106,20 +94,7 @@ export function AvatarSelectionStage({
});
const [customFile, setCustomFile] = useState<File | null>(null);
const [customPreviewUrl, setCustomPreviewUrl] = useState<string | null>(null);
const selectedKnowledgeBaseIds = agentConfig.knowledgeBaseIds;
const selectedPersona = personas.find((persona) => persona.id === selectedPersonaId) ?? null;
const knowledgeBasesById = new Map(knowledgeBases.map((kb) => [kb.id, kb]));
const selectedKnowledgeBases = selectedKnowledgeBaseIds.map((id) => (
knowledgeBasesById.get(id) ?? {
id,
name: id,
document_count: 0,
ready_document_count: 0,
error_document_count: 0,
created_at: "",
updated_at: "",
}
));
const configDisabled = loading || queued || prewarmState === "preparing";
const baseDisabled = loading || queued || prewarmState === "preparing" || !selectedAvatar || !modelConnected;
const startLabel = queued
@@ -163,15 +138,6 @@ export function AvatarSelectionStage({
if (file) onPersonaImport(file);
};
const updateKnowledgeBaseIds = (nextIds: string[]) => {
const deduped = Array.from(new Set(nextIds.filter((id) => id.trim())));
onAgentConfigChange({
...agentConfig,
knowledgeEnabled: deduped.length > 0,
knowledgeBaseIds: deduped,
});
};
return (
<div className="relative h-full min-h-[520px] overflow-hidden bg-white">
<div className="grid h-full min-h-0 gap-5 p-4 sm:p-5 xl:grid-cols-[minmax(28rem,1.15fr)_minmax(20rem,0.85fr)] xl:p-6">
@@ -371,78 +337,6 @@ export function AvatarSelectionStage({
<p className="mt-1 truncate text-sm font-semibold text-slate-950">{selectedVoiceLabel}</p>
</div>
</div>
<div className="mb-3 rounded-lg border border-slate-200 bg-slate-50 px-3 py-2.5">
<div className="flex items-center justify-between gap-2">
<p className="text-xs font-semibold text-slate-600">Agent </p>
<span className="shrink-0 rounded-full border border-slate-200 bg-white px-2 py-0.5 text-[11px] font-medium text-slate-500">
{knowledgeBases.length}
</span>
</div>
<div className="mt-2 min-h-20 rounded-md border border-slate-200 bg-white p-2">
<div className="mb-2 flex items-center justify-between gap-2">
<p className="text-[11px] font-semibold uppercase tracking-wide text-slate-500"></p>
<span className="text-[11px] text-slate-400">{selectedKnowledgeBases.length} </span>
</div>
{selectedKnowledgeBases.length ? (
<div className="flex flex-wrap gap-1.5">
{selectedKnowledgeBases.map((kb) => (
<div
key={kb.id}
className="inline-flex max-w-full items-center gap-1.5 rounded-full border border-cyan-200 bg-cyan-50 px-2 py-1 text-xs text-cyan-800"
>
<span className="min-w-0 truncate font-medium">{kb.name}</span>
<button
type="button"
disabled={baseDisabled}
onClick={() =>
updateKnowledgeBaseIds(selectedKnowledgeBaseIds.filter((id) => id !== kb.id))
}
className="shrink-0 rounded-full px-1.5 py-0.5 font-semibold text-cyan-700 transition hover:bg-cyan-100 disabled:opacity-50"
aria-label={`移除 ${kb.name}`}
>
×
</button>
</div>
))}
</div>
) : (
<p className="px-1 py-2 text-xs text-slate-400"></p>
)}
</div>
</div>
<div className="mb-3 rounded-lg border border-slate-200 bg-slate-50 px-3 py-2.5">
<div className="flex items-center justify-between gap-2">
<p className="text-xs font-semibold text-slate-600"></p>
<span
className={`shrink-0 rounded-full border px-2 py-0.5 text-[11px] font-medium ${
memorySummary?.enabled
? "border-cyan-200 bg-white text-cyan-700"
: "border-slate-200 bg-white text-slate-500"
}`}
>
{memorySummary?.enabled ? "已挂载" : "未挂载"}
</span>
</div>
<div className="mt-2 min-h-16 rounded-md border border-slate-200 bg-white p-2">
<div className="mb-2 flex items-center justify-between gap-2">
<p className="text-[11px] font-semibold uppercase tracking-wide text-slate-500">
</p>
<span className="text-[11px] text-slate-400">
{memorySummary?.enabled && memorySummary.memoryCount !== null
? `${memorySummary.memoryCount}`
: "0 条"}
</span>
</div>
{memorySummary?.enabled && memorySummary.libraryName ? (
<div className="inline-flex max-w-full items-center gap-1.5 rounded-full border border-cyan-200 bg-cyan-50 px-2 py-1 text-xs text-cyan-800">
<span className="min-w-0 truncate font-medium">{memorySummary.libraryName}</span>
</div>
) : (
<p className="px-1 py-2 text-xs text-slate-400"></p>
)}
</div>
</div>
{queued ? (
<p className="mb-3 rounded-lg border border-amber-200 bg-amber-50 px-3 py-2 text-sm font-medium text-amber-800">
{queueInfo?.position ?? 1} ...

View File

@@ -186,10 +186,6 @@ interface SettingsPanelProps {
qwenVoice: string;
onQwenVoiceChange: (voiceId: string) => void;
qwenVoiceOptions: VoiceOpt[];
llmSystemPrompt: string;
onLlmSystemPromptChange: (value: string) => void;
onSavePrompt: () => void;
promptSaving?: boolean;
onOpenVoiceClone?: () => void;
voiceApplyNotice?: string | null;
ttsPreviewText: string;
@@ -415,10 +411,6 @@ export function SettingsPanel({
qwenVoice,
onQwenVoiceChange,
qwenVoiceOptions,
llmSystemPrompt,
onLlmSystemPromptChange,
onSavePrompt,
promptSaving = false,
onOpenVoiceClone,
voiceApplyNotice = null,
ttsPreviewText,
@@ -525,6 +517,7 @@ export function SettingsPanel({
}));
const providerHasSingleModel = (provider: TtsProviderExtended) => {
if (provider === "edge" || provider === "openai_compatible") return true;
if (provider === "local_cosyvoice" || provider === "indextts") return true;
if (provider !== ttsProvider) return false;
return qwenModelColumnOptions.length <= 1;
};
@@ -1083,33 +1076,6 @@ export function SettingsPanel({
</div>
</SettingsSection>
<SettingsSection
id="role"
title="人设"
open={openSections.role}
onToggle={toggleSection}
>
<div className="space-y-3">
<label className="block">
<span className="mb-1.5 block text-xs text-slate-500"></span>
<textarea
value={llmSystemPrompt}
onChange={(e) => onLlmSystemPromptChange(e.target.value)}
rows={5}
className="w-full resize-none rounded-lg border border-slate-200 bg-slate-50 px-3 py-2.5 text-sm text-slate-800 outline-none transition placeholder:text-slate-400 focus:border-cyan-300 focus:bg-white"
placeholder={"你可以在这里定义数字人的身份、说话风格和边界。\n\n示例你是一位温和专业的产品讲解员回答简洁自然优先用中文回复。遇到不确定的问题先说明不确定再给出可执行建议。"}
/>
</label>
<button
type="button"
onClick={onSavePrompt}
disabled={promptSaving}
className="w-full rounded-lg bg-slate-950 px-3 py-2.5 text-sm font-semibold text-white transition hover:bg-slate-800 disabled:cursor-not-allowed disabled:opacity-60"
>
{promptSaving ? "保存中..." : "保存人设"}
</button>
</div>
</SettingsSection>
</div>
</aside>
);

View File

@@ -106,10 +106,10 @@ def _read_text_file(path: Path) -> str:
raw = path.read_bytes()
for encoding in ("utf-8", "utf-8-sig", "gb18030"):
try:
return raw.decode(encoding)
return raw.decode(encoding).replace("\r\n", "\n").replace("\r", "\n")
except UnicodeDecodeError:
continue
return raw.decode("utf-8", errors="replace")
return raw.decode("utf-8", errors="replace").replace("\r\n", "\n").replace("\r", "\n")
def _has_enough_text(text: str) -> bool:

View File

@@ -0,0 +1,9 @@
{
"voice_id": "cosyvoice-official-zero-shot",
"display_label": "官方示例女声(本地)",
"provider": "local_cosyvoice",
"target_model": "FunAudioLLM/Fun-CosyVoice3-0.5B-2512",
"mode": "zero_shot",
"source": "system",
"provenance": "FunAudioLLM/CosyVoice runtime asset prompt"
}

View File

@@ -0,0 +1 @@
希望你以后能够做的比我还好呦。

View File

@@ -0,0 +1,19 @@
{
"voice_id": "local-anchor-cherry",
"display_label": "主播女声 Cherry本地",
"provider": "local_cosyvoice",
"target_model": "FunAudioLLM/Fun-CosyVoice3-0.5B-2512",
"mode": "zero_shot",
"source": "system",
"prompt_audio": "/data2/zhongyi/model/opentalking-local-audio/voices/system/local-anchor-cherry/prompt.wav",
"prompt_text": "你好欢迎来到OpenTalking。我会用自然清晰的声音为你介绍今天的内容。",
"prompt_source": "generated_once_from_dashscope_qwen_tts",
"qwen_reference_voice": "Cherry",
"avatar_hint": "主播",
"role": "anchor",
"style": "清晰、亲和、适合新闻播报和产品讲解",
"duration_sec": 6.32,
"sample_rate": 24000,
"bytes": 303404,
"storage_note": "prompt.wav is a local Qwen-generated reference sample; kept small as one WAV per built-in local voice."
}

View File

@@ -0,0 +1 @@
你好欢迎来到OpenTalking。我会用自然清晰的声音为你介绍今天的内容。

View File

@@ -0,0 +1,19 @@
{
"voice_id": "local-ancient-jada",
"display_label": "古风女声 Jada本地",
"provider": "local_cosyvoice",
"target_model": "FunAudioLLM/Fun-CosyVoice3-0.5B-2512",
"mode": "zero_shot",
"source": "system",
"prompt_audio": "/data2/zhongyi/model/opentalking-local-audio/voices/system/local-ancient-jada/prompt.wav",
"prompt_text": "你好欢迎来到OpenTalking。我会用自然清晰的声音为你介绍今天的内容。",
"prompt_source": "generated_once_from_dashscope_qwen_tts",
"qwen_reference_voice": "Jada",
"avatar_hint": "古装美女",
"role": "ancient-beauty",
"style": "柔和、叙事感、适合文化讲解",
"duration_sec": 6.8,
"sample_rate": 24000,
"bytes": 326444,
"storage_note": "prompt.wav is a local Qwen-generated reference sample; kept small as one WAV per built-in local voice."
}

View File

@@ -0,0 +1 @@
你好欢迎来到OpenTalking。我会用自然清晰的声音为你介绍今天的内容。

View File

@@ -0,0 +1,19 @@
{
"voice_id": "local-anime-ethan",
"display_label": "动漫男声 Ethan本地",
"provider": "local_cosyvoice",
"target_model": "FunAudioLLM/Fun-CosyVoice3-0.5B-2512",
"mode": "zero_shot",
"source": "system",
"prompt_audio": "/data2/zhongyi/model/opentalking-local-audio/voices/system/local-anime-ethan/prompt.wav",
"prompt_text": "你好欢迎来到OpenTalking。我会用自然清晰的声音为你介绍今天的内容。",
"prompt_source": "generated_once_from_dashscope_qwen_tts",
"qwen_reference_voice": "Ethan",
"avatar_hint": "动漫帅哥",
"role": "anime-handsome-guy",
"style": "年轻、明亮、适合轻量互动和二次元角色",
"duration_sec": 5.12,
"sample_rate": 24000,
"bytes": 245804,
"storage_note": "prompt.wav is a local Qwen-generated reference sample; kept small as one WAV per built-in local voice."
}

View File

@@ -0,0 +1 @@
你好欢迎来到OpenTalking。我会用自然清晰的声音为你介绍今天的内容。

View File

@@ -0,0 +1,19 @@
{
"voice_id": "local-office-serena",
"display_label": "职场女声 Serena本地",
"provider": "local_cosyvoice",
"target_model": "FunAudioLLM/Fun-CosyVoice3-0.5B-2512",
"mode": "zero_shot",
"source": "system",
"prompt_audio": "/data2/zhongyi/model/opentalking-local-audio/voices/system/local-office-serena/prompt.wav",
"prompt_text": "你好欢迎来到OpenTalking。我会用自然清晰的声音为你介绍今天的内容。",
"prompt_source": "generated_once_from_dashscope_qwen_tts",
"qwen_reference_voice": "Serena",
"avatar_hint": "职场女",
"role": "office-woman",
"style": "稳重、专业、适合客服和商务说明",
"duration_sec": 5.52,
"sample_rate": 24000,
"bytes": 265004,
"storage_note": "prompt.wav is a local Qwen-generated reference sample; kept small as one WAV per built-in local voice."
}

View File

@@ -0,0 +1 @@
你好欢迎来到OpenTalking。我会用自然清晰的声音为你介绍今天的内容。

View File

@@ -306,7 +306,8 @@ def _legacy_env_mapping() -> dict[str, str]:
def _load_legacy_dotenv_source() -> dict[str, Any]:
values = dotenv_values(".env")
env_file = os.environ.get("OPENTALKING_ENV_FILE", ".env")
values = dotenv_values(env_file)
mapping = _legacy_env_mapping()
return {
target: value
@@ -323,7 +324,7 @@ def _load_legacy_env_source() -> dict[str, Any]:
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="OPENTALKING_",
env_file=".env",
env_file=os.environ.get("OPENTALKING_ENV_FILE", ".env"),
extra="ignore",
)

View File

@@ -4,10 +4,12 @@ import json
import re
import shutil
import uuid
from datetime import UTC, datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
UTC = timezone.utc
_ALLOWED_KINDS = {"realtime_dialogue", "video_clone", "video_creation"}
_SAFE_ID_RE = re.compile(r"^[a-zA-Z0-9_.-]{1,128}$")
_MIME_EXTENSIONS = {

View File

@@ -378,6 +378,22 @@ def _explicit_cuda_available(device: str) -> bool:
return False
def _is_cuda_oom(exc: BaseException) -> bool:
message = str(exc).lower()
return (
"out of memory" in message
or "cudaerrormemoryallocation" in message
or "cuda error: out of memory" in message
)
def _fallback_quicktalk_device(device: str) -> str | None:
raw = (device or "").strip().lower()
if raw.startswith("cuda"):
return "cpu"
return None
@register_model("quicktalk")
class QuickTalkAdapter:
"""QuickTalk realtime worker integrated into OpenTalking's model API."""
@@ -588,6 +604,7 @@ class QuickTalkAdapter:
cache_disabled = _env_value("OPENTALKING_QUICKTALK_WORKER_CACHE", "1") == "0"
worker: RealtimeV3Worker | None = None
cache_key_to_store = cache_key
if not cache_disabled:
worker = _WORKER_CACHE.get(cache_key)
if worker is not None:
@@ -607,34 +624,133 @@ class QuickTalkAdapter:
"quicktalk worker cache MISS — building (avatar=%s)",
bundle.manifest.id,
)
worker = RealtimeV3Worker(
asset_root=asset_root,
template_video=template_video,
face_cache_dir=face_cache_dir,
face_cache_file=face_cache_file,
device=self._device,
output_transform=self._output_transform,
scale_h=self._scale_h,
scale_w=self._scale_w,
resolution=self._resolution,
max_template_seconds=max_template_seconds,
neck_fade_start=self._neck_fade_start,
neck_fade_end=self._neck_fade_end,
hubert_device=self._hubert_device,
model_backend=self._model_backend,
)
try:
worker = RealtimeV3Worker(
asset_root=asset_root,
template_video=template_video,
face_cache_dir=face_cache_dir,
face_cache_file=face_cache_file,
device=self._device,
output_transform=self._output_transform,
scale_h=self._scale_h,
scale_w=self._scale_w,
resolution=self._resolution,
max_template_seconds=max_template_seconds,
neck_fade_start=self._neck_fade_start,
neck_fade_end=self._neck_fade_end,
hubert_device=self._hubert_device,
model_backend=self._model_backend,
)
except Exception as exc: # noqa: BLE001
fallback_device = _fallback_quicktalk_device(self._device)
if fallback_device is None or not _is_cuda_oom(exc):
raise
log.warning(
"QuickTalk worker build hit CUDA OOM; retrying on CPU (avatar=%s)",
bundle.manifest.id,
exc_info=True,
)
self._device = fallback_device
self._hubert_device = fallback_device
cache_key_to_store = _worker_cache_key(
asset_root=asset_root,
template_video=template_video,
face_cache_dir=face_cache_dir,
face_cache_file=face_cache_file,
device=self._device,
output_transform=self._output_transform,
scale_h=self._scale_h,
scale_w=self._scale_w,
resolution=self._resolution,
max_template_seconds=max_template_seconds,
neck_fade_start=self._neck_fade_start,
neck_fade_end=self._neck_fade_end,
hubert_device=self._hubert_device,
model_backend=self._model_backend,
)
worker = RealtimeV3Worker(
asset_root=asset_root,
template_video=template_video,
face_cache_dir=face_cache_dir,
face_cache_file=face_cache_file,
device=self._device,
output_transform=self._output_transform,
scale_h=self._scale_h,
scale_w=self._scale_w,
resolution=self._resolution,
max_template_seconds=max_template_seconds,
neck_fade_start=self._neck_fade_start,
neck_fade_end=self._neck_fade_end,
hubert_device=self._hubert_device,
model_backend=self._model_backend,
)
if not cache_disabled:
_WORKER_CACHE[cache_key] = worker
_WORKER_CACHE[cache_key_to_store] = worker
_enforce_worker_cache_limit()
session_state = worker.make_state()
return QuickTalkState(
manifest=bundle.manifest,
worker=worker,
fps=worker.fps,
extra={},
session_state=session_state,
)
try:
session_state = worker.make_state()
return QuickTalkState(
manifest=bundle.manifest,
worker=worker,
fps=worker.fps,
extra={},
session_state=session_state,
)
except Exception as exc: # noqa: BLE001
fallback_device = _fallback_quicktalk_device(self._device)
if fallback_device is None or not _is_cuda_oom(exc):
raise
log.warning(
"QuickTalk avatar load hit CUDA OOM; retrying on CPU (avatar=%s)",
bundle.manifest.id,
exc_info=True,
)
self._device = fallback_device
self._hubert_device = fallback_device
fallback_key = _worker_cache_key(
asset_root=asset_root,
template_video=template_video,
face_cache_dir=face_cache_dir,
face_cache_file=face_cache_file,
device=self._device,
output_transform=self._output_transform,
scale_h=self._scale_h,
scale_w=self._scale_w,
resolution=self._resolution,
max_template_seconds=max_template_seconds,
neck_fade_start=self._neck_fade_start,
neck_fade_end=self._neck_fade_end,
hubert_device=self._hubert_device,
model_backend=self._model_backend,
)
worker = RealtimeV3Worker(
asset_root=asset_root,
template_video=template_video,
face_cache_dir=face_cache_dir,
face_cache_file=face_cache_file,
device=self._device,
output_transform=self._output_transform,
scale_h=self._scale_h,
scale_w=self._scale_w,
resolution=self._resolution,
max_template_seconds=max_template_seconds,
neck_fade_start=self._neck_fade_start,
neck_fade_end=self._neck_fade_end,
hubert_device=self._hubert_device,
model_backend=self._model_backend,
)
if not cache_disabled:
_WORKER_CACHE[fallback_key] = worker
_enforce_worker_cache_limit()
session_state = worker.make_state()
return QuickTalkState(
manifest=bundle.manifest,
worker=worker,
fps=worker.fps,
extra={"device_fallback": "cpu"},
session_state=session_state,
)
def warmup(self, avatar_state: QuickTalkState | None = None) -> None:
if avatar_state is None:

View File

@@ -6,7 +6,11 @@ from typing import Any
from opentalking.core.config import Settings, get_settings
from opentalking.providers.memory.base import MemoryProvider
from opentalking.providers.memory.mem0_provider import InMemoryMemoryProvider, Mem0MemoryProvider
from opentalking.providers.memory.mem0_provider import (
InMemoryMemoryProvider,
Mem0MemoryProvider,
Mem0UnavailableError,
)
from opentalking.providers.memory.noop import NoopMemoryProvider
from opentalking.providers.memory.sqlite_provider import SQLiteMemoryProvider
@@ -110,7 +114,10 @@ def build_memory_provider() -> MemoryProvider:
if provider in {"sqlite", "local"}:
return SQLiteMemoryProvider(settings.memory_sqlite_path)
if provider == "mem0":
return Mem0MemoryProvider(config=_mem0_config(settings))
try:
return Mem0MemoryProvider(config=_mem0_config(settings))
except Mem0UnavailableError:
return SQLiteMemoryProvider(settings.memory_sqlite_path)
if provider in {"memory", "inmemory", "in-memory"}:
return InMemoryMemoryProvider()
raise ValueError(f"unsupported memory provider: {settings.memory_provider}")

View File

@@ -267,20 +267,22 @@ def _stt_model_dir(provider: str, model: str | None = None) -> str:
provider = normalize_stt_provider(provider, default="dashscope") or "dashscope"
direct = _provider_env(provider, "MODEL_DIR") or _settings_value(f"stt_{provider}_model_dir", "")
if direct:
return direct
return str(Path(direct).expanduser().resolve())
if provider in LOCAL_STT_PROVIDERS:
return str(_local_path_for_model((model or _stt_model(provider)).strip()))
return ""
return ""
def _local_path_for_model(model: str) -> Path:
path = Path(model).expanduser()
if path.is_absolute() or path.exists():
return path
if path.is_absolute():
return path.resolve()
root = _model_root().expanduser()
if path.exists():
return path.resolve()
if "/" in model:
return _model_root() / model.replace("/", "__")
return _model_root() / model
return (root / model.replace("/", "__")).resolve()
return (root / model).resolve()
def _write_pcm_queue_to_wav(
@@ -335,9 +337,9 @@ class LocalFunASRSTTAdapter:
def _runtime_model_name(self) -> str:
if self.model_dir:
return self.model_dir
return str(Path(self.model_dir).expanduser().resolve())
local_path = _local_path_for_model(self.model)
return str(local_path) if local_path.exists() else self.model
return str(local_path)
def _load_runtime(self) -> Any:
if self._runtime is not None:

View File

@@ -6,6 +6,7 @@ from opentalking.providers.tts import ( # noqa: F401 side-effect imports
dashscope_sambert,
edge,
elevenlabs,
mock,
)
from opentalking.providers.tts.edge.adapter import EdgeTTSAdapter
from opentalking.providers.tts.factory import build_tts_adapter, create_tts_adapter

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import os
import posixpath
from collections.abc import Mapping
from pathlib import Path
from types import SimpleNamespace
@@ -42,6 +43,13 @@ def _provider_env(provider: str, field: str) -> str:
return os.environ.get(f"OPENTALKING_TTS_{key_provider}_{field}", "").strip()
def _join_config_path(root: str, *parts: str) -> str:
value = root.strip()
if "/" in value and "\\" not in value:
return posixpath.join(value.rstrip("/"), *parts)
return str(Path(value).expanduser().joinpath(*parts))
def _configured_tts_provider_values() -> tuple[str, ...]:
return tuple(
value
@@ -118,20 +126,19 @@ def _local_cosyvoice_service_url() -> str:
)
def _local_audio_model_root() -> Path:
raw = (
def _local_audio_model_root() -> str:
return (
os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "").strip()
or _settings_value("local_audio_model_root", "")
or "./models/local-audio"
)
return Path(raw).expanduser()
def _local_cosyvoice_model_dir(model: str) -> str:
return (
_provider_env("local_cosyvoice", "MODEL_DIR")
or _settings_value("tts_local_cosyvoice_model_dir", "")
or str(_local_audio_model_root() / model.replace("/", "__"))
or _join_config_path(_local_audio_model_root(), model.replace("/", "__"))
)
@@ -184,10 +191,11 @@ def _local_cosyvoice_fp16() -> str:
def _local_audio_asset_dir(name: str, required_file: str, *fallback_names: str) -> str:
root = _local_audio_model_root()
for candidate_name in (name, *fallback_names):
candidate = root / candidate_name
candidate_value = _join_config_path(root, candidate_name)
candidate = Path(candidate_value).expanduser()
if (candidate / required_file).is_file():
return str(candidate)
return str(root / name)
return candidate_value
return _join_config_path(root, name)
def _local_audio_asset_file_dir(name: str, relative_file: str, *fallback_names: str) -> str:
@@ -206,7 +214,7 @@ def _local_indextts_model_dir(model: str) -> str:
return (
_provider_env("local_indextts", "MODEL_DIR")
or _settings_value("tts_local_indextts_model_dir", "")
or str(_local_audio_model_root() / model.replace("/", "__"))
or _join_config_path(_local_audio_model_root(), model.replace("/", "__"))
)
@@ -214,7 +222,7 @@ def _local_indextts_cfg_path(model_dir: str) -> str:
return (
_provider_env("local_indextts", "CFG_PATH")
or _settings_value("tts_local_indextts_cfg_path", "")
or str(Path(model_dir) / "config.yaml")
or _join_config_path(model_dir, "config.yaml")
)
@@ -236,7 +244,7 @@ def _local_indextts_w2v_bert_dir() -> str:
return (
_provider_env("local_indextts", "W2V_BERT_DIR")
or _settings_value("tts_local_indextts_w2v_bert_dir", "")
or str(_local_audio_model_root() / "facebook__w2v-bert-2.0")
or _join_config_path(_local_audio_model_root(), "facebook__w2v-bert-2.0")
)
@@ -252,7 +260,7 @@ def _local_indextts_campplus_dir() -> str:
return (
_provider_env("local_indextts", "CAMPPLUS_DIR")
or _settings_value("tts_local_indextts_campplus_dir", "")
or str(_local_audio_model_root() / "funasr__campplus")
or _join_config_path(_local_audio_model_root(), "funasr__campplus")
)
@@ -260,7 +268,7 @@ def _local_indextts_bigvgan_dir() -> str:
return (
_provider_env("local_indextts", "BIGVGAN_DIR")
or _settings_value("tts_local_indextts_bigvgan_dir", "")
or str(_local_audio_model_root() / "nvidia__bigvgan_v2_22khz_80band_256x")
or _join_config_path(_local_audio_model_root(), "nvidia__bigvgan_v2_22khz_80band_256x")
)
@@ -519,6 +527,16 @@ def tts_enabled_providers() -> list[str]:
def tts_provider_config(provider: str) -> dict[str, str | bool | int | float]:
p = normalize_tts_provider(provider, default=None) or _provider()
if p == "mock":
return {
"provider": p,
"model": "mock",
"model_dir": "",
"voice": "mock",
"device": "",
"key_set": False,
"service_url_set": True,
}
if p == "indextts":
resolved = _resolve_indextts_provider(p)
config = dict(tts_provider_config(resolved))
@@ -1260,6 +1278,15 @@ def build_tts_adapter(
tts_model=effective_tts_model,
)
if provider == "mock":
from opentalking.providers.tts.mock.adapter import MockTTSAdapter
return MockTTSAdapter(
default_voice=default_voice or "mock",
sample_rate=sample_rate,
chunk_ms=chunk_ms,
)
# For dashscope/bailian/etc., delegate to create_tts_adapter
if provider in _QWEN_RT or provider in _COSY_WS or provider in _SAMBERT or provider in _LOCAL or provider in _OMNIRT or provider in _INDEXTTS:
return create_tts_adapter(
@@ -1272,6 +1299,14 @@ def build_tts_adapter(
)
if provider in _CORE:
if provider == "mock":
from opentalking.providers.tts.mock.adapter import MockTTSAdapter
return MockTTSAdapter(
default_voice=default_voice or getattr(settings, "tts_voice", None) or "mock",
sample_rate=sample_rate,
chunk_ms=chunk_ms,
)
return EdgeTTSAdapter(
default_voice=default_voice or getattr(settings, "tts_voice", None) or _edge_default_voice(),
sample_rate=sample_rate,

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import inspect
import importlib
import io
import json
@@ -14,6 +15,11 @@ import httpx
import numpy as np
from opentalking.core.types.frames import AudioChunk
from opentalking.providers.tts.voice_assets import (
LOCAL_COSYVOICE_PROVIDER,
local_audio_model_root,
resolve_voice_asset,
)
def _settings_value(name: str, default: str = "") -> str:
@@ -93,12 +99,7 @@ def _resolve_service_url_for_model(model: str, default_model: str, default_url:
def _model_root() -> Path:
raw = (
os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "").strip()
or _settings_value("local_audio_model_root", "")
or "./models/local-audio"
)
return Path(raw).expanduser()
return local_audio_model_root()
def _resolve_model_path(model: str) -> str:
@@ -114,39 +115,70 @@ def _resolve_local_voice_prompt(voice: str | None) -> dict[str, str] | None:
return None
if not all(ch.isalnum() or ch in {"_", "-"} for ch in voice_id):
return None
for base in (_model_root() / "voices" / "clones", _model_root() / "voices" / "system"):
voice_dir = base / voice_id
prompt_audio = voice_dir / "prompt.wav"
prompt_text = voice_dir / "prompt.txt"
if not prompt_audio.is_file() or not prompt_text.is_file():
continue
result = {"prompt_audio": str(prompt_audio)}
text = prompt_text.read_text(encoding="utf-8").strip()
if text:
result["prompt_text"] = text
meta_path = voice_dir / "meta.json"
if meta_path.is_file():
try:
meta = json.loads(meta_path.read_text(encoding="utf-8"))
except Exception:
meta = {}
for key in ("mode", "instruction"):
value = str(meta.get(key) or "").strip()
if value:
result[key] = value
if result.get("prompt_text") or result.get("mode") in {"cross_lingual", "instruct"}:
return result
asset = resolve_voice_asset(
voice_id,
provider=LOCAL_COSYVOICE_PROVIDER,
sources=("clones", "system"),
model_root=_model_root(),
require_prompt_text=True,
)
if asset is None or asset.prompt_text is None:
return None
result = {"prompt_audio": str(asset.prompt_audio)}
try:
text = asset.prompt_text.read_text(encoding="utf-8").strip()
except OSError:
text = ""
if text:
result["prompt_text"] = text
for key in ("mode", "instruction"):
value = str(asset.meta.get(key) or "").strip()
if value:
result[key] = value
if result.get("prompt_text") or result.get("mode") in {"cross_lingual", "instruct"}:
return result
return None
class LocalCosyVoiceInputError(ValueError):
"""Invalid local CosyVoice request, usually an unavailable prompt voice."""
def _is_service_default_voice(voice: str | None) -> bool:
voice_id = (voice or "").strip()
return not voice_id or voice_id == "local-default"
def _service_default_prompt_configured() -> bool:
mode = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_MODE", "zero_shot").strip().lower()
prompt_audio = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_AUDIO", "").strip()
prompt_text = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_TEXT", "").strip()
if mode in {"cross_lingual", "instruct"}:
return bool(prompt_audio)
return bool(prompt_audio and prompt_text)
def _missing_local_voice_message(voice: str | None) -> str:
voice_text = (voice or "").strip() or "未选择"
return f"本地 CosyVoice 音色 {voice_text!r} 没有可用 prompt请先选择本地音色。"
def _local_cosyvoice_http_400_message(*, voice: str | None, service_url: str, detail: str) -> str:
suffix = f" sidecar detail: {detail.strip()}" if detail.strip() else ""
return (
f"本地 CosyVoice 请求无效:{_missing_local_voice_message(voice)} "
f"HTTP 400 from {service_url}.{suffix}"
).strip()
def _env_device() -> str:
return (
os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_DEVICE", "").strip()
or _settings_value("tts_local_cosyvoice_device", "")
or os.environ.get("OPENTALKING_LOCAL_TTS_DEVICE", "").strip()
or os.environ.get("OPENTALKING_LOCAL_AUDIO_DEVICE", "").strip()
or _settings_value("local_audio_device", "auto")
or "auto"
or _settings_value("local_audio_device", "cpu")
or "cpu"
)
@@ -299,15 +331,16 @@ class LocalCosyVoiceTTSAdapter:
default_service_url,
)
self.device = _env_device()
self.fp16 = _local_cosyvoice_fp16(self.device)
self.load_jit = _local_cosyvoice_bool("LOAD_JIT", "tts_local_cosyvoice_load_jit", False)
self.load_trt = _local_cosyvoice_bool("LOAD_TRT", "tts_local_cosyvoice_load_trt", False)
self.load_vllm = _local_cosyvoice_bool("LOAD_VLLM", "tts_local_cosyvoice_load_vllm", False)
self.fp16 = _local_cosyvoice_fp16(self.device)
self.trt_concurrent = max(
1,
_local_cosyvoice_int("TRT_CONCURRENT", "tts_local_cosyvoice_trt_concurrent", 1),
)
self._engine: Any | None = None
self._voice_payload_cache: dict[str, dict[str, str]] = {}
async def synthesize_stream(self, text: str, voice: str | None = None) -> AsyncIterator[AudioChunk]:
if not text.strip():
@@ -321,76 +354,109 @@ class LocalCosyVoiceTTSAdapter:
yield chunk
async def _synthesize_via_service(self, text: str, voice: str | None = None) -> AsyncIterator[AudioChunk]:
timeout = httpx.Timeout(connect=30.0, read=180.0, write=30.0, pool=30.0)
timeout = httpx.Timeout(connect=30.0, read=600.0, write=30.0, pool=30.0)
effective_voice = voice or self.default_voice
payload = {
"text": text,
"voice": voice or self.default_voice,
"voice": effective_voice,
"model": self.model,
"sample_rate": self.sample_rate,
}
local_prompt = _resolve_local_voice_prompt(voice or self.default_voice)
local_prompt = _resolve_local_voice_prompt(effective_voice)
if local_prompt is not None:
payload.update(local_prompt)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("POST", self.service_url, json=payload) as resp:
resp.raise_for_status()
input_format = _audio_format_from_content_type(resp.headers.get("content-type"))
if input_format == "pcm":
source_sr = _source_sample_rate_from_headers(resp.headers, self.sample_rate)
pending = b""
async for data in resp.aiter_bytes():
if not data:
continue
data = pending + data
if len(data) % 2:
pending = data[-1:]
data = data[:-1]
else:
pending = b""
if not data:
continue
pcm = np.frombuffer(data, dtype="<i2").astype(np.int16, copy=False)
payload["zero_shot_spk_id"] = effective_voice
elif not _is_service_default_voice(effective_voice):
raise LocalCosyVoiceInputError(_missing_local_voice_message(effective_voice))
try:
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("POST", self.service_url, json=payload) as resp:
resp.raise_for_status()
input_format = _audio_format_from_content_type(resp.headers.get("content-type"))
if input_format == "pcm":
source_sr = _source_sample_rate_from_headers(resp.headers, self.sample_rate)
pending = b""
async for data in resp.aiter_bytes():
if not data:
continue
data = pending + data
if len(data) % 2:
pending = data[-1:]
data = data[:-1]
else:
pending = b""
if not data:
continue
pcm = np.frombuffer(data, dtype="<i2").astype(np.int16, copy=False)
pcm = _resample_linear(pcm, source_sr, self.sample_rate)
for chunk in _split_pcm_chunks(pcm, self.sample_rate, self.chunk_ms):
yield chunk
return
if input_format == "wav":
raw = await resp.aread()
with wave.open(io.BytesIO(raw), "rb") as wf:
source_sr = int(wf.getframerate())
channels = int(wf.getnchannels())
sample_width = int(wf.getsampwidth())
pcm_bytes = wf.readframes(wf.getnframes())
if sample_width != 2:
raise RuntimeError(f"Unsupported WAV sample width for local CosyVoice: {sample_width}")
pcm = np.frombuffer(pcm_bytes, dtype="<i2").astype(np.int16, copy=False)
if channels > 1:
frame_count = pcm.size // channels
pcm = (
pcm[: frame_count * channels]
.reshape(frame_count, channels)
.mean(axis=1)
.astype(np.int16)
)
pcm = _resample_linear(pcm, source_sr, self.sample_rate)
for chunk in _split_pcm_chunks(pcm, self.sample_rate, self.chunk_ms):
yield chunk
return
if input_format == "wav":
raw = await resp.aread()
with wave.open(io.BytesIO(raw), "rb") as wf:
source_sr = int(wf.getframerate())
channels = int(wf.getnchannels())
sample_width = int(wf.getsampwidth())
pcm_bytes = wf.readframes(wf.getnframes())
if sample_width != 2:
raise RuntimeError(f"Unsupported WAV sample width for local CosyVoice: {sample_width}")
pcm = np.frombuffer(pcm_bytes, dtype="<i2").astype(np.int16, copy=False)
if channels > 1:
frame_count = pcm.size // channels
pcm = (
pcm[: frame_count * channels]
.reshape(frame_count, channels)
.mean(axis=1)
.astype(np.int16)
)
pcm = _resample_linear(pcm, source_sr, self.sample_rate)
for chunk in _split_pcm_chunks(pcm, self.sample_rate, self.chunk_ms):
return
from opentalking.providers.tts.edge.adapter import _stream_decode_audio_to_pcm_chunks
async def _audio_iter() -> AsyncIterator[bytes]:
async for data in resp.aiter_bytes():
if data:
yield data
async for chunk in _stream_decode_audio_to_pcm_chunks(
_audio_iter(),
self.sample_rate,
self.chunk_ms,
input_format=input_format,
):
yield chunk
return
from opentalking.providers.tts.edge.adapter import _stream_decode_audio_to_pcm_chunks
async def _audio_iter() -> AsyncIterator[bytes]:
async for data in resp.aiter_bytes():
if data:
yield data
async for chunk in _stream_decode_audio_to_pcm_chunks(
_audio_iter(),
self.sample_rate,
self.chunk_ms,
input_format=input_format,
):
yield chunk
except httpx.HTTPStatusError as exc:
detail = ""
try:
detail = exc.response.text[:500]
except Exception:
detail = ""
if exc.response.status_code == 400:
raise LocalCosyVoiceInputError(
_local_cosyvoice_http_400_message(
voice=effective_voice,
service_url=self.service_url,
detail=detail,
)
) from exc
if "CosyVoice returned no audio" in detail:
raise RuntimeError(
"本地 CosyVoice 返回空音频,模型推理状态异常;请重启本地 TTS 服务后重试。"
f" HTTP {exc.response.status_code} from {self.service_url}. {detail}".strip()
) from exc
raise RuntimeError(
"本地 CosyVoice 服务不可用(可能已退出/内存不足)。"
f" HTTP {exc.response.status_code} from {self.service_url}. {detail}".strip()
) from exc
except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout, httpx.RemoteProtocolError) as exc:
raise RuntimeError(
"本地 CosyVoice 服务不可用(可能已退出/内存不足)。"
f" 无法连接或读取 {self.service_url}: {type(exc).__name__}: {exc}"
) from exc
def _load_engine(self) -> Any:
if self._engine is not None:
@@ -432,6 +498,28 @@ class LocalCosyVoiceTTSAdapter:
return str(item)
return requested or "中文女"
def _callable_supports_keyword(self, fn: Any, keyword: str) -> bool:
try:
signature = inspect.signature(fn)
except (TypeError, ValueError):
return False
return keyword in signature.parameters or any(
param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
)
def _resolved_voice_payload(self, voice: str | None) -> tuple[str | None, dict[str, str] | None]:
voice_id = (voice or "").strip()
if not voice_id or voice_id == "local-default":
return None, None
cached = self._voice_payload_cache.get(voice_id)
if cached is not None:
return voice_id, dict(cached)
payload = _resolve_local_voice_prompt(voice_id)
if payload is None:
return voice_id, None
self._voice_payload_cache[voice_id] = dict(payload)
return voice_id, payload
def _synthesize_in_process(self, text: str, voice: str) -> list[AudioChunk]:
engine = self._load_engine()
spk_id = self._available_voice(engine, voice)
@@ -440,7 +528,34 @@ class LocalCosyVoiceTTSAdapter:
raise RuntimeError("CosyVoice runtime does not expose inference_sft().")
sr = int(getattr(engine, "sample_rate", 22050) or 22050)
pcm_parts: list[np.ndarray] = []
for item in infer(text, spk_id, stream=False):
voice_id, payload = self._resolved_voice_payload(spk_id)
if payload is not None:
add_zero_shot_spk = getattr(engine, "add_zero_shot_spk", None)
if callable(add_zero_shot_spk):
try:
prompt_text = payload.get("prompt_text", "")
prompt_audio = payload["prompt_audio"]
if self._callable_supports_keyword(add_zero_shot_spk, "zero_shot_spk_id"):
add_zero_shot_spk(prompt_text, prompt_audio, zero_shot_spk_id=voice_id)
else:
add_zero_shot_spk(prompt_text, prompt_audio, voice_id)
save_spkinfo = getattr(engine, "save_spkinfo", None)
if callable(save_spkinfo):
save_spkinfo()
except Exception:
pass
if payload is not None and self._callable_supports_keyword(infer, "zero_shot_spk_id"):
iterator = infer(text, "", "", stream=False, zero_shot_spk_id=voice_id)
elif payload is not None:
iterator = infer(
text,
payload.get("prompt_text", ""),
payload["prompt_audio"],
stream=False,
)
else:
iterator = infer(text, spk_id, stream=False)
for item in iterator:
speech = item.get("tts_speech") if isinstance(item, dict) else item
if hasattr(speech, "detach"):
speech = speech.detach().cpu().numpy()

View File

@@ -17,6 +17,7 @@ import httpx
from opentalking.core.types.frames import AudioChunk
from opentalking.providers.tts.indextts_config import indextts_infer_kwargs, normalize_indextts_config
from opentalking.providers.tts.voice_assets import INDEXTTS_PROVIDER, resolve_voice_asset
_ENGINE_CACHE_LOCK = threading.Lock()
@@ -113,6 +114,24 @@ def _read_wav_bytes_i16(raw: bytes) -> tuple[np.ndarray, int]:
return _read_wav_handle_i16(wf)
def _write_wav_i16(path: str | Path, data: np.ndarray, sample_rate: int) -> None:
pcm = np.asarray(data)
if pcm.ndim == 2 and pcm.shape[0] == 1:
pcm = pcm[0]
elif pcm.ndim == 2:
pcm = pcm.T.reshape(-1)
if np.issubdtype(pcm.dtype, np.floating):
pcm = np.clip(pcm, -1.0, 1.0)
pcm = np.round(pcm * 32767.0).astype("<i2")
else:
pcm = np.clip(pcm, -32768, 32767).astype("<i2")
with wave.open(str(path), "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(int(sample_rate))
wf.writeframes(pcm.reshape(-1).tobytes())
class LocalIndexTTSAdapter:
"""In-process IndexTTS2 adapter for OpenTalking local TTS mode."""
@@ -226,11 +245,14 @@ class LocalIndexTTSAdapter:
def _resolve_voice_prompt(self, voice: str | None) -> Path | None:
voice_id = (voice or "").strip()
if voice_id and re.fullmatch(r"[A-Za-z0-9_-]{3,80}", voice_id):
voices_root = _local_audio_model_root() / "voices"
for root in (voices_root / "clones", voices_root / "system", _bundled_system_voice_root()):
prompt = root / voice_id / "prompt.wav"
if prompt.is_file():
return prompt
asset = resolve_voice_asset(
voice_id,
provider=INDEXTTS_PROVIDER,
sources=("clones", "system"),
model_root=_local_audio_model_root(),
)
if asset is not None:
return asset.prompt_audio
if self.prompt_audio:
return Path(self.prompt_audio)
return None
@@ -369,15 +391,8 @@ class LocalIndexTTSAdapter:
text = str(exc)
if "TorchCodec is required" not in text and "libtorchcodec" not in text:
raise
import soundfile as sf
data = tensor.detach().cpu().numpy() if hasattr(tensor, "detach") else np.asarray(tensor)
data = np.asarray(data)
if data.ndim == 2 and data.shape[0] == 1:
data = data[0]
elif data.ndim == 2:
data = data.T
sf.write(path, data, sample_rate, subtype="PCM_16")
_write_wav_i16(path, data, sample_rate)
return None
torchaudio.save = save
@@ -435,10 +450,15 @@ class LocalIndexTTSAdapter:
if not callable(infer):
raise RuntimeError("IndexTTS2 runtime does not expose infer().")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
fd, tmp_name = tempfile.mkstemp(suffix=".wav")
os.close(fd)
tmp_path = Path(tmp_name)
try:
with engine_lock:
infer(str(prompt), text, tmp.name, **indextts_infer_kwargs(self.indextts_config))
pcm, source_sr = _read_wav_i16(Path(tmp.name))
infer(str(prompt), text, tmp_name, **indextts_infer_kwargs(self.indextts_config))
pcm, source_sr = _read_wav_i16(tmp_path)
finally:
tmp_path.unlink(missing_ok=True)
pcm = _resample_linear(pcm, source_sr, self.sample_rate)
return _split_pcm_chunks(pcm, self.sample_rate, self.chunk_ms)

View File

@@ -0,0 +1,6 @@
from opentalking.core.registry import register
from opentalking.providers.tts.mock.adapter import MockTTSAdapter
register("tts", "mock")(MockTTSAdapter)
__all__ = ["MockTTSAdapter"]

View File

@@ -0,0 +1,50 @@
from __future__ import annotations
import math
from collections.abc import AsyncIterator
import numpy as np
from opentalking.core.types.frames import AudioChunk
def _tone_frequency(text: str) -> float:
checksum = sum(ord(ch) for ch in text)
return 180.0 + float(checksum % 220)
def _sine_wave(text: str, sample_rate: int, chunk_ms: float) -> np.ndarray:
duration_sec = max(0.4, min(2.0, len(text) * 0.06))
sample_count = max(1, int(sample_rate * duration_sec))
t = np.arange(sample_count, dtype=np.float32) / sample_rate
freq = _tone_frequency(text)
carrier = np.sin(2.0 * math.pi * freq * t)
envelope = np.linspace(0.15, 0.95, sample_count, dtype=np.float32)
audio = carrier * envelope * 0.22
return np.clip(audio * 32767.0, -32768, 32767).astype(np.int16)
class MockTTSAdapter:
def __init__(
self,
default_voice: str = "mock",
sample_rate: int = 16000,
chunk_ms: float = 40.0,
) -> None:
self.default_voice = default_voice
self.sample_rate = sample_rate
self.chunk_ms = chunk_ms
async def synthesize_stream(self, text: str, voice: str | None = None) -> AsyncIterator[AudioChunk]:
del voice
pcm = _sine_wave(text, self.sample_rate, self.chunk_ms)
chunk_samples = max(1, int(self.sample_rate * (self.chunk_ms / 1000.0)))
for start in range(0, len(pcm), chunk_samples):
part = pcm[start : start + chunk_samples]
if part.size == 0:
continue
yield AudioChunk(
data=part.copy(),
sample_rate=self.sample_rate,
duration_ms=1000.0 * float(part.size) / float(self.sample_rate),
)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
CORE_TTS_PROVIDERS = frozenset({"auto", "edge", "elevenlabs"})
CORE_TTS_PROVIDERS = frozenset({"auto", "edge", "elevenlabs", "mock"})
OPENAI_COMPATIBLE_TTS_PROVIDERS = frozenset({"openai_compatible"})
XIAOMI_MIMO_TTS_PROVIDERS = frozenset({"xiaomi_mimo", "xiaomi", "mimo"})
QWEN_TTS_PROVIDERS = frozenset({"dashscope", "bailian", "qwen", "qwen_tts"})

View File

@@ -0,0 +1,191 @@
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
INDEXTTS_PROVIDER = "indextts"
INDEXTTS_LEGACY_PROVIDERS = {"local_indextts", "omnirt_indextts"}
INDEXTTS_PROVIDERS = {INDEXTTS_PROVIDER, *INDEXTTS_LEGACY_PROVIDERS}
LOCAL_COSYVOICE_PROVIDER = "local_cosyvoice"
@dataclass(frozen=True)
class VoiceAsset:
voice_id: str
source: str
root: Path
path: Path
prompt_audio: Path
prompt_text: Path | None
meta: dict[str, Any]
bundled_system: bool = False
def local_audio_model_root() -> Path:
raw = os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "").strip()
try:
from opentalking.core.config import get_settings
raw = raw or (get_settings().local_audio_model_root or "").strip()
except Exception:
pass
return Path(raw or "./models/local-audio").expanduser().resolve()
def bundled_system_voice_root() -> Path:
return Path(__file__).resolve().parents[2] / "assets" / "voices" / "system"
def system_voice_roots(model_root: Path | None = None) -> list[Path]:
root = model_root or local_audio_model_root()
roots = [root / "voices" / "system", bundled_system_voice_root()]
out: list[Path] = []
seen: set[Path] = set()
for item in roots:
try:
resolved = item.resolve()
except OSError:
resolved = item
if resolved in seen:
continue
seen.add(resolved)
out.append(item)
return out
def clone_voice_roots(model_root: Path | None = None) -> list[Path]:
root = model_root or local_audio_model_root()
return [root / "voices" / "clones"]
def _provider_aliases(provider: str) -> set[str]:
normalized = provider.strip().lower()
if normalized in INDEXTTS_PROVIDERS:
return {INDEXTTS_PROVIDER, *INDEXTTS_LEGACY_PROVIDERS}
return {normalized}
def _truthy_meta_flag(value: object) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "yes", "on", "universal", "zero_shot"}
return False
def voice_applies_to_provider(meta: dict[str, Any], provider: str, *, bundled_system: bool = False) -> bool:
normalized = provider.strip().lower()
if not normalized:
return True
if bundled_system and normalized == LOCAL_COSYVOICE_PROVIDER:
return True
if any(_truthy_meta_flag(meta.get(key)) for key in ("universal", "compatible", "zero_shot_compatible")):
return True
aliases = _provider_aliases(normalized)
raw_providers = meta.get("providers")
if isinstance(raw_providers, list):
allowed = {str(item).strip().lower() for item in raw_providers if str(item).strip()}
if allowed:
return bool(allowed & aliases)
raw_provider = str(meta.get("provider") or "").strip().lower()
if not raw_provider:
return True
if normalized == LOCAL_COSYVOICE_PROVIDER and raw_provider in INDEXTTS_PROVIDERS and not raw_providers:
return True
return raw_provider in aliases
def read_voice_meta(voice_dir: Path) -> dict[str, Any]:
meta_path = voice_dir / "meta.json"
if not meta_path.is_file():
return {}
try:
parsed = json.loads(meta_path.read_text(encoding="utf-8"))
except Exception:
return {}
return parsed if isinstance(parsed, dict) else {}
def iter_voice_assets(
*,
provider: str,
sources: Iterable[str] = ("clones", "system"),
model_root: Path | None = None,
require_prompt_text: bool = False,
) -> list[VoiceAsset]:
root = model_root or local_audio_model_root()
roots: list[tuple[str, Path, bool]] = []
for source in sources:
if source == "clones":
roots.extend(("clones", item, False) for item in clone_voice_roots(root))
elif source == "system":
bundled = bundled_system_voice_root()
for item in system_voice_roots(root):
roots.append(("system", item, _same_path(item, bundled)))
assets: list[VoiceAsset] = []
seen: set[tuple[str, str]] = set()
for source, voice_root, bundled_system in roots:
if not voice_root.is_dir():
continue
for voice_dir in sorted(path for path in voice_root.iterdir() if path.is_dir()):
voice_id = voice_dir.name
key = (source, voice_id)
if key in seen:
continue
prompt_audio = voice_dir / "prompt.wav"
prompt_text = voice_dir / "prompt.txt"
if not prompt_audio.is_file():
continue
if require_prompt_text and not prompt_text.is_file():
continue
meta = read_voice_meta(voice_dir)
if not voice_applies_to_provider(meta, provider, bundled_system=bundled_system):
continue
seen.add(key)
assets.append(
VoiceAsset(
voice_id=voice_id,
source=source,
root=voice_root,
path=voice_dir,
prompt_audio=prompt_audio,
prompt_text=prompt_text if prompt_text.is_file() else None,
meta=meta,
bundled_system=bundled_system,
)
)
return assets
def resolve_voice_asset(
voice_id: str,
*,
provider: str,
sources: Iterable[str] = ("clones", "system"),
model_root: Path | None = None,
require_prompt_text: bool = False,
) -> VoiceAsset | None:
wanted = voice_id.strip()
if not wanted:
return None
for asset in iter_voice_assets(
provider=provider,
sources=sources,
model_root=model_root,
require_prompt_text=require_prompt_text,
):
if asset.voice_id == wanted:
return asset
return None
def _same_path(left: Path, right: Path) -> bool:
try:
return left.resolve() == right.resolve()
except OSError:
return left == right

View File

@@ -117,10 +117,24 @@ def _fit_reference_driver_pcm(pcm: np.ndarray, total_samples: int) -> np.ndarray
return np.tile(source, repeats)[:target].astype(np.int16, copy=False)
def _read_pcm16_mono_wav(path: Path) -> np.ndarray | None:
try:
with wave.open(str(path), "rb") as wf:
if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getframerate() != 16000:
return None
raw = wf.readframes(wf.getnframes())
except (wave.Error, OSError):
return None
return np.frombuffer(raw, dtype="<i2").copy()
async def _load_reference_driver_pcm(settings: object, total_samples: int) -> np.ndarray | None:
path = _reference_driver_audio_path(settings)
if not path.is_file():
return None
direct_pcm = _read_pcm16_mono_wav(path)
if direct_pcm is not None:
return _fit_reference_driver_pcm(direct_pcm, total_samples)
try:
pcm = await decode_audio_file_to_pcm_i16(path)
return _fit_reference_driver_pcm(pcm, total_samples)
@@ -415,6 +429,12 @@ def _quicktalk_prepared_template_video(settings: object, avatar_path: Path) -> P
path = avatar_path / name
if path.is_file():
return path.resolve()
source_dir = avatar_path / "source"
if source_dir.is_dir():
for pattern in ("*.mp4", "*.mov", "*.webm", "*.avi"):
for candidate in sorted(source_dir.glob(pattern)):
if candidate.is_file():
return candidate.resolve()
return None
@@ -492,12 +512,13 @@ def _init_session_kwargs(
fasterliveportrait_config: Mapping[str, object] | None,
) -> dict[str, object]:
kwargs: dict[str, object] = {"avatar_path": avatar_path}
if model == "quicktalk":
kwargs.update(_quicktalk_init_session_kwargs(settings, avatar_path))
if not _remote_audio2video_backend(backend):
return kwargs
kwargs["ref_image"] = _reference_image_path(avatar_path)
if model == "quicktalk":
kwargs.update(_quicktalk_init_session_kwargs(settings, avatar_path))
return kwargs
if model != "fasterliveportrait":
return kwargs

View File

@@ -9,6 +9,11 @@ from pathlib import Path
DEFAULT_ROOT = Path(os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "./models/local-audio"))
DEFAULT_REUSE_ROOTS = (
Path("./models"),
Path("/root/models"),
Path.home() / ".cache" / "opentalking" / "models",
)
MODELS: dict[str, tuple[str, str]] = {
"sensevoice-small": ("modelscope", "iic/SenseVoiceSmall"),
@@ -58,6 +63,30 @@ HF_ALLOW_PATTERNS: dict[str, list[str]] = {
],
}
MODEL_HINTS: dict[str, tuple[str, ...]] = {
"sensevoice-small": ("iic__SenseVoiceSmall", "sensevoice", "SenseVoiceSmall"),
"fun-cosyvoice3-0.5b-2512": (
"FunAudioLLM__Fun-CosyVoice3-0.5B-2512",
"Fun-CosyVoice3-0.5B-2512",
"cosyvoice",
),
"indextts2": ("IndexTeam__IndexTTS-2", "IndexTTS-2"),
"indextts2-w2v-bert": ("facebook__w2v-bert-2.0", "w2v-bert-2.0"),
"indextts2-maskgct": ("amphion__MaskGCT", "amphion__MaskGCT-ms"),
"indextts2-campplus": ("funasr__campplus", "campplus"),
"indextts2-bigvgan": ("nvidia__bigvgan_v2_22khz_80band_256x", "bigvgan_v2_22khz_80band_256x"),
}
MODEL_REQUIRED_FILES: dict[str, tuple[str, ...]] = {
"sensevoice-small": ("model.pt", "config.yaml", "configuration.json"),
"fun-cosyvoice3-0.5b-2512": ("cosyvoice3.yaml", "flow.pt", "hift.pt", "llm.pt"),
"indextts2": ("config.yaml", "model.pt"),
"indextts2-w2v-bert": ("model.safetensors", "conformer_shaw.pt"),
"indextts2-maskgct": ("semantic_codec/model.safetensors", "acoustic_codec/model.safetensors"),
"indextts2-campplus": ("campplus_cn_common.bin", "config.yaml"),
"indextts2-bigvgan": ("bigvgan_generator.pt",),
}
def default_model_keys() -> list[str]:
return ["sensevoice-small", "fun-cosyvoice3-0.5b-2512"]
@@ -71,6 +100,61 @@ def _target(root: Path, model_id: str) -> Path:
return root / model_id.replace("/", "__")
def _reuse_root_values(raw: str | None) -> list[Path]:
if not raw:
return [root for root in DEFAULT_REUSE_ROOTS if root is not None]
roots: list[Path] = []
for chunk in raw.replace(";", os.pathsep).replace(",", os.pathsep).split(os.pathsep):
value = chunk.strip()
if value:
roots.append(Path(value).expanduser())
return roots
def _model_hints(model_key: str) -> tuple[str, ...]:
return MODEL_HINTS.get(model_key, (MODELS[model_key][1].replace("/", "__"),))
def _required_files(model_key: str) -> tuple[str, ...]:
return MODEL_REQUIRED_FILES.get(model_key, ())
def _is_model_ready(path: Path, *, model_key: str) -> bool:
if not path.exists():
return False
required = _required_files(model_key)
if not required:
return path.is_dir() or path.is_file()
return all((path / relative).exists() for relative in required)
def _find_reusable_source(model_key: str, roots: list[Path]) -> Path | None:
for root in roots:
for hint in _model_hints(model_key):
candidate = root / hint
if _is_model_ready(candidate, model_key=model_key):
return candidate
nested = candidate / "checkpoints"
if _is_model_ready(nested, model_key=model_key):
return nested
return None
def _mirror_existing_source(source: Path, target: Path) -> None:
target.parent.mkdir(parents=True, exist_ok=True)
if target.exists():
return
try:
target.symlink_to(source, target_is_directory=source.is_dir())
return
except Exception:
pass
if source.is_dir():
shutil.copytree(source, target)
else:
shutil.copy2(source, target)
def _download_modelscope(model_id: str, target: Path) -> None:
from modelscope import snapshot_download
@@ -99,6 +183,12 @@ def _git_lfs_pull_if_needed(target: Path) -> None:
def main() -> None:
parser = argparse.ArgumentParser(description="Download the supported local STT/TTS model weights.")
parser.add_argument("--root", type=Path, default=DEFAULT_ROOT)
parser.add_argument(
"--reuse-root",
action="append",
dest="reuse_roots",
help="Search these roots first and reuse existing weights instead of downloading.",
)
parser.add_argument(
"--model",
action="append",
@@ -110,6 +200,9 @@ def main() -> None:
root = args.root.expanduser().resolve()
root.mkdir(parents=True, exist_ok=True)
selected = args.model or default_model_keys()
reuse_roots = [root] + _reuse_root_values(os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_SEARCH_ROOTS"))
if args.reuse_roots:
reuse_roots.extend(Path(value).expanduser().resolve() for value in args.reuse_roots)
failures: list[tuple[str, str]] = []
for key in selected:
@@ -118,6 +211,17 @@ def main() -> None:
print(f"[{key}] {source}:{model_id} -> {target}", flush=True)
target.mkdir(parents=True, exist_ok=True)
try:
if _is_model_ready(target, model_key=key):
print(f"[{key}] reusing existing target: {target}", flush=True)
continue
reusable = _find_reusable_source(key, reuse_roots)
if reusable is not None:
print(f"[{key}] reusing existing source: {reusable}", flush=True)
if reusable.resolve() != target.resolve():
_mirror_existing_source(reusable, target)
continue
if source == "modelscope":
_download_modelscope(model_id, target)
else:

View File

@@ -1,13 +1,17 @@
from __future__ import annotations
import argparse
import gc
import importlib
import importlib.util
import io
import inspect
import hashlib
import os
import sys
import threading
import time
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import Any
@@ -19,6 +23,29 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
def _load_voice_assets_module():
module_name = "_opentalking_voice_assets_local_cosyvoice"
module = sys.modules.get(module_name)
if module is not None:
return module
module_path = Path(__file__).resolve().parents[1] / "opentalking" / "providers" / "tts" / "voice_assets.py"
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Unable to load voice assets module from {module_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
_voice_assets = _load_voice_assets_module()
LOCAL_COSYVOICE_PROVIDER = _voice_assets.LOCAL_COSYVOICE_PROVIDER
VoiceAsset = _voice_assets.VoiceAsset
iter_voice_assets = _voice_assets.iter_voice_assets
local_audio_model_root = _voice_assets.local_audio_model_root
resolve_voice_asset = _voice_assets.resolve_voice_asset
def _soundfile_load_wav(wav: str, target_sr: int):
import torch
@@ -128,6 +155,7 @@ def _patch_cosyvoice_load_wav() -> None:
class SynthesizeRequest(BaseModel):
text: str
voice: str | None = None
zero_shot_spk_id: str | None = None
model: str | None = None
sample_rate: int | None = None
prompt_audio: str | None = None
@@ -150,6 +178,34 @@ def _cosyvoice_flow(cosyvoice: Any) -> Any | None:
return getattr(model, "flow", None)
def _callable_supports_keyword(fn: Any, name: str) -> bool:
try:
signature = inspect.signature(fn)
except (TypeError, ValueError):
return False
return name in signature.parameters or any(
param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
)
def _voice_signature(asset: VoiceAsset) -> tuple[str, int, int, str]:
try:
stat = asset.prompt_audio.stat()
except OSError:
stat = None
try:
prompt_text = asset.prompt_text.read_text(encoding="utf-8").strip() if asset.prompt_text else ""
except OSError:
prompt_text = ""
digest = hashlib.sha1(prompt_text.encode("utf-8")).hexdigest()
return (
str(asset.prompt_audio.resolve()),
int(getattr(stat, "st_mtime_ns", 0) or 0),
int(getattr(stat, "st_size", 0) or 0),
digest,
)
def current_streaming_tuning(cosyvoice: Any) -> dict[str, Any]:
model = _cosyvoice_model(cosyvoice)
return {
@@ -195,6 +251,20 @@ def ensure_cosyvoice_flow_half(cosyvoice: Any) -> bool:
return True
def _is_cuda_runtime_incompatibility(exc: BaseException) -> bool:
text = f"{type(exc).__name__}: {exc}".lower()
return any(
marker in text
for marker in (
"no kernel image is available for execution on the device",
"cuda error",
"invalid device function",
"tensorrt",
"trt",
)
)
def reset_streaming_tuning(cosyvoice: Any) -> dict[str, Any]:
model = _cosyvoice_model(cosyvoice)
baseline = getattr(model, "_opentalking_streaming_tuning", None)
@@ -359,6 +429,7 @@ class CosyVoiceService:
*,
model_dir: str,
runtime_dir: str,
audio_root: str | None = None,
device: str,
prompt_audio: str,
prompt_text: str,
@@ -373,12 +444,15 @@ class CosyVoiceService:
token_max_hop_len: int | None = None,
stream_scale_factor: int | None = None,
flow_n_timesteps: int | None = None,
max_token_text_ratio: float | None = 6.0,
max_token_text_ratio: float | None = None,
min_token_text_ratio: float | None = None,
mask_stop_tokens: bool = True,
mask_stop_tokens: bool = False,
use_zero_shot_spk_id: bool = False,
precache_system_spks: bool = False,
) -> None:
self.model_dir = model_dir
self.runtime_dir = runtime_dir
self.audio_root = audio_root or ""
self.device = device
self.prompt_audio = prompt_audio
self.prompt_text = prompt_text
@@ -396,6 +470,8 @@ class CosyVoiceService:
self.max_token_text_ratio = max_token_text_ratio
self.min_token_text_ratio = min_token_text_ratio
self.mask_stop_tokens = mask_stop_tokens
self.use_zero_shot_spk_id = use_zero_shot_spk_id
self.precache_system_spks = precache_system_spks
self._model: Any | None = None
self._model_lock = threading.Lock()
self._loaded_model_kwargs: dict[str, Any] = {}
@@ -403,6 +479,60 @@ class CosyVoiceService:
self._flow_tuning: dict[str, Any] = {}
self._llm_token_ratio_tuning: dict[str, Any] = {}
self._llm_stop_token_patch: dict[str, Any] = {}
self._zero_shot_spk_cache: dict[str, tuple[str, int, int, str]] = {}
def _audio_root(self) -> Path:
if self.audio_root.strip():
return Path(self.audio_root).expanduser().resolve()
return local_audio_model_root()
def _resolve_voice_asset(self, voice_id: str | None) -> VoiceAsset | None:
voice_key = (voice_id or "").strip()
if not voice_key or voice_key == "local-default":
return None
return resolve_voice_asset(
voice_key,
provider=LOCAL_COSYVOICE_PROVIDER,
sources=("clones", "system"),
model_root=self._audio_root(),
require_prompt_text=True,
)
def _ensure_zero_shot_spk_registered(self, model: Any, voice_id: str, asset: VoiceAsset) -> bool:
if not voice_id or asset.prompt_text is None:
return False
add_zero_shot_spk = getattr(model, "add_zero_shot_spk", None)
if not callable(add_zero_shot_spk):
return False
signature = _voice_signature(asset)
if self._zero_shot_spk_cache.get(voice_id) == signature:
return True
prompt_text = asset.prompt_text.read_text(encoding="utf-8").strip()
if not prompt_text:
return False
prompt_text = self._prompt_text_for_zero_shot(prompt_text)
prompt_audio = str(asset.prompt_audio)
if _callable_supports_keyword(add_zero_shot_spk, "zero_shot_spk_id"):
add_zero_shot_spk(prompt_text, prompt_audio, zero_shot_spk_id=voice_id)
else:
add_zero_shot_spk(prompt_text, prompt_audio, voice_id)
self._zero_shot_spk_cache[voice_id] = signature
print(f"zero_shot_spk registered voice_id={voice_id} prompt_audio={prompt_audio}", flush=True)
save_spkinfo = getattr(model, "save_spkinfo", None)
if callable(save_spkinfo):
save_spkinfo()
return True
def _precache_system_zero_shot_spks(self, model: Any) -> None:
assets = iter_voice_assets(
provider=LOCAL_COSYVOICE_PROVIDER,
sources=("system",),
model_root=self._audio_root(),
require_prompt_text=True,
)
for asset in assets:
self._ensure_zero_shot_spk_registered(model, asset.voice_id, asset)
def model(self) -> Any:
if self._model is not None:
@@ -438,11 +568,52 @@ class CosyVoiceService:
"fp16": self.fp16,
"trt_concurrent": self.trt_concurrent,
}
self._model, self._loaded_model_kwargs = _instantiate_automodel(AutoModel, model_kwargs)
try:
self._model, self._loaded_model_kwargs = _instantiate_automodel(AutoModel, model_kwargs)
except Exception as exc:
if not self.load_trt:
raise
print(
"CosyVoice TensorRT startup failed; falling back to non-TRT runtime: "
f"{type(exc).__name__}: {exc}",
flush=True,
)
self.load_trt = False
model_kwargs["load_trt"] = False
self._model, self._loaded_model_kwargs = _instantiate_automodel(AutoModel, model_kwargs)
flow_half_applied = False
if self.load_trt and self.fp16:
flow_half_applied = ensure_cosyvoice_flow_half(self._model)
try:
flow_half_applied = ensure_cosyvoice_flow_half(self._model)
except Exception as exc:
if not _is_cuda_runtime_incompatibility(exc):
raise
print(
"CosyVoice TensorRT/FP16 startup failed after load; "
"falling back to non-TRT fp32 runtime: "
f"{type(exc).__name__}: {exc}",
flush=True,
)
self.load_trt = False
self.fp16 = False
old_model = self._model
self._model = None
del old_model
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
model_kwargs["load_trt"] = False
model_kwargs["fp16"] = False
self._model, self._loaded_model_kwargs = _instantiate_automodel(AutoModel, model_kwargs)
self._zero_shot_spk_cache.clear()
self._apply_runtime_tuning()
if self.precache_system_spks:
self._precache_system_zero_shot_spks(self._model)
# Keep the service zero-shot first so it does not require precomputed spk2info.pt.
print(
"loaded cosyvoice "
@@ -516,6 +687,35 @@ class CosyVoiceService:
),
}
def reset_model_after_empty_audio(self, *, reason: str) -> None:
with self._model_lock:
old_model = self._model
self._model = None
self._loaded_model_kwargs = {}
if self.load_trt or self.fp16:
print(
"cosyvoice empty audio recovery: disabling TRT/FP16 for retry",
flush=True,
)
self.load_trt = False
self.fp16 = False
self._zero_shot_spk_cache.clear()
self._streaming_tuning = {}
self._flow_tuning = {}
self._llm_token_ratio_tuning = {}
self._llm_stop_token_patch = {}
if old_model is not None:
del old_model
gc.collect()
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception:
pass
print(f"cosyvoice model reset after empty audio: {reason}", flush=True)
def _to_wav_bytes(self, speech: Any, sample_rate: int) -> bytes:
if hasattr(speech, "detach"):
speech = speech.detach().cpu().numpy()
@@ -552,6 +752,14 @@ class CosyVoiceService:
return f"You are a helpful assistant.<|endofprompt|>{text}"
return "You are a helpful assistant.<|endofprompt|>"
def _asset_prompt_text(self, asset: VoiceAsset, fallback_prompt_text: str = "") -> str:
prompt_text = ""
if asset.prompt_text is not None:
prompt_text = asset.prompt_text.read_text(encoding="utf-8").strip()
if not prompt_text:
prompt_text = fallback_prompt_text.strip()
return self._prompt_text_for_zero_shot(prompt_text)
def synthesize_wav(self, req: SynthesizeRequest) -> tuple[bytes, int, float]:
text = req.text.strip()
if not text:
@@ -559,6 +767,7 @@ class CosyVoiceService:
prompt_audio = (req.prompt_audio or self.prompt_audio).strip()
prompt_text = (req.prompt_text or self.prompt_text).strip()
mode = (req.mode or self.mode).strip().lower()
voice_id = (req.zero_shot_spk_id or req.voice or "").strip()
model = self.model()
sample_rate = int(getattr(model, "sample_rate", 24000) or 24000)
t0 = time.perf_counter()
@@ -572,17 +781,40 @@ class CosyVoiceService:
instruction = (req.instruction or self.instruction).strip()
iterator = model.inference_instruct2(text, instruction, prompt_audio, stream=False)
else:
if not prompt_audio or not prompt_text:
raise HTTPException(
status_code=400,
detail="zero_shot mode requires prompt_audio and prompt_text",
asset = self._resolve_voice_asset(voice_id)
if asset is not None:
asset_prompt_text = self._asset_prompt_text(asset, prompt_text)
asset_prompt_audio = str(asset.prompt_audio)
if (
self.use_zero_shot_spk_id
and _callable_supports_keyword(model.inference_zero_shot, "zero_shot_spk_id")
and self._ensure_zero_shot_spk_registered(model, asset.voice_id, asset)
):
iterator = model.inference_zero_shot(text, "", "", stream=False, zero_shot_spk_id=asset.voice_id)
else:
iterator = model.inference_zero_shot(
text,
asset_prompt_text,
asset_prompt_audio,
stream=False,
)
else:
if not prompt_audio or not prompt_text:
raise HTTPException(
status_code=400,
detail="zero_shot mode requires prompt_audio and prompt_text",
)
iterator = model.inference_zero_shot(
text,
self._prompt_text_for_zero_shot(prompt_text),
prompt_audio,
stream=False,
)
if asset is not None:
print(
f"zero_shot {'spk_id' if self.use_zero_shot_spk_id else 'prompt_path'} voice_id={asset.voice_id} stream=False prompt_audio={asset.prompt_audio}",
flush=True,
)
iterator = model.inference_zero_shot(
text,
self._prompt_text_for_zero_shot(prompt_text),
prompt_audio,
stream=False,
)
parts: list[np.ndarray] = []
with self._model_lock:
for item in _with_request_streaming_tuning(model, iterator):
@@ -595,17 +827,22 @@ class CosyVoiceService:
wav_bytes = self._to_wav_bytes(np.concatenate(parts), sample_rate)
return wav_bytes, sample_rate, time.perf_counter() - t0
def _streaming_iterator(self, req: SynthesizeRequest) -> tuple[Iterator[Any], int, int, float, Any]:
def _streaming_iterator(
self,
req: SynthesizeRequest,
) -> tuple[Iterator[Any], int, int, float, Any, Callable[[], Iterator[Any]] | None]:
text = req.text.strip()
if not text:
raise HTTPException(status_code=400, detail="text is required")
prompt_audio = (req.prompt_audio or self.prompt_audio).strip()
prompt_text = (req.prompt_text or self.prompt_text).strip()
mode = (req.mode or self.mode).strip().lower()
voice_id = (req.zero_shot_spk_id or req.voice or "").strip()
model = self.model()
source_sr = int(getattr(model, "sample_rate", 24000) or 24000)
target_sr = int(req.sample_rate or source_sr)
t0 = time.perf_counter()
fallback_iterator_factory: Callable[[], Iterator[Any]] | None = None
if mode == "cross_lingual":
if not prompt_audio:
raise HTTPException(status_code=400, detail="prompt_audio is required")
@@ -616,47 +853,111 @@ class CosyVoiceService:
instruction = (req.instruction or self.instruction).strip()
iterator = model.inference_instruct2(text, instruction, prompt_audio, stream=True)
else:
if not prompt_audio or not prompt_text:
raise HTTPException(
status_code=400,
detail="zero_shot mode requires prompt_audio and prompt_text",
asset = self._resolve_voice_asset(voice_id)
if asset is not None:
asset_prompt_text = self._asset_prompt_text(asset, prompt_text)
asset_prompt_audio = str(asset.prompt_audio)
if (
self.use_zero_shot_spk_id
and _callable_supports_keyword(model.inference_zero_shot, "zero_shot_spk_id")
and self._ensure_zero_shot_spk_registered(model, asset.voice_id, asset)
):
iterator = model.inference_zero_shot(text, "", "", stream=True, zero_shot_spk_id=asset.voice_id)
def fallback_iterator(
*,
text: str = text,
prompt_text: str = asset_prompt_text,
prompt_audio: str = asset_prompt_audio,
voice_id: str = asset.voice_id,
) -> Iterator[Any]:
self._zero_shot_spk_cache.pop(voice_id, None)
print(
"zero_shot_spk_id produced no audio; falling back to prompt "
f"voice_id={voice_id} prompt_audio={prompt_audio}",
flush=True,
)
return model.inference_zero_shot(text, prompt_text, prompt_audio, stream=True)
fallback_iterator_factory = fallback_iterator
else:
iterator = model.inference_zero_shot(
text,
asset_prompt_text,
asset_prompt_audio,
stream=True,
)
else:
if not prompt_audio or not prompt_text:
raise HTTPException(
status_code=400,
detail="zero_shot mode requires prompt_audio and prompt_text",
)
iterator = model.inference_zero_shot(
text,
self._prompt_text_for_zero_shot(prompt_text),
prompt_audio,
stream=True,
)
iterator = model.inference_zero_shot(
text,
self._prompt_text_for_zero_shot(prompt_text),
prompt_audio,
stream=True,
)
return iterator, source_sr, target_sr, t0, model
if asset is not None:
print(
f"zero_shot {'spk_id' if self.use_zero_shot_spk_id else 'prompt_path'} voice_id={asset.voice_id} stream=True prompt_audio={asset.prompt_audio}",
flush=True,
)
return iterator, source_sr, target_sr, t0, model, fallback_iterator_factory
def synthesize_pcm_stream(self, req: SynthesizeRequest) -> tuple[Iterator[bytes], int]:
iterator, source_sr, target_sr, t0, model = self._streaming_iterator(req)
iterator, source_sr, target_sr, t0, model, fallback_iterator_factory = self._streaming_iterator(req)
def generate() -> Iterator[bytes]:
first = True
chunks = 0
samples = 0
with self._model_lock:
tuned_iterator = _with_request_streaming_tuning(model, iterator)
output_sr = target_sr
def emit(
tuned_iterator: Iterator[Any],
*,
source_sr_for_attempt: int,
target_sr_for_attempt: int,
t0_for_attempt: float,
) -> Iterator[bytes]:
nonlocal first, chunks, samples, output_sr
output_sr = target_sr_for_attempt
for item in tuned_iterator:
speech = item.get("tts_speech") if isinstance(item, dict) else item
pcm = self._audio_to_i16(speech)
pcm = self._resample_linear(pcm, source_sr, target_sr)
pcm = self._resample_linear(pcm, source_sr_for_attempt, target_sr_for_attempt)
if pcm.size == 0:
continue
if first:
print(
f"first_pcm chars={len(req.text.strip())} sr={target_sr} seconds={time.perf_counter() - t0:.3f}",
f"first_pcm chars={len(req.text.strip())} sr={target_sr_for_attempt} seconds={time.perf_counter() - t0_for_attempt:.3f}",
flush=True,
)
first = False
chunks += 1
samples += int(pcm.size)
yield pcm.astype("<i2", copy=False).tobytes()
with self._model_lock:
yield from emit(
_with_request_streaming_tuning(model, iterator),
source_sr_for_attempt=source_sr,
target_sr_for_attempt=target_sr,
t0_for_attempt=t0,
)
if chunks == 0 and fallback_iterator_factory is not None:
yield from emit(
_with_request_streaming_tuning(model, fallback_iterator_factory()),
source_sr_for_attempt=source_sr,
target_sr_for_attempt=target_sr,
t0_for_attempt=t0,
)
if chunks == 0:
raise RuntimeError("CosyVoice returned no audio")
print(
f"synth_stream chars={len(req.text.strip())} sr={target_sr} chunks={chunks} audio_seconds={samples / target_sr:.3f} wall_seconds={time.perf_counter() - t0:.3f}",
f"synth_stream chars={len(req.text.strip())} sr={output_sr} chunks={chunks} audio_seconds={samples / output_sr:.3f} wall_seconds={time.perf_counter() - t0:.3f}",
flush=True,
)
@@ -683,17 +984,47 @@ def create_app(service: CosyVoiceService) -> FastAPI:
@app.post("/synthesize")
def synthesize(req: SynthesizeRequest) -> StreamingResponse:
try:
def open_stream() -> tuple[Iterator[bytes], bytes, int]:
stream, sr = service.synthesize_pcm_stream(req)
iterator = iter(stream)
first = next(iterator)
return iterator, first, sr
try:
iterator, first_chunk, sr = open_stream()
except HTTPException:
raise
except Exception as exc:
raise HTTPException(
status_code=500,
detail=f"cosyvoice synth failed: {type(exc).__name__}: {exc}",
) from exc
if "CosyVoice returned no audio" in str(exc):
reset = getattr(service, "reset_model_after_empty_audio", None)
if callable(reset):
reset(reason=str(exc))
try:
iterator, first_chunk, sr = open_stream()
except HTTPException:
raise
except Exception as retry_exc:
raise HTTPException(
status_code=500,
detail=f"cosyvoice synth failed after model reset: {type(retry_exc).__name__}: {retry_exc}",
) from retry_exc
else:
raise HTTPException(
status_code=500,
detail=f"cosyvoice synth failed: {type(exc).__name__}: {exc}",
) from exc
else:
raise HTTPException(
status_code=500,
detail=f"cosyvoice synth failed: {type(exc).__name__}: {exc}",
) from exc
def response_stream() -> Iterator[bytes]:
yield first_chunk
yield from iterator
return StreamingResponse(
stream,
response_stream(),
media_type=f"audio/L16; rate={sr}; channels=1",
headers={"X-Audio-Sample-Rate": str(sr)},
)
@@ -705,6 +1036,55 @@ def _local_audio_root() -> Path:
return Path(os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "./models/local-audio")).expanduser()
def _default_system_voice_prompt(root: Path) -> tuple[str, str] | None:
repo_root = Path(__file__).resolve().parents[1]
voice_roots = [
root / "voices" / "system",
repo_root / "opentalking" / "assets" / "voices" / "system",
]
seen: set[Path] = set()
for voice_root in voice_roots:
try:
resolved = voice_root.resolve()
except OSError:
resolved = voice_root
if resolved in seen or not voice_root.is_dir():
continue
seen.add(resolved)
for voice_dir in sorted(path for path in voice_root.iterdir() if path.is_dir()):
prompt_audio = voice_dir / "prompt.wav"
prompt_text = voice_dir / "prompt.txt"
if not prompt_audio.is_file() or not prompt_text.is_file():
continue
try:
text = prompt_text.read_text(encoding="utf-8").strip()
except OSError:
text = ""
if text:
print(f"using default CosyVoice system voice prompt: {voice_dir.name}", flush=True)
return str(prompt_audio), text
return None
def _torch_cuda_supports_device(device: str) -> tuple[bool, str]:
if not device.startswith("cuda"):
return True, ""
try:
import torch
if not torch.cuda.is_available():
return False, "torch.cuda.is_available() is false"
index = int(device.split(":", 1)[1]) if ":" in device else 0
major, minor = torch.cuda.get_device_capability(index)
wanted = f"sm_{major}{minor}"
arch_list = set(torch.cuda.get_arch_list() or [])
if arch_list and wanted not in arch_list:
return False, f"device capability {wanted} is not in torch arch list {sorted(arch_list)}"
except Exception as exc:
return False, f"failed to inspect torch CUDA support: {type(exc).__name__}: {exc}"
return True, ""
def _env_bool(name: str, default: bool = False) -> bool:
raw = os.environ.get(name, "").strip().lower()
if not raw:
@@ -733,6 +1113,31 @@ def build_service_from_env() -> CosyVoiceService:
fp16_raw = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_FP16", "auto").strip().lower()
fp16 = device.startswith("cuda") if fp16_raw == "auto" else fp16_raw not in {"0", "false", "no", "off"}
root = _local_audio_root()
load_trt = _env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_LOAD_TRT", False)
cuda_supported, cuda_reason = _torch_cuda_supports_device(device)
if not cuda_supported:
print(
"CosyVoice CUDA runtime is not compatible with this torch build; "
f"falling back to CPU runtime: {cuda_reason}",
flush=True,
)
device = "cpu"
fp16 = False
load_trt = False
os.environ["OPENTALKING_TTS_LOCAL_COSYVOICE_PRELOAD"] = "0"
mode = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_MODE", "zero_shot")
prompt_audio = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_AUDIO", "").strip()
prompt_text = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_TEXT", "").strip()
normalized_mode = mode.strip().lower()
if (
(normalized_mode in {"cross_lingual", "instruct"} and not prompt_audio)
or (normalized_mode not in {"cross_lingual", "instruct"} and (not prompt_audio or not prompt_text))
):
default_prompt = _default_system_voice_prompt(root)
if default_prompt is not None:
default_audio, default_text = default_prompt
prompt_audio = prompt_audio or default_audio
prompt_text = prompt_text or default_text
return CosyVoiceService(
model_dir=os.environ.get(
"OPENTALKING_TTS_LOCAL_COSYVOICE_MODEL_DIR",
@@ -742,26 +1147,29 @@ def build_service_from_env() -> CosyVoiceService:
"OPENTALKING_TTS_LOCAL_COSYVOICE_RUNTIME_DIR",
str(root / "runtime" / "CosyVoice"),
),
audio_root=str(root),
device=device,
prompt_audio=os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_AUDIO", ""),
prompt_text=os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_TEXT", ""),
mode=os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_MODE", "zero_shot"),
prompt_audio=prompt_audio,
prompt_text=prompt_text,
mode=mode,
instruction=os.environ.get(
"OPENTALKING_TTS_LOCAL_COSYVOICE_INSTRUCTION",
"You are a helpful assistant.<|endofprompt|>",
),
fp16=fp16,
load_jit=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_LOAD_JIT", False),
load_trt=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_LOAD_TRT", False),
load_trt=load_trt,
load_vllm=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_LOAD_VLLM", False),
trt_concurrent=int(os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_TRT_CONCURRENT", "1") or "1"),
token_hop_len=_env_optional_int("OPENTALKING_TTS_LOCAL_COSYVOICE_TOKEN_HOP_LEN"),
token_max_hop_len=_env_optional_int("OPENTALKING_TTS_LOCAL_COSYVOICE_TOKEN_MAX_HOP_LEN"),
stream_scale_factor=_env_optional_int("OPENTALKING_TTS_LOCAL_COSYVOICE_STREAM_SCALE_FACTOR"),
flow_n_timesteps=_env_optional_int("OPENTALKING_TTS_LOCAL_COSYVOICE_FLOW_N_TIMESTEPS"),
max_token_text_ratio=_env_optional_float("OPENTALKING_TTS_LOCAL_COSYVOICE_MAX_TOKEN_TEXT_RATIO", 6.0),
max_token_text_ratio=_env_optional_float("OPENTALKING_TTS_LOCAL_COSYVOICE_MAX_TOKEN_TEXT_RATIO"),
min_token_text_ratio=_env_optional_float("OPENTALKING_TTS_LOCAL_COSYVOICE_MIN_TOKEN_TEXT_RATIO"),
mask_stop_tokens=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_MASK_STOP_TOKENS", True),
mask_stop_tokens=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_MASK_STOP_TOKENS", False),
use_zero_shot_spk_id=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_USE_SPK_ID", False),
precache_system_spks=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_PRECACHE_SPKS", False),
)

View File

@@ -12,7 +12,6 @@ from pathlib import Path
from typing import Any
import numpy as np
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
@@ -63,10 +62,28 @@ def _resample_linear(pcm: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
def _to_wav_bytes(pcm: np.ndarray, sr: int) -> bytes:
buf = io.BytesIO()
sf.write(buf, np.asarray(pcm, dtype=np.int16), sr, format="WAV", subtype="PCM_16")
_write_wav_i16(buf, np.asarray(pcm, dtype=np.int16), sr)
return buf.getvalue()
def _write_wav_i16(path_or_file: str | Path | io.BytesIO, data: np.ndarray, sample_rate: int) -> None:
pcm = np.asarray(data)
if pcm.ndim == 2 and pcm.shape[0] == 1:
pcm = pcm[0]
elif pcm.ndim == 2:
pcm = pcm.T.reshape(-1)
if np.issubdtype(pcm.dtype, np.floating):
pcm = np.clip(pcm, -1.0, 1.0)
pcm = np.round(pcm * 32767.0).astype("<i2")
else:
pcm = np.clip(pcm, -32768, 32767).astype("<i2")
with wave.open(path_or_file if hasattr(path_or_file, "write") else str(path_or_file), "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(int(sample_rate))
wf.writeframes(pcm.reshape(-1).tobytes())
class LocalIndexTTSService:
def __init__(
self,
@@ -216,7 +233,7 @@ class LocalIndexTTSService:
data = data[0]
elif data.ndim == 2:
data = data.T
sf.write(path, data, sample_rate, subtype="PCM_16")
_write_wav_i16(path, data, sample_rate)
return None
torchaudio.save = save
@@ -286,10 +303,15 @@ class LocalIndexTTSService:
raise HTTPException(status_code=400, detail=f"prompt_audio does not exist: {prompt_audio}")
target_sr = int(req.sample_rate or 16000)
t0 = time.perf_counter()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
fd, tmp_name = tempfile.mkstemp(suffix=".wav")
os.close(fd)
tmp_path = Path(tmp_name)
try:
with self._lock:
self.model().infer(prompt_audio, text, tmp.name, **self._infer_kwargs(req))
pcm, sr = _audio_to_i16(Path(tmp.name))
self.model().infer(prompt_audio, text, tmp_name, **self._infer_kwargs(req))
pcm, sr = _audio_to_i16(tmp_path)
finally:
tmp_path.unlink(missing_ok=True)
pcm = _resample_linear(pcm, sr, target_sr)
elapsed = time.perf_counter() - t0
print(

View File

@@ -55,23 +55,22 @@ def test_app_clears_stale_subtitle_state_on_context_reset() -> None:
assert model_clear_idx < model_set_idx
def test_avatar_selection_stage_supports_manage_knowledge_bases() -> None:
source = (ROOT / "apps/web/src/components/AvatarSelectionStage.tsx").read_text(encoding="utf-8")
def test_settings_panel_supports_knowledge_base_selection() -> None:
source = (ROOT / "apps/web/src/components/SettingsPanel.tsx").read_text(encoding="utf-8")
assert "knowledgeBaseIds" in source
assert "selected knowledge bases" not in source.lower()
assert "可用知识库" not in source
agent_idx = source.index("Agent 增强")
start_idx = source.index("{queued ?", agent_idx)
agent_block = source[agent_idx:start_idx]
assert "当前形象知识库" in agent_block
assert "{knowledgeBases.length} 个知识库" in agent_block
assert "{selectedKnowledgeBaseIds.length || 0} 个知识库" not in agent_block
assert "可用知识库" not in agent_block
assert "onManageKnowledgeBases" not in agent_block
assert "flex flex-wrap" in agent_block
assert "inline-flex max-w-full" in agent_block
knowledge_idx = source.index('title="知识库"')
model_idx = source.index('title="驱动模型"')
knowledge_block = source[knowledge_idx:model_idx]
assert "{knowledgeBases.length}知识库" in knowledge_block
assert "onManageKnowledgeBases" in knowledge_block
assert "selectedKnowledgeBaseSet.has(knowledgeBase.id)" in knowledge_block
assert "disabled={!knowledgeBaseReady}" in knowledge_block
assert "已选" in knowledge_block
assert "可用知识库" not in knowledge_block
def test_settings_panel_places_knowledge_between_avatar_and_model() -> None: