mirror of
https://github.com/datascale-ai/opentalking.git
synced 2026-07-05 00:25:28 +08:00
feat: offline runtime providers (CosyVoice TRT/spk sidecar, SenseVoice, shared voice assets) (#124)
平台中立的本地运行时增强,Linux/Windows 本地与一键部署包通用: - 共享音色库 voice_assets + 跨 provider 复用 - CosyVoice 本地 sidecar: TRT/autocast_fp16 加速, zero-shot spk 预存, 独立 venv - SenseVoice 本地 ASR: model_dir 绝对路径修复 (避免被 funasr 当 repo_id) - DashScope STT key 兜底, QuickTalk CUDA OOM 回退 CPU, Mem0 不可用回退 SQLite - 前端: 解除人设/驱动模型耦合, 音色展示去技术前缀, mock TTS provider - 5 个内置 system 音色资产 不含 Windows 部署包(bundle/)、行尾归一、平台专属脚本。
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Query, UploadFile
|
||||
from pydantic import BaseModel
|
||||
@@ -162,27 +163,32 @@ async def _add_uploaded_document(
|
||||
filename = file.filename or "document.txt"
|
||||
mime_type = file.content_type or "application/octet-stream"
|
||||
total = 0
|
||||
with tempfile.NamedTemporaryFile(prefix="opentalking-kb-", delete=True) as tmp:
|
||||
while True:
|
||||
chunk = await file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > MAX_DOCUMENT_BYTES:
|
||||
raise HTTPException(status_code=413, detail="document is larger than 20MB")
|
||||
tmp.write(chunk)
|
||||
tmp.flush()
|
||||
tmp_path: Path | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(prefix="opentalking-kb-", delete=False) as tmp:
|
||||
tmp_path = Path(tmp.name)
|
||||
while True:
|
||||
chunk = await file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > MAX_DOCUMENT_BYTES:
|
||||
raise HTTPException(status_code=413, detail="document is larger than 20MB")
|
||||
tmp.write(chunk)
|
||||
try:
|
||||
doc = await store.add_document(
|
||||
kb_id=kb_id,
|
||||
filename=filename,
|
||||
mime_type=mime_type,
|
||||
source_path=tmp.name,
|
||||
source_path=tmp_path,
|
||||
)
|
||||
except DuplicateKnowledgeDocumentError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
finally:
|
||||
if tmp_path is not None:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
return KnowledgeDocumentResponse(**asdict(doc))
|
||||
|
||||
|
||||
@@ -194,26 +200,31 @@ async def _add_uploaded_file(
|
||||
filename = file.filename or "document.txt"
|
||||
mime_type = file.content_type or "application/octet-stream"
|
||||
total = 0
|
||||
with tempfile.NamedTemporaryFile(prefix="opentalking-kb-file-", delete=True) as tmp:
|
||||
while True:
|
||||
chunk = await file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > MAX_DOCUMENT_BYTES:
|
||||
raise HTTPException(status_code=413, detail="document is larger than 20MB")
|
||||
tmp.write(chunk)
|
||||
tmp.flush()
|
||||
tmp_path: Path | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(prefix="opentalking-kb-file-", delete=False) as tmp:
|
||||
tmp_path = Path(tmp.name)
|
||||
while True:
|
||||
chunk = await file.read(1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
total += len(chunk)
|
||||
if total > MAX_DOCUMENT_BYTES:
|
||||
raise HTTPException(status_code=413, detail="document is larger than 20MB")
|
||||
tmp.write(chunk)
|
||||
try:
|
||||
doc = await store.add_file(
|
||||
filename=filename,
|
||||
mime_type=mime_type,
|
||||
source_path=tmp.name,
|
||||
source_path=tmp_path,
|
||||
)
|
||||
except DuplicateKnowledgeDocumentError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
finally:
|
||||
if tmp_path is not None:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
return KnowledgeDocumentResponse(**asdict(doc))
|
||||
|
||||
|
||||
|
||||
@@ -706,12 +706,33 @@ def _prewarm_local_backend(
|
||||
settings=settings,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
return _prewarm_local_adapter(
|
||||
model,
|
||||
avatar_dir,
|
||||
settings,
|
||||
prepared_cache=prepared_cache,
|
||||
)
|
||||
try:
|
||||
return _prewarm_local_adapter(
|
||||
model,
|
||||
avatar_dir,
|
||||
settings,
|
||||
prepared_cache=prepared_cache,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if model == "quicktalk" and "out of memory" in str(exc).lower():
|
||||
cache = {
|
||||
"model": model,
|
||||
"status": "ready",
|
||||
"source_mode": "local",
|
||||
"frames": prepared_cache.frames if prepared_cache is not None else None,
|
||||
"detail": "local adapter prepared QuickTalk assets but warmup ran out of memory",
|
||||
"prepared_status": prepared_cache.status if prepared_cache is not None else None,
|
||||
}
|
||||
runtime = {
|
||||
"type": "local_prewarm_result",
|
||||
"backend": "local",
|
||||
"model": model,
|
||||
"warmed": False,
|
||||
"elapsed_ms": 0.0,
|
||||
"message": str(exc),
|
||||
}
|
||||
return cache, runtime
|
||||
raise
|
||||
|
||||
|
||||
def _quicktalk_runtime_payload(
|
||||
@@ -1063,11 +1084,12 @@ async def prewarm_avatar(avatar_id: str, request: Request) -> dict[str, Any]:
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise HTTPException(status_code=500, detail=f"failed to prewarm local {model}: {exc}") from exc
|
||||
runtime_status = "failed" if not bool(runtime.get("warmed", True)) else "ready"
|
||||
return {
|
||||
"avatar_id": avatar_id,
|
||||
"model": model,
|
||||
"status": "ready",
|
||||
"runtime_status": "ready",
|
||||
"runtime_status": runtime_status,
|
||||
"cache": cache_response,
|
||||
"runtime": runtime,
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ def _runtime_status_payload(request: Request) -> dict[str, Any]:
|
||||
tts_provider_list = [tts_provider, *tts_provider_list]
|
||||
tts_status_providers = [*tts_provider_list]
|
||||
for provider in (
|
||||
"mock",
|
||||
"local_cosyvoice",
|
||||
"indextts",
|
||||
"dashscope",
|
||||
@@ -63,9 +64,11 @@ def _runtime_status_payload(request: Request) -> dict[str, Any]:
|
||||
tts_status_providers.append(provider)
|
||||
tts_provider_map = {provider: tts_provider_config(provider) for provider in tts_status_providers}
|
||||
tts_effective = tts_provider_map.get(tts_provider, tts)
|
||||
llm_key = os.environ.get("OPENTALKING_LLM_API_KEY", "").strip() or str(
|
||||
getattr(settings, "llm_api_key", "") or ""
|
||||
).strip()
|
||||
llm_key = (
|
||||
os.environ.get("OPENTALKING_LLM_API_KEY", "").strip()
|
||||
or os.environ.get("DASHSCOPE_API_KEY", "").strip()
|
||||
or str(getattr(settings, "llm_api_key", "") or "").strip()
|
||||
)
|
||||
ignored_legacy_env = [name for name in _IGNORED_LEGACY_ENV if os.environ.get(name)]
|
||||
quicktalk_backend = os.environ.get("OPENTALKING_QUICKTALK_BACKEND", "").strip() or str(
|
||||
getattr(settings, "quicktalk_backend", "") or ""
|
||||
|
||||
@@ -15,6 +15,7 @@ from opentalking.persona.wechat_import import WeChatImportJobRegistry
|
||||
from opentalking.providers.memory.decision_agent import MemoryDecisionAgent
|
||||
from opentalking.providers.memory.factory import build_memory_provider
|
||||
from opentalking.providers.memory.runtime import MemoryRuntime, normalize_memory_scope
|
||||
from opentalking.providers.memory.sqlite_provider import SQLiteMemoryProvider
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||
|
||||
@@ -38,16 +39,6 @@ class MemoryImportRequest(BaseModel):
|
||||
source: str | None = None
|
||||
|
||||
|
||||
class WeChatSpeakerSelectionRequest(BaseModel):
|
||||
target_speaker_id: str
|
||||
|
||||
|
||||
class WeChatImportCommitRequest(BaseModel):
|
||||
persona_id: str
|
||||
persona_name: str | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
def _profile(value: str | None) -> str:
|
||||
return (value or get_settings().memory_default_profile_id or "default").strip() or "default"
|
||||
|
||||
@@ -76,6 +67,28 @@ def _wechat_registry(request: Request) -> WeChatImportJobRegistry:
|
||||
return registry
|
||||
|
||||
|
||||
def _fallback_memory_provider() -> SQLiteMemoryProvider:
|
||||
settings = get_settings()
|
||||
return SQLiteMemoryProvider(settings.memory_sqlite_path)
|
||||
|
||||
|
||||
async def _memory_provider():
|
||||
try:
|
||||
return build_memory_provider()
|
||||
except Exception: # noqa: BLE001
|
||||
return _fallback_memory_provider()
|
||||
|
||||
|
||||
class WeChatSpeakerSelectionRequest(BaseModel):
|
||||
target_speaker_id: str
|
||||
|
||||
|
||||
class WeChatImportCommitRequest(BaseModel):
|
||||
persona_id: str
|
||||
persona_name: str | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
@router.post("/wechat-import")
|
||||
async def create_wechat_import_job(
|
||||
request: Request,
|
||||
@@ -172,7 +185,7 @@ async def list_libraries(
|
||||
profile_id: str = Query("default"),
|
||||
character_id: str = Query(...),
|
||||
) -> dict[str, list[dict[str, object]]]:
|
||||
provider = build_memory_provider()
|
||||
provider = await _memory_provider()
|
||||
items = await provider.list_libraries(
|
||||
profile_id=_profile(profile_id),
|
||||
character_id=_ensure_character(character_id),
|
||||
@@ -182,7 +195,7 @@ async def list_libraries(
|
||||
|
||||
@router.post("/libraries")
|
||||
async def create_library(body: MemoryLibraryRequest) -> dict[str, object]:
|
||||
provider = build_memory_provider()
|
||||
provider = await _memory_provider()
|
||||
library = await provider.create_library(
|
||||
library_id=_library_id(body.id),
|
||||
name=(body.name or "").strip() or None,
|
||||
@@ -198,7 +211,7 @@ async def list_items(
|
||||
profile_id: str = Query("default"),
|
||||
character_id: str = Query(...),
|
||||
) -> dict[str, list[dict[str, object]]]:
|
||||
provider = build_memory_provider()
|
||||
provider = await _memory_provider()
|
||||
items = await provider.list_items(
|
||||
library_id=library_id,
|
||||
profile_id=_profile(profile_id),
|
||||
@@ -214,7 +227,7 @@ async def delete_item(
|
||||
profile_id: str = Query("default"),
|
||||
character_id: str = Query(...),
|
||||
) -> dict[str, bool]:
|
||||
provider = build_memory_provider()
|
||||
provider = await _memory_provider()
|
||||
deleted = await provider.delete_item(
|
||||
library_id=library_id,
|
||||
item_id=item_id,
|
||||
@@ -237,7 +250,7 @@ async def import_items(library_id: str, body: MemoryImportRequest) -> dict[str,
|
||||
)
|
||||
runtime = MemoryRuntime(
|
||||
scope=scope,
|
||||
provider=build_memory_provider(),
|
||||
provider=await _memory_provider(),
|
||||
decision_agent=MemoryDecisionAgent(),
|
||||
)
|
||||
imported = await runtime.import_turns(
|
||||
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -20,10 +21,11 @@ from opentalking.providers.stt.factory import (
|
||||
)
|
||||
from opentalking.providers.tts.factory import tts_enabled_providers, tts_provider_config
|
||||
from opentalking.providers.tts.providers import normalize_tts_provider
|
||||
from opentalking.providers.tts.voice_assets import iter_voice_assets, resolve_voice_asset
|
||||
|
||||
router = APIRouter(prefix="/runtime-config", tags=["runtime-config"])
|
||||
|
||||
_ENV_PATH = Path(__file__).resolve().parents[3] / ".env"
|
||||
_ENV_PATH = Path(os.environ.get("OPENTALKING_ENV_FILE") or Path(__file__).resolve().parents[3] / ".env")
|
||||
_ENV_REF_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}|\$([A-Za-z_][A-Za-z0-9_]*)")
|
||||
|
||||
_RUNTIME_ENV_KEYS = {
|
||||
@@ -127,6 +129,20 @@ def _strip(value: str | None) -> str:
|
||||
return (value or "").strip()
|
||||
|
||||
|
||||
def _local_cosyvoice_default_voice() -> str:
|
||||
for asset in iter_voice_assets(provider="local_cosyvoice", sources=("system", "clones")):
|
||||
if asset.voice_id:
|
||||
return asset.voice_id
|
||||
return "local-default"
|
||||
|
||||
|
||||
def _normalize_local_cosyvoice_voice(value: str | None) -> str:
|
||||
voice = _strip(value)
|
||||
if voice and resolve_voice_asset(voice, provider="local_cosyvoice") is not None:
|
||||
return voice
|
||||
return _local_cosyvoice_default_voice()
|
||||
|
||||
|
||||
def _unquote_env_value(value: str) -> str:
|
||||
if len(value) >= 2 and value[0] == value[-1] and value[0] in {'"', "'"}:
|
||||
return value[1:-1]
|
||||
@@ -189,6 +205,81 @@ def _enabled_provider_csv(current: list[str], provider: str) -> str:
|
||||
return ",".join(providers)
|
||||
|
||||
|
||||
def _path_exists(raw: str) -> bool:
|
||||
if not raw:
|
||||
return False
|
||||
try:
|
||||
return Path(raw).expanduser().exists()
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _ensure_sensevoice_available(values: dict[str, str], settings: Any) -> None:
|
||||
status = stt_provider_config("sensevoice")
|
||||
model = _env_value(
|
||||
values,
|
||||
"OPENTALKING_STT_SENSEVOICE_MODEL",
|
||||
_settings_value(settings, "stt_sensevoice_model", "iic/SenseVoiceSmall"),
|
||||
)
|
||||
model_dir = str(status.get("model_dir") or "").strip()
|
||||
candidates = [model_dir]
|
||||
root = os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "").strip() or _settings_value(
|
||||
settings,
|
||||
"local_audio_model_root",
|
||||
"",
|
||||
)
|
||||
if root and model:
|
||||
candidates.append(str(Path(root).expanduser() / model.replace("/", "__")))
|
||||
if any(_path_exists(candidate) for candidate in candidates):
|
||||
return
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="本地 ASR SenseVoice 模型未就绪,请在启动时启用本地 ASR 或确认模型已下载。",
|
||||
)
|
||||
|
||||
|
||||
def _local_cosyvoice_health_url(service_url: str) -> str:
|
||||
value = service_url.strip()
|
||||
if not value:
|
||||
return ""
|
||||
if value.rstrip("/").endswith("/synthesize"):
|
||||
return value.rstrip("/")[: -len("/synthesize")] + "/health"
|
||||
return value.rstrip("/") + "/health"
|
||||
|
||||
|
||||
async def _ensure_local_cosyvoice_available(values: dict[str, str], settings: Any) -> None:
|
||||
service_url = _env_value(
|
||||
values,
|
||||
"OPENTALKING_TTS_LOCAL_COSYVOICE_SERVICE_URL",
|
||||
_settings_value(settings, "tts_local_cosyvoice_service_url"),
|
||||
)
|
||||
health_url = _local_cosyvoice_health_url(service_url)
|
||||
if not health_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="本地 TTS local_cosyvoice 未启动或未配置服务地址,请在启动时启用本地 TTS,或改用 API/Edge TTS。",
|
||||
)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(1.0, connect=0.5)) as client:
|
||||
response = await client.get(health_url)
|
||||
response.raise_for_status()
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="本地 TTS local_cosyvoice 未启动或不可用,请在启动时启用本地 TTS,或改用 API/Edge TTS。",
|
||||
) from exc
|
||||
|
||||
|
||||
async def _validate_local_provider_switches(updates: dict[str, str], request: Request) -> None:
|
||||
settings = getattr(request.app.state, "settings", None) or get_settings()
|
||||
_, current_values = _read_env_lines(_ENV_PATH)
|
||||
values = {**current_values, **updates}
|
||||
if updates.get("OPENTALKING_STT_DEFAULT_PROVIDER") == "sensevoice":
|
||||
_ensure_sensevoice_available(values, settings)
|
||||
if updates.get("OPENTALKING_TTS_DEFAULT_PROVIDER") == "local_cosyvoice":
|
||||
await _ensure_local_cosyvoice_available(values, settings)
|
||||
|
||||
|
||||
def _write_env_updates(path: Path, updates: dict[str, str]) -> None:
|
||||
lines, _ = _read_env_lines(path)
|
||||
if path.exists():
|
||||
@@ -301,7 +392,7 @@ def _current_tts_payload(provider: str, settings: Any, values: dict[str, str]) -
|
||||
elif provider == "local_cosyvoice":
|
||||
base_url = _env_value(values, "OPENTALKING_TTS_LOCAL_COSYVOICE_SERVICE_URL", _settings_value(settings, "tts_local_cosyvoice_service_url"))
|
||||
model = _env_value(values, "OPENTALKING_TTS_LOCAL_COSYVOICE_MODEL", _settings_value(settings, "tts_local_cosyvoice_model", "FunAudioLLM/Fun-CosyVoice3-0.5B-2512"))
|
||||
voice = _env_value(values, "OPENTALKING_TTS_VOICE", _settings_value(settings, "tts_voice"))
|
||||
voice = _normalize_local_cosyvoice_voice(_env_value(values, "OPENTALKING_TTS_VOICE", _settings_value(settings, "tts_voice")))
|
||||
key = ""
|
||||
elif provider == "indextts":
|
||||
base_url = (
|
||||
@@ -502,7 +593,9 @@ def _build_updates(payload: RuntimeConfigPayload) -> dict[str, str]:
|
||||
}.get(tts_provider)
|
||||
if key:
|
||||
updates[key] = value
|
||||
if value := _strip(payload.tts_voice):
|
||||
if tts_provider == "local_cosyvoice":
|
||||
updates["OPENTALKING_TTS_VOICE"] = _normalize_local_cosyvoice_voice(payload.tts_voice)
|
||||
elif value := _strip(payload.tts_voice):
|
||||
updates["OPENTALKING_TTS_VOICE"] = value
|
||||
if tts_provider == "edge":
|
||||
updates["OPENTALKING_TTS_EDGE_VOICE"] = value
|
||||
@@ -592,6 +685,7 @@ async def apply_runtime_config(payload: RuntimeConfigPayload, request: Request)
|
||||
unknown = set(updates) - _RUNTIME_ENV_KEYS
|
||||
if unknown:
|
||||
raise HTTPException(status_code=400, detail=f"unsupported runtime config keys: {', '.join(sorted(unknown))}")
|
||||
await _validate_local_provider_switches(updates, request)
|
||||
_write_env_updates(_ENV_PATH, updates)
|
||||
_, values = _read_env_lines(_ENV_PATH)
|
||||
for key in _RUNTIME_ENV_KEYS:
|
||||
|
||||
@@ -87,6 +87,11 @@ def _normalize_preview_request(
|
||||
return text, voice, provider, model, indextts_config
|
||||
|
||||
|
||||
def _is_local_cosyvoice_input_error(exc: Exception) -> bool:
|
||||
message = str(exc)
|
||||
return isinstance(exc, ValueError) or "HTTP 400" in message or "请求无效" in message or "请先选择本地音色" in message
|
||||
|
||||
|
||||
def _config_from_form_value(raw: object) -> dict[str, Any] | None:
|
||||
text = str(raw or "").strip()
|
||||
if not text:
|
||||
@@ -234,6 +239,29 @@ async def preview_tts(request: Request) -> Response:
|
||||
if sample_limit is not None and total_samples >= sample_limit:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"tts preview failed | provider=%s voice_id=%s model=%s error=%s",
|
||||
provider or "default",
|
||||
voice or "default",
|
||||
model or "default",
|
||||
exc,
|
||||
)
|
||||
if provider == "local_cosyvoice":
|
||||
if _is_local_cosyvoice_input_error(exc):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
error_text = str(exc)
|
||||
if "CosyVoice returned no audio" in error_text or "本地 CosyVoice 返回空音频" in error_text:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"TTS preview failed: {error_text}",
|
||||
) from exc
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"TTS preview failed: 本地 CosyVoice 服务不可用(可能已退出/内存不足);前端可回退 Edge 试听。{exc}",
|
||||
) from exc
|
||||
raise HTTPException(status_code=502, detail=f"TTS preview failed: {exc}") from exc
|
||||
finally:
|
||||
close = getattr(tts, "aclose", None)
|
||||
|
||||
@@ -7,7 +7,6 @@ import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import uuid
|
||||
@@ -21,6 +20,15 @@ from fastapi import APIRouter, BackgroundTasks, File, Form, HTTPException, Reque
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
|
||||
from opentalking.providers.tts.dashscope_qwen import clone as bailian_clone
|
||||
from opentalking.providers.tts.voice_assets import (
|
||||
INDEXTTS_PROVIDER,
|
||||
INDEXTTS_PROVIDERS,
|
||||
LOCAL_COSYVOICE_PROVIDER,
|
||||
bundled_system_voice_root,
|
||||
iter_voice_assets,
|
||||
local_audio_model_root,
|
||||
system_voice_roots,
|
||||
)
|
||||
from opentalking.voice.store import delete_entry, get_entry, init_voice_store, insert_clone, list_voices
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -31,13 +39,16 @@ _UPLOAD_DIR = Path("data/voice_uploads")
|
||||
LOCAL_COSYVOICE_SAMPLE_TEXT = "你好,今天阳光很好,我正在用自然清晰的声音,记录这一段音色。"
|
||||
LOCAL_COSYVOICE_MIN_ACTIVE_SEC = 2.0
|
||||
LOCAL_COSYVOICE_MIN_RMS_DBFS = -45.0
|
||||
_TECHNICAL_VOICE_LABEL_PREFIX_RE = re.compile(
|
||||
r"^\s*(?:IndexTTS|CosyVoice|local_cosyvoice|local_indextts|local|本地)\s*[-_/::·|]?\s*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_TECHNICAL_VOICE_LABEL_SUFFIX_RE = re.compile(
|
||||
r"\s*[((]\s*(?:IndexTTS|CosyVoice|local_cosyvoice|local_indextts|local|本地)\s*[))]\s*$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
LOCAL_COSYVOICE_MIN_RECOGNIZED_CHARS = 4
|
||||
LOCAL_COSYVOICE_MIN_TEXT_OVERLAP = 0.45
|
||||
INDEXTTS_PROVIDER = "indextts"
|
||||
INDEXTTS_LEGACY_PROVIDERS = {"local_indextts", "omnirt_indextts"}
|
||||
INDEXTTS_PROVIDERS = {INDEXTTS_PROVIDER, *INDEXTTS_LEGACY_PROVIDERS}
|
||||
|
||||
|
||||
class VoiceItem(TypedDict):
|
||||
id: int
|
||||
user_id: int
|
||||
@@ -61,16 +72,7 @@ def _upload_dir() -> Path:
|
||||
|
||||
|
||||
def _local_audio_model_root() -> Path:
|
||||
raw = os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "").strip()
|
||||
try:
|
||||
from opentalking.core.config import get_settings
|
||||
|
||||
raw = raw or (get_settings().local_audio_model_root or "").strip()
|
||||
except Exception:
|
||||
pass
|
||||
if not raw:
|
||||
raw = "./models/local-audio"
|
||||
return Path(raw).expanduser().resolve()
|
||||
return local_audio_model_root()
|
||||
|
||||
|
||||
def _safe_local_voice_id(label: str) -> str:
|
||||
@@ -259,47 +261,38 @@ def _remove_local_cosyvoice_prompt(voice_id: str) -> None:
|
||||
|
||||
|
||||
def _bundled_system_voice_root() -> Path:
|
||||
return Path(__file__).resolve().parents[3] / "opentalking" / "assets" / "voices" / "system"
|
||||
return bundled_system_voice_root()
|
||||
|
||||
|
||||
def _system_voice_roots() -> list[Path]:
|
||||
roots = [_local_audio_model_root() / "voices" / "system", _bundled_system_voice_root()]
|
||||
out: list[Path] = []
|
||||
seen: set[Path] = set()
|
||||
for root in roots:
|
||||
try:
|
||||
resolved = root.resolve()
|
||||
except OSError:
|
||||
resolved = root
|
||||
if resolved in seen:
|
||||
continue
|
||||
seen.add(resolved)
|
||||
out.append(root)
|
||||
return out
|
||||
return system_voice_roots(_local_audio_model_root())
|
||||
|
||||
|
||||
def _public_voice_label(label: str, *, fallback: str) -> str:
|
||||
cleaned = _TECHNICAL_VOICE_LABEL_PREFIX_RE.sub("", label or "").strip()
|
||||
cleaned = _TECHNICAL_VOICE_LABEL_SUFFIX_RE.sub("", cleaned).strip()
|
||||
return cleaned or fallback
|
||||
|
||||
|
||||
def _local_cosyvoice_system_voice_items() -> list[VoiceItem]:
|
||||
root = _local_audio_model_root() / "voices" / "system"
|
||||
if not root.is_dir():
|
||||
return []
|
||||
items: list[VoiceItem] = []
|
||||
for idx, voice_dir in enumerate(sorted(p for p in root.iterdir() if p.is_dir()), start=1):
|
||||
voice_id = voice_dir.name
|
||||
for idx, asset in enumerate(
|
||||
iter_voice_assets(
|
||||
provider=LOCAL_COSYVOICE_PROVIDER,
|
||||
sources=("system",),
|
||||
model_root=_local_audio_model_root(),
|
||||
require_prompt_text=True,
|
||||
),
|
||||
start=1,
|
||||
):
|
||||
voice_id = asset.voice_id
|
||||
if not re.fullmatch(r"[A-Za-z0-9_-]{3,80}", voice_id):
|
||||
continue
|
||||
if not (voice_dir / "prompt.wav").is_file() or not (voice_dir / "prompt.txt").is_file():
|
||||
continue
|
||||
label = voice_id
|
||||
target_model: str | None = None
|
||||
meta_path = voice_dir / "meta.json"
|
||||
if meta_path.is_file():
|
||||
try:
|
||||
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
label = str(meta.get("display_label") or meta.get("label") or label)
|
||||
tm = str(meta.get("target_model") or "").strip()
|
||||
target_model = tm or None
|
||||
except Exception:
|
||||
pass
|
||||
meta = asset.meta
|
||||
label = _public_voice_label(str(meta.get("display_label") or meta.get("label") or label), fallback=voice_id)
|
||||
tm = str(meta.get("target_model") or "").strip()
|
||||
target_model = tm or None
|
||||
items.append(
|
||||
{
|
||||
"id": -idx,
|
||||
@@ -314,48 +307,26 @@ def _local_cosyvoice_system_voice_items() -> list[VoiceItem]:
|
||||
return items
|
||||
|
||||
|
||||
def _local_indextts_voice_items(provider: str, source: str) -> list[VoiceItem]:
|
||||
if provider not in INDEXTTS_PROVIDERS or source not in {"system", "clones"}:
|
||||
def _local_indextts_voice_items(source: str) -> list[VoiceItem]:
|
||||
if source not in {"system", "clones"}:
|
||||
return []
|
||||
roots = [_local_audio_model_root() / "voices" / source]
|
||||
if source == "system":
|
||||
roots = _system_voice_roots()
|
||||
items: list[VoiceItem] = []
|
||||
idx = 0
|
||||
for root in roots:
|
||||
if not root.is_dir():
|
||||
continue
|
||||
for voice_dir in sorted(p for p in root.iterdir() if p.is_dir()):
|
||||
idx += 1
|
||||
voice_id = voice_dir.name
|
||||
if not re.fullmatch(r"[A-Za-z0-9_-]{3,80}", voice_id):
|
||||
seen: set[str] = set()
|
||||
for provider in sorted(INDEXTTS_PROVIDERS):
|
||||
for asset in iter_voice_assets(provider=provider, sources=(source,), model_root=_local_audio_model_root()):
|
||||
voice_id = asset.voice_id
|
||||
if voice_id in seen or not re.fullmatch(r"[A-Za-z0-9_-]{3,80}", voice_id):
|
||||
continue
|
||||
if not (voice_dir / "prompt.wav").is_file():
|
||||
continue
|
||||
label = voice_id
|
||||
target_model: str | None = None
|
||||
meta_path = voice_dir / "meta.json"
|
||||
if meta_path.is_file():
|
||||
try:
|
||||
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
raw_providers = meta.get("providers")
|
||||
requested = {provider, *INDEXTTS_LEGACY_PROVIDERS} if provider == INDEXTTS_PROVIDER else {provider}
|
||||
if isinstance(raw_providers, list):
|
||||
allowed = {str(item).strip().lower() for item in raw_providers}
|
||||
if allowed and not (allowed & requested):
|
||||
continue
|
||||
elif str(meta.get("provider") or "").strip().lower() not in {"", *INDEXTTS_PROVIDERS}:
|
||||
continue
|
||||
label = str(meta.get("display_label") or meta.get("label") or label)
|
||||
tm = str(meta.get("target_model") or "").strip()
|
||||
target_model = tm or None
|
||||
except Exception:
|
||||
pass
|
||||
seen.add(voice_id)
|
||||
meta = asset.meta
|
||||
label = _public_voice_label(str(meta.get("display_label") or meta.get("label") or voice_id), fallback=voice_id)
|
||||
tm = str(meta.get("target_model") or "").strip()
|
||||
target_model = tm or None
|
||||
items.append(
|
||||
{
|
||||
"id": -idx,
|
||||
"id": -len(items) - 1,
|
||||
"user_id": 1,
|
||||
"provider": provider,
|
||||
"provider": INDEXTTS_PROVIDER,
|
||||
"voice_id": voice_id,
|
||||
"display_label": label,
|
||||
"target_model": target_model,
|
||||
@@ -430,7 +401,7 @@ async def get_voices(provider: str | None = None) -> JSONResponse:
|
||||
existing.add(key)
|
||||
if public_p is None or public_p == INDEXTTS_PROVIDER:
|
||||
for source in ("system", "clones"):
|
||||
for item in _local_indextts_voice_items(INDEXTTS_PROVIDER, source):
|
||||
for item in _local_indextts_voice_items(source):
|
||||
key = (item["provider"], item["voice_id"])
|
||||
if key not in existing:
|
||||
items.append(item)
|
||||
|
||||
@@ -205,17 +205,16 @@ def test_get_voices_includes_local_cosyvoice_system_voice_dirs(tmp_path, monkeyp
|
||||
response = TestClient(app).get("/voices?provider=local_cosyvoice")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["items"] == [
|
||||
{
|
||||
"id": -1,
|
||||
"user_id": 1,
|
||||
"provider": "local_cosyvoice",
|
||||
"voice_id": "local-female-standard",
|
||||
"display_label": "标准女声",
|
||||
"target_model": None,
|
||||
"source": "system",
|
||||
}
|
||||
]
|
||||
item = next(item for item in response.json()["items"] if item["voice_id"] == "local-female-standard")
|
||||
assert item == {
|
||||
"id": -1,
|
||||
"user_id": 1,
|
||||
"provider": "local_cosyvoice",
|
||||
"voice_id": "local-female-standard",
|
||||
"display_label": "标准女声",
|
||||
"target_model": None,
|
||||
"source": "system",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["indextts", "local_indextts", "omnirt_indextts"])
|
||||
@@ -242,7 +241,7 @@ def test_get_voices_includes_indextts_system_voice_dirs(provider: str, tmp_path,
|
||||
"user_id": 1,
|
||||
"provider": "indextts",
|
||||
"voice_id": "indextts-clear-cn",
|
||||
"display_label": "IndexTTS 清晰中文",
|
||||
"display_label": "清晰中文",
|
||||
"target_model": "IndexTeam/IndexTTS-2",
|
||||
"source": "system",
|
||||
} in response.json()["items"]
|
||||
@@ -426,12 +425,13 @@ def test_get_voices_includes_indextts_clone_voice_dirs(provider: str, tmp_path,
|
||||
response = TestClient(app).get(f"/voices?provider={provider}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert {
|
||||
"id": -1,
|
||||
item = next(item for item in response.json()["items"] if item["voice_id"] == "indextts-cloned-cn")
|
||||
assert item == {
|
||||
"id": item["id"],
|
||||
"user_id": 1,
|
||||
"provider": "indextts",
|
||||
"voice_id": "indextts-cloned-cn",
|
||||
"display_label": "IndexTTS 复刻中文",
|
||||
"display_label": "复刻中文",
|
||||
"target_model": "IndexTeam/IndexTTS-2",
|
||||
"source": "clone",
|
||||
} in response.json()["items"]
|
||||
}
|
||||
|
||||
@@ -166,13 +166,13 @@ function mergeVoiceCatalogIntoOptions(
|
||||
const extras: VoiceOpt[] = [];
|
||||
for (const r of catalog) {
|
||||
if (r.provider !== cp) continue;
|
||||
if (activeModel && r.target_model && r.target_model !== activeModel) continue;
|
||||
if (activeModel && r.target_model && r.target_model !== activeModel && !(ttsProvider === "local_cosyvoice" && r.source === "system")) continue;
|
||||
if (cloneOnly && r.source !== "clone") continue;
|
||||
if (staticIds.has(r.voice_id)) continue;
|
||||
extras.push({
|
||||
id: r.voice_id,
|
||||
label: r.source === "clone" ? `复刻 · ${r.display_label}` : r.display_label,
|
||||
targetModel: r.target_model,
|
||||
targetModel: ttsProvider === "local_cosyvoice" && r.source === "system" ? undefined : r.target_model,
|
||||
});
|
||||
staticIds.add(r.voice_id);
|
||||
}
|
||||
@@ -983,7 +983,6 @@ export default function App() {
|
||||
return false;
|
||||
});
|
||||
const [voiceCloneOpen, setVoiceCloneOpen] = useState(false);
|
||||
const [promptSaving, setPromptSaving] = useState(false);
|
||||
const [referenceSaving, setReferenceSaving] = useState(false);
|
||||
const [panelTab, setPanelTab] = useState<PanelTab>("chat");
|
||||
const [sessionPanelCollapsed, setSessionPanelCollapsed] = useState(() => {
|
||||
@@ -1621,7 +1620,7 @@ export default function App() {
|
||||
}
|
||||
}, [sessionPanelCollapsed]);
|
||||
|
||||
const [llmSystemPrompt, setLlmSystemPrompt] = useState<string>(() => {
|
||||
const [llmSystemPrompt] = useState<string>(() => {
|
||||
try {
|
||||
return window.localStorage.getItem(LLM_SYSTEM_PROMPT_STORAGE_KEY) ?? "";
|
||||
} catch {
|
||||
@@ -2081,9 +2080,19 @@ export default function App() {
|
||||
clearSubtitleState();
|
||||
if (msgId) {
|
||||
if (finalText) {
|
||||
setMessages((prev) =>
|
||||
prev.map((m) => (m.id === msgId ? { ...m, text: finalText } : m)),
|
||||
);
|
||||
setMessages((prev) => {
|
||||
let updated = false;
|
||||
const next = prev.map((m) => {
|
||||
if (m.id !== msgId) return m;
|
||||
updated = true;
|
||||
return { ...m, text: finalText };
|
||||
});
|
||||
if (updated) return next;
|
||||
return [
|
||||
...prev,
|
||||
{ id: makeId(), role: "assistant", text: finalText, timestamp: Date.now() },
|
||||
];
|
||||
});
|
||||
} else {
|
||||
setMessages((prev) => prev.filter((m) => m.id !== msgId));
|
||||
}
|
||||
@@ -2311,27 +2320,6 @@ export default function App() {
|
||||
}
|
||||
}, [connection, fasterliveportraitConfig, model, notify]);
|
||||
|
||||
const handleSavePrompt = useCallback(async () => {
|
||||
setPromptSaving(true);
|
||||
try {
|
||||
await apiPost("/sessions/customize/prompt", {
|
||||
avatar_id: avatarId,
|
||||
llm_system_prompt: llmSystemPrompt,
|
||||
});
|
||||
const sid = sessionIdRef.current;
|
||||
if (sid) await releaseSession(sid);
|
||||
resetLiveState(true);
|
||||
setConnection("idle");
|
||||
notify("System Prompt 已保存,页面即将刷新并在新会话生效。", "success");
|
||||
window.setTimeout(() => window.location.reload(), 900);
|
||||
} catch (e) {
|
||||
console.warn("save prompt failed", e);
|
||||
notify("保存 Prompt 失败,请查看后端日志。", "error");
|
||||
} finally {
|
||||
setPromptSaving(false);
|
||||
}
|
||||
}, [avatarId, llmSystemPrompt, notify, releaseSession, resetLiveState]);
|
||||
|
||||
const handleCreateCustomAvatar = useCallback(async (file: File, name: string) => {
|
||||
const trimmedName = name.trim();
|
||||
if (!trimmedName) {
|
||||
@@ -2442,7 +2430,8 @@ export default function App() {
|
||||
notify("正在播放试听音频。", "success");
|
||||
} catch (e) {
|
||||
console.warn("tts preview failed", e);
|
||||
notify("试听失败,请确认音色、模型和后端密钥配置。", "error");
|
||||
const detail = apiErrorMessage(e, "请确认音色、模型和后端密钥配置。");
|
||||
notify(`试听失败:${detail}`, "error");
|
||||
} finally {
|
||||
setTtsPreviewing(false);
|
||||
}
|
||||
@@ -3012,10 +3001,6 @@ export default function App() {
|
||||
onMemoryLibrarySelect={setMemoryLibraryId}
|
||||
onMemoryEnabledChange={setMemoryEnabled}
|
||||
onManageMemoryLibraries={() => void handleManageMemoryLibraries()}
|
||||
llmSystemPrompt={llmSystemPrompt}
|
||||
onLlmSystemPromptChange={setLlmSystemPrompt}
|
||||
onSavePrompt={() => void handleSavePrompt()}
|
||||
promptSaving={promptSaving}
|
||||
onOpenVoiceClone={() => setVoiceCloneOpen(true)}
|
||||
/>
|
||||
</div>
|
||||
@@ -3169,15 +3154,11 @@ export default function App() {
|
||||
modelBadge={selectedModelBadge}
|
||||
queueInfo={queueInfo}
|
||||
prewarmState={selectedPrewarmState}
|
||||
agentConfig={agentConfig}
|
||||
onAgentConfigChange={setAgentConfig}
|
||||
knowledgeBases={knowledgeBaseSummaries}
|
||||
personas={personas}
|
||||
selectedPersonaId={selectedPersonaId}
|
||||
personaImporting={personaImporting}
|
||||
onPersonaChange={handlePersonaChange}
|
||||
onPersonaImport={handlePersonaImport}
|
||||
memorySummary={memorySummary}
|
||||
onAvatarChange={handleAvatarChange}
|
||||
onStart={() => void handleStart()}
|
||||
onCustomAvatarCreate={(file, name) => void handleCreateCustomAvatar(file, name)}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useEffect, useRef, useState, type ChangeEvent } from "react";
|
||||
import type { AvatarSummary, KnowledgeBaseSummary, PersonaSummary } from "../lib/api";
|
||||
import type { AvatarSummary, PersonaSummary } from "../lib/api";
|
||||
import { buildApiUrl } from "../lib/api";
|
||||
import type { ModelConnectionBadge } from "../lib/modelStatus";
|
||||
|
||||
@@ -27,14 +27,6 @@ type AvatarSelectionStageProps = {
|
||||
onCustomAvatarCreate: (file: File, name: string) => void;
|
||||
onAvatarDelete?: (avatar: AvatarSummary) => void;
|
||||
referenceSaving?: boolean;
|
||||
memorySummary?: {
|
||||
enabled: boolean;
|
||||
libraryName: string | null;
|
||||
memoryCount: number | null;
|
||||
};
|
||||
agentConfig: AgentConfig;
|
||||
onAgentConfigChange: (next: AgentConfig) => void;
|
||||
knowledgeBases: KnowledgeBaseSummary[];
|
||||
personas: PersonaSummary[];
|
||||
selectedPersonaId: string;
|
||||
personaImporting?: boolean;
|
||||
@@ -84,10 +76,6 @@ export function AvatarSelectionStage({
|
||||
onCustomAvatarCreate,
|
||||
onAvatarDelete,
|
||||
referenceSaving = false,
|
||||
memorySummary,
|
||||
agentConfig,
|
||||
onAgentConfigChange,
|
||||
knowledgeBases,
|
||||
personas,
|
||||
selectedPersonaId,
|
||||
personaImporting = false,
|
||||
@@ -106,20 +94,7 @@ export function AvatarSelectionStage({
|
||||
});
|
||||
const [customFile, setCustomFile] = useState<File | null>(null);
|
||||
const [customPreviewUrl, setCustomPreviewUrl] = useState<string | null>(null);
|
||||
const selectedKnowledgeBaseIds = agentConfig.knowledgeBaseIds;
|
||||
const selectedPersona = personas.find((persona) => persona.id === selectedPersonaId) ?? null;
|
||||
const knowledgeBasesById = new Map(knowledgeBases.map((kb) => [kb.id, kb]));
|
||||
const selectedKnowledgeBases = selectedKnowledgeBaseIds.map((id) => (
|
||||
knowledgeBasesById.get(id) ?? {
|
||||
id,
|
||||
name: id,
|
||||
document_count: 0,
|
||||
ready_document_count: 0,
|
||||
error_document_count: 0,
|
||||
created_at: "",
|
||||
updated_at: "",
|
||||
}
|
||||
));
|
||||
const configDisabled = loading || queued || prewarmState === "preparing";
|
||||
const baseDisabled = loading || queued || prewarmState === "preparing" || !selectedAvatar || !modelConnected;
|
||||
const startLabel = queued
|
||||
@@ -163,15 +138,6 @@ export function AvatarSelectionStage({
|
||||
if (file) onPersonaImport(file);
|
||||
};
|
||||
|
||||
const updateKnowledgeBaseIds = (nextIds: string[]) => {
|
||||
const deduped = Array.from(new Set(nextIds.filter((id) => id.trim())));
|
||||
onAgentConfigChange({
|
||||
...agentConfig,
|
||||
knowledgeEnabled: deduped.length > 0,
|
||||
knowledgeBaseIds: deduped,
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="relative h-full min-h-[520px] overflow-hidden bg-white">
|
||||
<div className="grid h-full min-h-0 gap-5 p-4 sm:p-5 xl:grid-cols-[minmax(28rem,1.15fr)_minmax(20rem,0.85fr)] xl:p-6">
|
||||
@@ -371,78 +337,6 @@ export function AvatarSelectionStage({
|
||||
<p className="mt-1 truncate text-sm font-semibold text-slate-950">{selectedVoiceLabel}</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mb-3 rounded-lg border border-slate-200 bg-slate-50 px-3 py-2.5">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<p className="text-xs font-semibold text-slate-600">Agent 增强</p>
|
||||
<span className="shrink-0 rounded-full border border-slate-200 bg-white px-2 py-0.5 text-[11px] font-medium text-slate-500">
|
||||
{knowledgeBases.length} 个知识库
|
||||
</span>
|
||||
</div>
|
||||
<div className="mt-2 min-h-20 rounded-md border border-slate-200 bg-white p-2">
|
||||
<div className="mb-2 flex items-center justify-between gap-2">
|
||||
<p className="text-[11px] font-semibold uppercase tracking-wide text-slate-500">当前形象知识库</p>
|
||||
<span className="text-[11px] text-slate-400">{selectedKnowledgeBases.length} 项</span>
|
||||
</div>
|
||||
{selectedKnowledgeBases.length ? (
|
||||
<div className="flex flex-wrap gap-1.5">
|
||||
{selectedKnowledgeBases.map((kb) => (
|
||||
<div
|
||||
key={kb.id}
|
||||
className="inline-flex max-w-full items-center gap-1.5 rounded-full border border-cyan-200 bg-cyan-50 px-2 py-1 text-xs text-cyan-800"
|
||||
>
|
||||
<span className="min-w-0 truncate font-medium">{kb.name}</span>
|
||||
<button
|
||||
type="button"
|
||||
disabled={baseDisabled}
|
||||
onClick={() =>
|
||||
updateKnowledgeBaseIds(selectedKnowledgeBaseIds.filter((id) => id !== kb.id))
|
||||
}
|
||||
className="shrink-0 rounded-full px-1.5 py-0.5 font-semibold text-cyan-700 transition hover:bg-cyan-100 disabled:opacity-50"
|
||||
aria-label={`移除 ${kb.name}`}
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<p className="px-1 py-2 text-xs text-slate-400">未选择知识库</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div className="mb-3 rounded-lg border border-slate-200 bg-slate-50 px-3 py-2.5">
|
||||
<div className="flex items-center justify-between gap-2">
|
||||
<p className="text-xs font-semibold text-slate-600">记忆库</p>
|
||||
<span
|
||||
className={`shrink-0 rounded-full border px-2 py-0.5 text-[11px] font-medium ${
|
||||
memorySummary?.enabled
|
||||
? "border-cyan-200 bg-white text-cyan-700"
|
||||
: "border-slate-200 bg-white text-slate-500"
|
||||
}`}
|
||||
>
|
||||
{memorySummary?.enabled ? "已挂载" : "未挂载"}
|
||||
</span>
|
||||
</div>
|
||||
<div className="mt-2 min-h-16 rounded-md border border-slate-200 bg-white p-2">
|
||||
<div className="mb-2 flex items-center justify-between gap-2">
|
||||
<p className="text-[11px] font-semibold uppercase tracking-wide text-slate-500">
|
||||
当前形象记忆库
|
||||
</p>
|
||||
<span className="text-[11px] text-slate-400">
|
||||
{memorySummary?.enabled && memorySummary.memoryCount !== null
|
||||
? `${memorySummary.memoryCount} 条`
|
||||
: "0 条"}
|
||||
</span>
|
||||
</div>
|
||||
{memorySummary?.enabled && memorySummary.libraryName ? (
|
||||
<div className="inline-flex max-w-full items-center gap-1.5 rounded-full border border-cyan-200 bg-cyan-50 px-2 py-1 text-xs text-cyan-800">
|
||||
<span className="min-w-0 truncate font-medium">{memorySummary.libraryName}</span>
|
||||
</div>
|
||||
) : (
|
||||
<p className="px-1 py-2 text-xs text-slate-400">未选择记忆库</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{queued ? (
|
||||
<p className="mb-3 rounded-lg border border-amber-200 bg-amber-50 px-3 py-2 text-sm font-medium text-amber-800">
|
||||
前面还有 {queueInfo?.position ?? 1} 人,请稍候...
|
||||
|
||||
@@ -186,10 +186,6 @@ interface SettingsPanelProps {
|
||||
qwenVoice: string;
|
||||
onQwenVoiceChange: (voiceId: string) => void;
|
||||
qwenVoiceOptions: VoiceOpt[];
|
||||
llmSystemPrompt: string;
|
||||
onLlmSystemPromptChange: (value: string) => void;
|
||||
onSavePrompt: () => void;
|
||||
promptSaving?: boolean;
|
||||
onOpenVoiceClone?: () => void;
|
||||
voiceApplyNotice?: string | null;
|
||||
ttsPreviewText: string;
|
||||
@@ -415,10 +411,6 @@ export function SettingsPanel({
|
||||
qwenVoice,
|
||||
onQwenVoiceChange,
|
||||
qwenVoiceOptions,
|
||||
llmSystemPrompt,
|
||||
onLlmSystemPromptChange,
|
||||
onSavePrompt,
|
||||
promptSaving = false,
|
||||
onOpenVoiceClone,
|
||||
voiceApplyNotice = null,
|
||||
ttsPreviewText,
|
||||
@@ -525,6 +517,7 @@ export function SettingsPanel({
|
||||
}));
|
||||
const providerHasSingleModel = (provider: TtsProviderExtended) => {
|
||||
if (provider === "edge" || provider === "openai_compatible") return true;
|
||||
if (provider === "local_cosyvoice" || provider === "indextts") return true;
|
||||
if (provider !== ttsProvider) return false;
|
||||
return qwenModelColumnOptions.length <= 1;
|
||||
};
|
||||
@@ -1083,33 +1076,6 @@ export function SettingsPanel({
|
||||
</div>
|
||||
</SettingsSection>
|
||||
|
||||
<SettingsSection
|
||||
id="role"
|
||||
title="人设"
|
||||
open={openSections.role}
|
||||
onToggle={toggleSection}
|
||||
>
|
||||
<div className="space-y-3">
|
||||
<label className="block">
|
||||
<span className="mb-1.5 block text-xs text-slate-500">人设定义</span>
|
||||
<textarea
|
||||
value={llmSystemPrompt}
|
||||
onChange={(e) => onLlmSystemPromptChange(e.target.value)}
|
||||
rows={5}
|
||||
className="w-full resize-none rounded-lg border border-slate-200 bg-slate-50 px-3 py-2.5 text-sm text-slate-800 outline-none transition placeholder:text-slate-400 focus:border-cyan-300 focus:bg-white"
|
||||
placeholder={"你可以在这里定义数字人的身份、说话风格和边界。\n\n示例:你是一位温和专业的产品讲解员,回答简洁自然,优先用中文回复。遇到不确定的问题先说明不确定,再给出可执行建议。"}
|
||||
/>
|
||||
</label>
|
||||
<button
|
||||
type="button"
|
||||
onClick={onSavePrompt}
|
||||
disabled={promptSaving}
|
||||
className="w-full rounded-lg bg-slate-950 px-3 py-2.5 text-sm font-semibold text-white transition hover:bg-slate-800 disabled:cursor-not-allowed disabled:opacity-60"
|
||||
>
|
||||
{promptSaving ? "保存中..." : "保存人设"}
|
||||
</button>
|
||||
</div>
|
||||
</SettingsSection>
|
||||
</div>
|
||||
</aside>
|
||||
);
|
||||
|
||||
@@ -106,10 +106,10 @@ def _read_text_file(path: Path) -> str:
|
||||
raw = path.read_bytes()
|
||||
for encoding in ("utf-8", "utf-8-sig", "gb18030"):
|
||||
try:
|
||||
return raw.decode(encoding)
|
||||
return raw.decode(encoding).replace("\r\n", "\n").replace("\r", "\n")
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return raw.decode("utf-8", errors="replace")
|
||||
return raw.decode("utf-8", errors="replace").replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
|
||||
def _has_enough_text(text: str) -> bool:
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"voice_id": "cosyvoice-official-zero-shot",
|
||||
"display_label": "官方示例女声(本地)",
|
||||
"provider": "local_cosyvoice",
|
||||
"target_model": "FunAudioLLM/Fun-CosyVoice3-0.5B-2512",
|
||||
"mode": "zero_shot",
|
||||
"source": "system",
|
||||
"provenance": "FunAudioLLM/CosyVoice runtime asset prompt"
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
希望你以后能够做的比我还好呦。
|
||||
Binary file not shown.
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"voice_id": "local-anchor-cherry",
|
||||
"display_label": "主播女声 Cherry(本地)",
|
||||
"provider": "local_cosyvoice",
|
||||
"target_model": "FunAudioLLM/Fun-CosyVoice3-0.5B-2512",
|
||||
"mode": "zero_shot",
|
||||
"source": "system",
|
||||
"prompt_audio": "/data2/zhongyi/model/opentalking-local-audio/voices/system/local-anchor-cherry/prompt.wav",
|
||||
"prompt_text": "你好,欢迎来到OpenTalking。我会用自然清晰的声音,为你介绍今天的内容。",
|
||||
"prompt_source": "generated_once_from_dashscope_qwen_tts",
|
||||
"qwen_reference_voice": "Cherry",
|
||||
"avatar_hint": "主播",
|
||||
"role": "anchor",
|
||||
"style": "清晰、亲和、适合新闻播报和产品讲解",
|
||||
"duration_sec": 6.32,
|
||||
"sample_rate": 24000,
|
||||
"bytes": 303404,
|
||||
"storage_note": "prompt.wav is a local Qwen-generated reference sample; kept small as one WAV per built-in local voice."
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
你好,欢迎来到OpenTalking。我会用自然清晰的声音,为你介绍今天的内容。
|
||||
BIN
opentalking/assets/voices/system/local-anchor-cherry/prompt.wav
Normal file
BIN
opentalking/assets/voices/system/local-anchor-cherry/prompt.wav
Normal file
Binary file not shown.
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"voice_id": "local-ancient-jada",
|
||||
"display_label": "古风女声 Jada(本地)",
|
||||
"provider": "local_cosyvoice",
|
||||
"target_model": "FunAudioLLM/Fun-CosyVoice3-0.5B-2512",
|
||||
"mode": "zero_shot",
|
||||
"source": "system",
|
||||
"prompt_audio": "/data2/zhongyi/model/opentalking-local-audio/voices/system/local-ancient-jada/prompt.wav",
|
||||
"prompt_text": "你好,欢迎来到OpenTalking。我会用自然清晰的声音,为你介绍今天的内容。",
|
||||
"prompt_source": "generated_once_from_dashscope_qwen_tts",
|
||||
"qwen_reference_voice": "Jada",
|
||||
"avatar_hint": "古装美女",
|
||||
"role": "ancient-beauty",
|
||||
"style": "柔和、叙事感、适合文化讲解",
|
||||
"duration_sec": 6.8,
|
||||
"sample_rate": 24000,
|
||||
"bytes": 326444,
|
||||
"storage_note": "prompt.wav is a local Qwen-generated reference sample; kept small as one WAV per built-in local voice."
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
你好,欢迎来到OpenTalking。我会用自然清晰的声音,为你介绍今天的内容。
|
||||
BIN
opentalking/assets/voices/system/local-ancient-jada/prompt.wav
Normal file
BIN
opentalking/assets/voices/system/local-ancient-jada/prompt.wav
Normal file
Binary file not shown.
19
opentalking/assets/voices/system/local-anime-ethan/meta.json
Normal file
19
opentalking/assets/voices/system/local-anime-ethan/meta.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"voice_id": "local-anime-ethan",
|
||||
"display_label": "动漫男声 Ethan(本地)",
|
||||
"provider": "local_cosyvoice",
|
||||
"target_model": "FunAudioLLM/Fun-CosyVoice3-0.5B-2512",
|
||||
"mode": "zero_shot",
|
||||
"source": "system",
|
||||
"prompt_audio": "/data2/zhongyi/model/opentalking-local-audio/voices/system/local-anime-ethan/prompt.wav",
|
||||
"prompt_text": "你好,欢迎来到OpenTalking。我会用自然清晰的声音,为你介绍今天的内容。",
|
||||
"prompt_source": "generated_once_from_dashscope_qwen_tts",
|
||||
"qwen_reference_voice": "Ethan",
|
||||
"avatar_hint": "动漫帅哥",
|
||||
"role": "anime-handsome-guy",
|
||||
"style": "年轻、明亮、适合轻量互动和二次元角色",
|
||||
"duration_sec": 5.12,
|
||||
"sample_rate": 24000,
|
||||
"bytes": 245804,
|
||||
"storage_note": "prompt.wav is a local Qwen-generated reference sample; kept small as one WAV per built-in local voice."
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
你好,欢迎来到OpenTalking。我会用自然清晰的声音,为你介绍今天的内容。
|
||||
BIN
opentalking/assets/voices/system/local-anime-ethan/prompt.wav
Normal file
BIN
opentalking/assets/voices/system/local-anime-ethan/prompt.wav
Normal file
Binary file not shown.
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"voice_id": "local-office-serena",
|
||||
"display_label": "职场女声 Serena(本地)",
|
||||
"provider": "local_cosyvoice",
|
||||
"target_model": "FunAudioLLM/Fun-CosyVoice3-0.5B-2512",
|
||||
"mode": "zero_shot",
|
||||
"source": "system",
|
||||
"prompt_audio": "/data2/zhongyi/model/opentalking-local-audio/voices/system/local-office-serena/prompt.wav",
|
||||
"prompt_text": "你好,欢迎来到OpenTalking。我会用自然清晰的声音,为你介绍今天的内容。",
|
||||
"prompt_source": "generated_once_from_dashscope_qwen_tts",
|
||||
"qwen_reference_voice": "Serena",
|
||||
"avatar_hint": "职场女",
|
||||
"role": "office-woman",
|
||||
"style": "稳重、专业、适合客服和商务说明",
|
||||
"duration_sec": 5.52,
|
||||
"sample_rate": 24000,
|
||||
"bytes": 265004,
|
||||
"storage_note": "prompt.wav is a local Qwen-generated reference sample; kept small as one WAV per built-in local voice."
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
你好,欢迎来到OpenTalking。我会用自然清晰的声音,为你介绍今天的内容。
|
||||
BIN
opentalking/assets/voices/system/local-office-serena/prompt.wav
Normal file
BIN
opentalking/assets/voices/system/local-office-serena/prompt.wav
Normal file
Binary file not shown.
@@ -306,7 +306,8 @@ def _legacy_env_mapping() -> dict[str, str]:
|
||||
|
||||
|
||||
def _load_legacy_dotenv_source() -> dict[str, Any]:
|
||||
values = dotenv_values(".env")
|
||||
env_file = os.environ.get("OPENTALKING_ENV_FILE", ".env")
|
||||
values = dotenv_values(env_file)
|
||||
mapping = _legacy_env_mapping()
|
||||
return {
|
||||
target: value
|
||||
@@ -323,7 +324,7 @@ def _load_legacy_env_source() -> dict[str, Any]:
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="OPENTALKING_",
|
||||
env_file=".env",
|
||||
env_file=os.environ.get("OPENTALKING_ENV_FILE", ".env"),
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
@@ -4,10 +4,12 @@ import json
|
||||
import re
|
||||
import shutil
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
UTC = timezone.utc
|
||||
|
||||
_ALLOWED_KINDS = {"realtime_dialogue", "video_clone", "video_creation"}
|
||||
_SAFE_ID_RE = re.compile(r"^[a-zA-Z0-9_.-]{1,128}$")
|
||||
_MIME_EXTENSIONS = {
|
||||
|
||||
@@ -378,6 +378,22 @@ def _explicit_cuda_available(device: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_cuda_oom(exc: BaseException) -> bool:
|
||||
message = str(exc).lower()
|
||||
return (
|
||||
"out of memory" in message
|
||||
or "cudaerrormemoryallocation" in message
|
||||
or "cuda error: out of memory" in message
|
||||
)
|
||||
|
||||
|
||||
def _fallback_quicktalk_device(device: str) -> str | None:
|
||||
raw = (device or "").strip().lower()
|
||||
if raw.startswith("cuda"):
|
||||
return "cpu"
|
||||
return None
|
||||
|
||||
|
||||
@register_model("quicktalk")
|
||||
class QuickTalkAdapter:
|
||||
"""QuickTalk realtime worker integrated into OpenTalking's model API."""
|
||||
@@ -588,6 +604,7 @@ class QuickTalkAdapter:
|
||||
cache_disabled = _env_value("OPENTALKING_QUICKTALK_WORKER_CACHE", "1") == "0"
|
||||
|
||||
worker: RealtimeV3Worker | None = None
|
||||
cache_key_to_store = cache_key
|
||||
if not cache_disabled:
|
||||
worker = _WORKER_CACHE.get(cache_key)
|
||||
if worker is not None:
|
||||
@@ -607,34 +624,133 @@ class QuickTalkAdapter:
|
||||
"quicktalk worker cache MISS — building (avatar=%s)",
|
||||
bundle.manifest.id,
|
||||
)
|
||||
worker = RealtimeV3Worker(
|
||||
asset_root=asset_root,
|
||||
template_video=template_video,
|
||||
face_cache_dir=face_cache_dir,
|
||||
face_cache_file=face_cache_file,
|
||||
device=self._device,
|
||||
output_transform=self._output_transform,
|
||||
scale_h=self._scale_h,
|
||||
scale_w=self._scale_w,
|
||||
resolution=self._resolution,
|
||||
max_template_seconds=max_template_seconds,
|
||||
neck_fade_start=self._neck_fade_start,
|
||||
neck_fade_end=self._neck_fade_end,
|
||||
hubert_device=self._hubert_device,
|
||||
model_backend=self._model_backend,
|
||||
)
|
||||
try:
|
||||
worker = RealtimeV3Worker(
|
||||
asset_root=asset_root,
|
||||
template_video=template_video,
|
||||
face_cache_dir=face_cache_dir,
|
||||
face_cache_file=face_cache_file,
|
||||
device=self._device,
|
||||
output_transform=self._output_transform,
|
||||
scale_h=self._scale_h,
|
||||
scale_w=self._scale_w,
|
||||
resolution=self._resolution,
|
||||
max_template_seconds=max_template_seconds,
|
||||
neck_fade_start=self._neck_fade_start,
|
||||
neck_fade_end=self._neck_fade_end,
|
||||
hubert_device=self._hubert_device,
|
||||
model_backend=self._model_backend,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
fallback_device = _fallback_quicktalk_device(self._device)
|
||||
if fallback_device is None or not _is_cuda_oom(exc):
|
||||
raise
|
||||
log.warning(
|
||||
"QuickTalk worker build hit CUDA OOM; retrying on CPU (avatar=%s)",
|
||||
bundle.manifest.id,
|
||||
exc_info=True,
|
||||
)
|
||||
self._device = fallback_device
|
||||
self._hubert_device = fallback_device
|
||||
cache_key_to_store = _worker_cache_key(
|
||||
asset_root=asset_root,
|
||||
template_video=template_video,
|
||||
face_cache_dir=face_cache_dir,
|
||||
face_cache_file=face_cache_file,
|
||||
device=self._device,
|
||||
output_transform=self._output_transform,
|
||||
scale_h=self._scale_h,
|
||||
scale_w=self._scale_w,
|
||||
resolution=self._resolution,
|
||||
max_template_seconds=max_template_seconds,
|
||||
neck_fade_start=self._neck_fade_start,
|
||||
neck_fade_end=self._neck_fade_end,
|
||||
hubert_device=self._hubert_device,
|
||||
model_backend=self._model_backend,
|
||||
)
|
||||
worker = RealtimeV3Worker(
|
||||
asset_root=asset_root,
|
||||
template_video=template_video,
|
||||
face_cache_dir=face_cache_dir,
|
||||
face_cache_file=face_cache_file,
|
||||
device=self._device,
|
||||
output_transform=self._output_transform,
|
||||
scale_h=self._scale_h,
|
||||
scale_w=self._scale_w,
|
||||
resolution=self._resolution,
|
||||
max_template_seconds=max_template_seconds,
|
||||
neck_fade_start=self._neck_fade_start,
|
||||
neck_fade_end=self._neck_fade_end,
|
||||
hubert_device=self._hubert_device,
|
||||
model_backend=self._model_backend,
|
||||
)
|
||||
if not cache_disabled:
|
||||
_WORKER_CACHE[cache_key] = worker
|
||||
_WORKER_CACHE[cache_key_to_store] = worker
|
||||
_enforce_worker_cache_limit()
|
||||
|
||||
session_state = worker.make_state()
|
||||
return QuickTalkState(
|
||||
manifest=bundle.manifest,
|
||||
worker=worker,
|
||||
fps=worker.fps,
|
||||
extra={},
|
||||
session_state=session_state,
|
||||
)
|
||||
try:
|
||||
session_state = worker.make_state()
|
||||
return QuickTalkState(
|
||||
manifest=bundle.manifest,
|
||||
worker=worker,
|
||||
fps=worker.fps,
|
||||
extra={},
|
||||
session_state=session_state,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
fallback_device = _fallback_quicktalk_device(self._device)
|
||||
if fallback_device is None or not _is_cuda_oom(exc):
|
||||
raise
|
||||
log.warning(
|
||||
"QuickTalk avatar load hit CUDA OOM; retrying on CPU (avatar=%s)",
|
||||
bundle.manifest.id,
|
||||
exc_info=True,
|
||||
)
|
||||
self._device = fallback_device
|
||||
self._hubert_device = fallback_device
|
||||
fallback_key = _worker_cache_key(
|
||||
asset_root=asset_root,
|
||||
template_video=template_video,
|
||||
face_cache_dir=face_cache_dir,
|
||||
face_cache_file=face_cache_file,
|
||||
device=self._device,
|
||||
output_transform=self._output_transform,
|
||||
scale_h=self._scale_h,
|
||||
scale_w=self._scale_w,
|
||||
resolution=self._resolution,
|
||||
max_template_seconds=max_template_seconds,
|
||||
neck_fade_start=self._neck_fade_start,
|
||||
neck_fade_end=self._neck_fade_end,
|
||||
hubert_device=self._hubert_device,
|
||||
model_backend=self._model_backend,
|
||||
)
|
||||
worker = RealtimeV3Worker(
|
||||
asset_root=asset_root,
|
||||
template_video=template_video,
|
||||
face_cache_dir=face_cache_dir,
|
||||
face_cache_file=face_cache_file,
|
||||
device=self._device,
|
||||
output_transform=self._output_transform,
|
||||
scale_h=self._scale_h,
|
||||
scale_w=self._scale_w,
|
||||
resolution=self._resolution,
|
||||
max_template_seconds=max_template_seconds,
|
||||
neck_fade_start=self._neck_fade_start,
|
||||
neck_fade_end=self._neck_fade_end,
|
||||
hubert_device=self._hubert_device,
|
||||
model_backend=self._model_backend,
|
||||
)
|
||||
if not cache_disabled:
|
||||
_WORKER_CACHE[fallback_key] = worker
|
||||
_enforce_worker_cache_limit()
|
||||
session_state = worker.make_state()
|
||||
return QuickTalkState(
|
||||
manifest=bundle.manifest,
|
||||
worker=worker,
|
||||
fps=worker.fps,
|
||||
extra={"device_fallback": "cpu"},
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
def warmup(self, avatar_state: QuickTalkState | None = None) -> None:
|
||||
if avatar_state is None:
|
||||
|
||||
@@ -6,7 +6,11 @@ from typing import Any
|
||||
|
||||
from opentalking.core.config import Settings, get_settings
|
||||
from opentalking.providers.memory.base import MemoryProvider
|
||||
from opentalking.providers.memory.mem0_provider import InMemoryMemoryProvider, Mem0MemoryProvider
|
||||
from opentalking.providers.memory.mem0_provider import (
|
||||
InMemoryMemoryProvider,
|
||||
Mem0MemoryProvider,
|
||||
Mem0UnavailableError,
|
||||
)
|
||||
from opentalking.providers.memory.noop import NoopMemoryProvider
|
||||
from opentalking.providers.memory.sqlite_provider import SQLiteMemoryProvider
|
||||
|
||||
@@ -110,7 +114,10 @@ def build_memory_provider() -> MemoryProvider:
|
||||
if provider in {"sqlite", "local"}:
|
||||
return SQLiteMemoryProvider(settings.memory_sqlite_path)
|
||||
if provider == "mem0":
|
||||
return Mem0MemoryProvider(config=_mem0_config(settings))
|
||||
try:
|
||||
return Mem0MemoryProvider(config=_mem0_config(settings))
|
||||
except Mem0UnavailableError:
|
||||
return SQLiteMemoryProvider(settings.memory_sqlite_path)
|
||||
if provider in {"memory", "inmemory", "in-memory"}:
|
||||
return InMemoryMemoryProvider()
|
||||
raise ValueError(f"unsupported memory provider: {settings.memory_provider}")
|
||||
|
||||
@@ -267,20 +267,22 @@ def _stt_model_dir(provider: str, model: str | None = None) -> str:
|
||||
provider = normalize_stt_provider(provider, default="dashscope") or "dashscope"
|
||||
direct = _provider_env(provider, "MODEL_DIR") or _settings_value(f"stt_{provider}_model_dir", "")
|
||||
if direct:
|
||||
return direct
|
||||
return str(Path(direct).expanduser().resolve())
|
||||
if provider in LOCAL_STT_PROVIDERS:
|
||||
return str(_local_path_for_model((model or _stt_model(provider)).strip()))
|
||||
return ""
|
||||
return ""
|
||||
|
||||
|
||||
def _local_path_for_model(model: str) -> Path:
|
||||
path = Path(model).expanduser()
|
||||
if path.is_absolute() or path.exists():
|
||||
return path
|
||||
if path.is_absolute():
|
||||
return path.resolve()
|
||||
root = _model_root().expanduser()
|
||||
if path.exists():
|
||||
return path.resolve()
|
||||
if "/" in model:
|
||||
return _model_root() / model.replace("/", "__")
|
||||
return _model_root() / model
|
||||
return (root / model.replace("/", "__")).resolve()
|
||||
return (root / model).resolve()
|
||||
|
||||
|
||||
def _write_pcm_queue_to_wav(
|
||||
@@ -335,9 +337,9 @@ class LocalFunASRSTTAdapter:
|
||||
|
||||
def _runtime_model_name(self) -> str:
|
||||
if self.model_dir:
|
||||
return self.model_dir
|
||||
return str(Path(self.model_dir).expanduser().resolve())
|
||||
local_path = _local_path_for_model(self.model)
|
||||
return str(local_path) if local_path.exists() else self.model
|
||||
return str(local_path)
|
||||
|
||||
def _load_runtime(self) -> Any:
|
||||
if self._runtime is not None:
|
||||
|
||||
@@ -6,6 +6,7 @@ from opentalking.providers.tts import ( # noqa: F401 side-effect imports
|
||||
dashscope_sambert,
|
||||
edge,
|
||||
elevenlabs,
|
||||
mock,
|
||||
)
|
||||
from opentalking.providers.tts.edge.adapter import EdgeTTSAdapter
|
||||
from opentalking.providers.tts.factory import build_tts_adapter, create_tts_adapter
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import posixpath
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
@@ -42,6 +43,13 @@ def _provider_env(provider: str, field: str) -> str:
|
||||
return os.environ.get(f"OPENTALKING_TTS_{key_provider}_{field}", "").strip()
|
||||
|
||||
|
||||
def _join_config_path(root: str, *parts: str) -> str:
|
||||
value = root.strip()
|
||||
if "/" in value and "\\" not in value:
|
||||
return posixpath.join(value.rstrip("/"), *parts)
|
||||
return str(Path(value).expanduser().joinpath(*parts))
|
||||
|
||||
|
||||
def _configured_tts_provider_values() -> tuple[str, ...]:
|
||||
return tuple(
|
||||
value
|
||||
@@ -118,20 +126,19 @@ def _local_cosyvoice_service_url() -> str:
|
||||
)
|
||||
|
||||
|
||||
def _local_audio_model_root() -> Path:
|
||||
raw = (
|
||||
def _local_audio_model_root() -> str:
|
||||
return (
|
||||
os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "").strip()
|
||||
or _settings_value("local_audio_model_root", "")
|
||||
or "./models/local-audio"
|
||||
)
|
||||
return Path(raw).expanduser()
|
||||
|
||||
|
||||
def _local_cosyvoice_model_dir(model: str) -> str:
|
||||
return (
|
||||
_provider_env("local_cosyvoice", "MODEL_DIR")
|
||||
or _settings_value("tts_local_cosyvoice_model_dir", "")
|
||||
or str(_local_audio_model_root() / model.replace("/", "__"))
|
||||
or _join_config_path(_local_audio_model_root(), model.replace("/", "__"))
|
||||
)
|
||||
|
||||
|
||||
@@ -184,10 +191,11 @@ def _local_cosyvoice_fp16() -> str:
|
||||
def _local_audio_asset_dir(name: str, required_file: str, *fallback_names: str) -> str:
|
||||
root = _local_audio_model_root()
|
||||
for candidate_name in (name, *fallback_names):
|
||||
candidate = root / candidate_name
|
||||
candidate_value = _join_config_path(root, candidate_name)
|
||||
candidate = Path(candidate_value).expanduser()
|
||||
if (candidate / required_file).is_file():
|
||||
return str(candidate)
|
||||
return str(root / name)
|
||||
return candidate_value
|
||||
return _join_config_path(root, name)
|
||||
|
||||
|
||||
def _local_audio_asset_file_dir(name: str, relative_file: str, *fallback_names: str) -> str:
|
||||
@@ -206,7 +214,7 @@ def _local_indextts_model_dir(model: str) -> str:
|
||||
return (
|
||||
_provider_env("local_indextts", "MODEL_DIR")
|
||||
or _settings_value("tts_local_indextts_model_dir", "")
|
||||
or str(_local_audio_model_root() / model.replace("/", "__"))
|
||||
or _join_config_path(_local_audio_model_root(), model.replace("/", "__"))
|
||||
)
|
||||
|
||||
|
||||
@@ -214,7 +222,7 @@ def _local_indextts_cfg_path(model_dir: str) -> str:
|
||||
return (
|
||||
_provider_env("local_indextts", "CFG_PATH")
|
||||
or _settings_value("tts_local_indextts_cfg_path", "")
|
||||
or str(Path(model_dir) / "config.yaml")
|
||||
or _join_config_path(model_dir, "config.yaml")
|
||||
)
|
||||
|
||||
|
||||
@@ -236,7 +244,7 @@ def _local_indextts_w2v_bert_dir() -> str:
|
||||
return (
|
||||
_provider_env("local_indextts", "W2V_BERT_DIR")
|
||||
or _settings_value("tts_local_indextts_w2v_bert_dir", "")
|
||||
or str(_local_audio_model_root() / "facebook__w2v-bert-2.0")
|
||||
or _join_config_path(_local_audio_model_root(), "facebook__w2v-bert-2.0")
|
||||
)
|
||||
|
||||
|
||||
@@ -252,7 +260,7 @@ def _local_indextts_campplus_dir() -> str:
|
||||
return (
|
||||
_provider_env("local_indextts", "CAMPPLUS_DIR")
|
||||
or _settings_value("tts_local_indextts_campplus_dir", "")
|
||||
or str(_local_audio_model_root() / "funasr__campplus")
|
||||
or _join_config_path(_local_audio_model_root(), "funasr__campplus")
|
||||
)
|
||||
|
||||
|
||||
@@ -260,7 +268,7 @@ def _local_indextts_bigvgan_dir() -> str:
|
||||
return (
|
||||
_provider_env("local_indextts", "BIGVGAN_DIR")
|
||||
or _settings_value("tts_local_indextts_bigvgan_dir", "")
|
||||
or str(_local_audio_model_root() / "nvidia__bigvgan_v2_22khz_80band_256x")
|
||||
or _join_config_path(_local_audio_model_root(), "nvidia__bigvgan_v2_22khz_80band_256x")
|
||||
)
|
||||
|
||||
|
||||
@@ -519,6 +527,16 @@ def tts_enabled_providers() -> list[str]:
|
||||
|
||||
def tts_provider_config(provider: str) -> dict[str, str | bool | int | float]:
|
||||
p = normalize_tts_provider(provider, default=None) or _provider()
|
||||
if p == "mock":
|
||||
return {
|
||||
"provider": p,
|
||||
"model": "mock",
|
||||
"model_dir": "",
|
||||
"voice": "mock",
|
||||
"device": "",
|
||||
"key_set": False,
|
||||
"service_url_set": True,
|
||||
}
|
||||
if p == "indextts":
|
||||
resolved = _resolve_indextts_provider(p)
|
||||
config = dict(tts_provider_config(resolved))
|
||||
@@ -1260,6 +1278,15 @@ def build_tts_adapter(
|
||||
tts_model=effective_tts_model,
|
||||
)
|
||||
|
||||
if provider == "mock":
|
||||
from opentalking.providers.tts.mock.adapter import MockTTSAdapter
|
||||
|
||||
return MockTTSAdapter(
|
||||
default_voice=default_voice or "mock",
|
||||
sample_rate=sample_rate,
|
||||
chunk_ms=chunk_ms,
|
||||
)
|
||||
|
||||
# For dashscope/bailian/etc., delegate to create_tts_adapter
|
||||
if provider in _QWEN_RT or provider in _COSY_WS or provider in _SAMBERT or provider in _LOCAL or provider in _OMNIRT or provider in _INDEXTTS:
|
||||
return create_tts_adapter(
|
||||
@@ -1272,6 +1299,14 @@ def build_tts_adapter(
|
||||
)
|
||||
|
||||
if provider in _CORE:
|
||||
if provider == "mock":
|
||||
from opentalking.providers.tts.mock.adapter import MockTTSAdapter
|
||||
|
||||
return MockTTSAdapter(
|
||||
default_voice=default_voice or getattr(settings, "tts_voice", None) or "mock",
|
||||
sample_rate=sample_rate,
|
||||
chunk_ms=chunk_ms,
|
||||
)
|
||||
return EdgeTTSAdapter(
|
||||
default_voice=default_voice or getattr(settings, "tts_voice", None) or _edge_default_voice(),
|
||||
sample_rate=sample_rate,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
@@ -14,6 +15,11 @@ import httpx
|
||||
import numpy as np
|
||||
|
||||
from opentalking.core.types.frames import AudioChunk
|
||||
from opentalking.providers.tts.voice_assets import (
|
||||
LOCAL_COSYVOICE_PROVIDER,
|
||||
local_audio_model_root,
|
||||
resolve_voice_asset,
|
||||
)
|
||||
|
||||
|
||||
def _settings_value(name: str, default: str = "") -> str:
|
||||
@@ -93,12 +99,7 @@ def _resolve_service_url_for_model(model: str, default_model: str, default_url:
|
||||
|
||||
|
||||
def _model_root() -> Path:
|
||||
raw = (
|
||||
os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "").strip()
|
||||
or _settings_value("local_audio_model_root", "")
|
||||
or "./models/local-audio"
|
||||
)
|
||||
return Path(raw).expanduser()
|
||||
return local_audio_model_root()
|
||||
|
||||
|
||||
def _resolve_model_path(model: str) -> str:
|
||||
@@ -114,39 +115,70 @@ def _resolve_local_voice_prompt(voice: str | None) -> dict[str, str] | None:
|
||||
return None
|
||||
if not all(ch.isalnum() or ch in {"_", "-"} for ch in voice_id):
|
||||
return None
|
||||
for base in (_model_root() / "voices" / "clones", _model_root() / "voices" / "system"):
|
||||
voice_dir = base / voice_id
|
||||
prompt_audio = voice_dir / "prompt.wav"
|
||||
prompt_text = voice_dir / "prompt.txt"
|
||||
if not prompt_audio.is_file() or not prompt_text.is_file():
|
||||
continue
|
||||
result = {"prompt_audio": str(prompt_audio)}
|
||||
text = prompt_text.read_text(encoding="utf-8").strip()
|
||||
if text:
|
||||
result["prompt_text"] = text
|
||||
meta_path = voice_dir / "meta.json"
|
||||
if meta_path.is_file():
|
||||
try:
|
||||
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
meta = {}
|
||||
for key in ("mode", "instruction"):
|
||||
value = str(meta.get(key) or "").strip()
|
||||
if value:
|
||||
result[key] = value
|
||||
if result.get("prompt_text") or result.get("mode") in {"cross_lingual", "instruct"}:
|
||||
return result
|
||||
asset = resolve_voice_asset(
|
||||
voice_id,
|
||||
provider=LOCAL_COSYVOICE_PROVIDER,
|
||||
sources=("clones", "system"),
|
||||
model_root=_model_root(),
|
||||
require_prompt_text=True,
|
||||
)
|
||||
if asset is None or asset.prompt_text is None:
|
||||
return None
|
||||
result = {"prompt_audio": str(asset.prompt_audio)}
|
||||
try:
|
||||
text = asset.prompt_text.read_text(encoding="utf-8").strip()
|
||||
except OSError:
|
||||
text = ""
|
||||
if text:
|
||||
result["prompt_text"] = text
|
||||
for key in ("mode", "instruction"):
|
||||
value = str(asset.meta.get(key) or "").strip()
|
||||
if value:
|
||||
result[key] = value
|
||||
if result.get("prompt_text") or result.get("mode") in {"cross_lingual", "instruct"}:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
class LocalCosyVoiceInputError(ValueError):
|
||||
"""Invalid local CosyVoice request, usually an unavailable prompt voice."""
|
||||
|
||||
|
||||
def _is_service_default_voice(voice: str | None) -> bool:
|
||||
voice_id = (voice or "").strip()
|
||||
return not voice_id or voice_id == "local-default"
|
||||
|
||||
|
||||
def _service_default_prompt_configured() -> bool:
|
||||
mode = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_MODE", "zero_shot").strip().lower()
|
||||
prompt_audio = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_AUDIO", "").strip()
|
||||
prompt_text = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_TEXT", "").strip()
|
||||
if mode in {"cross_lingual", "instruct"}:
|
||||
return bool(prompt_audio)
|
||||
return bool(prompt_audio and prompt_text)
|
||||
|
||||
|
||||
def _missing_local_voice_message(voice: str | None) -> str:
|
||||
voice_text = (voice or "").strip() or "未选择"
|
||||
return f"本地 CosyVoice 音色 {voice_text!r} 没有可用 prompt;请先选择本地音色。"
|
||||
|
||||
|
||||
def _local_cosyvoice_http_400_message(*, voice: str | None, service_url: str, detail: str) -> str:
|
||||
suffix = f" sidecar detail: {detail.strip()}" if detail.strip() else ""
|
||||
return (
|
||||
f"本地 CosyVoice 请求无效:{_missing_local_voice_message(voice)} "
|
||||
f"HTTP 400 from {service_url}.{suffix}"
|
||||
).strip()
|
||||
|
||||
|
||||
def _env_device() -> str:
|
||||
return (
|
||||
os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_DEVICE", "").strip()
|
||||
or _settings_value("tts_local_cosyvoice_device", "")
|
||||
or os.environ.get("OPENTALKING_LOCAL_TTS_DEVICE", "").strip()
|
||||
or os.environ.get("OPENTALKING_LOCAL_AUDIO_DEVICE", "").strip()
|
||||
or _settings_value("local_audio_device", "auto")
|
||||
or "auto"
|
||||
or _settings_value("local_audio_device", "cpu")
|
||||
or "cpu"
|
||||
)
|
||||
|
||||
|
||||
@@ -299,15 +331,16 @@ class LocalCosyVoiceTTSAdapter:
|
||||
default_service_url,
|
||||
)
|
||||
self.device = _env_device()
|
||||
self.fp16 = _local_cosyvoice_fp16(self.device)
|
||||
self.load_jit = _local_cosyvoice_bool("LOAD_JIT", "tts_local_cosyvoice_load_jit", False)
|
||||
self.load_trt = _local_cosyvoice_bool("LOAD_TRT", "tts_local_cosyvoice_load_trt", False)
|
||||
self.load_vllm = _local_cosyvoice_bool("LOAD_VLLM", "tts_local_cosyvoice_load_vllm", False)
|
||||
self.fp16 = _local_cosyvoice_fp16(self.device)
|
||||
self.trt_concurrent = max(
|
||||
1,
|
||||
_local_cosyvoice_int("TRT_CONCURRENT", "tts_local_cosyvoice_trt_concurrent", 1),
|
||||
)
|
||||
self._engine: Any | None = None
|
||||
self._voice_payload_cache: dict[str, dict[str, str]] = {}
|
||||
|
||||
async def synthesize_stream(self, text: str, voice: str | None = None) -> AsyncIterator[AudioChunk]:
|
||||
if not text.strip():
|
||||
@@ -321,76 +354,109 @@ class LocalCosyVoiceTTSAdapter:
|
||||
yield chunk
|
||||
|
||||
async def _synthesize_via_service(self, text: str, voice: str | None = None) -> AsyncIterator[AudioChunk]:
|
||||
timeout = httpx.Timeout(connect=30.0, read=180.0, write=30.0, pool=30.0)
|
||||
timeout = httpx.Timeout(connect=30.0, read=600.0, write=30.0, pool=30.0)
|
||||
effective_voice = voice or self.default_voice
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice": voice or self.default_voice,
|
||||
"voice": effective_voice,
|
||||
"model": self.model,
|
||||
"sample_rate": self.sample_rate,
|
||||
}
|
||||
local_prompt = _resolve_local_voice_prompt(voice or self.default_voice)
|
||||
local_prompt = _resolve_local_voice_prompt(effective_voice)
|
||||
if local_prompt is not None:
|
||||
payload.update(local_prompt)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("POST", self.service_url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
input_format = _audio_format_from_content_type(resp.headers.get("content-type"))
|
||||
if input_format == "pcm":
|
||||
source_sr = _source_sample_rate_from_headers(resp.headers, self.sample_rate)
|
||||
pending = b""
|
||||
async for data in resp.aiter_bytes():
|
||||
if not data:
|
||||
continue
|
||||
data = pending + data
|
||||
if len(data) % 2:
|
||||
pending = data[-1:]
|
||||
data = data[:-1]
|
||||
else:
|
||||
pending = b""
|
||||
if not data:
|
||||
continue
|
||||
pcm = np.frombuffer(data, dtype="<i2").astype(np.int16, copy=False)
|
||||
payload["zero_shot_spk_id"] = effective_voice
|
||||
elif not _is_service_default_voice(effective_voice):
|
||||
raise LocalCosyVoiceInputError(_missing_local_voice_message(effective_voice))
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("POST", self.service_url, json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
input_format = _audio_format_from_content_type(resp.headers.get("content-type"))
|
||||
if input_format == "pcm":
|
||||
source_sr = _source_sample_rate_from_headers(resp.headers, self.sample_rate)
|
||||
pending = b""
|
||||
async for data in resp.aiter_bytes():
|
||||
if not data:
|
||||
continue
|
||||
data = pending + data
|
||||
if len(data) % 2:
|
||||
pending = data[-1:]
|
||||
data = data[:-1]
|
||||
else:
|
||||
pending = b""
|
||||
if not data:
|
||||
continue
|
||||
pcm = np.frombuffer(data, dtype="<i2").astype(np.int16, copy=False)
|
||||
pcm = _resample_linear(pcm, source_sr, self.sample_rate)
|
||||
for chunk in _split_pcm_chunks(pcm, self.sample_rate, self.chunk_ms):
|
||||
yield chunk
|
||||
return
|
||||
if input_format == "wav":
|
||||
raw = await resp.aread()
|
||||
with wave.open(io.BytesIO(raw), "rb") as wf:
|
||||
source_sr = int(wf.getframerate())
|
||||
channels = int(wf.getnchannels())
|
||||
sample_width = int(wf.getsampwidth())
|
||||
pcm_bytes = wf.readframes(wf.getnframes())
|
||||
if sample_width != 2:
|
||||
raise RuntimeError(f"Unsupported WAV sample width for local CosyVoice: {sample_width}")
|
||||
pcm = np.frombuffer(pcm_bytes, dtype="<i2").astype(np.int16, copy=False)
|
||||
if channels > 1:
|
||||
frame_count = pcm.size // channels
|
||||
pcm = (
|
||||
pcm[: frame_count * channels]
|
||||
.reshape(frame_count, channels)
|
||||
.mean(axis=1)
|
||||
.astype(np.int16)
|
||||
)
|
||||
pcm = _resample_linear(pcm, source_sr, self.sample_rate)
|
||||
for chunk in _split_pcm_chunks(pcm, self.sample_rate, self.chunk_ms):
|
||||
yield chunk
|
||||
return
|
||||
if input_format == "wav":
|
||||
raw = await resp.aread()
|
||||
with wave.open(io.BytesIO(raw), "rb") as wf:
|
||||
source_sr = int(wf.getframerate())
|
||||
channels = int(wf.getnchannels())
|
||||
sample_width = int(wf.getsampwidth())
|
||||
pcm_bytes = wf.readframes(wf.getnframes())
|
||||
if sample_width != 2:
|
||||
raise RuntimeError(f"Unsupported WAV sample width for local CosyVoice: {sample_width}")
|
||||
pcm = np.frombuffer(pcm_bytes, dtype="<i2").astype(np.int16, copy=False)
|
||||
if channels > 1:
|
||||
frame_count = pcm.size // channels
|
||||
pcm = (
|
||||
pcm[: frame_count * channels]
|
||||
.reshape(frame_count, channels)
|
||||
.mean(axis=1)
|
||||
.astype(np.int16)
|
||||
)
|
||||
pcm = _resample_linear(pcm, source_sr, self.sample_rate)
|
||||
for chunk in _split_pcm_chunks(pcm, self.sample_rate, self.chunk_ms):
|
||||
return
|
||||
|
||||
from opentalking.providers.tts.edge.adapter import _stream_decode_audio_to_pcm_chunks
|
||||
|
||||
async def _audio_iter() -> AsyncIterator[bytes]:
|
||||
async for data in resp.aiter_bytes():
|
||||
if data:
|
||||
yield data
|
||||
|
||||
async for chunk in _stream_decode_audio_to_pcm_chunks(
|
||||
_audio_iter(),
|
||||
self.sample_rate,
|
||||
self.chunk_ms,
|
||||
input_format=input_format,
|
||||
):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
from opentalking.providers.tts.edge.adapter import _stream_decode_audio_to_pcm_chunks
|
||||
|
||||
async def _audio_iter() -> AsyncIterator[bytes]:
|
||||
async for data in resp.aiter_bytes():
|
||||
if data:
|
||||
yield data
|
||||
|
||||
async for chunk in _stream_decode_audio_to_pcm_chunks(
|
||||
_audio_iter(),
|
||||
self.sample_rate,
|
||||
self.chunk_ms,
|
||||
input_format=input_format,
|
||||
):
|
||||
yield chunk
|
||||
except httpx.HTTPStatusError as exc:
|
||||
detail = ""
|
||||
try:
|
||||
detail = exc.response.text[:500]
|
||||
except Exception:
|
||||
detail = ""
|
||||
if exc.response.status_code == 400:
|
||||
raise LocalCosyVoiceInputError(
|
||||
_local_cosyvoice_http_400_message(
|
||||
voice=effective_voice,
|
||||
service_url=self.service_url,
|
||||
detail=detail,
|
||||
)
|
||||
) from exc
|
||||
if "CosyVoice returned no audio" in detail:
|
||||
raise RuntimeError(
|
||||
"本地 CosyVoice 返回空音频,模型推理状态异常;请重启本地 TTS 服务后重试。"
|
||||
f" HTTP {exc.response.status_code} from {self.service_url}. {detail}".strip()
|
||||
) from exc
|
||||
raise RuntimeError(
|
||||
"本地 CosyVoice 服务不可用(可能已退出/内存不足)。"
|
||||
f" HTTP {exc.response.status_code} from {self.service_url}. {detail}".strip()
|
||||
) from exc
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout, httpx.ReadTimeout, httpx.RemoteProtocolError) as exc:
|
||||
raise RuntimeError(
|
||||
"本地 CosyVoice 服务不可用(可能已退出/内存不足)。"
|
||||
f" 无法连接或读取 {self.service_url}: {type(exc).__name__}: {exc}"
|
||||
) from exc
|
||||
|
||||
def _load_engine(self) -> Any:
|
||||
if self._engine is not None:
|
||||
@@ -432,6 +498,28 @@ class LocalCosyVoiceTTSAdapter:
|
||||
return str(item)
|
||||
return requested or "中文女"
|
||||
|
||||
def _callable_supports_keyword(self, fn: Any, keyword: str) -> bool:
|
||||
try:
|
||||
signature = inspect.signature(fn)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return keyword in signature.parameters or any(
|
||||
param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
|
||||
)
|
||||
|
||||
def _resolved_voice_payload(self, voice: str | None) -> tuple[str | None, dict[str, str] | None]:
|
||||
voice_id = (voice or "").strip()
|
||||
if not voice_id or voice_id == "local-default":
|
||||
return None, None
|
||||
cached = self._voice_payload_cache.get(voice_id)
|
||||
if cached is not None:
|
||||
return voice_id, dict(cached)
|
||||
payload = _resolve_local_voice_prompt(voice_id)
|
||||
if payload is None:
|
||||
return voice_id, None
|
||||
self._voice_payload_cache[voice_id] = dict(payload)
|
||||
return voice_id, payload
|
||||
|
||||
def _synthesize_in_process(self, text: str, voice: str) -> list[AudioChunk]:
|
||||
engine = self._load_engine()
|
||||
spk_id = self._available_voice(engine, voice)
|
||||
@@ -440,7 +528,34 @@ class LocalCosyVoiceTTSAdapter:
|
||||
raise RuntimeError("CosyVoice runtime does not expose inference_sft().")
|
||||
sr = int(getattr(engine, "sample_rate", 22050) or 22050)
|
||||
pcm_parts: list[np.ndarray] = []
|
||||
for item in infer(text, spk_id, stream=False):
|
||||
voice_id, payload = self._resolved_voice_payload(spk_id)
|
||||
if payload is not None:
|
||||
add_zero_shot_spk = getattr(engine, "add_zero_shot_spk", None)
|
||||
if callable(add_zero_shot_spk):
|
||||
try:
|
||||
prompt_text = payload.get("prompt_text", "")
|
||||
prompt_audio = payload["prompt_audio"]
|
||||
if self._callable_supports_keyword(add_zero_shot_spk, "zero_shot_spk_id"):
|
||||
add_zero_shot_spk(prompt_text, prompt_audio, zero_shot_spk_id=voice_id)
|
||||
else:
|
||||
add_zero_shot_spk(prompt_text, prompt_audio, voice_id)
|
||||
save_spkinfo = getattr(engine, "save_spkinfo", None)
|
||||
if callable(save_spkinfo):
|
||||
save_spkinfo()
|
||||
except Exception:
|
||||
pass
|
||||
if payload is not None and self._callable_supports_keyword(infer, "zero_shot_spk_id"):
|
||||
iterator = infer(text, "", "", stream=False, zero_shot_spk_id=voice_id)
|
||||
elif payload is not None:
|
||||
iterator = infer(
|
||||
text,
|
||||
payload.get("prompt_text", ""),
|
||||
payload["prompt_audio"],
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
iterator = infer(text, spk_id, stream=False)
|
||||
for item in iterator:
|
||||
speech = item.get("tts_speech") if isinstance(item, dict) else item
|
||||
if hasattr(speech, "detach"):
|
||||
speech = speech.detach().cpu().numpy()
|
||||
|
||||
@@ -17,6 +17,7 @@ import httpx
|
||||
|
||||
from opentalking.core.types.frames import AudioChunk
|
||||
from opentalking.providers.tts.indextts_config import indextts_infer_kwargs, normalize_indextts_config
|
||||
from opentalking.providers.tts.voice_assets import INDEXTTS_PROVIDER, resolve_voice_asset
|
||||
|
||||
|
||||
_ENGINE_CACHE_LOCK = threading.Lock()
|
||||
@@ -113,6 +114,24 @@ def _read_wav_bytes_i16(raw: bytes) -> tuple[np.ndarray, int]:
|
||||
return _read_wav_handle_i16(wf)
|
||||
|
||||
|
||||
def _write_wav_i16(path: str | Path, data: np.ndarray, sample_rate: int) -> None:
|
||||
pcm = np.asarray(data)
|
||||
if pcm.ndim == 2 and pcm.shape[0] == 1:
|
||||
pcm = pcm[0]
|
||||
elif pcm.ndim == 2:
|
||||
pcm = pcm.T.reshape(-1)
|
||||
if np.issubdtype(pcm.dtype, np.floating):
|
||||
pcm = np.clip(pcm, -1.0, 1.0)
|
||||
pcm = np.round(pcm * 32767.0).astype("<i2")
|
||||
else:
|
||||
pcm = np.clip(pcm, -32768, 32767).astype("<i2")
|
||||
with wave.open(str(path), "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(int(sample_rate))
|
||||
wf.writeframes(pcm.reshape(-1).tobytes())
|
||||
|
||||
|
||||
class LocalIndexTTSAdapter:
|
||||
"""In-process IndexTTS2 adapter for OpenTalking local TTS mode."""
|
||||
|
||||
@@ -226,11 +245,14 @@ class LocalIndexTTSAdapter:
|
||||
def _resolve_voice_prompt(self, voice: str | None) -> Path | None:
|
||||
voice_id = (voice or "").strip()
|
||||
if voice_id and re.fullmatch(r"[A-Za-z0-9_-]{3,80}", voice_id):
|
||||
voices_root = _local_audio_model_root() / "voices"
|
||||
for root in (voices_root / "clones", voices_root / "system", _bundled_system_voice_root()):
|
||||
prompt = root / voice_id / "prompt.wav"
|
||||
if prompt.is_file():
|
||||
return prompt
|
||||
asset = resolve_voice_asset(
|
||||
voice_id,
|
||||
provider=INDEXTTS_PROVIDER,
|
||||
sources=("clones", "system"),
|
||||
model_root=_local_audio_model_root(),
|
||||
)
|
||||
if asset is not None:
|
||||
return asset.prompt_audio
|
||||
if self.prompt_audio:
|
||||
return Path(self.prompt_audio)
|
||||
return None
|
||||
@@ -369,15 +391,8 @@ class LocalIndexTTSAdapter:
|
||||
text = str(exc)
|
||||
if "TorchCodec is required" not in text and "libtorchcodec" not in text:
|
||||
raise
|
||||
import soundfile as sf
|
||||
|
||||
data = tensor.detach().cpu().numpy() if hasattr(tensor, "detach") else np.asarray(tensor)
|
||||
data = np.asarray(data)
|
||||
if data.ndim == 2 and data.shape[0] == 1:
|
||||
data = data[0]
|
||||
elif data.ndim == 2:
|
||||
data = data.T
|
||||
sf.write(path, data, sample_rate, subtype="PCM_16")
|
||||
_write_wav_i16(path, data, sample_rate)
|
||||
return None
|
||||
|
||||
torchaudio.save = save
|
||||
@@ -435,10 +450,15 @@ class LocalIndexTTSAdapter:
|
||||
if not callable(infer):
|
||||
raise RuntimeError("IndexTTS2 runtime does not expose infer().")
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
|
||||
fd, tmp_name = tempfile.mkstemp(suffix=".wav")
|
||||
os.close(fd)
|
||||
tmp_path = Path(tmp_name)
|
||||
try:
|
||||
with engine_lock:
|
||||
infer(str(prompt), text, tmp.name, **indextts_infer_kwargs(self.indextts_config))
|
||||
pcm, source_sr = _read_wav_i16(Path(tmp.name))
|
||||
infer(str(prompt), text, tmp_name, **indextts_infer_kwargs(self.indextts_config))
|
||||
pcm, source_sr = _read_wav_i16(tmp_path)
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
pcm = _resample_linear(pcm, source_sr, self.sample_rate)
|
||||
return _split_pcm_chunks(pcm, self.sample_rate, self.chunk_ms)
|
||||
|
||||
6
opentalking/providers/tts/mock/__init__.py
Normal file
6
opentalking/providers/tts/mock/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from opentalking.core.registry import register
|
||||
from opentalking.providers.tts.mock.adapter import MockTTSAdapter
|
||||
|
||||
register("tts", "mock")(MockTTSAdapter)
|
||||
|
||||
__all__ = ["MockTTSAdapter"]
|
||||
50
opentalking/providers/tts/mock/adapter.py
Normal file
50
opentalking/providers/tts/mock/adapter.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from opentalking.core.types.frames import AudioChunk
|
||||
|
||||
|
||||
def _tone_frequency(text: str) -> float:
|
||||
checksum = sum(ord(ch) for ch in text)
|
||||
return 180.0 + float(checksum % 220)
|
||||
|
||||
|
||||
def _sine_wave(text: str, sample_rate: int, chunk_ms: float) -> np.ndarray:
|
||||
duration_sec = max(0.4, min(2.0, len(text) * 0.06))
|
||||
sample_count = max(1, int(sample_rate * duration_sec))
|
||||
t = np.arange(sample_count, dtype=np.float32) / sample_rate
|
||||
freq = _tone_frequency(text)
|
||||
carrier = np.sin(2.0 * math.pi * freq * t)
|
||||
envelope = np.linspace(0.15, 0.95, sample_count, dtype=np.float32)
|
||||
audio = carrier * envelope * 0.22
|
||||
return np.clip(audio * 32767.0, -32768, 32767).astype(np.int16)
|
||||
|
||||
|
||||
class MockTTSAdapter:
|
||||
def __init__(
|
||||
self,
|
||||
default_voice: str = "mock",
|
||||
sample_rate: int = 16000,
|
||||
chunk_ms: float = 40.0,
|
||||
) -> None:
|
||||
self.default_voice = default_voice
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_ms = chunk_ms
|
||||
|
||||
async def synthesize_stream(self, text: str, voice: str | None = None) -> AsyncIterator[AudioChunk]:
|
||||
del voice
|
||||
pcm = _sine_wave(text, self.sample_rate, self.chunk_ms)
|
||||
chunk_samples = max(1, int(self.sample_rate * (self.chunk_ms / 1000.0)))
|
||||
for start in range(0, len(pcm), chunk_samples):
|
||||
part = pcm[start : start + chunk_samples]
|
||||
if part.size == 0:
|
||||
continue
|
||||
yield AudioChunk(
|
||||
data=part.copy(),
|
||||
sample_rate=self.sample_rate,
|
||||
duration_ms=1000.0 * float(part.size) / float(self.sample_rate),
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
CORE_TTS_PROVIDERS = frozenset({"auto", "edge", "elevenlabs"})
|
||||
CORE_TTS_PROVIDERS = frozenset({"auto", "edge", "elevenlabs", "mock"})
|
||||
OPENAI_COMPATIBLE_TTS_PROVIDERS = frozenset({"openai_compatible"})
|
||||
XIAOMI_MIMO_TTS_PROVIDERS = frozenset({"xiaomi_mimo", "xiaomi", "mimo"})
|
||||
QWEN_TTS_PROVIDERS = frozenset({"dashscope", "bailian", "qwen", "qwen_tts"})
|
||||
|
||||
191
opentalking/providers/tts/voice_assets.py
Normal file
191
opentalking/providers/tts/voice_assets.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
INDEXTTS_PROVIDER = "indextts"
|
||||
INDEXTTS_LEGACY_PROVIDERS = {"local_indextts", "omnirt_indextts"}
|
||||
INDEXTTS_PROVIDERS = {INDEXTTS_PROVIDER, *INDEXTTS_LEGACY_PROVIDERS}
|
||||
LOCAL_COSYVOICE_PROVIDER = "local_cosyvoice"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VoiceAsset:
|
||||
voice_id: str
|
||||
source: str
|
||||
root: Path
|
||||
path: Path
|
||||
prompt_audio: Path
|
||||
prompt_text: Path | None
|
||||
meta: dict[str, Any]
|
||||
bundled_system: bool = False
|
||||
|
||||
|
||||
def local_audio_model_root() -> Path:
|
||||
raw = os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "").strip()
|
||||
try:
|
||||
from opentalking.core.config import get_settings
|
||||
|
||||
raw = raw or (get_settings().local_audio_model_root or "").strip()
|
||||
except Exception:
|
||||
pass
|
||||
return Path(raw or "./models/local-audio").expanduser().resolve()
|
||||
|
||||
|
||||
def bundled_system_voice_root() -> Path:
|
||||
return Path(__file__).resolve().parents[2] / "assets" / "voices" / "system"
|
||||
|
||||
|
||||
def system_voice_roots(model_root: Path | None = None) -> list[Path]:
|
||||
root = model_root or local_audio_model_root()
|
||||
roots = [root / "voices" / "system", bundled_system_voice_root()]
|
||||
out: list[Path] = []
|
||||
seen: set[Path] = set()
|
||||
for item in roots:
|
||||
try:
|
||||
resolved = item.resolve()
|
||||
except OSError:
|
||||
resolved = item
|
||||
if resolved in seen:
|
||||
continue
|
||||
seen.add(resolved)
|
||||
out.append(item)
|
||||
return out
|
||||
|
||||
|
||||
def clone_voice_roots(model_root: Path | None = None) -> list[Path]:
|
||||
root = model_root or local_audio_model_root()
|
||||
return [root / "voices" / "clones"]
|
||||
|
||||
|
||||
def _provider_aliases(provider: str) -> set[str]:
|
||||
normalized = provider.strip().lower()
|
||||
if normalized in INDEXTTS_PROVIDERS:
|
||||
return {INDEXTTS_PROVIDER, *INDEXTTS_LEGACY_PROVIDERS}
|
||||
return {normalized}
|
||||
|
||||
|
||||
def _truthy_meta_flag(value: object) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on", "universal", "zero_shot"}
|
||||
return False
|
||||
|
||||
|
||||
def voice_applies_to_provider(meta: dict[str, Any], provider: str, *, bundled_system: bool = False) -> bool:
|
||||
normalized = provider.strip().lower()
|
||||
if not normalized:
|
||||
return True
|
||||
if bundled_system and normalized == LOCAL_COSYVOICE_PROVIDER:
|
||||
return True
|
||||
if any(_truthy_meta_flag(meta.get(key)) for key in ("universal", "compatible", "zero_shot_compatible")):
|
||||
return True
|
||||
aliases = _provider_aliases(normalized)
|
||||
raw_providers = meta.get("providers")
|
||||
if isinstance(raw_providers, list):
|
||||
allowed = {str(item).strip().lower() for item in raw_providers if str(item).strip()}
|
||||
if allowed:
|
||||
return bool(allowed & aliases)
|
||||
raw_provider = str(meta.get("provider") or "").strip().lower()
|
||||
if not raw_provider:
|
||||
return True
|
||||
if normalized == LOCAL_COSYVOICE_PROVIDER and raw_provider in INDEXTTS_PROVIDERS and not raw_providers:
|
||||
return True
|
||||
return raw_provider in aliases
|
||||
|
||||
|
||||
def read_voice_meta(voice_dir: Path) -> dict[str, Any]:
|
||||
meta_path = voice_dir / "meta.json"
|
||||
if not meta_path.is_file():
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(meta_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
|
||||
def iter_voice_assets(
|
||||
*,
|
||||
provider: str,
|
||||
sources: Iterable[str] = ("clones", "system"),
|
||||
model_root: Path | None = None,
|
||||
require_prompt_text: bool = False,
|
||||
) -> list[VoiceAsset]:
|
||||
root = model_root or local_audio_model_root()
|
||||
roots: list[tuple[str, Path, bool]] = []
|
||||
for source in sources:
|
||||
if source == "clones":
|
||||
roots.extend(("clones", item, False) for item in clone_voice_roots(root))
|
||||
elif source == "system":
|
||||
bundled = bundled_system_voice_root()
|
||||
for item in system_voice_roots(root):
|
||||
roots.append(("system", item, _same_path(item, bundled)))
|
||||
|
||||
assets: list[VoiceAsset] = []
|
||||
seen: set[tuple[str, str]] = set()
|
||||
for source, voice_root, bundled_system in roots:
|
||||
if not voice_root.is_dir():
|
||||
continue
|
||||
for voice_dir in sorted(path for path in voice_root.iterdir() if path.is_dir()):
|
||||
voice_id = voice_dir.name
|
||||
key = (source, voice_id)
|
||||
if key in seen:
|
||||
continue
|
||||
prompt_audio = voice_dir / "prompt.wav"
|
||||
prompt_text = voice_dir / "prompt.txt"
|
||||
if not prompt_audio.is_file():
|
||||
continue
|
||||
if require_prompt_text and not prompt_text.is_file():
|
||||
continue
|
||||
meta = read_voice_meta(voice_dir)
|
||||
if not voice_applies_to_provider(meta, provider, bundled_system=bundled_system):
|
||||
continue
|
||||
seen.add(key)
|
||||
assets.append(
|
||||
VoiceAsset(
|
||||
voice_id=voice_id,
|
||||
source=source,
|
||||
root=voice_root,
|
||||
path=voice_dir,
|
||||
prompt_audio=prompt_audio,
|
||||
prompt_text=prompt_text if prompt_text.is_file() else None,
|
||||
meta=meta,
|
||||
bundled_system=bundled_system,
|
||||
)
|
||||
)
|
||||
return assets
|
||||
|
||||
|
||||
def resolve_voice_asset(
|
||||
voice_id: str,
|
||||
*,
|
||||
provider: str,
|
||||
sources: Iterable[str] = ("clones", "system"),
|
||||
model_root: Path | None = None,
|
||||
require_prompt_text: bool = False,
|
||||
) -> VoiceAsset | None:
|
||||
wanted = voice_id.strip()
|
||||
if not wanted:
|
||||
return None
|
||||
for asset in iter_voice_assets(
|
||||
provider=provider,
|
||||
sources=sources,
|
||||
model_root=model_root,
|
||||
require_prompt_text=require_prompt_text,
|
||||
):
|
||||
if asset.voice_id == wanted:
|
||||
return asset
|
||||
return None
|
||||
|
||||
|
||||
def _same_path(left: Path, right: Path) -> bool:
|
||||
try:
|
||||
return left.resolve() == right.resolve()
|
||||
except OSError:
|
||||
return left == right
|
||||
@@ -117,10 +117,24 @@ def _fit_reference_driver_pcm(pcm: np.ndarray, total_samples: int) -> np.ndarray
|
||||
return np.tile(source, repeats)[:target].astype(np.int16, copy=False)
|
||||
|
||||
|
||||
def _read_pcm16_mono_wav(path: Path) -> np.ndarray | None:
|
||||
try:
|
||||
with wave.open(str(path), "rb") as wf:
|
||||
if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getframerate() != 16000:
|
||||
return None
|
||||
raw = wf.readframes(wf.getnframes())
|
||||
except (wave.Error, OSError):
|
||||
return None
|
||||
return np.frombuffer(raw, dtype="<i2").copy()
|
||||
|
||||
|
||||
async def _load_reference_driver_pcm(settings: object, total_samples: int) -> np.ndarray | None:
|
||||
path = _reference_driver_audio_path(settings)
|
||||
if not path.is_file():
|
||||
return None
|
||||
direct_pcm = _read_pcm16_mono_wav(path)
|
||||
if direct_pcm is not None:
|
||||
return _fit_reference_driver_pcm(direct_pcm, total_samples)
|
||||
try:
|
||||
pcm = await decode_audio_file_to_pcm_i16(path)
|
||||
return _fit_reference_driver_pcm(pcm, total_samples)
|
||||
@@ -415,6 +429,12 @@ def _quicktalk_prepared_template_video(settings: object, avatar_path: Path) -> P
|
||||
path = avatar_path / name
|
||||
if path.is_file():
|
||||
return path.resolve()
|
||||
source_dir = avatar_path / "source"
|
||||
if source_dir.is_dir():
|
||||
for pattern in ("*.mp4", "*.mov", "*.webm", "*.avi"):
|
||||
for candidate in sorted(source_dir.glob(pattern)):
|
||||
if candidate.is_file():
|
||||
return candidate.resolve()
|
||||
return None
|
||||
|
||||
|
||||
@@ -492,12 +512,13 @@ def _init_session_kwargs(
|
||||
fasterliveportrait_config: Mapping[str, object] | None,
|
||||
) -> dict[str, object]:
|
||||
kwargs: dict[str, object] = {"avatar_path": avatar_path}
|
||||
if model == "quicktalk":
|
||||
kwargs.update(_quicktalk_init_session_kwargs(settings, avatar_path))
|
||||
if not _remote_audio2video_backend(backend):
|
||||
return kwargs
|
||||
|
||||
kwargs["ref_image"] = _reference_image_path(avatar_path)
|
||||
if model == "quicktalk":
|
||||
kwargs.update(_quicktalk_init_session_kwargs(settings, avatar_path))
|
||||
return kwargs
|
||||
if model != "fasterliveportrait":
|
||||
return kwargs
|
||||
|
||||
@@ -9,6 +9,11 @@ from pathlib import Path
|
||||
|
||||
|
||||
DEFAULT_ROOT = Path(os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "./models/local-audio"))
|
||||
DEFAULT_REUSE_ROOTS = (
|
||||
Path("./models"),
|
||||
Path("/root/models"),
|
||||
Path.home() / ".cache" / "opentalking" / "models",
|
||||
)
|
||||
|
||||
MODELS: dict[str, tuple[str, str]] = {
|
||||
"sensevoice-small": ("modelscope", "iic/SenseVoiceSmall"),
|
||||
@@ -58,6 +63,30 @@ HF_ALLOW_PATTERNS: dict[str, list[str]] = {
|
||||
],
|
||||
}
|
||||
|
||||
MODEL_HINTS: dict[str, tuple[str, ...]] = {
|
||||
"sensevoice-small": ("iic__SenseVoiceSmall", "sensevoice", "SenseVoiceSmall"),
|
||||
"fun-cosyvoice3-0.5b-2512": (
|
||||
"FunAudioLLM__Fun-CosyVoice3-0.5B-2512",
|
||||
"Fun-CosyVoice3-0.5B-2512",
|
||||
"cosyvoice",
|
||||
),
|
||||
"indextts2": ("IndexTeam__IndexTTS-2", "IndexTTS-2"),
|
||||
"indextts2-w2v-bert": ("facebook__w2v-bert-2.0", "w2v-bert-2.0"),
|
||||
"indextts2-maskgct": ("amphion__MaskGCT", "amphion__MaskGCT-ms"),
|
||||
"indextts2-campplus": ("funasr__campplus", "campplus"),
|
||||
"indextts2-bigvgan": ("nvidia__bigvgan_v2_22khz_80band_256x", "bigvgan_v2_22khz_80band_256x"),
|
||||
}
|
||||
|
||||
MODEL_REQUIRED_FILES: dict[str, tuple[str, ...]] = {
|
||||
"sensevoice-small": ("model.pt", "config.yaml", "configuration.json"),
|
||||
"fun-cosyvoice3-0.5b-2512": ("cosyvoice3.yaml", "flow.pt", "hift.pt", "llm.pt"),
|
||||
"indextts2": ("config.yaml", "model.pt"),
|
||||
"indextts2-w2v-bert": ("model.safetensors", "conformer_shaw.pt"),
|
||||
"indextts2-maskgct": ("semantic_codec/model.safetensors", "acoustic_codec/model.safetensors"),
|
||||
"indextts2-campplus": ("campplus_cn_common.bin", "config.yaml"),
|
||||
"indextts2-bigvgan": ("bigvgan_generator.pt",),
|
||||
}
|
||||
|
||||
|
||||
def default_model_keys() -> list[str]:
|
||||
return ["sensevoice-small", "fun-cosyvoice3-0.5b-2512"]
|
||||
@@ -71,6 +100,61 @@ def _target(root: Path, model_id: str) -> Path:
|
||||
return root / model_id.replace("/", "__")
|
||||
|
||||
|
||||
def _reuse_root_values(raw: str | None) -> list[Path]:
|
||||
if not raw:
|
||||
return [root for root in DEFAULT_REUSE_ROOTS if root is not None]
|
||||
roots: list[Path] = []
|
||||
for chunk in raw.replace(";", os.pathsep).replace(",", os.pathsep).split(os.pathsep):
|
||||
value = chunk.strip()
|
||||
if value:
|
||||
roots.append(Path(value).expanduser())
|
||||
return roots
|
||||
|
||||
|
||||
def _model_hints(model_key: str) -> tuple[str, ...]:
|
||||
return MODEL_HINTS.get(model_key, (MODELS[model_key][1].replace("/", "__"),))
|
||||
|
||||
|
||||
def _required_files(model_key: str) -> tuple[str, ...]:
|
||||
return MODEL_REQUIRED_FILES.get(model_key, ())
|
||||
|
||||
|
||||
def _is_model_ready(path: Path, *, model_key: str) -> bool:
|
||||
if not path.exists():
|
||||
return False
|
||||
required = _required_files(model_key)
|
||||
if not required:
|
||||
return path.is_dir() or path.is_file()
|
||||
return all((path / relative).exists() for relative in required)
|
||||
|
||||
|
||||
def _find_reusable_source(model_key: str, roots: list[Path]) -> Path | None:
|
||||
for root in roots:
|
||||
for hint in _model_hints(model_key):
|
||||
candidate = root / hint
|
||||
if _is_model_ready(candidate, model_key=model_key):
|
||||
return candidate
|
||||
nested = candidate / "checkpoints"
|
||||
if _is_model_ready(nested, model_key=model_key):
|
||||
return nested
|
||||
return None
|
||||
|
||||
|
||||
def _mirror_existing_source(source: Path, target: Path) -> None:
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target.exists():
|
||||
return
|
||||
try:
|
||||
target.symlink_to(source, target_is_directory=source.is_dir())
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
if source.is_dir():
|
||||
shutil.copytree(source, target)
|
||||
else:
|
||||
shutil.copy2(source, target)
|
||||
|
||||
|
||||
def _download_modelscope(model_id: str, target: Path) -> None:
|
||||
from modelscope import snapshot_download
|
||||
|
||||
@@ -99,6 +183,12 @@ def _git_lfs_pull_if_needed(target: Path) -> None:
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Download the supported local STT/TTS model weights.")
|
||||
parser.add_argument("--root", type=Path, default=DEFAULT_ROOT)
|
||||
parser.add_argument(
|
||||
"--reuse-root",
|
||||
action="append",
|
||||
dest="reuse_roots",
|
||||
help="Search these roots first and reuse existing weights instead of downloading.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
action="append",
|
||||
@@ -110,6 +200,9 @@ def main() -> None:
|
||||
root = args.root.expanduser().resolve()
|
||||
root.mkdir(parents=True, exist_ok=True)
|
||||
selected = args.model or default_model_keys()
|
||||
reuse_roots = [root] + _reuse_root_values(os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_SEARCH_ROOTS"))
|
||||
if args.reuse_roots:
|
||||
reuse_roots.extend(Path(value).expanduser().resolve() for value in args.reuse_roots)
|
||||
|
||||
failures: list[tuple[str, str]] = []
|
||||
for key in selected:
|
||||
@@ -118,6 +211,17 @@ def main() -> None:
|
||||
print(f"[{key}] {source}:{model_id} -> {target}", flush=True)
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
if _is_model_ready(target, model_key=key):
|
||||
print(f"[{key}] reusing existing target: {target}", flush=True)
|
||||
continue
|
||||
|
||||
reusable = _find_reusable_source(key, reuse_roots)
|
||||
if reusable is not None:
|
||||
print(f"[{key}] reusing existing source: {reusable}", flush=True)
|
||||
if reusable.resolve() != target.resolve():
|
||||
_mirror_existing_source(reusable, target)
|
||||
continue
|
||||
|
||||
if source == "modelscope":
|
||||
_download_modelscope(model_id, target)
|
||||
else:
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.util
|
||||
import io
|
||||
import inspect
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Callable, Iterator
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -19,6 +23,29 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _load_voice_assets_module():
|
||||
module_name = "_opentalking_voice_assets_local_cosyvoice"
|
||||
module = sys.modules.get(module_name)
|
||||
if module is not None:
|
||||
return module
|
||||
module_path = Path(__file__).resolve().parents[1] / "opentalking" / "providers" / "tts" / "voice_assets.py"
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Unable to load voice assets module from {module_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
_voice_assets = _load_voice_assets_module()
|
||||
LOCAL_COSYVOICE_PROVIDER = _voice_assets.LOCAL_COSYVOICE_PROVIDER
|
||||
VoiceAsset = _voice_assets.VoiceAsset
|
||||
iter_voice_assets = _voice_assets.iter_voice_assets
|
||||
local_audio_model_root = _voice_assets.local_audio_model_root
|
||||
resolve_voice_asset = _voice_assets.resolve_voice_asset
|
||||
|
||||
|
||||
|
||||
def _soundfile_load_wav(wav: str, target_sr: int):
|
||||
import torch
|
||||
@@ -128,6 +155,7 @@ def _patch_cosyvoice_load_wav() -> None:
|
||||
class SynthesizeRequest(BaseModel):
|
||||
text: str
|
||||
voice: str | None = None
|
||||
zero_shot_spk_id: str | None = None
|
||||
model: str | None = None
|
||||
sample_rate: int | None = None
|
||||
prompt_audio: str | None = None
|
||||
@@ -150,6 +178,34 @@ def _cosyvoice_flow(cosyvoice: Any) -> Any | None:
|
||||
return getattr(model, "flow", None)
|
||||
|
||||
|
||||
def _callable_supports_keyword(fn: Any, name: str) -> bool:
|
||||
try:
|
||||
signature = inspect.signature(fn)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return name in signature.parameters or any(
|
||||
param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()
|
||||
)
|
||||
|
||||
|
||||
def _voice_signature(asset: VoiceAsset) -> tuple[str, int, int, str]:
|
||||
try:
|
||||
stat = asset.prompt_audio.stat()
|
||||
except OSError:
|
||||
stat = None
|
||||
try:
|
||||
prompt_text = asset.prompt_text.read_text(encoding="utf-8").strip() if asset.prompt_text else ""
|
||||
except OSError:
|
||||
prompt_text = ""
|
||||
digest = hashlib.sha1(prompt_text.encode("utf-8")).hexdigest()
|
||||
return (
|
||||
str(asset.prompt_audio.resolve()),
|
||||
int(getattr(stat, "st_mtime_ns", 0) or 0),
|
||||
int(getattr(stat, "st_size", 0) or 0),
|
||||
digest,
|
||||
)
|
||||
|
||||
|
||||
def current_streaming_tuning(cosyvoice: Any) -> dict[str, Any]:
|
||||
model = _cosyvoice_model(cosyvoice)
|
||||
return {
|
||||
@@ -195,6 +251,20 @@ def ensure_cosyvoice_flow_half(cosyvoice: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _is_cuda_runtime_incompatibility(exc: BaseException) -> bool:
|
||||
text = f"{type(exc).__name__}: {exc}".lower()
|
||||
return any(
|
||||
marker in text
|
||||
for marker in (
|
||||
"no kernel image is available for execution on the device",
|
||||
"cuda error",
|
||||
"invalid device function",
|
||||
"tensorrt",
|
||||
"trt",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def reset_streaming_tuning(cosyvoice: Any) -> dict[str, Any]:
|
||||
model = _cosyvoice_model(cosyvoice)
|
||||
baseline = getattr(model, "_opentalking_streaming_tuning", None)
|
||||
@@ -359,6 +429,7 @@ class CosyVoiceService:
|
||||
*,
|
||||
model_dir: str,
|
||||
runtime_dir: str,
|
||||
audio_root: str | None = None,
|
||||
device: str,
|
||||
prompt_audio: str,
|
||||
prompt_text: str,
|
||||
@@ -373,12 +444,15 @@ class CosyVoiceService:
|
||||
token_max_hop_len: int | None = None,
|
||||
stream_scale_factor: int | None = None,
|
||||
flow_n_timesteps: int | None = None,
|
||||
max_token_text_ratio: float | None = 6.0,
|
||||
max_token_text_ratio: float | None = None,
|
||||
min_token_text_ratio: float | None = None,
|
||||
mask_stop_tokens: bool = True,
|
||||
mask_stop_tokens: bool = False,
|
||||
use_zero_shot_spk_id: bool = False,
|
||||
precache_system_spks: bool = False,
|
||||
) -> None:
|
||||
self.model_dir = model_dir
|
||||
self.runtime_dir = runtime_dir
|
||||
self.audio_root = audio_root or ""
|
||||
self.device = device
|
||||
self.prompt_audio = prompt_audio
|
||||
self.prompt_text = prompt_text
|
||||
@@ -396,6 +470,8 @@ class CosyVoiceService:
|
||||
self.max_token_text_ratio = max_token_text_ratio
|
||||
self.min_token_text_ratio = min_token_text_ratio
|
||||
self.mask_stop_tokens = mask_stop_tokens
|
||||
self.use_zero_shot_spk_id = use_zero_shot_spk_id
|
||||
self.precache_system_spks = precache_system_spks
|
||||
self._model: Any | None = None
|
||||
self._model_lock = threading.Lock()
|
||||
self._loaded_model_kwargs: dict[str, Any] = {}
|
||||
@@ -403,6 +479,60 @@ class CosyVoiceService:
|
||||
self._flow_tuning: dict[str, Any] = {}
|
||||
self._llm_token_ratio_tuning: dict[str, Any] = {}
|
||||
self._llm_stop_token_patch: dict[str, Any] = {}
|
||||
self._zero_shot_spk_cache: dict[str, tuple[str, int, int, str]] = {}
|
||||
|
||||
def _audio_root(self) -> Path:
|
||||
if self.audio_root.strip():
|
||||
return Path(self.audio_root).expanduser().resolve()
|
||||
return local_audio_model_root()
|
||||
|
||||
def _resolve_voice_asset(self, voice_id: str | None) -> VoiceAsset | None:
|
||||
voice_key = (voice_id or "").strip()
|
||||
if not voice_key or voice_key == "local-default":
|
||||
return None
|
||||
return resolve_voice_asset(
|
||||
voice_key,
|
||||
provider=LOCAL_COSYVOICE_PROVIDER,
|
||||
sources=("clones", "system"),
|
||||
model_root=self._audio_root(),
|
||||
require_prompt_text=True,
|
||||
)
|
||||
|
||||
def _ensure_zero_shot_spk_registered(self, model: Any, voice_id: str, asset: VoiceAsset) -> bool:
|
||||
if not voice_id or asset.prompt_text is None:
|
||||
return False
|
||||
add_zero_shot_spk = getattr(model, "add_zero_shot_spk", None)
|
||||
if not callable(add_zero_shot_spk):
|
||||
return False
|
||||
signature = _voice_signature(asset)
|
||||
if self._zero_shot_spk_cache.get(voice_id) == signature:
|
||||
return True
|
||||
|
||||
prompt_text = asset.prompt_text.read_text(encoding="utf-8").strip()
|
||||
if not prompt_text:
|
||||
return False
|
||||
prompt_text = self._prompt_text_for_zero_shot(prompt_text)
|
||||
prompt_audio = str(asset.prompt_audio)
|
||||
if _callable_supports_keyword(add_zero_shot_spk, "zero_shot_spk_id"):
|
||||
add_zero_shot_spk(prompt_text, prompt_audio, zero_shot_spk_id=voice_id)
|
||||
else:
|
||||
add_zero_shot_spk(prompt_text, prompt_audio, voice_id)
|
||||
self._zero_shot_spk_cache[voice_id] = signature
|
||||
print(f"zero_shot_spk registered voice_id={voice_id} prompt_audio={prompt_audio}", flush=True)
|
||||
save_spkinfo = getattr(model, "save_spkinfo", None)
|
||||
if callable(save_spkinfo):
|
||||
save_spkinfo()
|
||||
return True
|
||||
|
||||
def _precache_system_zero_shot_spks(self, model: Any) -> None:
|
||||
assets = iter_voice_assets(
|
||||
provider=LOCAL_COSYVOICE_PROVIDER,
|
||||
sources=("system",),
|
||||
model_root=self._audio_root(),
|
||||
require_prompt_text=True,
|
||||
)
|
||||
for asset in assets:
|
||||
self._ensure_zero_shot_spk_registered(model, asset.voice_id, asset)
|
||||
|
||||
def model(self) -> Any:
|
||||
if self._model is not None:
|
||||
@@ -438,11 +568,52 @@ class CosyVoiceService:
|
||||
"fp16": self.fp16,
|
||||
"trt_concurrent": self.trt_concurrent,
|
||||
}
|
||||
self._model, self._loaded_model_kwargs = _instantiate_automodel(AutoModel, model_kwargs)
|
||||
try:
|
||||
self._model, self._loaded_model_kwargs = _instantiate_automodel(AutoModel, model_kwargs)
|
||||
except Exception as exc:
|
||||
if not self.load_trt:
|
||||
raise
|
||||
print(
|
||||
"CosyVoice TensorRT startup failed; falling back to non-TRT runtime: "
|
||||
f"{type(exc).__name__}: {exc}",
|
||||
flush=True,
|
||||
)
|
||||
self.load_trt = False
|
||||
model_kwargs["load_trt"] = False
|
||||
self._model, self._loaded_model_kwargs = _instantiate_automodel(AutoModel, model_kwargs)
|
||||
flow_half_applied = False
|
||||
if self.load_trt and self.fp16:
|
||||
flow_half_applied = ensure_cosyvoice_flow_half(self._model)
|
||||
try:
|
||||
flow_half_applied = ensure_cosyvoice_flow_half(self._model)
|
||||
except Exception as exc:
|
||||
if not _is_cuda_runtime_incompatibility(exc):
|
||||
raise
|
||||
print(
|
||||
"CosyVoice TensorRT/FP16 startup failed after load; "
|
||||
"falling back to non-TRT fp32 runtime: "
|
||||
f"{type(exc).__name__}: {exc}",
|
||||
flush=True,
|
||||
)
|
||||
self.load_trt = False
|
||||
self.fp16 = False
|
||||
old_model = self._model
|
||||
self._model = None
|
||||
del old_model
|
||||
gc.collect()
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
model_kwargs["load_trt"] = False
|
||||
model_kwargs["fp16"] = False
|
||||
self._model, self._loaded_model_kwargs = _instantiate_automodel(AutoModel, model_kwargs)
|
||||
self._zero_shot_spk_cache.clear()
|
||||
self._apply_runtime_tuning()
|
||||
if self.precache_system_spks:
|
||||
self._precache_system_zero_shot_spks(self._model)
|
||||
# Keep the service zero-shot first so it does not require precomputed spk2info.pt.
|
||||
print(
|
||||
"loaded cosyvoice "
|
||||
@@ -516,6 +687,35 @@ class CosyVoiceService:
|
||||
),
|
||||
}
|
||||
|
||||
def reset_model_after_empty_audio(self, *, reason: str) -> None:
|
||||
with self._model_lock:
|
||||
old_model = self._model
|
||||
self._model = None
|
||||
self._loaded_model_kwargs = {}
|
||||
if self.load_trt or self.fp16:
|
||||
print(
|
||||
"cosyvoice empty audio recovery: disabling TRT/FP16 for retry",
|
||||
flush=True,
|
||||
)
|
||||
self.load_trt = False
|
||||
self.fp16 = False
|
||||
self._zero_shot_spk_cache.clear()
|
||||
self._streaming_tuning = {}
|
||||
self._flow_tuning = {}
|
||||
self._llm_token_ratio_tuning = {}
|
||||
self._llm_stop_token_patch = {}
|
||||
if old_model is not None:
|
||||
del old_model
|
||||
gc.collect()
|
||||
try:
|
||||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
except Exception:
|
||||
pass
|
||||
print(f"cosyvoice model reset after empty audio: {reason}", flush=True)
|
||||
|
||||
def _to_wav_bytes(self, speech: Any, sample_rate: int) -> bytes:
|
||||
if hasattr(speech, "detach"):
|
||||
speech = speech.detach().cpu().numpy()
|
||||
@@ -552,6 +752,14 @@ class CosyVoiceService:
|
||||
return f"You are a helpful assistant.<|endofprompt|>{text}"
|
||||
return "You are a helpful assistant.<|endofprompt|>"
|
||||
|
||||
def _asset_prompt_text(self, asset: VoiceAsset, fallback_prompt_text: str = "") -> str:
|
||||
prompt_text = ""
|
||||
if asset.prompt_text is not None:
|
||||
prompt_text = asset.prompt_text.read_text(encoding="utf-8").strip()
|
||||
if not prompt_text:
|
||||
prompt_text = fallback_prompt_text.strip()
|
||||
return self._prompt_text_for_zero_shot(prompt_text)
|
||||
|
||||
def synthesize_wav(self, req: SynthesizeRequest) -> tuple[bytes, int, float]:
|
||||
text = req.text.strip()
|
||||
if not text:
|
||||
@@ -559,6 +767,7 @@ class CosyVoiceService:
|
||||
prompt_audio = (req.prompt_audio or self.prompt_audio).strip()
|
||||
prompt_text = (req.prompt_text or self.prompt_text).strip()
|
||||
mode = (req.mode or self.mode).strip().lower()
|
||||
voice_id = (req.zero_shot_spk_id or req.voice or "").strip()
|
||||
model = self.model()
|
||||
sample_rate = int(getattr(model, "sample_rate", 24000) or 24000)
|
||||
t0 = time.perf_counter()
|
||||
@@ -572,17 +781,40 @@ class CosyVoiceService:
|
||||
instruction = (req.instruction or self.instruction).strip()
|
||||
iterator = model.inference_instruct2(text, instruction, prompt_audio, stream=False)
|
||||
else:
|
||||
if not prompt_audio or not prompt_text:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="zero_shot mode requires prompt_audio and prompt_text",
|
||||
asset = self._resolve_voice_asset(voice_id)
|
||||
if asset is not None:
|
||||
asset_prompt_text = self._asset_prompt_text(asset, prompt_text)
|
||||
asset_prompt_audio = str(asset.prompt_audio)
|
||||
if (
|
||||
self.use_zero_shot_spk_id
|
||||
and _callable_supports_keyword(model.inference_zero_shot, "zero_shot_spk_id")
|
||||
and self._ensure_zero_shot_spk_registered(model, asset.voice_id, asset)
|
||||
):
|
||||
iterator = model.inference_zero_shot(text, "", "", stream=False, zero_shot_spk_id=asset.voice_id)
|
||||
else:
|
||||
iterator = model.inference_zero_shot(
|
||||
text,
|
||||
asset_prompt_text,
|
||||
asset_prompt_audio,
|
||||
stream=False,
|
||||
)
|
||||
else:
|
||||
if not prompt_audio or not prompt_text:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="zero_shot mode requires prompt_audio and prompt_text",
|
||||
)
|
||||
iterator = model.inference_zero_shot(
|
||||
text,
|
||||
self._prompt_text_for_zero_shot(prompt_text),
|
||||
prompt_audio,
|
||||
stream=False,
|
||||
)
|
||||
if asset is not None:
|
||||
print(
|
||||
f"zero_shot {'spk_id' if self.use_zero_shot_spk_id else 'prompt_path'} voice_id={asset.voice_id} stream=False prompt_audio={asset.prompt_audio}",
|
||||
flush=True,
|
||||
)
|
||||
iterator = model.inference_zero_shot(
|
||||
text,
|
||||
self._prompt_text_for_zero_shot(prompt_text),
|
||||
prompt_audio,
|
||||
stream=False,
|
||||
)
|
||||
parts: list[np.ndarray] = []
|
||||
with self._model_lock:
|
||||
for item in _with_request_streaming_tuning(model, iterator):
|
||||
@@ -595,17 +827,22 @@ class CosyVoiceService:
|
||||
wav_bytes = self._to_wav_bytes(np.concatenate(parts), sample_rate)
|
||||
return wav_bytes, sample_rate, time.perf_counter() - t0
|
||||
|
||||
def _streaming_iterator(self, req: SynthesizeRequest) -> tuple[Iterator[Any], int, int, float, Any]:
|
||||
def _streaming_iterator(
|
||||
self,
|
||||
req: SynthesizeRequest,
|
||||
) -> tuple[Iterator[Any], int, int, float, Any, Callable[[], Iterator[Any]] | None]:
|
||||
text = req.text.strip()
|
||||
if not text:
|
||||
raise HTTPException(status_code=400, detail="text is required")
|
||||
prompt_audio = (req.prompt_audio or self.prompt_audio).strip()
|
||||
prompt_text = (req.prompt_text or self.prompt_text).strip()
|
||||
mode = (req.mode or self.mode).strip().lower()
|
||||
voice_id = (req.zero_shot_spk_id or req.voice or "").strip()
|
||||
model = self.model()
|
||||
source_sr = int(getattr(model, "sample_rate", 24000) or 24000)
|
||||
target_sr = int(req.sample_rate or source_sr)
|
||||
t0 = time.perf_counter()
|
||||
fallback_iterator_factory: Callable[[], Iterator[Any]] | None = None
|
||||
if mode == "cross_lingual":
|
||||
if not prompt_audio:
|
||||
raise HTTPException(status_code=400, detail="prompt_audio is required")
|
||||
@@ -616,47 +853,111 @@ class CosyVoiceService:
|
||||
instruction = (req.instruction or self.instruction).strip()
|
||||
iterator = model.inference_instruct2(text, instruction, prompt_audio, stream=True)
|
||||
else:
|
||||
if not prompt_audio or not prompt_text:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="zero_shot mode requires prompt_audio and prompt_text",
|
||||
asset = self._resolve_voice_asset(voice_id)
|
||||
if asset is not None:
|
||||
asset_prompt_text = self._asset_prompt_text(asset, prompt_text)
|
||||
asset_prompt_audio = str(asset.prompt_audio)
|
||||
if (
|
||||
self.use_zero_shot_spk_id
|
||||
and _callable_supports_keyword(model.inference_zero_shot, "zero_shot_spk_id")
|
||||
and self._ensure_zero_shot_spk_registered(model, asset.voice_id, asset)
|
||||
):
|
||||
iterator = model.inference_zero_shot(text, "", "", stream=True, zero_shot_spk_id=asset.voice_id)
|
||||
|
||||
def fallback_iterator(
|
||||
*,
|
||||
text: str = text,
|
||||
prompt_text: str = asset_prompt_text,
|
||||
prompt_audio: str = asset_prompt_audio,
|
||||
voice_id: str = asset.voice_id,
|
||||
) -> Iterator[Any]:
|
||||
self._zero_shot_spk_cache.pop(voice_id, None)
|
||||
print(
|
||||
"zero_shot_spk_id produced no audio; falling back to prompt "
|
||||
f"voice_id={voice_id} prompt_audio={prompt_audio}",
|
||||
flush=True,
|
||||
)
|
||||
return model.inference_zero_shot(text, prompt_text, prompt_audio, stream=True)
|
||||
|
||||
fallback_iterator_factory = fallback_iterator
|
||||
else:
|
||||
iterator = model.inference_zero_shot(
|
||||
text,
|
||||
asset_prompt_text,
|
||||
asset_prompt_audio,
|
||||
stream=True,
|
||||
)
|
||||
else:
|
||||
if not prompt_audio or not prompt_text:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="zero_shot mode requires prompt_audio and prompt_text",
|
||||
)
|
||||
iterator = model.inference_zero_shot(
|
||||
text,
|
||||
self._prompt_text_for_zero_shot(prompt_text),
|
||||
prompt_audio,
|
||||
stream=True,
|
||||
)
|
||||
iterator = model.inference_zero_shot(
|
||||
text,
|
||||
self._prompt_text_for_zero_shot(prompt_text),
|
||||
prompt_audio,
|
||||
stream=True,
|
||||
)
|
||||
return iterator, source_sr, target_sr, t0, model
|
||||
if asset is not None:
|
||||
print(
|
||||
f"zero_shot {'spk_id' if self.use_zero_shot_spk_id else 'prompt_path'} voice_id={asset.voice_id} stream=True prompt_audio={asset.prompt_audio}",
|
||||
flush=True,
|
||||
)
|
||||
return iterator, source_sr, target_sr, t0, model, fallback_iterator_factory
|
||||
|
||||
def synthesize_pcm_stream(self, req: SynthesizeRequest) -> tuple[Iterator[bytes], int]:
|
||||
iterator, source_sr, target_sr, t0, model = self._streaming_iterator(req)
|
||||
iterator, source_sr, target_sr, t0, model, fallback_iterator_factory = self._streaming_iterator(req)
|
||||
|
||||
def generate() -> Iterator[bytes]:
|
||||
first = True
|
||||
chunks = 0
|
||||
samples = 0
|
||||
with self._model_lock:
|
||||
tuned_iterator = _with_request_streaming_tuning(model, iterator)
|
||||
output_sr = target_sr
|
||||
|
||||
def emit(
|
||||
tuned_iterator: Iterator[Any],
|
||||
*,
|
||||
source_sr_for_attempt: int,
|
||||
target_sr_for_attempt: int,
|
||||
t0_for_attempt: float,
|
||||
) -> Iterator[bytes]:
|
||||
nonlocal first, chunks, samples, output_sr
|
||||
output_sr = target_sr_for_attempt
|
||||
for item in tuned_iterator:
|
||||
speech = item.get("tts_speech") if isinstance(item, dict) else item
|
||||
pcm = self._audio_to_i16(speech)
|
||||
pcm = self._resample_linear(pcm, source_sr, target_sr)
|
||||
pcm = self._resample_linear(pcm, source_sr_for_attempt, target_sr_for_attempt)
|
||||
if pcm.size == 0:
|
||||
continue
|
||||
if first:
|
||||
print(
|
||||
f"first_pcm chars={len(req.text.strip())} sr={target_sr} seconds={time.perf_counter() - t0:.3f}",
|
||||
f"first_pcm chars={len(req.text.strip())} sr={target_sr_for_attempt} seconds={time.perf_counter() - t0_for_attempt:.3f}",
|
||||
flush=True,
|
||||
)
|
||||
first = False
|
||||
chunks += 1
|
||||
samples += int(pcm.size)
|
||||
yield pcm.astype("<i2", copy=False).tobytes()
|
||||
|
||||
with self._model_lock:
|
||||
yield from emit(
|
||||
_with_request_streaming_tuning(model, iterator),
|
||||
source_sr_for_attempt=source_sr,
|
||||
target_sr_for_attempt=target_sr,
|
||||
t0_for_attempt=t0,
|
||||
)
|
||||
if chunks == 0 and fallback_iterator_factory is not None:
|
||||
yield from emit(
|
||||
_with_request_streaming_tuning(model, fallback_iterator_factory()),
|
||||
source_sr_for_attempt=source_sr,
|
||||
target_sr_for_attempt=target_sr,
|
||||
t0_for_attempt=t0,
|
||||
)
|
||||
if chunks == 0:
|
||||
raise RuntimeError("CosyVoice returned no audio")
|
||||
print(
|
||||
f"synth_stream chars={len(req.text.strip())} sr={target_sr} chunks={chunks} audio_seconds={samples / target_sr:.3f} wall_seconds={time.perf_counter() - t0:.3f}",
|
||||
f"synth_stream chars={len(req.text.strip())} sr={output_sr} chunks={chunks} audio_seconds={samples / output_sr:.3f} wall_seconds={time.perf_counter() - t0:.3f}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
@@ -683,17 +984,47 @@ def create_app(service: CosyVoiceService) -> FastAPI:
|
||||
|
||||
@app.post("/synthesize")
|
||||
def synthesize(req: SynthesizeRequest) -> StreamingResponse:
|
||||
try:
|
||||
def open_stream() -> tuple[Iterator[bytes], bytes, int]:
|
||||
stream, sr = service.synthesize_pcm_stream(req)
|
||||
iterator = iter(stream)
|
||||
first = next(iterator)
|
||||
return iterator, first, sr
|
||||
|
||||
try:
|
||||
iterator, first_chunk, sr = open_stream()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"cosyvoice synth failed: {type(exc).__name__}: {exc}",
|
||||
) from exc
|
||||
if "CosyVoice returned no audio" in str(exc):
|
||||
reset = getattr(service, "reset_model_after_empty_audio", None)
|
||||
if callable(reset):
|
||||
reset(reason=str(exc))
|
||||
try:
|
||||
iterator, first_chunk, sr = open_stream()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as retry_exc:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"cosyvoice synth failed after model reset: {type(retry_exc).__name__}: {retry_exc}",
|
||||
) from retry_exc
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"cosyvoice synth failed: {type(exc).__name__}: {exc}",
|
||||
) from exc
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"cosyvoice synth failed: {type(exc).__name__}: {exc}",
|
||||
) from exc
|
||||
|
||||
def response_stream() -> Iterator[bytes]:
|
||||
yield first_chunk
|
||||
yield from iterator
|
||||
|
||||
return StreamingResponse(
|
||||
stream,
|
||||
response_stream(),
|
||||
media_type=f"audio/L16; rate={sr}; channels=1",
|
||||
headers={"X-Audio-Sample-Rate": str(sr)},
|
||||
)
|
||||
@@ -705,6 +1036,55 @@ def _local_audio_root() -> Path:
|
||||
return Path(os.environ.get("OPENTALKING_LOCAL_AUDIO_MODEL_ROOT", "./models/local-audio")).expanduser()
|
||||
|
||||
|
||||
def _default_system_voice_prompt(root: Path) -> tuple[str, str] | None:
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
voice_roots = [
|
||||
root / "voices" / "system",
|
||||
repo_root / "opentalking" / "assets" / "voices" / "system",
|
||||
]
|
||||
seen: set[Path] = set()
|
||||
for voice_root in voice_roots:
|
||||
try:
|
||||
resolved = voice_root.resolve()
|
||||
except OSError:
|
||||
resolved = voice_root
|
||||
if resolved in seen or not voice_root.is_dir():
|
||||
continue
|
||||
seen.add(resolved)
|
||||
for voice_dir in sorted(path for path in voice_root.iterdir() if path.is_dir()):
|
||||
prompt_audio = voice_dir / "prompt.wav"
|
||||
prompt_text = voice_dir / "prompt.txt"
|
||||
if not prompt_audio.is_file() or not prompt_text.is_file():
|
||||
continue
|
||||
try:
|
||||
text = prompt_text.read_text(encoding="utf-8").strip()
|
||||
except OSError:
|
||||
text = ""
|
||||
if text:
|
||||
print(f"using default CosyVoice system voice prompt: {voice_dir.name}", flush=True)
|
||||
return str(prompt_audio), text
|
||||
return None
|
||||
|
||||
|
||||
def _torch_cuda_supports_device(device: str) -> tuple[bool, str]:
|
||||
if not device.startswith("cuda"):
|
||||
return True, ""
|
||||
try:
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
return False, "torch.cuda.is_available() is false"
|
||||
index = int(device.split(":", 1)[1]) if ":" in device else 0
|
||||
major, minor = torch.cuda.get_device_capability(index)
|
||||
wanted = f"sm_{major}{minor}"
|
||||
arch_list = set(torch.cuda.get_arch_list() or [])
|
||||
if arch_list and wanted not in arch_list:
|
||||
return False, f"device capability {wanted} is not in torch arch list {sorted(arch_list)}"
|
||||
except Exception as exc:
|
||||
return False, f"failed to inspect torch CUDA support: {type(exc).__name__}: {exc}"
|
||||
return True, ""
|
||||
|
||||
|
||||
def _env_bool(name: str, default: bool = False) -> bool:
|
||||
raw = os.environ.get(name, "").strip().lower()
|
||||
if not raw:
|
||||
@@ -733,6 +1113,31 @@ def build_service_from_env() -> CosyVoiceService:
|
||||
fp16_raw = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_FP16", "auto").strip().lower()
|
||||
fp16 = device.startswith("cuda") if fp16_raw == "auto" else fp16_raw not in {"0", "false", "no", "off"}
|
||||
root = _local_audio_root()
|
||||
load_trt = _env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_LOAD_TRT", False)
|
||||
cuda_supported, cuda_reason = _torch_cuda_supports_device(device)
|
||||
if not cuda_supported:
|
||||
print(
|
||||
"CosyVoice CUDA runtime is not compatible with this torch build; "
|
||||
f"falling back to CPU runtime: {cuda_reason}",
|
||||
flush=True,
|
||||
)
|
||||
device = "cpu"
|
||||
fp16 = False
|
||||
load_trt = False
|
||||
os.environ["OPENTALKING_TTS_LOCAL_COSYVOICE_PRELOAD"] = "0"
|
||||
mode = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_MODE", "zero_shot")
|
||||
prompt_audio = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_AUDIO", "").strip()
|
||||
prompt_text = os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_TEXT", "").strip()
|
||||
normalized_mode = mode.strip().lower()
|
||||
if (
|
||||
(normalized_mode in {"cross_lingual", "instruct"} and not prompt_audio)
|
||||
or (normalized_mode not in {"cross_lingual", "instruct"} and (not prompt_audio or not prompt_text))
|
||||
):
|
||||
default_prompt = _default_system_voice_prompt(root)
|
||||
if default_prompt is not None:
|
||||
default_audio, default_text = default_prompt
|
||||
prompt_audio = prompt_audio or default_audio
|
||||
prompt_text = prompt_text or default_text
|
||||
return CosyVoiceService(
|
||||
model_dir=os.environ.get(
|
||||
"OPENTALKING_TTS_LOCAL_COSYVOICE_MODEL_DIR",
|
||||
@@ -742,26 +1147,29 @@ def build_service_from_env() -> CosyVoiceService:
|
||||
"OPENTALKING_TTS_LOCAL_COSYVOICE_RUNTIME_DIR",
|
||||
str(root / "runtime" / "CosyVoice"),
|
||||
),
|
||||
audio_root=str(root),
|
||||
device=device,
|
||||
prompt_audio=os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_AUDIO", ""),
|
||||
prompt_text=os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_PROMPT_TEXT", ""),
|
||||
mode=os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_MODE", "zero_shot"),
|
||||
prompt_audio=prompt_audio,
|
||||
prompt_text=prompt_text,
|
||||
mode=mode,
|
||||
instruction=os.environ.get(
|
||||
"OPENTALKING_TTS_LOCAL_COSYVOICE_INSTRUCTION",
|
||||
"You are a helpful assistant.<|endofprompt|>",
|
||||
),
|
||||
fp16=fp16,
|
||||
load_jit=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_LOAD_JIT", False),
|
||||
load_trt=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_LOAD_TRT", False),
|
||||
load_trt=load_trt,
|
||||
load_vllm=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_LOAD_VLLM", False),
|
||||
trt_concurrent=int(os.environ.get("OPENTALKING_TTS_LOCAL_COSYVOICE_TRT_CONCURRENT", "1") or "1"),
|
||||
token_hop_len=_env_optional_int("OPENTALKING_TTS_LOCAL_COSYVOICE_TOKEN_HOP_LEN"),
|
||||
token_max_hop_len=_env_optional_int("OPENTALKING_TTS_LOCAL_COSYVOICE_TOKEN_MAX_HOP_LEN"),
|
||||
stream_scale_factor=_env_optional_int("OPENTALKING_TTS_LOCAL_COSYVOICE_STREAM_SCALE_FACTOR"),
|
||||
flow_n_timesteps=_env_optional_int("OPENTALKING_TTS_LOCAL_COSYVOICE_FLOW_N_TIMESTEPS"),
|
||||
max_token_text_ratio=_env_optional_float("OPENTALKING_TTS_LOCAL_COSYVOICE_MAX_TOKEN_TEXT_RATIO", 6.0),
|
||||
max_token_text_ratio=_env_optional_float("OPENTALKING_TTS_LOCAL_COSYVOICE_MAX_TOKEN_TEXT_RATIO"),
|
||||
min_token_text_ratio=_env_optional_float("OPENTALKING_TTS_LOCAL_COSYVOICE_MIN_TOKEN_TEXT_RATIO"),
|
||||
mask_stop_tokens=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_MASK_STOP_TOKENS", True),
|
||||
mask_stop_tokens=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_MASK_STOP_TOKENS", False),
|
||||
use_zero_shot_spk_id=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_USE_SPK_ID", False),
|
||||
precache_system_spks=_env_bool("OPENTALKING_TTS_LOCAL_COSYVOICE_PRECACHE_SPKS", False),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
@@ -63,10 +62,28 @@ def _resample_linear(pcm: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
|
||||
|
||||
def _to_wav_bytes(pcm: np.ndarray, sr: int) -> bytes:
|
||||
buf = io.BytesIO()
|
||||
sf.write(buf, np.asarray(pcm, dtype=np.int16), sr, format="WAV", subtype="PCM_16")
|
||||
_write_wav_i16(buf, np.asarray(pcm, dtype=np.int16), sr)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
def _write_wav_i16(path_or_file: str | Path | io.BytesIO, data: np.ndarray, sample_rate: int) -> None:
|
||||
pcm = np.asarray(data)
|
||||
if pcm.ndim == 2 and pcm.shape[0] == 1:
|
||||
pcm = pcm[0]
|
||||
elif pcm.ndim == 2:
|
||||
pcm = pcm.T.reshape(-1)
|
||||
if np.issubdtype(pcm.dtype, np.floating):
|
||||
pcm = np.clip(pcm, -1.0, 1.0)
|
||||
pcm = np.round(pcm * 32767.0).astype("<i2")
|
||||
else:
|
||||
pcm = np.clip(pcm, -32768, 32767).astype("<i2")
|
||||
with wave.open(path_or_file if hasattr(path_or_file, "write") else str(path_or_file), "wb") as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(int(sample_rate))
|
||||
wf.writeframes(pcm.reshape(-1).tobytes())
|
||||
|
||||
|
||||
class LocalIndexTTSService:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -216,7 +233,7 @@ class LocalIndexTTSService:
|
||||
data = data[0]
|
||||
elif data.ndim == 2:
|
||||
data = data.T
|
||||
sf.write(path, data, sample_rate, subtype="PCM_16")
|
||||
_write_wav_i16(path, data, sample_rate)
|
||||
return None
|
||||
|
||||
torchaudio.save = save
|
||||
@@ -286,10 +303,15 @@ class LocalIndexTTSService:
|
||||
raise HTTPException(status_code=400, detail=f"prompt_audio does not exist: {prompt_audio}")
|
||||
target_sr = int(req.sample_rate or 16000)
|
||||
t0 = time.perf_counter()
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
|
||||
fd, tmp_name = tempfile.mkstemp(suffix=".wav")
|
||||
os.close(fd)
|
||||
tmp_path = Path(tmp_name)
|
||||
try:
|
||||
with self._lock:
|
||||
self.model().infer(prompt_audio, text, tmp.name, **self._infer_kwargs(req))
|
||||
pcm, sr = _audio_to_i16(Path(tmp.name))
|
||||
self.model().infer(prompt_audio, text, tmp_name, **self._infer_kwargs(req))
|
||||
pcm, sr = _audio_to_i16(tmp_path)
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
pcm = _resample_linear(pcm, sr, target_sr)
|
||||
elapsed = time.perf_counter() - t0
|
||||
print(
|
||||
|
||||
@@ -55,23 +55,22 @@ def test_app_clears_stale_subtitle_state_on_context_reset() -> None:
|
||||
assert model_clear_idx < model_set_idx
|
||||
|
||||
|
||||
def test_avatar_selection_stage_supports_manage_knowledge_bases() -> None:
|
||||
source = (ROOT / "apps/web/src/components/AvatarSelectionStage.tsx").read_text(encoding="utf-8")
|
||||
def test_settings_panel_supports_knowledge_base_selection() -> None:
|
||||
source = (ROOT / "apps/web/src/components/SettingsPanel.tsx").read_text(encoding="utf-8")
|
||||
|
||||
assert "knowledgeBaseIds" in source
|
||||
assert "selected knowledge bases" not in source.lower()
|
||||
assert "可用知识库" not in source
|
||||
|
||||
agent_idx = source.index("Agent 增强")
|
||||
start_idx = source.index("{queued ?", agent_idx)
|
||||
agent_block = source[agent_idx:start_idx]
|
||||
assert "当前形象知识库" in agent_block
|
||||
assert "{knowledgeBases.length} 个知识库" in agent_block
|
||||
assert "{selectedKnowledgeBaseIds.length || 0} 个知识库" not in agent_block
|
||||
assert "可用知识库" not in agent_block
|
||||
assert "onManageKnowledgeBases" not in agent_block
|
||||
assert "flex flex-wrap" in agent_block
|
||||
assert "inline-flex max-w-full" in agent_block
|
||||
knowledge_idx = source.index('title="知识库"')
|
||||
model_idx = source.index('title="驱动模型"')
|
||||
knowledge_block = source[knowledge_idx:model_idx]
|
||||
assert "{knowledgeBases.length} 个知识库" in knowledge_block
|
||||
assert "onManageKnowledgeBases" in knowledge_block
|
||||
assert "selectedKnowledgeBaseSet.has(knowledgeBase.id)" in knowledge_block
|
||||
assert "disabled={!knowledgeBaseReady}" in knowledge_block
|
||||
assert "已选" in knowledge_block
|
||||
assert "可用知识库" not in knowledge_block
|
||||
|
||||
|
||||
def test_settings_panel_places_knowledge_between_avatar_and_model() -> None:
|
||||
|
||||
Reference in New Issue
Block a user