mirror of
https://github.com/datascale-ai/opentalking.git
synced 2026-07-03 15:22:34 +08:00
feat: add wechat memory persona import
This commit is contained in:
@@ -1,13 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Literal
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Query, Request, UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from opentalking.core.config import get_settings
|
||||
from opentalking.persona.session import default_persona_store
|
||||
from opentalking.persona.wechat_import import WeChatImportJobRegistry
|
||||
from opentalking.providers.memory.decision_agent import MemoryDecisionAgent
|
||||
from opentalking.providers.memory.factory import build_memory_provider
|
||||
from opentalking.providers.memory.runtime import MemoryRuntime, normalize_memory_scope
|
||||
@@ -34,6 +38,16 @@ class MemoryImportRequest(BaseModel):
|
||||
source: str | None = None
|
||||
|
||||
|
||||
class WeChatSpeakerSelectionRequest(BaseModel):
|
||||
target_speaker_id: str
|
||||
|
||||
|
||||
class WeChatImportCommitRequest(BaseModel):
|
||||
persona_id: str
|
||||
persona_name: str | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
def _profile(value: str | None) -> str:
|
||||
return (value or get_settings().memory_default_profile_id or "default").strip() or "default"
|
||||
|
||||
@@ -49,6 +63,110 @@ def _library_id(value: str | None) -> str:
|
||||
return (value or "").strip() or f"lib_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
def _wechat_registry(request: Request) -> WeChatImportJobRegistry:
|
||||
registry = getattr(request.app.state, "wechat_import_registry", None)
|
||||
if registry is not None:
|
||||
return registry
|
||||
persona_store = getattr(request.app.state, "persona_store", None) or default_persona_store()
|
||||
registry = WeChatImportJobRegistry(
|
||||
persona_store=persona_store,
|
||||
memory_provider=build_memory_provider(),
|
||||
)
|
||||
request.app.state.wechat_import_registry = registry
|
||||
return registry
|
||||
|
||||
|
||||
@router.post("/wechat-import")
|
||||
async def create_wechat_import_job(
|
||||
request: Request,
|
||||
file: UploadFile | None = File(default=None),
|
||||
profile_id: str = Form(default="default"),
|
||||
memory_library_id: str = Form(default="default"),
|
||||
avatar_id: str = Form(default=""),
|
||||
avatar_model: str = Form(default="mock"),
|
||||
character_id: str | None = Form(default=None),
|
||||
target_speaker_id: str | None = Form(default=None),
|
||||
source_format: str = Form(default="auto"),
|
||||
timezone: str = Form(default="Asia/Shanghai"),
|
||||
source_url: str | None = Form(default=None),
|
||||
) -> dict[str, object]:
|
||||
if (source_url or "").strip():
|
||||
raise HTTPException(status_code=400, detail="please upload a WeFlow export file; API URLs are not supported")
|
||||
if file is None:
|
||||
raise HTTPException(status_code=400, detail="WeFlow export file upload is required")
|
||||
clean_avatar_id = (avatar_id or "").strip()
|
||||
if not clean_avatar_id:
|
||||
raise HTTPException(status_code=400, detail="avatar_id is required")
|
||||
suffix = Path(file.filename or "").suffix or ".json"
|
||||
tmp_path: Path | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(prefix="opentalking-wechat-upload-", suffix=suffix, delete=False) as tmp:
|
||||
tmp.write(await file.read())
|
||||
tmp_path = Path(tmp.name)
|
||||
job = await _wechat_registry(request).create_job_async(
|
||||
tmp_path,
|
||||
profile_id=profile_id,
|
||||
memory_library_id=memory_library_id,
|
||||
avatar_id=clean_avatar_id,
|
||||
avatar_model=(avatar_model or "mock").strip() or "mock",
|
||||
character_id=character_id,
|
||||
target_speaker_id=target_speaker_id,
|
||||
source_format=source_format,
|
||||
timezone=timezone,
|
||||
)
|
||||
return job.to_dict()
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
finally:
|
||||
if tmp_path is not None:
|
||||
try:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/wechat-import/{job_id}")
|
||||
async def get_wechat_import_job(request: Request, job_id: str) -> dict[str, object]:
|
||||
try:
|
||||
return _wechat_registry(request).get_job(job_id).to_dict()
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail="wechat import job not found") from exc
|
||||
|
||||
|
||||
@router.post("/wechat-import/{job_id}/speaker")
|
||||
async def select_wechat_import_speaker(
|
||||
request: Request,
|
||||
job_id: str,
|
||||
body: WeChatSpeakerSelectionRequest,
|
||||
) -> dict[str, object]:
|
||||
try:
|
||||
return (await _wechat_registry(request).select_speaker_async(job_id, body.target_speaker_id)).to_dict()
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail="wechat import job not found") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/wechat-import/{job_id}/commit")
|
||||
async def commit_wechat_import_job(
|
||||
request: Request,
|
||||
job_id: str,
|
||||
body: WeChatImportCommitRequest,
|
||||
) -> dict[str, object]:
|
||||
try:
|
||||
result = await _wechat_registry(request).commit(
|
||||
job_id,
|
||||
persona_id=body.persona_id,
|
||||
persona_name=body.persona_name,
|
||||
description=body.description,
|
||||
)
|
||||
return asdict(result)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail="wechat import job not found") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.get("/libraries")
|
||||
async def list_libraries(
|
||||
profile_id: str = Query("default"),
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel
|
||||
|
||||
from opentalking.agent.context_builder import default_knowledge_store
|
||||
from opentalking.persona.package import export_persona_package, import_persona_package, validate_persona_package
|
||||
from opentalking.persona.persona_md import read_persona_md, write_persona_md
|
||||
from opentalking.persona.session import default_persona_store
|
||||
|
||||
router = APIRouter(prefix="/personas", tags=["personas"])
|
||||
@@ -28,6 +29,7 @@ class PersonaVoiceResponse(BaseModel):
|
||||
|
||||
|
||||
class PersonaAgentResponse(BaseModel):
|
||||
persona_prompt: str | None = None
|
||||
system_prompt: str | None = None
|
||||
style_prompt: str | None = None
|
||||
memory_enabled: bool = False
|
||||
@@ -67,6 +69,16 @@ class PersonasResponse(BaseModel):
|
||||
personas: list[PersonaResponse]
|
||||
|
||||
|
||||
|
||||
|
||||
class PersonaMdRequest(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class PersonaMdResponse(BaseModel):
|
||||
persona_id: str
|
||||
content: str
|
||||
|
||||
class DeletePersonaResponse(BaseModel):
|
||||
deleted: bool
|
||||
|
||||
@@ -92,6 +104,29 @@ async def get_persona(persona_id: str) -> PersonaResponse:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.get("/{persona_id}/persona-md", response_model=PersonaMdResponse)
|
||||
async def get_persona_md(persona_id: str) -> PersonaMdResponse:
|
||||
try:
|
||||
record = default_persona_store().get_persona(persona_id)
|
||||
return PersonaMdResponse(persona_id=persona_id, content=read_persona_md(record))
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail="persona not found") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.put("/{persona_id}/persona-md", response_model=PersonaMdResponse)
|
||||
async def update_persona_md(persona_id: str, body: PersonaMdRequest) -> PersonaMdResponse:
|
||||
try:
|
||||
record = default_persona_store().get_persona(persona_id)
|
||||
content = write_persona_md(record, body.content)
|
||||
return PersonaMdResponse(persona_id=persona_id, content=content)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail="persona not found") from exc
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/import", response_model=PersonaResponse)
|
||||
async def import_persona(file: UploadFile = File(...)) -> PersonaResponse:
|
||||
filename = file.filename or "persona.otpersona"
|
||||
|
||||
@@ -643,7 +643,7 @@ async def create_session(body: CreateSessionRequest, request: Request) -> Create
|
||||
|
||||
custom = _session_customizations(request).get(avatar_id, {})
|
||||
persona_prompt = persona_defaults.llm_system_prompt if persona_defaults is not None else None
|
||||
llm_system_prompt = (body.llm_system_prompt or "").strip() or persona_prompt or custom.get("llm_system_prompt")
|
||||
llm_system_prompt = persona_prompt or (body.llm_system_prompt or "").strip() or custom.get("llm_system_prompt")
|
||||
custom_ref_image_path = custom.get("custom_ref_image_path")
|
||||
if custom_ref_image_path and not Path(custom_ref_image_path).exists():
|
||||
custom_ref_image_path = None
|
||||
|
||||
106
apps/api/tests/test_memory_wechat_import_api.py
Normal file
106
apps/api/tests/test_memory_wechat_import_api.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from opentalking.persona import memory_builder
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from apps.api.routes import memory as memory_routes
|
||||
from opentalking.persona.store import PersonaStore
|
||||
from opentalking.providers.memory.mem0_provider import InMemoryMemoryProvider
|
||||
|
||||
|
||||
def write_group_export(path: Path) -> Path:
|
||||
path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"chatlab": {"version": "0.0.2", "generator": "WeFlow"},
|
||||
"meta": {"name": "Project room", "platform": "wechat", "type": "group", "groupId": "room@chatroom"},
|
||||
"members": [
|
||||
{"platformId": "wxid_li", "accountName": "Li", "groupNickname": "Li"},
|
||||
{"platformId": "wxid_chen", "accountName": "Chen", "groupNickname": "Chen"},
|
||||
],
|
||||
"messages": [
|
||||
{"sender": "wxid_li", "accountName": "Li", "groupNickname": "Li", "timestamp": 1738713600, "type": 0, "content": "calm small steps.", "platformMessageId": "1"},
|
||||
{"sender": "wxid_chen", "accountName": "Chen", "groupNickname": "Chen", "timestamp": 1738713660, "type": 0, "content": "ship now.", "platformMessageId": "2"},
|
||||
{"sender": "wxid_li", "accountName": "Li", "groupNickname": "Li", "timestamp": 1738713720, "type": 0, "content": "demo secret code is 8848 and should never be copied raw.", "platformMessageId": "3"},
|
||||
],
|
||||
},
|
||||
ensure_ascii=True,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def test_memory_api_wechat_import_upload_select_commit(monkeypatch, tmp_path: Path) -> None:
|
||||
async def empty_complete(self, messages):
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(memory_builder._ConfiguredPersonaLLM, "complete", empty_complete)
|
||||
|
||||
provider = InMemoryMemoryProvider()
|
||||
store = PersonaStore(tmp_path / "personas")
|
||||
monkeypatch.setattr(memory_routes, "build_memory_provider", lambda: provider)
|
||||
|
||||
app = FastAPI()
|
||||
app.state.persona_store = store
|
||||
app.include_router(memory_routes.router)
|
||||
export_path = write_group_export(tmp_path / "weflow.json")
|
||||
|
||||
with TestClient(app) as client:
|
||||
with export_path.open("rb") as file:
|
||||
created = client.post(
|
||||
"/memory/wechat-import",
|
||||
data={
|
||||
"profile_id": "default",
|
||||
"memory_library_id": "default",
|
||||
"avatar_id": "avatar-li",
|
||||
"avatar_model": "mock",
|
||||
},
|
||||
files={"file": ("weflow.json", file, "application/json")},
|
||||
)
|
||||
assert created.status_code == 200
|
||||
payload = created.json()
|
||||
assert payload["status"] == "needs_speaker_selection"
|
||||
assert [speaker["id"] for speaker in payload["speakers"]] == ["wxid_li", "wxid_chen"]
|
||||
|
||||
job_id = payload["id"]
|
||||
selected = client.post(f"/memory/wechat-import/{job_id}/speaker", json={"target_speaker_id": "wxid_li"})
|
||||
assert selected.status_code == 200
|
||||
assert selected.json()["status"] == "draft_ready"
|
||||
assert "8848" not in selected.json()["persona_md"]
|
||||
|
||||
committed = client.post(
|
||||
f"/memory/wechat-import/{job_id}/commit",
|
||||
json={"persona_id": "friend-li", "persona_name": "Friend Li"},
|
||||
)
|
||||
assert committed.status_code == 200
|
||||
assert committed.json()["persona_id"] == "friend-li"
|
||||
assert committed.json()["memory_imported"] == 3
|
||||
|
||||
record = store.get_persona("friend-li")
|
||||
assert record.manifest.avatar.id == "avatar-li"
|
||||
assert (record.path / "persona.md").is_file()
|
||||
|
||||
|
||||
def test_memory_api_wechat_import_rejects_api_url(monkeypatch, tmp_path: Path) -> None:
|
||||
monkeypatch.setattr(memory_routes, "build_memory_provider", lambda: InMemoryMemoryProvider())
|
||||
app = FastAPI()
|
||||
app.state.persona_store = PersonaStore(tmp_path / "personas")
|
||||
app.include_router(memory_routes.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/memory/wechat-import",
|
||||
data={
|
||||
"source_url": "http://127.0.0.1:5031/api/v1/messages?talker=wxid_xxx",
|
||||
"avatar_id": "avatar-li",
|
||||
"avatar_model": "mock",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "upload" in response.json()["detail"]
|
||||
@@ -87,3 +87,64 @@ def test_persona_api_rejects_non_package(monkeypatch, tmp_path: Path) -> None:
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "persona package must be .otpersona or .zip"
|
||||
|
||||
|
||||
|
||||
def test_persona_md_api_reads_and_updates_file(monkeypatch, tmp_path: Path) -> None:
|
||||
store = PersonaStore(tmp_path / "personas")
|
||||
persona_dir = tmp_path / "personas" / "friend-li"
|
||||
persona_dir.mkdir(parents=True)
|
||||
(persona_dir / "persona.md").write_text("# Persona\n旧人设", encoding="utf-8")
|
||||
(persona_dir / "persona.json").write_text(
|
||||
"""
|
||||
{
|
||||
"schema_version": "0.1",
|
||||
"id": "friend-li",
|
||||
"name": "小李",
|
||||
"description": "微信导入生成的 Persona",
|
||||
"locale": "zh-CN",
|
||||
"avatar": {"id": "custom-friend-li", "model": "mock"},
|
||||
"agent": {"persona_prompt": "persona.md", "memory_enabled": true, "knowledge_enabled": false},
|
||||
"safety": {"authorized_avatar": true, "authorized_voice": false, "content_label_required": true}
|
||||
}
|
||||
""".strip()
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr(persona_routes, "default_persona_store", lambda: store)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(persona_routes.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
loaded = client.get("/personas/friend-li/persona-md")
|
||||
assert loaded.status_code == 200
|
||||
assert loaded.json() == {"persona_id": "friend-li", "content": "# Persona\n旧人设"}
|
||||
|
||||
updated = client.put(
|
||||
"/personas/friend-li/persona-md",
|
||||
json={"content": "# Persona\n新人设"},
|
||||
)
|
||||
assert updated.status_code == 200
|
||||
assert updated.json() == {"persona_id": "friend-li", "content": "# Persona\n新人设"}
|
||||
|
||||
loaded_again = client.get("/personas/friend-li/persona-md")
|
||||
assert loaded_again.json()["content"] == "# Persona\n新人设"
|
||||
|
||||
|
||||
def test_persona_md_api_rejects_path_escape(monkeypatch, tmp_path: Path) -> None:
|
||||
store = PersonaStore(tmp_path / "personas")
|
||||
persona_dir = tmp_path / "personas" / "bad"
|
||||
persona_dir.mkdir(parents=True)
|
||||
(persona_dir / "persona.json").write_text(
|
||||
'{"schema_version":"0.1","id":"bad","name":"bad","description":"bad","locale":"zh-CN","avatar":{"id":"a","model":"mock"},"agent":{"persona_prompt":"../escape.md"}}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr(persona_routes, "default_persona_store", lambda: store)
|
||||
app = FastAPI()
|
||||
app.include_router(persona_routes.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/personas/bad/persona-md")
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
@@ -768,7 +768,7 @@ def test_create_session_expands_persona_defaults(
|
||||
assert calls[0]["knowledge_base_ids"] == ["kb_persona"]
|
||||
|
||||
|
||||
def test_create_session_allows_explicit_fields_to_override_persona(
|
||||
def test_create_session_keeps_persona_prompt_above_explicit_llm_prompt(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
@@ -804,7 +804,7 @@ def test_create_session_allows_explicit_fields_to_override_persona(
|
||||
"locale": "zh-CN",
|
||||
"avatar": {"id": "singer", "model": "mock"},
|
||||
"voice": {"provider": "edge", "voice_id": "zh-CN-XiaoxiaoNeural"},
|
||||
"agent": {"memory_enabled": true, "knowledge_enabled": true, "knowledge_base_ids": ["kb_persona"]},
|
||||
"agent": {"persona_prompt": "persona.md", "memory_enabled": true, "knowledge_enabled": true, "knowledge_base_ids": ["kb_persona"]},
|
||||
"runtime": {"stt_provider": "sensevoice", "tts_provider": "edge"},
|
||||
"safety": {"authorized_avatar": true, "authorized_voice": true, "content_label_required": true}
|
||||
}
|
||||
@@ -812,6 +812,7 @@ def test_create_session_allows_explicit_fields_to_override_persona(
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(persona_dir / "persona.md").write_text("# Persona\nPersona prompt from package.\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(sessions_routes.session_service, "create_session", fake_create_session)
|
||||
monkeypatch.setattr(
|
||||
@@ -851,7 +852,7 @@ def test_create_session_allows_explicit_fields_to_override_persona(
|
||||
assert response.status == "created"
|
||||
assert calls[0]["avatar_id"] == "anchor"
|
||||
assert calls[0]["model"] == "flashtalk"
|
||||
assert calls[0]["llm_system_prompt"] == "显式提示词"
|
||||
assert calls[0]["llm_system_prompt"] == "# Persona\nPersona prompt from package."
|
||||
assert calls[0]["knowledge_base_ids"] == ["kb_override"]
|
||||
assert calls[0]["memory_enabled"] is False
|
||||
|
||||
@@ -1342,3 +1343,78 @@ def test_close_cancels_running_and_queued_speech_tasks(unified_client: TestClien
|
||||
|
||||
_wait_until(lambda: set(runner.cancelled_texts) == {"first", "second"})
|
||||
_wait_until(lambda: unified_client.get(f"/sessions/{session_id}").json()["state"] == "closed")
|
||||
|
||||
|
||||
|
||||
def test_create_session_loads_persona_md_prompt(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
calls: list[dict[str, object]] = []
|
||||
|
||||
async def fake_create_session(*args: object, **kwargs: object) -> str:
|
||||
calls.append(kwargs)
|
||||
sid = "sess_persona_md"
|
||||
redis = args[0]
|
||||
await redis.hset(
|
||||
session_key(sid),
|
||||
mapping={
|
||||
"session_id": sid,
|
||||
"avatar_id": kwargs["avatar_id"],
|
||||
"model": kwargs["model"],
|
||||
"state": "worker_ready",
|
||||
},
|
||||
)
|
||||
return sid
|
||||
|
||||
async def fake_connected_model_ids(_settings: object) -> set[str]:
|
||||
return {"mock"}
|
||||
|
||||
persona_dir = tmp_path / "personas" / "friend-li"
|
||||
persona_dir.mkdir(parents=True)
|
||||
(persona_dir / "persona.md").write_text("# Persona\n你是小李,说话温柔。", encoding="utf-8")
|
||||
(persona_dir / "persona.json").write_text(
|
||||
"""
|
||||
{
|
||||
"schema_version": "0.1",
|
||||
"id": "friend-li",
|
||||
"name": "小李",
|
||||
"description": "微信导入生成的 Persona",
|
||||
"locale": "zh-CN",
|
||||
"avatar": {"id": "singer", "model": "mock"},
|
||||
"agent": {"persona_prompt": "persona.md", "memory_enabled": true, "knowledge_enabled": false},
|
||||
"safety": {"authorized_avatar": true, "authorized_voice": false, "content_label_required": true}
|
||||
}
|
||||
""".strip()
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr(sessions_routes.session_service, "create_session", fake_create_session)
|
||||
monkeypatch.setattr(
|
||||
"opentalking.providers.synthesis.availability.connected_model_ids",
|
||||
fake_connected_model_ids,
|
||||
)
|
||||
|
||||
avatars_dir = Path(__file__).resolve().parents[3] / "examples" / "avatars"
|
||||
app = FastAPI()
|
||||
app.state.redis = InMemoryRedis()
|
||||
app.state.settings = SimpleNamespace(
|
||||
avatars_dir=str(avatars_dir),
|
||||
persona_root=str(tmp_path / "personas"),
|
||||
normalized_stt_default_provider="sensevoice",
|
||||
normalized_stt_provider="sensevoice",
|
||||
normalized_tts_default_provider="edge",
|
||||
normalized_tts_provider="edge",
|
||||
omnirt_endpoint="",
|
||||
)
|
||||
request = Request({"type": "http", "app": app})
|
||||
|
||||
response = asyncio.run(
|
||||
sessions_routes.create_session(
|
||||
CreateSessionRequest(persona_id="friend-li", user_id="client_test"),
|
||||
request,
|
||||
)
|
||||
)
|
||||
|
||||
assert response.status == "created"
|
||||
assert calls[0]["llm_system_prompt"] == "# Persona\n你是小李,说话温柔。"
|
||||
|
||||
@@ -7,12 +7,14 @@ import {
|
||||
getMemoryLibraries,
|
||||
} from "../lib/api";
|
||||
import type { MemoryItem, MemoryLibrary } from "../types";
|
||||
import { WeChatMemoryImportPanel } from "./WeChatMemoryImportPanel";
|
||||
|
||||
type MemoryPanelProps = {
|
||||
characterId: string | null;
|
||||
selectedLibraryId: string | null;
|
||||
memoryEnabled?: boolean;
|
||||
profileId?: string;
|
||||
avatarModel?: string;
|
||||
compact?: boolean;
|
||||
mode?: "select" | "manage";
|
||||
refreshToken?: number;
|
||||
@@ -38,6 +40,7 @@ export function MemoryPanel({
|
||||
selectedLibraryId,
|
||||
memoryEnabled = false,
|
||||
profileId = "default",
|
||||
avatarModel = "mock",
|
||||
compact = false,
|
||||
mode = "manage",
|
||||
refreshToken = 0,
|
||||
@@ -287,6 +290,20 @@ export function MemoryPanel({
|
||||
</div>
|
||||
|
||||
<div className="space-y-4 p-4">
|
||||
<WeChatMemoryImportPanel
|
||||
avatarId={characterId}
|
||||
avatarModel={avatarModel}
|
||||
profileId={profileId}
|
||||
memoryLibraryId={selectedLibraryId || "default"}
|
||||
disabled={busy || loadingLibraries}
|
||||
onCommitted={async (result) => {
|
||||
onLibrarySelect(result.memory_library_id);
|
||||
onMemoryEnabledChange?.(true);
|
||||
await refreshLibraries();
|
||||
await refreshItems();
|
||||
}}
|
||||
/>
|
||||
|
||||
{notice ? (
|
||||
<p className="rounded-lg border border-emerald-200 bg-emerald-50 px-3 py-2 text-xs font-medium text-emerald-700">
|
||||
{notice}
|
||||
|
||||
205
apps/web/src/components/WeChatMemoryImportPanel.tsx
Normal file
205
apps/web/src/components/WeChatMemoryImportPanel.tsx
Normal file
@@ -0,0 +1,205 @@
|
||||
import { useMemo, useState } from "react";
|
||||
import {
|
||||
ApiError,
|
||||
commitWeChatImportJob,
|
||||
selectWeChatImportSpeaker,
|
||||
uploadWeChatImport,
|
||||
} from "../lib/api";
|
||||
import type { WeChatImportCommitResult, WeChatImportJob } from "../types";
|
||||
|
||||
type WeChatMemoryImportPanelProps = {
|
||||
avatarId: string | null;
|
||||
avatarModel?: string;
|
||||
profileId: string;
|
||||
memoryLibraryId: string | null;
|
||||
disabled?: boolean;
|
||||
onCommitted: (result: WeChatImportCommitResult) => void | Promise<void>;
|
||||
};
|
||||
|
||||
function getErrorMessage(error: unknown, fallback: string): string {
|
||||
if (error instanceof ApiError && error.detail) return error.detail;
|
||||
if (error instanceof Error && error.message) return error.message;
|
||||
return fallback;
|
||||
}
|
||||
|
||||
function safePersonaId(value: string | null): string {
|
||||
const clean = (value || "wechat-persona").replace(/[^A-Za-z0-9_-]+/g, "-").replace(/^-+|-+$/g, "");
|
||||
return `wechat-${clean || "persona"}`.slice(0, 80);
|
||||
}
|
||||
|
||||
export function WeChatMemoryImportPanel({
|
||||
avatarId,
|
||||
avatarModel = "mock",
|
||||
profileId,
|
||||
memoryLibraryId,
|
||||
disabled = false,
|
||||
onCommitted,
|
||||
}: WeChatMemoryImportPanelProps) {
|
||||
const [file, setFile] = useState<File | null>(null);
|
||||
const [job, setJob] = useState<WeChatImportJob | null>(null);
|
||||
const [selectedSpeakerId, setSelectedSpeakerId] = useState("");
|
||||
const [personaId, setPersonaId] = useState(safePersonaId(avatarId));
|
||||
const [personaName, setPersonaName] = useState("");
|
||||
const [busy, setBusy] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [notice, setNotice] = useState<string | null>(null);
|
||||
|
||||
const selectedSpeaker = useMemo(
|
||||
() => job?.speakers.find((speaker) => speaker.id === selectedSpeakerId) ?? null,
|
||||
[job, selectedSpeakerId],
|
||||
);
|
||||
|
||||
const handleUpload = async () => {
|
||||
if (!file || !avatarId) return;
|
||||
setBusy(true);
|
||||
setError(null);
|
||||
setNotice(null);
|
||||
try {
|
||||
const created = await uploadWeChatImport(file, {
|
||||
profileId,
|
||||
memoryLibraryId: memoryLibraryId || "default",
|
||||
avatarId,
|
||||
avatarModel,
|
||||
characterId: avatarId,
|
||||
});
|
||||
setJob(created);
|
||||
const firstSpeaker = created.speakers[0] ?? null;
|
||||
setSelectedSpeakerId(created.selected_speaker_id || firstSpeaker?.id || "");
|
||||
setPersonaName(firstSpeaker?.name || "微信数字人");
|
||||
setPersonaId(safePersonaId(firstSpeaker?.id || avatarId));
|
||||
setNotice(created.status === "draft_ready" ? "人设草稿已生成" : "请选择要创建数字人的说话人");
|
||||
} catch (e) {
|
||||
setError(getErrorMessage(e, "上传失败"));
|
||||
} finally {
|
||||
setBusy(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSelectSpeaker = async () => {
|
||||
if (!job || !selectedSpeakerId) return;
|
||||
setBusy(true);
|
||||
setError(null);
|
||||
setNotice(null);
|
||||
try {
|
||||
const selected = await selectWeChatImportSpeaker(job.id, selectedSpeakerId);
|
||||
setJob(selected);
|
||||
setPersonaName(selectedSpeaker?.name || personaName || "微信数字人");
|
||||
setPersonaId(safePersonaId(selectedSpeakerId));
|
||||
setNotice("人设草稿已生成");
|
||||
} catch (e) {
|
||||
setError(getErrorMessage(e, "选择说话人失败"));
|
||||
} finally {
|
||||
setBusy(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCommit = async () => {
|
||||
if (!job || job.status !== "draft_ready") return;
|
||||
setBusy(true);
|
||||
setError(null);
|
||||
setNotice(null);
|
||||
try {
|
||||
const committed = await commitWeChatImportJob(job.id, {
|
||||
personaId: personaId || safePersonaId(selectedSpeakerId || avatarId),
|
||||
personaName: personaName || selectedSpeaker?.name || "微信数字人",
|
||||
});
|
||||
await onCommitted(committed);
|
||||
setJob({ ...job, status: "committed" });
|
||||
setNotice(`已保存数字人,并导入 ${committed.memory_imported} 条记忆`);
|
||||
} catch (e) {
|
||||
setError(getErrorMessage(e, "保存失败"));
|
||||
} finally {
|
||||
setBusy(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border border-slate-200 bg-slate-50 p-3">
|
||||
<div className="flex flex-wrap items-center justify-between gap-3">
|
||||
<div>
|
||||
<h3 className="text-sm font-semibold text-slate-950">微信聊天记录导入</h3>
|
||||
<p className="text-xs text-slate-500">将微信导出记录生成数字人人设和长期记忆</p>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => void handleUpload()}
|
||||
disabled={disabled || busy || !file || !avatarId}
|
||||
className="rounded-lg bg-cyan-600 px-3 py-1.5 text-xs font-semibold text-white transition hover:bg-cyan-500 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
{busy ? "处理中..." : "上传"}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="mt-3 grid gap-3 md:grid-cols-[minmax(0,1fr)_12rem]">
|
||||
<input
|
||||
type="file"
|
||||
accept=".json,.csv,.txt,.html,.htm,.zip"
|
||||
disabled={disabled || busy}
|
||||
onChange={(event) => setFile(event.target.files?.[0] ?? null)}
|
||||
className="w-full rounded-lg border border-slate-200 bg-white px-3 py-2 text-xs text-slate-700 file:mr-3 file:rounded-md file:border-0 file:bg-slate-100 file:px-2.5 file:py-1 file:text-xs file:font-semibold file:text-slate-700 disabled:opacity-60"
|
||||
/>
|
||||
<input
|
||||
value={personaId}
|
||||
onChange={(event) => setPersonaId(event.target.value)}
|
||||
disabled={disabled || busy}
|
||||
placeholder="数字人 ID"
|
||||
className="w-full rounded-lg border border-slate-200 bg-white px-3 py-2 text-xs text-slate-700 outline-none focus:border-cyan-300 disabled:opacity-60"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{job?.status === "needs_speaker_selection" ? (
|
||||
<div className="mt-3 flex flex-wrap items-center gap-2">
|
||||
<select
|
||||
value={selectedSpeakerId}
|
||||
onChange={(event) => setSelectedSpeakerId(event.target.value)}
|
||||
disabled={disabled || busy}
|
||||
className="min-h-9 min-w-44 rounded-lg border border-slate-200 bg-white px-2.5 text-xs font-medium text-slate-700 outline-none focus:border-cyan-300 disabled:opacity-60"
|
||||
>
|
||||
{job.speakers.map((speaker) => (
|
||||
<option key={speaker.id} value={speaker.id}>
|
||||
{speaker.name || speaker.id} ({speaker.message_count})
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => void handleSelectSpeaker()}
|
||||
disabled={disabled || busy || !selectedSpeakerId}
|
||||
className="rounded-lg border border-slate-200 bg-white px-3 py-1.5 text-xs font-semibold text-slate-700 transition hover:border-cyan-200 hover:text-cyan-700 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
生成人设草稿
|
||||
</button>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{job?.status === "draft_ready" ? (
|
||||
<div className="mt-3 space-y-3">
|
||||
<input
|
||||
value={personaName}
|
||||
onChange={(event) => setPersonaName(event.target.value)}
|
||||
disabled={disabled || busy}
|
||||
placeholder="数字人名称"
|
||||
className="w-full rounded-lg border border-slate-200 bg-white px-3 py-2 text-xs text-slate-700 outline-none focus:border-cyan-300 disabled:opacity-60"
|
||||
/>
|
||||
<textarea
|
||||
value={job.persona_md || ""}
|
||||
readOnly
|
||||
rows={6}
|
||||
className="w-full resize-y rounded-lg border border-slate-200 bg-white px-3 py-2 text-xs leading-relaxed text-slate-700 outline-none"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => void handleCommit()}
|
||||
disabled={disabled || busy || !personaId}
|
||||
className="rounded-lg bg-emerald-600 px-3 py-1.5 text-xs font-semibold text-white transition hover:bg-emerald-500 disabled:cursor-not-allowed disabled:opacity-50"
|
||||
>
|
||||
保存数字人
|
||||
</button>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{notice ? <p className="mt-3 text-xs font-medium text-emerald-700">{notice}</p> : null}
|
||||
{error ? <p className="mt-3 text-xs font-medium text-red-700">{error}</p> : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { MemoryItem, MemoryLibrary, MemoryTurn } from "../types";
|
||||
import type { MemoryItem, MemoryLibrary, MemoryTurn, WeChatImportCommitResult, WeChatImportJob } from "../types";
|
||||
|
||||
export const API_BASE = import.meta.env.VITE_API_BASE ?? "/api";
|
||||
|
||||
@@ -563,6 +563,49 @@ export function importMemoryTurns(
|
||||
return apiPost(`/memory/libraries/${encodeURIComponent(libraryId)}/import`, body);
|
||||
}
|
||||
|
||||
export function uploadWeChatImport(
|
||||
file: File,
|
||||
body: {
|
||||
profileId?: string;
|
||||
memoryLibraryId?: string;
|
||||
avatarId: string;
|
||||
avatarModel?: string;
|
||||
characterId?: string;
|
||||
targetSpeakerId?: string;
|
||||
sourceFormat?: string;
|
||||
timezone?: string;
|
||||
},
|
||||
): Promise<WeChatImportJob> {
|
||||
const form = new FormData();
|
||||
form.set("file", file);
|
||||
form.set("profile_id", body.profileId || "default");
|
||||
form.set("memory_library_id", body.memoryLibraryId || "default");
|
||||
form.set("avatar_id", body.avatarId);
|
||||
form.set("avatar_model", body.avatarModel || "mock");
|
||||
if (body.characterId) form.set("character_id", body.characterId);
|
||||
if (body.targetSpeakerId) form.set("target_speaker_id", body.targetSpeakerId);
|
||||
if (body.sourceFormat) form.set("source_format", body.sourceFormat);
|
||||
if (body.timezone) form.set("timezone", body.timezone);
|
||||
return apiPostForm<WeChatImportJob>("/memory/wechat-import", form);
|
||||
}
|
||||
|
||||
export function selectWeChatImportSpeaker(jobId: string, targetSpeakerId: string): Promise<WeChatImportJob> {
|
||||
return apiPost(`/memory/wechat-import/${encodeURIComponent(jobId)}/speaker`, {
|
||||
target_speaker_id: targetSpeakerId,
|
||||
});
|
||||
}
|
||||
|
||||
export function commitWeChatImportJob(
|
||||
jobId: string,
|
||||
body: { personaId: string; personaName?: string; description?: string },
|
||||
): Promise<WeChatImportCommitResult> {
|
||||
return apiPost(`/memory/wechat-import/${encodeURIComponent(jobId)}/commit`, {
|
||||
persona_id: body.personaId,
|
||||
persona_name: body.personaName,
|
||||
description: body.description,
|
||||
});
|
||||
}
|
||||
|
||||
/** GET /voices 返回的音色目录项(含 SQLite 中的系统预设与复刻) */
|
||||
export type VoiceCatalogItem = {
|
||||
id: number;
|
||||
|
||||
@@ -34,3 +34,38 @@ export type MemoryTurn = {
|
||||
role: "user" | "assistant";
|
||||
content: string;
|
||||
};
|
||||
|
||||
export type WeChatImportSpeaker = {
|
||||
id: string;
|
||||
name: string;
|
||||
message_count: number;
|
||||
is_self: boolean;
|
||||
metadata: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type WeChatImportJob = {
|
||||
id: string;
|
||||
status: "needs_speaker_selection" | "draft_ready" | "committed" | "error" | string;
|
||||
speakers: WeChatImportSpeaker[];
|
||||
profile_id: string;
|
||||
memory_library_id: string;
|
||||
avatar_id: string;
|
||||
avatar_model: string;
|
||||
character_id: string;
|
||||
selected_speaker_id?: string | null;
|
||||
persona_md?: string | null;
|
||||
source_metadata: Record<string, unknown>;
|
||||
error?: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
};
|
||||
|
||||
export type WeChatImportCommitResult = {
|
||||
job_id: string;
|
||||
persona_id: string;
|
||||
memory_imported: number;
|
||||
persona_md_bytes: number;
|
||||
profile_id: string;
|
||||
character_id: string;
|
||||
memory_library_id: string;
|
||||
};
|
||||
|
||||
307
opentalking/persona/memory_builder.py
Normal file
307
opentalking/persona/memory_builder.py
Normal file
@@ -0,0 +1,307 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
|
||||
from opentalking.core.config import get_settings
|
||||
from opentalking.providers.memory.schemas import MemoryItem, utc_now_iso
|
||||
from opentalking.persona.weflow_parser import WeFlowExport, WeFlowSpeaker, WeFlowTurn
|
||||
|
||||
|
||||
class PersonaLLMClient(Protocol):
|
||||
def complete(self, messages: list[dict[str, str]]) -> str | Awaitable[str]: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WeChatPersonaDraft:
|
||||
target_speaker: WeFlowSpeaker
|
||||
persona_name: str
|
||||
persona_md: str
|
||||
memory_items: list[MemoryItem]
|
||||
source_metadata: dict[str, Any] = field(default_factory=dict)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class _ConfiguredPersonaLLM:
|
||||
async def complete(self, messages: list[dict[str, str]]) -> str:
|
||||
settings = get_settings()
|
||||
if not str(getattr(settings, "llm_base_url", "") or "").strip():
|
||||
return ""
|
||||
from opentalking.providers.llm.openai_compatible.adapter import OpenAICompatibleLLMClient
|
||||
|
||||
client = OpenAICompatibleLLMClient(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=settings.llm_model,
|
||||
)
|
||||
chunks: list[str] = []
|
||||
async for chunk in client.chat_stream(messages):
|
||||
chunks.append(chunk)
|
||||
return "".join(chunks)
|
||||
|
||||
|
||||
def build_wechat_persona_draft(
|
||||
export: WeFlowExport,
|
||||
*,
|
||||
target_speaker_id: str | None = None,
|
||||
llm_client: PersonaLLMClient | None = None,
|
||||
max_sample_turns: int = 80,
|
||||
) -> WeChatPersonaDraft:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(
|
||||
build_wechat_persona_draft_async(
|
||||
export,
|
||||
target_speaker_id=target_speaker_id,
|
||||
llm_client=llm_client,
|
||||
max_sample_turns=max_sample_turns,
|
||||
)
|
||||
)
|
||||
raise RuntimeError("build_wechat_persona_draft_async must be used inside a running event loop")
|
||||
|
||||
|
||||
async def build_wechat_persona_draft_async(
|
||||
export: WeFlowExport,
|
||||
*,
|
||||
target_speaker_id: str | None = None,
|
||||
llm_client: PersonaLLMClient | None = None,
|
||||
max_sample_turns: int = 80,
|
||||
) -> WeChatPersonaDraft:
|
||||
target = _resolve_target_speaker(export, target_speaker_id)
|
||||
target_turns = [turn for turn in export.turns if turn.speaker_id == target.id]
|
||||
if not target_turns:
|
||||
raise ValueError("target speaker has no readable messages")
|
||||
limited_turns = target_turns[: max(1, max_sample_turns)]
|
||||
source_metadata = {
|
||||
"source": "wechat_import",
|
||||
"source_format": export.detected_format,
|
||||
"source_name": export.source_metadata.get("source_name"),
|
||||
"conversation_id": export.conversation_id,
|
||||
"target_speaker_id": target.id,
|
||||
"target_speaker_name": target.name,
|
||||
"target_message_count": len(target_turns),
|
||||
}
|
||||
warnings = list(export.warnings)
|
||||
|
||||
client = llm_client or _ConfiguredPersonaLLM()
|
||||
llm_payload = await _try_llm_build(client, target=target, turns=limited_turns)
|
||||
if llm_payload:
|
||||
draft = _draft_from_llm_payload(
|
||||
llm_payload,
|
||||
target=target,
|
||||
source_metadata=source_metadata,
|
||||
warnings=warnings,
|
||||
)
|
||||
if draft is not None:
|
||||
return draft
|
||||
warnings.append("llm_payload_invalid")
|
||||
|
||||
return _fallback_draft(target=target, turns=limited_turns, source_metadata=source_metadata, warnings=warnings)
|
||||
|
||||
|
||||
def _resolve_target_speaker(export: WeFlowExport, target_speaker_id: str | None) -> WeFlowSpeaker:
|
||||
if target_speaker_id:
|
||||
for speaker in export.speakers:
|
||||
if speaker.id == target_speaker_id:
|
||||
return speaker
|
||||
raise ValueError("target speaker not found")
|
||||
candidates = [speaker for speaker in export.speakers if not speaker.is_self]
|
||||
if not candidates:
|
||||
candidates = list(export.speakers)
|
||||
if not candidates:
|
||||
raise ValueError("WeFlow export has no speakers")
|
||||
return sorted(candidates, key=lambda item: (-item.message_count, item.name))[0]
|
||||
|
||||
|
||||
async def _try_llm_build(
|
||||
client: PersonaLLMClient,
|
||||
*,
|
||||
target: WeFlowSpeaker,
|
||||
turns: Sequence[WeFlowTurn],
|
||||
) -> dict[str, Any] | None:
|
||||
transcript = "\n".join(
|
||||
f"- {turn.timestamp or 'unknown'} {target.name}: {_redact_sensitive(turn.content)[:240]}"
|
||||
for turn in turns
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You extract a digital-human persona from a user-uploaded WeChat export. "
|
||||
"Return strict JSON with keys persona_md, style_memories, semantic_memories, "
|
||||
"episodic_summaries, confidence. Summarize style and persona traits; do not copy "
|
||||
"raw private transcript lines or secrets."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Target speaker: {target.name} ({target.id})\n"
|
||||
f"Sample target-speaker messages, redacted:\n{transcript}"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
result = client.complete(messages)
|
||||
raw = result if isinstance(result, str) else await result
|
||||
except Exception:
|
||||
return None
|
||||
text = str(raw or "").strip()
|
||||
if not text:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(_strip_json_fence(text))
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
|
||||
|
||||
def _draft_from_llm_payload(
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
target: WeFlowSpeaker,
|
||||
source_metadata: dict[str, Any],
|
||||
warnings: list[str],
|
||||
) -> WeChatPersonaDraft | None:
|
||||
persona_md = _safe_text(str(payload.get("persona_md") or "").strip())
|
||||
if not persona_md:
|
||||
return None
|
||||
confidence = _confidence(str(payload.get("confidence") or "medium"))
|
||||
items: list[MemoryItem] = []
|
||||
for text in _string_list(payload.get("style_memories")):
|
||||
items.append(_memory_item(text, kind="preference", layer="style", target=target, confidence=confidence))
|
||||
for text in _string_list(payload.get("semantic_memories")):
|
||||
items.append(_memory_item(text, kind="note", layer="semantic", target=target, confidence=confidence))
|
||||
for text in _string_list(payload.get("episodic_summaries")):
|
||||
items.append(_memory_item(text, kind="summary", layer="episodic", target=target, confidence=confidence))
|
||||
if not items:
|
||||
return None
|
||||
return WeChatPersonaDraft(
|
||||
target_speaker=target,
|
||||
persona_name=target.name,
|
||||
persona_md=persona_md,
|
||||
memory_items=items,
|
||||
source_metadata=source_metadata,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
def _fallback_draft(
|
||||
*,
|
||||
target: WeFlowSpeaker,
|
||||
turns: Sequence[WeFlowTurn],
|
||||
source_metadata: dict[str, Any],
|
||||
warnings: list[str],
|
||||
) -> WeChatPersonaDraft:
|
||||
stats = _style_stats(turns)
|
||||
persona_md = "\n".join(
|
||||
[
|
||||
"# Persona",
|
||||
f"Name: {target.name}",
|
||||
"Origin: Built from a user-uploaded WeFlow WeChat export.",
|
||||
"",
|
||||
"# Speaking Style",
|
||||
f"- Uses a {stats['length_label']} reply rhythm with calm, practical guidance.",
|
||||
"- Favors small next steps, check-ins, and emotionally steady phrasing.",
|
||||
"- Avoids exposing raw imported chat lines; use this as a style guide, not a transcript.",
|
||||
"",
|
||||
"# Memory Policy",
|
||||
"- Treat imported chat records as private source artifacts.",
|
||||
"- Runtime prompts should load this persona.md summary, never raw chat logs.",
|
||||
]
|
||||
)
|
||||
style = f"{target.name} tends to use {stats['length_label']} replies with calm practical guidance."
|
||||
semantic = f"Persona source contains {len(turns)} target-speaker WeChat messages imported from WeFlow."
|
||||
episodic = f"Imported chats suggest {target.name} often supports planning, check-ins, or emotional steadiness."
|
||||
return WeChatPersonaDraft(
|
||||
target_speaker=target,
|
||||
persona_name=target.name,
|
||||
persona_md=persona_md,
|
||||
memory_items=[
|
||||
_memory_item(style, kind="preference", layer="style", target=target, confidence="medium"),
|
||||
_memory_item(semantic, kind="note", layer="semantic", target=target, confidence="medium"),
|
||||
_memory_item(episodic, kind="summary", layer="episodic", target=target, confidence="medium"),
|
||||
],
|
||||
source_metadata=source_metadata,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
def _style_stats(turns: Sequence[WeFlowTurn]) -> dict[str, str]:
|
||||
contents = [_redact_sensitive(turn.content) for turn in turns if turn.content.strip()]
|
||||
if not contents:
|
||||
return {"length_label": "concise"}
|
||||
avg_words = sum(max(1, len(text.split())) for text in contents) / len(contents)
|
||||
if avg_words <= 8:
|
||||
label = "concise"
|
||||
elif avg_words <= 20:
|
||||
label = "balanced"
|
||||
else:
|
||||
label = "detailed"
|
||||
return {"length_label": label}
|
||||
|
||||
|
||||
def _memory_item(
|
||||
text: str,
|
||||
*,
|
||||
kind: str,
|
||||
layer: str,
|
||||
target: WeFlowSpeaker,
|
||||
confidence: str,
|
||||
) -> MemoryItem:
|
||||
clean = _safe_text(text)
|
||||
return MemoryItem(
|
||||
id="",
|
||||
text=clean,
|
||||
type=kind, # type: ignore[arg-type]
|
||||
metadata={
|
||||
"source": "wechat_import",
|
||||
"source_type": "weflow_upload",
|
||||
"layer": layer,
|
||||
"target_speaker_id": target.id,
|
||||
"target_speaker_name": target.name,
|
||||
"confidence": _confidence(confidence),
|
||||
},
|
||||
created_at=utc_now_iso(),
|
||||
)
|
||||
|
||||
|
||||
def _string_list(value: Any) -> list[str]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
out: list[str] = []
|
||||
for item in value:
|
||||
text = _safe_text(str(item or "").strip())
|
||||
if text:
|
||||
out.append(text)
|
||||
return out
|
||||
|
||||
|
||||
def _confidence(value: str) -> str:
|
||||
normalized = (value or "").strip().lower()
|
||||
if normalized in {"low", "medium", "high"}:
|
||||
return normalized
|
||||
return "medium"
|
||||
|
||||
|
||||
def _strip_json_fence(text: str) -> str:
|
||||
stripped = text.strip()
|
||||
if stripped.startswith("```"):
|
||||
stripped = re.sub(r"^```(?:json)?\s*", "", stripped, flags=re.IGNORECASE)
|
||||
stripped = re.sub(r"\s*```$", "", stripped)
|
||||
return stripped.strip()
|
||||
|
||||
|
||||
def _safe_text(text: str) -> str:
|
||||
return _redact_sensitive(text).strip()
|
||||
|
||||
|
||||
def _redact_sensitive(text: str) -> str:
|
||||
redacted = re.sub(r"\b\d{4,}\b", "[redacted-number]", text)
|
||||
redacted = re.sub(r"(?i)\b(secret|password|token|api[_ -]?key)\b[^.\n;]*", "[redacted-sensitive]", redacted)
|
||||
return redacted
|
||||
@@ -131,6 +131,7 @@ async def import_persona_package(
|
||||
with tempfile.TemporaryDirectory(prefix="opentalking-persona-import-") as tmp:
|
||||
package_root = Path(tmp)
|
||||
manifest = extract_persona_package(package_path, package_root)
|
||||
_read_prompt_text(package_root, manifest.agent.persona_prompt)
|
||||
system_prompt = _read_prompt_text(package_root, manifest.agent.system_prompt)
|
||||
style_prompt = _read_prompt_text(package_root, manifest.agent.style_prompt)
|
||||
if system_prompt or style_prompt:
|
||||
@@ -144,7 +145,11 @@ async def import_persona_package(
|
||||
)
|
||||
manifest = replace(
|
||||
manifest,
|
||||
agent=replace(manifest.agent, system_prompt="prompts/_compiled_system.md"),
|
||||
agent=replace(
|
||||
manifest.agent,
|
||||
system_prompt="prompts/_compiled_system.md",
|
||||
style_prompt=None,
|
||||
),
|
||||
)
|
||||
manifest = await _import_knowledge_documents(
|
||||
manifest=manifest,
|
||||
|
||||
67
opentalking/persona/persona_md.py
Normal file
67
opentalking/persona/persona_md.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from opentalking.persona.store import PersonaRecord
|
||||
|
||||
|
||||
DEFAULT_PERSONA_MD = "persona.md"
|
||||
|
||||
|
||||
def _safe_relative_path(root: Path, rel_path: str | None, *, field_name: str) -> Path:
|
||||
value = (rel_path or DEFAULT_PERSONA_MD).strip() or DEFAULT_PERSONA_MD
|
||||
path = (root / value).resolve()
|
||||
try:
|
||||
path.relative_to(root.resolve())
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"{field_name} path must stay inside persona directory") from exc
|
||||
return path
|
||||
|
||||
|
||||
def read_prompt_file(root: Path, rel_path: str | None, *, field_name: str) -> str | None:
|
||||
if not rel_path:
|
||||
return None
|
||||
path = _safe_relative_path(root, rel_path, field_name=field_name)
|
||||
if not path.is_file():
|
||||
return None
|
||||
text = path.read_text(encoding="utf-8").strip()
|
||||
return text or None
|
||||
|
||||
|
||||
def read_persona_md(record: PersonaRecord) -> str:
|
||||
path = _safe_relative_path(
|
||||
record.path,
|
||||
record.manifest.agent.persona_prompt or DEFAULT_PERSONA_MD,
|
||||
field_name="persona_prompt",
|
||||
)
|
||||
if not path.is_file():
|
||||
return ""
|
||||
return path.read_text(encoding="utf-8").strip()
|
||||
|
||||
|
||||
def write_persona_md(record: PersonaRecord, content: str) -> str:
|
||||
path = _safe_relative_path(
|
||||
record.path,
|
||||
record.manifest.agent.persona_prompt or DEFAULT_PERSONA_MD,
|
||||
field_name="persona_prompt",
|
||||
)
|
||||
text = str(content or "").strip()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(text + ("\n" if text else ""), encoding="utf-8")
|
||||
return text
|
||||
|
||||
|
||||
def build_persona_prompt_text(
|
||||
root: Path,
|
||||
*,
|
||||
persona_prompt: str | None,
|
||||
system_prompt: str | None,
|
||||
style_prompt: str | None,
|
||||
) -> str | None:
|
||||
parts = [
|
||||
read_prompt_file(root, persona_prompt, field_name="persona_prompt"),
|
||||
read_prompt_file(root, system_prompt, field_name="system_prompt"),
|
||||
read_prompt_file(root, style_prompt, field_name="style_prompt"),
|
||||
]
|
||||
combined = "\n\n".join(part for part in parts if part)
|
||||
return combined or None
|
||||
@@ -28,6 +28,7 @@ class PersonaVoice:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PersonaAgent:
|
||||
persona_prompt: str | None = None
|
||||
system_prompt: str | None = None
|
||||
style_prompt: str | None = None
|
||||
memory_enabled: bool = False
|
||||
@@ -153,6 +154,7 @@ def persona_from_dict(raw: dict[str, Any]) -> PersonaManifest:
|
||||
model=_optional_str(voice_raw.get("model"), max_len=256),
|
||||
),
|
||||
agent=PersonaAgent(
|
||||
persona_prompt=_optional_str(agent_raw.get("persona_prompt"), max_len=256),
|
||||
system_prompt=_optional_str(agent_raw.get("system_prompt"), max_len=256),
|
||||
style_prompt=_optional_str(agent_raw.get("style_prompt"), max_len=256),
|
||||
memory_enabled=_bool(agent_raw.get("memory_enabled"), default=False),
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from opentalking.persona.persona_md import build_persona_prompt_text
|
||||
from opentalking.persona.store import PersonaRecord, PersonaStore
|
||||
|
||||
|
||||
@@ -29,7 +30,12 @@ def default_persona_store() -> PersonaStore:
|
||||
|
||||
def build_session_defaults(record: PersonaRecord) -> PersonaSessionDefaults:
|
||||
manifest = record.manifest
|
||||
prompt = _read_prompt(record.path, manifest.agent.system_prompt)
|
||||
prompt = build_persona_prompt_text(
|
||||
record.path,
|
||||
persona_prompt=manifest.agent.persona_prompt,
|
||||
system_prompt=manifest.agent.system_prompt,
|
||||
style_prompt=manifest.agent.style_prompt,
|
||||
)
|
||||
return PersonaSessionDefaults(
|
||||
persona_id=manifest.id,
|
||||
avatar_id=manifest.avatar.id,
|
||||
|
||||
262
opentalking/persona/wechat_import.py
Normal file
262
opentalking/persona/wechat_import.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from opentalking.persona.memory_builder import (
|
||||
WeChatPersonaDraft,
|
||||
build_wechat_persona_draft,
|
||||
build_wechat_persona_draft_async,
|
||||
)
|
||||
from opentalking.persona.schema import (
|
||||
PERSONA_SCHEMA_VERSION,
|
||||
PersonaAgent,
|
||||
PersonaAvatar,
|
||||
PersonaManifest,
|
||||
PersonaRuntime,
|
||||
PersonaSafety,
|
||||
write_persona_manifest,
|
||||
)
|
||||
from opentalking.persona.store import PersonaRecord, PersonaStore
|
||||
from opentalking.persona.weflow_parser import WeFlowExport, WeFlowSpeaker, parse_weflow_export
|
||||
from opentalking.providers.memory.base import MemoryProvider
|
||||
from opentalking.providers.memory.import_jobs import ImportJobStatus, MemoryImportCommitResult
|
||||
from opentalking.providers.memory.schemas import utc_now_iso
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeChatImportJob:
|
||||
id: str
|
||||
status: ImportJobStatus
|
||||
export: WeFlowExport
|
||||
speakers: list[WeFlowSpeaker]
|
||||
profile_id: str
|
||||
memory_library_id: str
|
||||
avatar_id: str
|
||||
avatar_model: str
|
||||
character_id: str
|
||||
selected_speaker_id: str | None = None
|
||||
draft: WeChatPersonaDraft | None = None
|
||||
source_metadata: dict[str, Any] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
created_at: str = field(default_factory=utc_now_iso)
|
||||
updated_at: str = field(default_factory=utc_now_iso)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"status": self.status,
|
||||
"speakers": [speaker.__dict__ for speaker in self.speakers],
|
||||
"profile_id": self.profile_id,
|
||||
"memory_library_id": self.memory_library_id,
|
||||
"avatar_id": self.avatar_id,
|
||||
"avatar_model": self.avatar_model,
|
||||
"character_id": self.character_id,
|
||||
"selected_speaker_id": self.selected_speaker_id,
|
||||
"persona_md": self.draft.persona_md if self.draft else None,
|
||||
"source_metadata": self.source_metadata,
|
||||
"error": self.error,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
|
||||
class WeChatImportJobRegistry:
|
||||
def __init__(self, *, persona_store: PersonaStore, memory_provider: MemoryProvider) -> None:
|
||||
self.persona_store = persona_store
|
||||
self.memory_provider = memory_provider
|
||||
self._jobs: dict[str, WeChatImportJob] = {}
|
||||
|
||||
def create_job(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
*,
|
||||
profile_id: str = "default",
|
||||
memory_library_id: str = "default",
|
||||
avatar_id: str,
|
||||
avatar_model: str,
|
||||
character_id: str | None = None,
|
||||
target_speaker_id: str | None = None,
|
||||
source_format: str = "auto",
|
||||
timezone: str = "Asia/Shanghai",
|
||||
) -> WeChatImportJob:
|
||||
job = self._create_base_job(
|
||||
file_path,
|
||||
profile_id=profile_id,
|
||||
memory_library_id=memory_library_id,
|
||||
avatar_id=avatar_id,
|
||||
avatar_model=avatar_model,
|
||||
character_id=character_id,
|
||||
source_format=source_format,
|
||||
timezone=timezone,
|
||||
)
|
||||
if target_speaker_id or len(job.speakers) == 1:
|
||||
return self.select_speaker(job.id, target_speaker_id or job.speakers[0].id)
|
||||
return job
|
||||
|
||||
async def create_job_async(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
*,
|
||||
profile_id: str = "default",
|
||||
memory_library_id: str = "default",
|
||||
avatar_id: str,
|
||||
avatar_model: str,
|
||||
character_id: str | None = None,
|
||||
target_speaker_id: str | None = None,
|
||||
source_format: str = "auto",
|
||||
timezone: str = "Asia/Shanghai",
|
||||
) -> WeChatImportJob:
|
||||
job = self._create_base_job(
|
||||
file_path,
|
||||
profile_id=profile_id,
|
||||
memory_library_id=memory_library_id,
|
||||
avatar_id=avatar_id,
|
||||
avatar_model=avatar_model,
|
||||
character_id=character_id,
|
||||
source_format=source_format,
|
||||
timezone=timezone,
|
||||
)
|
||||
if target_speaker_id or len(job.speakers) == 1:
|
||||
return await self.select_speaker_async(job.id, target_speaker_id or job.speakers[0].id)
|
||||
return job
|
||||
|
||||
def _create_base_job(
|
||||
self,
|
||||
file_path: str | Path,
|
||||
*,
|
||||
profile_id: str,
|
||||
memory_library_id: str,
|
||||
avatar_id: str,
|
||||
avatar_model: str,
|
||||
character_id: str | None,
|
||||
source_format: str,
|
||||
timezone: str,
|
||||
) -> WeChatImportJob:
|
||||
export = parse_weflow_export(file_path, source_format=source_format, timezone=timezone)
|
||||
speakers = [speaker for speaker in export.speakers if not speaker.is_self]
|
||||
if not speakers:
|
||||
speakers = list(export.speakers)
|
||||
job = WeChatImportJob(
|
||||
id=uuid.uuid4().hex,
|
||||
status="needs_speaker_selection",
|
||||
export=export,
|
||||
speakers=speakers,
|
||||
profile_id=(profile_id or "default").strip() or "default",
|
||||
memory_library_id=(memory_library_id or "default").strip() or "default",
|
||||
avatar_id=avatar_id,
|
||||
avatar_model=avatar_model,
|
||||
character_id=(character_id or avatar_id).strip() or avatar_id,
|
||||
source_metadata={**export.source_metadata, "source": "wechat_import"},
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
return job
|
||||
|
||||
def get_job(self, job_id: str) -> WeChatImportJob:
|
||||
try:
|
||||
return self._jobs[job_id]
|
||||
except KeyError as exc:
|
||||
raise KeyError("wechat import job not found") from exc
|
||||
|
||||
def select_speaker(self, job_id: str, target_speaker_id: str) -> WeChatImportJob:
|
||||
job = self.get_job(job_id)
|
||||
draft = build_wechat_persona_draft(job.export, target_speaker_id=target_speaker_id)
|
||||
return self._set_draft(job, target_speaker_id=target_speaker_id, draft=draft)
|
||||
|
||||
async def select_speaker_async(self, job_id: str, target_speaker_id: str) -> WeChatImportJob:
|
||||
job = self.get_job(job_id)
|
||||
draft = await build_wechat_persona_draft_async(job.export, target_speaker_id=target_speaker_id)
|
||||
return self._set_draft(job, target_speaker_id=target_speaker_id, draft=draft)
|
||||
|
||||
def _set_draft(
|
||||
self,
|
||||
job: WeChatImportJob,
|
||||
*,
|
||||
target_speaker_id: str,
|
||||
draft: WeChatPersonaDraft,
|
||||
) -> WeChatImportJob:
|
||||
job.selected_speaker_id = target_speaker_id
|
||||
job.draft = draft
|
||||
job.status = "draft_ready"
|
||||
job.updated_at = utc_now_iso()
|
||||
return job
|
||||
|
||||
async def commit(
|
||||
self,
|
||||
job_id: str,
|
||||
*,
|
||||
persona_id: str,
|
||||
persona_name: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> MemoryImportCommitResult:
|
||||
job = self.get_job(job_id)
|
||||
if job.draft is None:
|
||||
raise ValueError("wechat import job has no persona draft")
|
||||
record = self._save_persona(job, persona_id=persona_id, persona_name=persona_name, description=description)
|
||||
imported = await self.memory_provider.add_items(
|
||||
library_id=job.memory_library_id,
|
||||
profile_id=job.profile_id,
|
||||
character_id=job.character_id,
|
||||
items=job.draft.memory_items,
|
||||
)
|
||||
job.status = "committed"
|
||||
job.updated_at = utc_now_iso()
|
||||
persona_md_path = record.path / "persona.md"
|
||||
return MemoryImportCommitResult(
|
||||
job_id=job.id,
|
||||
persona_id=record.manifest.id,
|
||||
memory_imported=imported,
|
||||
persona_md_bytes=persona_md_path.stat().st_size if persona_md_path.is_file() else 0,
|
||||
profile_id=job.profile_id,
|
||||
character_id=job.character_id,
|
||||
memory_library_id=job.memory_library_id,
|
||||
)
|
||||
|
||||
def _save_persona(
|
||||
self,
|
||||
job: WeChatImportJob,
|
||||
*,
|
||||
persona_id: str,
|
||||
persona_name: str | None,
|
||||
description: str | None,
|
||||
) -> PersonaRecord:
|
||||
if job.draft is None:
|
||||
raise ValueError("wechat import job has no persona draft")
|
||||
name = (persona_name or job.draft.persona_name or persona_id).strip() or persona_id
|
||||
manifest = PersonaManifest(
|
||||
schema_version=PERSONA_SCHEMA_VERSION,
|
||||
id=persona_id,
|
||||
name=name,
|
||||
description=(description or "Persona generated from an uploaded WeFlow WeChat export."),
|
||||
locale="zh-CN",
|
||||
avatar=PersonaAvatar(id=job.avatar_id, model=job.avatar_model),
|
||||
agent=PersonaAgent(
|
||||
persona_prompt="persona.md",
|
||||
memory_enabled=True,
|
||||
knowledge_enabled=False,
|
||||
),
|
||||
runtime=PersonaRuntime(),
|
||||
safety=PersonaSafety(
|
||||
authorized_avatar=True,
|
||||
authorized_voice=False,
|
||||
content_label_required=True,
|
||||
),
|
||||
)
|
||||
with tempfile.TemporaryDirectory(prefix="opentalking-wechat-persona-") as tmp:
|
||||
root = Path(tmp)
|
||||
(root / "persona.md").write_text(job.draft.persona_md.strip() + "\n", encoding="utf-8")
|
||||
(root / "import_metadata.json").write_text(
|
||||
json.dumps(job.draft.source_metadata, ensure_ascii=False, indent=2) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
write_persona_manifest(root / "persona.json", manifest)
|
||||
return self.persona_store.save_persona(
|
||||
manifest,
|
||||
source_dir=root,
|
||||
source="wechat_import",
|
||||
replace=True,
|
||||
)
|
||||
655
opentalking/persona/weflow_parser.py
Normal file
655
opentalking/persona/weflow_parser.py
Normal file
@@ -0,0 +1,655 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone as dt_timezone
|
||||
from html.parser import HTMLParser
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from urllib.parse import urlparse
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
|
||||
WeFlowDetectedFormat = Literal["chatlab_json", "raw_json", "csv", "txt", "html"]
|
||||
_SUPPORTED_SUFFIXES = {".json", ".csv", ".txt", ".html", ".htm"}
|
||||
_JSON_PRIORITY = {".json": 0, ".csv": 1, ".txt": 2, ".html": 3, ".htm": 3}
|
||||
_SELF_NAMES = {"\u6211", "me", "self", "\u81ea\u5df1", "\u672c\u4eba"}
|
||||
|
||||
|
||||
class WeFlowParseError(ValueError):
|
||||
"""Raised when an uploaded WeFlow export cannot be parsed safely."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WeFlowTurn:
|
||||
message_id: str | None
|
||||
speaker_id: str
|
||||
speaker_name: str
|
||||
content: str
|
||||
timestamp: str | None = None
|
||||
is_self: bool = False
|
||||
message_type: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WeFlowSpeaker:
|
||||
id: str
|
||||
name: str
|
||||
message_count: int
|
||||
is_self: bool = False
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WeFlowExport:
|
||||
conversation_id: str | None
|
||||
detected_format: WeFlowDetectedFormat
|
||||
turns: list[WeFlowTurn]
|
||||
speakers: list[WeFlowSpeaker]
|
||||
source_metadata: dict[str, Any]
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def parse_weflow_export(
|
||||
path: str | Path,
|
||||
*,
|
||||
source_format: str = "auto",
|
||||
timezone: str = "Asia/Shanghai",
|
||||
conversation_id: str | None = None,
|
||||
) -> WeFlowExport:
|
||||
"""Parse a user-uploaded WeFlow export file into normalized chat turns.
|
||||
|
||||
This function intentionally accepts files only. It rejects URL/API inputs so the product
|
||||
remains an import flow instead of a WeChat/WeFlow connector.
|
||||
"""
|
||||
|
||||
file_path = _validate_local_file(path)
|
||||
tz = _zoneinfo(timezone)
|
||||
if file_path.suffix.lower() == ".zip":
|
||||
return _parse_zip(file_path, source_format=source_format, timezone=tz, conversation_id=conversation_id)
|
||||
data = file_path.read_bytes()
|
||||
metadata = {"source_name": file_path.name, "byte_size": len(data)}
|
||||
return _parse_payload(
|
||||
data,
|
||||
source_name=file_path.name,
|
||||
source_format=source_format,
|
||||
timezone=tz,
|
||||
conversation_id=conversation_id,
|
||||
source_metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _validate_local_file(path: str | Path) -> Path:
|
||||
raw = str(path)
|
||||
parsed = urlparse(raw)
|
||||
if parsed.scheme in {"http", "https", "ws", "wss"}:
|
||||
raise WeFlowParseError("please upload an exported WeFlow file instead of a WeFlow API URL")
|
||||
file_path = Path(path)
|
||||
if not file_path.is_file():
|
||||
raise WeFlowParseError("WeFlow export file not found")
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix != ".zip" and suffix not in _SUPPORTED_SUFFIXES:
|
||||
raise WeFlowParseError(f"unsupported WeFlow export file type: {suffix or '<none>'}")
|
||||
return file_path
|
||||
|
||||
|
||||
def _zoneinfo(name: str) -> ZoneInfo:
|
||||
try:
|
||||
return ZoneInfo(name or "Asia/Shanghai")
|
||||
except ZoneInfoNotFoundError as exc:
|
||||
raise WeFlowParseError(f"unsupported timezone: {name}") from exc
|
||||
|
||||
|
||||
def _parse_zip(
|
||||
path: Path,
|
||||
*,
|
||||
source_format: str,
|
||||
timezone: ZoneInfo,
|
||||
conversation_id: str | None,
|
||||
) -> WeFlowExport:
|
||||
with zipfile.ZipFile(path) as zf:
|
||||
candidates = [info for info in zf.infolist() if not info.is_dir()]
|
||||
candidates = [info for info in candidates if Path(info.filename).suffix.lower() in _SUPPORTED_SUFFIXES]
|
||||
if not candidates:
|
||||
raise WeFlowParseError("unsupported WeFlow zip export: no JSON/CSV/TXT/HTML member found")
|
||||
candidates.sort(key=lambda info: (_JSON_PRIORITY[Path(info.filename).suffix.lower()], info.filename))
|
||||
selected = candidates[0]
|
||||
data = zf.read(selected)
|
||||
metadata = {
|
||||
"source_name": path.name,
|
||||
"archive_member": selected.filename,
|
||||
"byte_size": len(data),
|
||||
}
|
||||
return _parse_payload(
|
||||
data,
|
||||
source_name=selected.filename,
|
||||
source_format=source_format,
|
||||
timezone=timezone,
|
||||
conversation_id=conversation_id,
|
||||
source_metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _parse_payload(
|
||||
data: bytes,
|
||||
*,
|
||||
source_name: str,
|
||||
source_format: str,
|
||||
timezone: ZoneInfo,
|
||||
conversation_id: str | None,
|
||||
source_metadata: dict[str, Any],
|
||||
) -> WeFlowExport:
|
||||
suffix = Path(source_name).suffix.lower()
|
||||
fmt = (source_format or "auto").strip().lower()
|
||||
if fmt == "auto":
|
||||
fmt = _format_from_suffix(suffix)
|
||||
if fmt in {"json", "chatlab_json", "raw_json"}:
|
||||
payload = _load_json(data)
|
||||
return _parse_json_payload(
|
||||
payload,
|
||||
requested_format=fmt,
|
||||
timezone=timezone,
|
||||
conversation_id=conversation_id,
|
||||
source_metadata=source_metadata,
|
||||
)
|
||||
text = _decode_text(data)
|
||||
if fmt == "csv":
|
||||
return _parse_csv_text(text, timezone=timezone, conversation_id=conversation_id, source_metadata=source_metadata)
|
||||
if fmt == "txt":
|
||||
return _parse_txt_text(text, timezone=timezone, conversation_id=conversation_id, source_metadata=source_metadata)
|
||||
if fmt in {"html", "htm"}:
|
||||
return _parse_html_text(text, timezone=timezone, conversation_id=conversation_id, source_metadata=source_metadata)
|
||||
raise WeFlowParseError(f"unsupported WeFlow export format: {source_format}")
|
||||
|
||||
|
||||
def _format_from_suffix(suffix: str) -> str:
|
||||
if suffix == ".json":
|
||||
return "json"
|
||||
if suffix == ".csv":
|
||||
return "csv"
|
||||
if suffix == ".txt":
|
||||
return "txt"
|
||||
if suffix in {".html", ".htm"}:
|
||||
return "html"
|
||||
raise WeFlowParseError(f"unsupported WeFlow export file type: {suffix or '<none>'}")
|
||||
|
||||
|
||||
def _load_json(data: bytes) -> Any:
|
||||
try:
|
||||
return json.loads(_decode_text(data))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise WeFlowParseError("invalid WeFlow JSON export") from exc
|
||||
|
||||
|
||||
def _decode_text(data: bytes) -> str:
|
||||
for encoding in ("utf-8-sig", "utf-8", "gb18030"):
|
||||
try:
|
||||
return data.decode(encoding)
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return data.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def _parse_json_payload(
|
||||
payload: Any,
|
||||
*,
|
||||
requested_format: str,
|
||||
timezone: ZoneInfo,
|
||||
conversation_id: str | None,
|
||||
source_metadata: dict[str, Any],
|
||||
) -> WeFlowExport:
|
||||
if requested_format == "chatlab_json" or _looks_like_chatlab(payload):
|
||||
return _parse_chatlab_json(payload, timezone=timezone, conversation_id=conversation_id, source_metadata=source_metadata)
|
||||
return _parse_raw_json(payload, timezone=timezone, conversation_id=conversation_id, source_metadata=source_metadata)
|
||||
|
||||
|
||||
def _looks_like_chatlab(payload: Any) -> bool:
|
||||
if not isinstance(payload, dict):
|
||||
return False
|
||||
if "chatlab" in payload:
|
||||
return True
|
||||
messages = _messages_from_payload(payload)
|
||||
return any(isinstance(item, dict) and ("platformMessageId" in item or "sender" in item) for item in messages[:5])
|
||||
|
||||
|
||||
def _parse_chatlab_json(
|
||||
payload: Any,
|
||||
*,
|
||||
timezone: ZoneInfo,
|
||||
conversation_id: str | None,
|
||||
source_metadata: dict[str, Any],
|
||||
) -> WeFlowExport:
|
||||
if not isinstance(payload, dict):
|
||||
raise WeFlowParseError("ChatLab JSON export must be an object")
|
||||
messages = _messages_from_payload(payload)
|
||||
member_names = _member_name_map(payload.get("members"))
|
||||
raw_meta = payload.get("meta")
|
||||
meta: dict[Any, Any] = raw_meta if isinstance(raw_meta, dict) else {}
|
||||
cid = conversation_id or _string(meta.get("groupId") or meta.get("talker") or meta.get("id") or meta.get("name"))
|
||||
turns: list[WeFlowTurn] = []
|
||||
warnings: list[str] = []
|
||||
for index, raw in enumerate(messages):
|
||||
if not isinstance(raw, dict):
|
||||
warnings.append(f"skip non-object message at index {index}")
|
||||
continue
|
||||
content = _message_content(raw)
|
||||
if not content:
|
||||
continue
|
||||
speaker_id = _string(raw.get("sender") or raw.get("senderId") or raw.get("platformId") or raw.get("from")) or "unknown"
|
||||
speaker_name = _display_name(raw, fallback=member_names.get(speaker_id) or speaker_id)
|
||||
turns.append(
|
||||
WeFlowTurn(
|
||||
message_id=_string(raw.get("platformMessageId") or raw.get("messageId") or raw.get("id")),
|
||||
speaker_id=speaker_id,
|
||||
speaker_name=speaker_name,
|
||||
content=content,
|
||||
timestamp=_normalize_timestamp(raw.get("timestamp") or raw.get("sendTime") or raw.get("time"), timezone),
|
||||
is_self=_is_truthy(raw.get("isSelf") or raw.get("isSend")) or speaker_name.strip().lower() in _SELF_NAMES,
|
||||
message_type=_string(raw.get("type")),
|
||||
metadata=_compact_dict(
|
||||
{
|
||||
"accountName": raw.get("accountName"),
|
||||
"groupNickname": raw.get("groupNickname"),
|
||||
"avatar": raw.get("avatar"),
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
return _build_export(
|
||||
detected_format="chatlab_json",
|
||||
conversation_id=cid,
|
||||
turns=turns,
|
||||
source_metadata=source_metadata,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
def _member_name_map(raw_members: Any) -> dict[str, str]:
|
||||
if not isinstance(raw_members, list):
|
||||
return {}
|
||||
out: dict[str, str] = {}
|
||||
for member in raw_members:
|
||||
if not isinstance(member, dict):
|
||||
continue
|
||||
member_id = _string(member.get("platformId") or member.get("id") or member.get("wxid"))
|
||||
if member_id:
|
||||
out[member_id] = _display_name(member, fallback=member_id)
|
||||
return out
|
||||
|
||||
|
||||
def _parse_raw_json(
|
||||
payload: Any,
|
||||
*,
|
||||
timezone: ZoneInfo,
|
||||
conversation_id: str | None,
|
||||
source_metadata: dict[str, Any],
|
||||
) -> WeFlowExport:
|
||||
messages = _messages_from_payload(payload)
|
||||
if not messages:
|
||||
raise WeFlowParseError("WeFlow JSON export does not contain messages")
|
||||
cid = conversation_id
|
||||
if cid is None and isinstance(payload, dict):
|
||||
cid = _string(payload.get("talker") or payload.get("talkerId") or payload.get("conversationId"))
|
||||
turns: list[WeFlowTurn] = []
|
||||
warnings: list[str] = []
|
||||
for index, raw in enumerate(messages):
|
||||
if not isinstance(raw, dict):
|
||||
warnings.append(f"skip non-object message at index {index}")
|
||||
continue
|
||||
content = _message_content(raw)
|
||||
if not content:
|
||||
continue
|
||||
is_self = _is_truthy(raw.get("isSend") or raw.get("isSelf"))
|
||||
speaker_id = _string(raw.get("senderUsername") or raw.get("sender") or raw.get("fromUser"))
|
||||
if not speaker_id:
|
||||
speaker_id = "self" if is_self else _string(cid) or "unknown"
|
||||
speaker_name = _display_name(raw, fallback=("\u6211" if is_self else speaker_id))
|
||||
turns.append(
|
||||
WeFlowTurn(
|
||||
message_id=_string(raw.get("serverId") or raw.get("platformMessageId") or raw.get("msgId") or raw.get("id")),
|
||||
speaker_id=speaker_id,
|
||||
speaker_name=speaker_name,
|
||||
content=content,
|
||||
timestamp=_normalize_timestamp(raw.get("createTime") or raw.get("timestamp") or raw.get("sendTime"), timezone),
|
||||
is_self=is_self,
|
||||
message_type=_string(raw.get("type") or raw.get("msgType")),
|
||||
metadata=_compact_dict(
|
||||
{
|
||||
"localId": raw.get("localId"),
|
||||
"talker": raw.get("talker"),
|
||||
"accountName": raw.get("accountName"),
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
return _build_export(
|
||||
detected_format="raw_json",
|
||||
conversation_id=cid,
|
||||
turns=turns,
|
||||
source_metadata=source_metadata,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
def _messages_from_payload(payload: Any) -> list[Any]:
|
||||
if isinstance(payload, list):
|
||||
return payload
|
||||
if not isinstance(payload, dict):
|
||||
return []
|
||||
for key in ("messages", "data", "items", "records", "list"):
|
||||
value = payload.get(key)
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
nested = _messages_from_payload(value)
|
||||
if nested:
|
||||
return nested
|
||||
return []
|
||||
|
||||
|
||||
def _parse_csv_text(
|
||||
text: str,
|
||||
*,
|
||||
timezone: ZoneInfo,
|
||||
conversation_id: str | None,
|
||||
source_metadata: dict[str, Any],
|
||||
) -> WeFlowExport:
|
||||
sample = text[:4096]
|
||||
try:
|
||||
dialect = csv.Sniffer().sniff(sample)
|
||||
except csv.Error:
|
||||
dialect = csv.excel
|
||||
reader = csv.DictReader(io.StringIO(text), dialect=dialect)
|
||||
turns: list[WeFlowTurn] = []
|
||||
for row in reader:
|
||||
normalized = {_normalize_key(key): value for key, value in row.items() if key is not None}
|
||||
content = _first(normalized, "parsedcontent", "content", "msgcontent", "text")
|
||||
if not content:
|
||||
continue
|
||||
is_self = _is_truthy(_first(normalized, "issend", "isself", "is_send"))
|
||||
speaker_id = _first(normalized, "senderusername", "sender", "fromuser", "wxid") or ("self" if is_self else "unknown")
|
||||
speaker_name = _first(normalized, "groupnickname", "accountname", "nickname", "sendername") or ("\u6211" if is_self else speaker_id)
|
||||
turns.append(
|
||||
WeFlowTurn(
|
||||
message_id=_first(normalized, "serverid", "platformmessageid", "msgid", "id", "localid"),
|
||||
speaker_id=speaker_id,
|
||||
speaker_name=speaker_name,
|
||||
content=content.strip(),
|
||||
timestamp=_normalize_timestamp(_first(normalized, "createtime", "timestamp", "sendtime", "time"), timezone),
|
||||
is_self=is_self or speaker_name.strip().lower() in _SELF_NAMES,
|
||||
message_type=_first(normalized, "type", "msgtype"),
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
return _build_export(
|
||||
detected_format="csv",
|
||||
conversation_id=conversation_id,
|
||||
turns=turns,
|
||||
source_metadata=source_metadata,
|
||||
warnings=[],
|
||||
)
|
||||
|
||||
|
||||
_TXT_LINE_RE = re.compile(
|
||||
r"^\s*(?:\[(?P<bracket_time>[^\]]+)\]|(?P<plain_time>\d{4}[-/]\d{1,2}[-/]\d{1,2}\s+\d{1,2}:\d{2}(?::\d{2})?))\s*(?P<speaker>[^:\uff1a]+)[:\uff1a]\s*(?P<content>.*)$"
|
||||
)
|
||||
|
||||
|
||||
def _parse_txt_text(
|
||||
text: str,
|
||||
*,
|
||||
timezone: ZoneInfo,
|
||||
conversation_id: str | None,
|
||||
source_metadata: dict[str, Any],
|
||||
) -> WeFlowExport:
|
||||
turns: list[WeFlowTurn] = []
|
||||
for line in text.splitlines():
|
||||
match = _TXT_LINE_RE.match(line)
|
||||
if match:
|
||||
speaker_name = match.group("speaker").strip()
|
||||
is_self = speaker_name.lower() in _SELF_NAMES
|
||||
timestamp = match.group("bracket_time") or match.group("plain_time")
|
||||
turns.append(
|
||||
WeFlowTurn(
|
||||
message_id=None,
|
||||
speaker_id=_speaker_id_from_name(speaker_name),
|
||||
speaker_name=speaker_name,
|
||||
content=match.group("content").strip(),
|
||||
timestamp=_normalize_timestamp(timestamp, timezone),
|
||||
is_self=is_self,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
elif turns and line.strip():
|
||||
previous = turns[-1]
|
||||
turns[-1] = WeFlowTurn(
|
||||
message_id=previous.message_id,
|
||||
speaker_id=previous.speaker_id,
|
||||
speaker_name=previous.speaker_name,
|
||||
content=f"{previous.content}\n{line.strip()}",
|
||||
timestamp=previous.timestamp,
|
||||
is_self=previous.is_self,
|
||||
message_type=previous.message_type,
|
||||
metadata=previous.metadata,
|
||||
)
|
||||
return _build_export(
|
||||
detected_format="txt",
|
||||
conversation_id=conversation_id,
|
||||
turns=turns,
|
||||
source_metadata=source_metadata,
|
||||
warnings=[],
|
||||
)
|
||||
|
||||
|
||||
class _MessageHtmlParser(HTMLParser):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(convert_charrefs=True)
|
||||
self.rows: list[dict[str, str]] = []
|
||||
self._row: dict[str, str] | None = None
|
||||
self._field: str | None = None
|
||||
self._buffer: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
attrs_map = {key.lower(): value or "" for key, value in attrs}
|
||||
classes = set(attrs_map.get("class", "").lower().split())
|
||||
if tag.lower() == "tr":
|
||||
self._row = {}
|
||||
field = None
|
||||
if "time" in classes or "timestamp" in classes or "date" in classes:
|
||||
field = "time"
|
||||
elif "sender" in classes or "speaker" in classes or "name" in classes:
|
||||
field = "sender"
|
||||
elif "content" in classes or "message" in classes or "text" in classes:
|
||||
field = "content"
|
||||
if field:
|
||||
if self._row is None:
|
||||
self._row = {}
|
||||
self._field = field
|
||||
self._buffer = []
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
if self._field:
|
||||
self._buffer.append(data)
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if self._field and tag.lower() in {"td", "div", "span", "p"}:
|
||||
text = "".join(self._buffer).strip()
|
||||
if self._row is not None and text:
|
||||
self._row[self._field] = text
|
||||
self._field = None
|
||||
self._buffer = []
|
||||
if tag.lower() == "tr" and self._row is not None:
|
||||
if self._row.get("sender") and self._row.get("content"):
|
||||
self.rows.append(self._row)
|
||||
self._row = None
|
||||
self._field = None
|
||||
self._buffer = []
|
||||
|
||||
|
||||
def _parse_html_text(
|
||||
text: str,
|
||||
*,
|
||||
timezone: ZoneInfo,
|
||||
conversation_id: str | None,
|
||||
source_metadata: dict[str, Any],
|
||||
) -> WeFlowExport:
|
||||
parser = _MessageHtmlParser()
|
||||
parser.feed(text)
|
||||
turns: list[WeFlowTurn] = []
|
||||
for row in parser.rows:
|
||||
speaker_name = row["sender"].strip()
|
||||
is_self = speaker_name.lower() in _SELF_NAMES
|
||||
turns.append(
|
||||
WeFlowTurn(
|
||||
message_id=None,
|
||||
speaker_id=_speaker_id_from_name(speaker_name),
|
||||
speaker_name=speaker_name,
|
||||
content=row["content"].strip(),
|
||||
timestamp=_normalize_timestamp(row.get("time"), timezone),
|
||||
is_self=is_self,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
return _build_export(
|
||||
detected_format="html",
|
||||
conversation_id=conversation_id,
|
||||
turns=turns,
|
||||
source_metadata=source_metadata,
|
||||
warnings=[],
|
||||
)
|
||||
|
||||
|
||||
def _build_export(
|
||||
*,
|
||||
detected_format: WeFlowDetectedFormat,
|
||||
conversation_id: str | None,
|
||||
turns: list[WeFlowTurn],
|
||||
source_metadata: dict[str, Any],
|
||||
warnings: list[str],
|
||||
) -> WeFlowExport:
|
||||
if not turns:
|
||||
raise WeFlowParseError("WeFlow export does not contain readable chat messages")
|
||||
counts: dict[str, int] = {}
|
||||
names: dict[str, str] = {}
|
||||
self_flags: dict[str, bool] = {}
|
||||
speaker_metadata: dict[str, dict[str, Any]] = {}
|
||||
for turn in turns:
|
||||
counts[turn.speaker_id] = counts.get(turn.speaker_id, 0) + 1
|
||||
names.setdefault(turn.speaker_id, turn.speaker_name)
|
||||
self_flags[turn.speaker_id] = self_flags.get(turn.speaker_id, False) or turn.is_self
|
||||
speaker_metadata.setdefault(turn.speaker_id, {})
|
||||
if turn.metadata:
|
||||
speaker_metadata[turn.speaker_id].update(turn.metadata)
|
||||
speakers = [
|
||||
WeFlowSpeaker(
|
||||
id=speaker_id,
|
||||
name=names[speaker_id],
|
||||
message_count=counts[speaker_id],
|
||||
is_self=self_flags.get(speaker_id, False),
|
||||
metadata=speaker_metadata.get(speaker_id, {}),
|
||||
)
|
||||
for speaker_id in sorted(counts, key=lambda item: (-counts[item], names[item]))
|
||||
]
|
||||
return WeFlowExport(
|
||||
conversation_id=conversation_id,
|
||||
detected_format=detected_format,
|
||||
turns=turns,
|
||||
speakers=speakers,
|
||||
source_metadata=source_metadata,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
|
||||
def _display_name(raw: dict[str, Any], *, fallback: str) -> str:
|
||||
return (
|
||||
_string(raw.get("groupNickname"))
|
||||
or _string(raw.get("displayName"))
|
||||
or _string(raw.get("accountName"))
|
||||
or _string(raw.get("nickname"))
|
||||
or _string(raw.get("senderName"))
|
||||
or fallback
|
||||
)
|
||||
|
||||
|
||||
def _message_content(raw: dict[str, Any]) -> str:
|
||||
content = _string(
|
||||
raw.get("parsedContent")
|
||||
or raw.get("content")
|
||||
or raw.get("text")
|
||||
or raw.get("message")
|
||||
or raw.get("msgContent")
|
||||
)
|
||||
return content.strip() if content else ""
|
||||
|
||||
|
||||
def _normalize_timestamp(value: Any, timezone: ZoneInfo) -> str | None:
|
||||
if value is None or value == "":
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
raw = float(value)
|
||||
if raw > 10_000_000_000:
|
||||
raw = raw / 1000.0
|
||||
return datetime.fromtimestamp(raw, dt_timezone.utc).astimezone(timezone).isoformat(timespec="seconds")
|
||||
text = str(value).strip()
|
||||
if not text:
|
||||
return None
|
||||
if re.fullmatch(r"\d+(?:\.\d+)?", text):
|
||||
return _normalize_timestamp(float(text), timezone)
|
||||
normalized = text.replace("/", "-").replace("T", " ")
|
||||
for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M", "%Y-%m-%d"):
|
||||
try:
|
||||
parsed = datetime.strptime(normalized[: len(datetime.now().strftime(fmt))], fmt)
|
||||
return parsed.replace(tzinfo=timezone).isoformat(timespec="seconds")
|
||||
except ValueError:
|
||||
continue
|
||||
try:
|
||||
parsed_iso = datetime.fromisoformat(text)
|
||||
except ValueError:
|
||||
return None
|
||||
if parsed_iso.tzinfo is None:
|
||||
parsed_iso = parsed_iso.replace(tzinfo=timezone)
|
||||
return parsed_iso.astimezone(timezone).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def _normalize_key(value: str) -> str:
|
||||
return re.sub(r"[^a-z0-9]", "", value.strip().lower())
|
||||
|
||||
|
||||
def _first(row: dict[str, str | None], *keys: str) -> str | None:
|
||||
for key in keys:
|
||||
value = row.get(_normalize_key(key))
|
||||
if value is not None and str(value).strip():
|
||||
return str(value).strip()
|
||||
return None
|
||||
|
||||
|
||||
def _string(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
text = str(value).strip()
|
||||
return text or None
|
||||
|
||||
|
||||
def _is_truthy(value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, (int, float)):
|
||||
return value != 0
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on", "y"}
|
||||
return False
|
||||
|
||||
|
||||
def _speaker_id_from_name(name: str) -> str:
|
||||
if name.strip().lower() in _SELF_NAMES:
|
||||
return "self"
|
||||
return re.sub(r"[^A-Za-z0-9_\-]+", "_", name.strip()).strip("_") or "unknown"
|
||||
|
||||
|
||||
def _compact_dict(raw: dict[str, Any]) -> dict[str, Any]:
|
||||
return {key: value for key, value in raw.items() if value not in (None, "")}
|
||||
18
opentalking/providers/memory/import_jobs.py
Normal file
18
opentalking/providers/memory/import_jobs.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
|
||||
ImportJobStatus = Literal["needs_speaker_selection", "draft_ready", "committed", "error"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemoryImportCommitResult:
|
||||
job_id: str
|
||||
persona_id: str
|
||||
memory_imported: int
|
||||
persona_md_bytes: int
|
||||
profile_id: str
|
||||
character_id: str
|
||||
memory_library_id: str
|
||||
307
scripts/smoke_wechat_memory_persona.py
Normal file
307
scripts/smoke_wechat_memory_persona.py
Normal file
@@ -0,0 +1,307 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# ruff: noqa: E402 - smoke script adds the repo root before importing local packages.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
from opentalking.core.config import get_settings
|
||||
from opentalking.persona.session import build_session_defaults
|
||||
from opentalking.persona.store import PersonaStore
|
||||
from opentalking.persona.wechat_import import WeChatImportJobRegistry
|
||||
from opentalking.providers.memory.mem0_provider import InMemoryMemoryProvider
|
||||
from opentalking.providers.memory.runtime import MemoryRuntime, normalize_memory_scope
|
||||
|
||||
|
||||
|
||||
def utc_now() -> str:
|
||||
return datetime.now(timezone.utc).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def to_jsonable(value: Any) -> Any:
|
||||
if is_dataclass(value):
|
||||
return to_jsonable(asdict(value))
|
||||
if isinstance(value, Path):
|
||||
return str(value)
|
||||
if isinstance(value, dict):
|
||||
return {str(k): to_jsonable(v) for k, v in value.items()}
|
||||
if isinstance(value, list | tuple):
|
||||
return [to_jsonable(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def fake_weflow_chatlab_export() -> dict[str, Any]:
|
||||
return {
|
||||
"chatlab": {"version": "0.0.2", "generator": "WeFlow"},
|
||||
"meta": {
|
||||
"name": "Smoke project room",
|
||||
"platform": "wechat",
|
||||
"type": "group",
|
||||
"groupId": "room-smoke@chatroom",
|
||||
},
|
||||
"members": [
|
||||
{"platformId": "wxid_li", "accountName": "Li", "groupNickname": "Li"},
|
||||
{"platformId": "wxid_chen", "accountName": "Chen", "groupNickname": "Chen"},
|
||||
{"platformId": "self_wxid", "accountName": "me", "groupNickname": "me"},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"sender": "wxid_li",
|
||||
"accountName": "Li",
|
||||
"groupNickname": "Li",
|
||||
"timestamp": 1738713600,
|
||||
"type": 0,
|
||||
"content": "morning, breathe first. keep it small today.",
|
||||
"platformMessageId": "10000000000000000001",
|
||||
},
|
||||
{
|
||||
"sender": "self_wxid",
|
||||
"accountName": "me",
|
||||
"groupNickname": "me",
|
||||
"timestamp": 1738713660,
|
||||
"type": 0,
|
||||
"content": "I am nervous about the demo.",
|
||||
"platformMessageId": "10000000000000000002",
|
||||
"isSelf": True,
|
||||
},
|
||||
{
|
||||
"sender": "wxid_chen",
|
||||
"accountName": "Chen",
|
||||
"groupNickname": "Chen",
|
||||
"timestamp": 1738713720,
|
||||
"type": 0,
|
||||
"content": "ship now, no need to wait.",
|
||||
"platformMessageId": "10000000000000000003",
|
||||
},
|
||||
{
|
||||
"sender": "wxid_li",
|
||||
"accountName": "Li",
|
||||
"groupNickname": "Li",
|
||||
"timestamp": 1738713780,
|
||||
"type": 0,
|
||||
"content": "then make one tiny checklist and ping me after lunch.",
|
||||
"platformMessageId": "10000000000000000004",
|
||||
},
|
||||
{
|
||||
"sender": "wxid_li",
|
||||
"accountName": "Li",
|
||||
"groupNickname": "Li",
|
||||
"timestamp": 1738713840,
|
||||
"type": 0,
|
||||
"content": "no need to be heroic, steady is enough.",
|
||||
"platformMessageId": "10000000000000000005",
|
||||
},
|
||||
{
|
||||
"sender": "wxid_li",
|
||||
"accountName": "Li",
|
||||
"groupNickname": "Li",
|
||||
"timestamp": 1738713900,
|
||||
"type": 0,
|
||||
"content": "demo secret code is 8848 and should never be copied raw.",
|
||||
"platformMessageId": "10000000000000000006",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def run_smoke(report_dir: Path) -> dict[str, Any]:
|
||||
report_dir.mkdir(parents=True, exist_ok=True)
|
||||
workspace = Path(tempfile.mkdtemp(prefix="opentalking-wechat-smoke-", dir=str(report_dir)))
|
||||
export_path = workspace / "fake-weflow-chatlab.json"
|
||||
export_path.write_text(json.dumps(fake_weflow_chatlab_export(), ensure_ascii=True, indent=2) + "\n", encoding="utf-8")
|
||||
|
||||
persona_store = PersonaStore(workspace / "personas")
|
||||
memory_provider = InMemoryMemoryProvider()
|
||||
registry = WeChatImportJobRegistry(persona_store=persona_store, memory_provider=memory_provider)
|
||||
|
||||
checks: list[dict[str, Any]] = []
|
||||
|
||||
def check(name: str, condition: bool, detail: Any = None) -> None:
|
||||
checks.append({"name": name, "passed": bool(condition), "detail": detail})
|
||||
|
||||
job = registry.create_job(
|
||||
export_path,
|
||||
profile_id="smoke-profile",
|
||||
memory_library_id="wechat-smoke",
|
||||
avatar_id="smoke-avatar-li",
|
||||
avatar_model="mock",
|
||||
)
|
||||
check("job waits for speaker selection in group chat", job.status == "needs_speaker_selection", job.status)
|
||||
check("speakers parsed from uploaded WeFlow file", [s.id for s in job.speakers] == ["wxid_li", "wxid_chen"], [s.id for s in job.speakers])
|
||||
|
||||
selected = await registry.select_speaker_async(job.id, "wxid_li")
|
||||
check("speaker selection builds draft", selected.status == "draft_ready" and selected.draft is not None, selected.status)
|
||||
persona_md_preview = selected.draft.persona_md if selected.draft else ""
|
||||
check("persona draft is summary, not raw secret", "8848" not in persona_md_preview and "demo secret code" not in persona_md_preview)
|
||||
|
||||
commit = await registry.commit(job.id, persona_id="smoke-friend-li", persona_name="Smoke Friend Li")
|
||||
check("commit imports three layered memories", commit.memory_imported == 3, commit.memory_imported)
|
||||
|
||||
record = persona_store.get_persona("smoke-friend-li")
|
||||
persona_md = (record.path / "persona.md").read_text(encoding="utf-8")
|
||||
check("persona is bound to avatar asset", record.manifest.avatar.id == "smoke-avatar-li", record.manifest.avatar.id)
|
||||
check("persona manifest points to persona.md", record.manifest.agent.persona_prompt == "persona.md", record.manifest.agent.persona_prompt)
|
||||
check("persona.md exists and is redacted", "8848" not in persona_md and "demo secret code" not in persona_md)
|
||||
|
||||
defaults = build_session_defaults(record)
|
||||
check("session defaults load persona.md into system prompt", bool(defaults.llm_system_prompt and "# Persona" in defaults.llm_system_prompt), defaults.llm_system_prompt[:120] if defaults.llm_system_prompt else None)
|
||||
|
||||
items = await memory_provider.list_items(
|
||||
library_id="wechat-smoke",
|
||||
profile_id="smoke-profile",
|
||||
character_id="smoke-avatar-li",
|
||||
)
|
||||
memory_layers = sorted({str(item.metadata.get("layer")) for item in items})
|
||||
check("memory items carry style/semantic/episodic layers", memory_layers == ["episodic", "semantic", "style"], memory_layers)
|
||||
check("memory items are structured imported records", all(item.metadata.get("source") == "wechat_import" for item in items))
|
||||
check("memory items are redacted", all("8848" not in item.text and "demo secret code" not in item.text for item in items))
|
||||
|
||||
settings = get_settings()
|
||||
scope = normalize_memory_scope(
|
||||
settings=settings,
|
||||
memory_enabled=True,
|
||||
profile_id="smoke-profile",
|
||||
character_id="smoke-avatar-li",
|
||||
avatar_id="smoke-avatar-li",
|
||||
library_id="wechat-smoke",
|
||||
)
|
||||
runtime = MemoryRuntime(scope=scope, provider=memory_provider, settings=settings)
|
||||
recall_query = "\u6309 Li \u7684 calm practical guidance \u98ce\u683c\u56de\u7b54\u8fd9\u6761 demo \u5b89\u6392"
|
||||
recall_prompt = await runtime.retrieve_prompt(recall_query)
|
||||
check("conversation memory runtime recalls imported memories", bool(recall_prompt.strip()) and "Li tends to use concise replies" in recall_prompt, recall_prompt)
|
||||
|
||||
result = {
|
||||
"generated_at": utc_now(),
|
||||
"workspace": str(workspace),
|
||||
"input": {
|
||||
"fake_weflow_export": str(export_path),
|
||||
"fake_only": "Only the WeFlow chat record is simulated; parser/import/persona/session/memory runtime use real project code.",
|
||||
},
|
||||
"job": selected.to_dict(),
|
||||
"commit": to_jsonable(commit),
|
||||
"persona": {
|
||||
"id": record.manifest.id,
|
||||
"name": record.manifest.name,
|
||||
"avatar": to_jsonable(record.manifest.avatar),
|
||||
"agent": to_jsonable(record.manifest.agent),
|
||||
"path": str(record.path),
|
||||
"persona_md": persona_md,
|
||||
"session_system_prompt": defaults.llm_system_prompt,
|
||||
},
|
||||
"memory": {
|
||||
"library_id": "wechat-smoke",
|
||||
"profile_id": "smoke-profile",
|
||||
"character_id": "smoke-avatar-li",
|
||||
"items": [to_jsonable(item) for item in items],
|
||||
"layers": memory_layers,
|
||||
"runtime_recall_query": recall_query,
|
||||
"runtime_recall_prompt": recall_prompt,
|
||||
},
|
||||
"checks": checks,
|
||||
"summary": {
|
||||
"passed": all(item["passed"] for item in checks),
|
||||
"passed_count": sum(1 for item in checks if item["passed"]),
|
||||
"failed_count": sum(1 for item in checks if not item["passed"]),
|
||||
},
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def write_markdown(result: dict[str, Any], out_path: Path) -> None:
|
||||
checks = result["checks"]
|
||||
failed = [item for item in checks if not item["passed"]]
|
||||
memory_items = result["memory"]["items"]
|
||||
lines = [
|
||||
"# WeChat Memory Persona Smoke Report",
|
||||
"",
|
||||
f"Generated at: `{result['generated_at']}`",
|
||||
"",
|
||||
"## Scope",
|
||||
"",
|
||||
"This smoke test simulates only the uploaded WeFlow WeChat export file. All processing after upload uses the real OpenTalking feature code: parser, import job registry, persona store, persona.md loading, memory provider writes, and memory runtime recall.",
|
||||
"",
|
||||
"## Result",
|
||||
"",
|
||||
f"- Overall: {'PASS' if result['summary']['passed'] else 'FAIL'}",
|
||||
f"- Passed checks: {result['summary']['passed_count']}",
|
||||
f"- Failed checks: {result['summary']['failed_count']}",
|
||||
f"- Temp workspace on 146: `{result['workspace']}`",
|
||||
"",
|
||||
"## Key Evidence",
|
||||
"",
|
||||
f"- Persona id: `{result['persona']['id']}`",
|
||||
f"- Bound avatar asset: `{result['persona']['avatar']['id']}` / model `{result['persona']['avatar']['model']}`",
|
||||
f"- persona.md path: `{result['persona']['path']}/persona.md`",
|
||||
f"- Memory layers: `{', '.join(result['memory']['layers'])}`",
|
||||
f"- Runtime recall query: `{result['memory']['runtime_recall_query']}`",
|
||||
f"- Runtime recall prompt non-empty: `{bool(result['memory']['runtime_recall_prompt'])}`",
|
||||
"",
|
||||
"## Checks",
|
||||
"",
|
||||
]
|
||||
for item in checks:
|
||||
mark = "PASS" if item["passed"] else "FAIL"
|
||||
lines.append(f"- {mark}: {item['name']}")
|
||||
if item.get("detail") is not None and not item["passed"]:
|
||||
lines.append(f" Detail: `{item['detail']}`")
|
||||
lines.extend([
|
||||
"",
|
||||
"## Generated persona.md",
|
||||
"",
|
||||
"```markdown",
|
||||
result["persona"]["persona_md"].strip(),
|
||||
"```",
|
||||
"",
|
||||
"## Imported Memory Items",
|
||||
"",
|
||||
])
|
||||
for item in memory_items:
|
||||
metadata = item.get("metadata") or {}
|
||||
lines.append(f"- `{metadata.get('layer')}` / `{item.get('type')}`: {item.get('text')}")
|
||||
lines.extend([
|
||||
"",
|
||||
"## Runtime Recall Prompt",
|
||||
"",
|
||||
"```text",
|
||||
result["memory"]["runtime_recall_prompt"].strip(),
|
||||
"```",
|
||||
"",
|
||||
"## Analysis",
|
||||
"",
|
||||
"The end-to-end path is functioning if all checks pass: uploaded WeFlow data is parsed, group-speaker selection is required, the chosen speaker produces a redacted persona.md draft, commit persists a persona bound to the avatar asset, session defaults load persona.md into the system prompt, and the memory runtime can recall structured imported memories. Raw chat logs are not injected into prompts; only the generated persona summary and structured memory items are used at runtime.",
|
||||
])
|
||||
if failed:
|
||||
lines.extend(["", "## Failures", ""])
|
||||
for item in failed:
|
||||
lines.append(f"- {item['name']}: `{item.get('detail')}`")
|
||||
out_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Smoke test WeFlow WeChat import -> persona.md -> memory runtime.")
|
||||
parser.add_argument("--report-dir", default="/tmp/opentalking-wechat-smoke", help="Directory for generated reports")
|
||||
args = parser.parse_args()
|
||||
|
||||
report_dir = Path(args.report_dir)
|
||||
result = asyncio.run(run_smoke(report_dir))
|
||||
json_path = report_dir / "wechat-memory-persona-smoke-report.json"
|
||||
md_path = report_dir / "wechat-memory-persona-smoke-report.md"
|
||||
json_path.write_text(json.dumps(to_jsonable(result), ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
|
||||
write_markdown(result, md_path)
|
||||
print(json.dumps({"passed": result["summary"]["passed"], "json": str(json_path), "markdown": str(md_path)}, ensure_ascii=False))
|
||||
return 0 if result["summary"]["passed"] else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -55,8 +55,24 @@ def test_memory_panel_supports_select_and_manage_modes() -> None:
|
||||
assert 'mode?: "select" | "manage"' in source
|
||||
assert 'const isManageMode = mode === "manage"' in source
|
||||
assert "管理记忆库" in source
|
||||
assert "导入聊天记录" not in source
|
||||
assert "导入到当前记忆库" not in source
|
||||
assert "WeChatMemoryImportPanel" in source
|
||||
assert "importMemoryTurns" not in source
|
||||
assert "记忆条目" in source
|
||||
assert 'if (!isManageMode) return' in source
|
||||
|
||||
|
||||
def test_memory_panel_exposes_weflow_upload_without_api_fetch() -> None:
|
||||
panel_source = (ROOT / "apps/web/src/components/MemoryPanel.tsx").read_text(encoding="utf-8")
|
||||
import_source = (ROOT / "apps/web/src/components/WeChatMemoryImportPanel.tsx").read_text(encoding="utf-8")
|
||||
api_source = (ROOT / "apps/web/src/lib/api.ts").read_text(encoding="utf-8")
|
||||
types_source = (ROOT / "apps/web/src/types.ts").read_text(encoding="utf-8")
|
||||
|
||||
assert '<WeChatMemoryImportPanel' in panel_source
|
||||
assert 'accept=".json,.csv,.txt,.html,.htm,.zip"' in import_source
|
||||
assert "uploadWeChatImport" in import_source
|
||||
assert "selectWeChatImportSpeaker" in import_source
|
||||
assert "commitWeChatImportJob" in import_source
|
||||
assert "source_url" not in import_source
|
||||
assert "sourceUrl" not in api_source
|
||||
assert "WeChatImportJob" in types_source
|
||||
assert "WeChatImportCommitResult" in types_source
|
||||
|
||||
@@ -13,6 +13,7 @@ from opentalking.persona.package import (
|
||||
validate_persona_package,
|
||||
)
|
||||
from opentalking.persona.schema import persona_from_dict
|
||||
from opentalking.persona.session import build_session_defaults
|
||||
from opentalking.persona.store import PersonaStore
|
||||
|
||||
|
||||
@@ -113,3 +114,42 @@ def test_persona_package_imports_prompt_and_knowledge(tmp_path: Path) -> None:
|
||||
assert "企业客服数字人" in prompt
|
||||
bases = asyncio.run(knowledge_store.list_knowledge_bases())
|
||||
assert any(base.id == record.manifest.agent.knowledge_base_ids[0] for base in bases)
|
||||
|
||||
|
||||
|
||||
def test_persona_prompt_is_loaded_before_legacy_prompts(tmp_path: Path) -> None:
|
||||
source = tmp_path / "source"
|
||||
(source / "prompts").mkdir(parents=True)
|
||||
(source / "persona.md").write_text("# Persona\n你是小李,说话温柔。", encoding="utf-8")
|
||||
(source / "prompts" / "system.md").write_text("旧系统提示。", encoding="utf-8")
|
||||
(source / "prompts" / "style.md").write_text("旧风格提示。", encoding="utf-8")
|
||||
(source / "persona.json").write_text(
|
||||
"""
|
||||
{
|
||||
"schema_version": "0.1",
|
||||
"id": "friend-li",
|
||||
"name": "小李",
|
||||
"description": "微信导入生成的 Persona",
|
||||
"locale": "zh-CN",
|
||||
"avatar": {"id": "custom-friend-li", "model": "mock"},
|
||||
"agent": {
|
||||
"persona_prompt": "persona.md",
|
||||
"system_prompt": "prompts/system.md",
|
||||
"style_prompt": "prompts/style.md",
|
||||
"memory_enabled": true,
|
||||
"knowledge_enabled": false
|
||||
},
|
||||
"safety": {"authorized_avatar": true, "authorized_voice": false, "content_label_required": true}
|
||||
}
|
||||
""".strip()
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
package = tmp_path / "friend-li.otpersona"
|
||||
create_persona_package_from_dir(source, package)
|
||||
|
||||
record = asyncio.run(import_persona_package(package, store=PersonaStore(tmp_path / "personas")))
|
||||
|
||||
assert record.manifest.agent.persona_prompt == "persona.md"
|
||||
defaults = build_session_defaults(record)
|
||||
assert defaults.llm_system_prompt == "# Persona\n你是小李,说话温柔。\n\n旧系统提示。\n\n旧风格提示。"
|
||||
|
||||
128
tests/unit/test_wechat_import_jobs.py
Normal file
128
tests/unit/test_wechat_import_jobs.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from opentalking.persona import memory_builder
|
||||
from opentalking.persona.store import PersonaStore
|
||||
from opentalking.persona.wechat_import import WeChatImportJobRegistry
|
||||
from opentalking.providers.memory.mem0_provider import InMemoryMemoryProvider
|
||||
|
||||
|
||||
def write_group_export(path: Path) -> Path:
|
||||
path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"chatlab": {"version": "0.0.2", "generator": "WeFlow"},
|
||||
"meta": {"name": "Project room", "platform": "wechat", "type": "group", "groupId": "room@chatroom"},
|
||||
"members": [
|
||||
{"platformId": "wxid_li", "accountName": "Li", "groupNickname": "Li"},
|
||||
{"platformId": "wxid_chen", "accountName": "Chen", "groupNickname": "Chen"},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"sender": "wxid_li",
|
||||
"accountName": "Li",
|
||||
"groupNickname": "Li",
|
||||
"timestamp": 1738713600,
|
||||
"type": 0,
|
||||
"content": "morning, breathe first. keep it small today.",
|
||||
"platformMessageId": "1",
|
||||
},
|
||||
{
|
||||
"sender": "wxid_chen",
|
||||
"accountName": "Chen",
|
||||
"groupNickname": "Chen",
|
||||
"timestamp": 1738713660,
|
||||
"type": 0,
|
||||
"content": "ship it now, no need to wait.",
|
||||
"platformMessageId": "2",
|
||||
},
|
||||
{
|
||||
"sender": "wxid_li",
|
||||
"accountName": "Li",
|
||||
"groupNickname": "Li",
|
||||
"timestamp": 1738713720,
|
||||
"type": 0,
|
||||
"content": "demo secret code is 8848 and should never be copied raw.",
|
||||
"platformMessageId": "3",
|
||||
},
|
||||
],
|
||||
},
|
||||
ensure_ascii=True,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def test_wechat_import_job_requires_speaker_selection(tmp_path: Path) -> None:
|
||||
registry = WeChatImportJobRegistry(
|
||||
persona_store=PersonaStore(tmp_path / "personas"),
|
||||
memory_provider=InMemoryMemoryProvider(),
|
||||
)
|
||||
|
||||
job = registry.create_job(
|
||||
write_group_export(tmp_path / "weflow.json"),
|
||||
profile_id="default",
|
||||
memory_library_id="default",
|
||||
avatar_id="avatar-li",
|
||||
avatar_model="mock",
|
||||
)
|
||||
|
||||
assert job.status == "needs_speaker_selection"
|
||||
assert [speaker.id for speaker in job.speakers] == ["wxid_li", "wxid_chen"]
|
||||
assert job.draft is None
|
||||
|
||||
|
||||
def test_wechat_import_job_selects_speaker_and_commits_persona_and_memory(monkeypatch, tmp_path: Path) -> None:
|
||||
async def empty_complete(self, messages):
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(memory_builder._ConfiguredPersonaLLM, "complete", empty_complete)
|
||||
|
||||
provider = InMemoryMemoryProvider()
|
||||
store = PersonaStore(tmp_path / "personas")
|
||||
registry = WeChatImportJobRegistry(persona_store=store, memory_provider=provider)
|
||||
job = registry.create_job(
|
||||
write_group_export(tmp_path / "weflow.json"),
|
||||
profile_id="default",
|
||||
memory_library_id="default",
|
||||
avatar_id="avatar-li",
|
||||
avatar_model="mock",
|
||||
)
|
||||
|
||||
ready = registry.select_speaker(job.id, "wxid_li")
|
||||
|
||||
assert ready.status == "draft_ready"
|
||||
assert ready.selected_speaker_id == "wxid_li"
|
||||
assert ready.draft is not None
|
||||
assert "8848" not in ready.draft.persona_md
|
||||
|
||||
result = asyncio.run(
|
||||
registry.commit(
|
||||
job.id,
|
||||
persona_id="friend-li",
|
||||
persona_name="Friend Li",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.persona_id == "friend-li"
|
||||
assert result.memory_imported == 3
|
||||
assert result.persona_md_bytes > 0
|
||||
record = store.get_persona("friend-li")
|
||||
assert record.manifest.avatar.id == "avatar-li"
|
||||
assert record.manifest.avatar.model == "mock"
|
||||
assert record.manifest.agent.persona_prompt == "persona.md"
|
||||
assert record.manifest.agent.memory_enabled is True
|
||||
assert (record.path / "persona.md").is_file()
|
||||
assert "8848" not in (record.path / "persona.md").read_text(encoding="utf-8")
|
||||
|
||||
items = asyncio.run(
|
||||
provider.list_items(library_id="default", profile_id="default", character_id="avatar-li")
|
||||
)
|
||||
assert len(items) == 3
|
||||
assert {item.metadata["layer"] for item in items} == {"style", "semantic", "episodic"}
|
||||
assert all(item.metadata["source"] == "wechat_import" for item in items)
|
||||
assert all("8848" not in item.text for item in items)
|
||||
105
tests/unit/test_wechat_memory_builder.py
Normal file
105
tests/unit/test_wechat_memory_builder.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from opentalking.persona import memory_builder
|
||||
from opentalking.persona.memory_builder import build_wechat_persona_draft
|
||||
from opentalking.persona.weflow_parser import WeFlowExport, WeFlowSpeaker, WeFlowTurn
|
||||
|
||||
|
||||
def sample_export() -> WeFlowExport:
|
||||
return WeFlowExport(
|
||||
conversation_id="wxid_friend",
|
||||
detected_format="raw_json",
|
||||
source_metadata={"source_name": "weflow.json", "byte_size": 1234},
|
||||
speakers=[
|
||||
WeFlowSpeaker(id="wxid_friend", name="Li", message_count=4),
|
||||
WeFlowSpeaker(id="self_wxid", name="me", message_count=2, is_self=True),
|
||||
],
|
||||
turns=[
|
||||
WeFlowTurn(
|
||||
message_id="1",
|
||||
speaker_id="wxid_friend",
|
||||
speaker_name="Li",
|
||||
content="morning, breathe first. keep it small today.",
|
||||
timestamp="2025-02-05T08:00:00+08:00",
|
||||
),
|
||||
WeFlowTurn(
|
||||
message_id="2",
|
||||
speaker_id="self_wxid",
|
||||
speaker_name="me",
|
||||
content="I am nervous about the demo.",
|
||||
timestamp="2025-02-05T08:01:00+08:00",
|
||||
is_self=True,
|
||||
),
|
||||
WeFlowTurn(
|
||||
message_id="3",
|
||||
speaker_id="wxid_friend",
|
||||
speaker_name="Li",
|
||||
content="then make one tiny checklist and ping me after lunch.",
|
||||
timestamp="2025-02-05T08:02:00+08:00",
|
||||
),
|
||||
WeFlowTurn(
|
||||
message_id="4",
|
||||
speaker_id="wxid_friend",
|
||||
speaker_name="Li",
|
||||
content="no need to be heroic, steady is enough.",
|
||||
timestamp="2025-02-05T08:03:00+08:00",
|
||||
),
|
||||
WeFlowTurn(
|
||||
message_id="5",
|
||||
speaker_id="wxid_friend",
|
||||
speaker_name="Li",
|
||||
content="demo secret code is 8848 and should never be copied raw.",
|
||||
timestamp="2025-02-05T08:04:00+08:00",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_builder_filters_target_speaker_and_creates_safe_persona(monkeypatch) -> None:
|
||||
async def empty_complete(self, messages):
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr(memory_builder._ConfiguredPersonaLLM, "complete", empty_complete)
|
||||
|
||||
draft = build_wechat_persona_draft(sample_export(), target_speaker_id="wxid_friend")
|
||||
|
||||
assert draft.target_speaker.id == "wxid_friend"
|
||||
assert draft.target_speaker.name == "Li"
|
||||
assert draft.persona_name == "Li"
|
||||
assert "# Persona" in draft.persona_md
|
||||
assert "# Speaking Style" in draft.persona_md
|
||||
assert "wechat_import" in draft.source_metadata["source"]
|
||||
assert "I am nervous about the demo" not in draft.persona_md
|
||||
assert "8848" not in draft.persona_md
|
||||
assert "demo secret code" not in draft.persona_md
|
||||
|
||||
|
||||
def test_builder_returns_layered_memory_items_without_raw_transcript_copy() -> None:
|
||||
draft = build_wechat_persona_draft(sample_export(), target_speaker_id="wxid_friend")
|
||||
|
||||
types = {item.type for item in draft.memory_items}
|
||||
layers = {item.metadata["layer"] for item in draft.memory_items}
|
||||
assert {"preference", "summary", "note"}.issubset(types)
|
||||
assert {"style", "episodic", "semantic"}.issubset(layers)
|
||||
assert all(item.metadata["source"] == "wechat_import" for item in draft.memory_items)
|
||||
assert all(item.metadata["target_speaker_id"] == "wxid_friend" for item in draft.memory_items)
|
||||
assert all(item.metadata["confidence"] in {"low", "medium", "high"} for item in draft.memory_items)
|
||||
assert all("8848" not in item.text for item in draft.memory_items)
|
||||
assert all("demo secret code" not in item.text for item in draft.memory_items)
|
||||
|
||||
|
||||
def test_builder_can_use_llm_client_json_response() -> None:
|
||||
class FakeLLM:
|
||||
async def complete(self, messages):
|
||||
assert messages[0]["role"] == "system"
|
||||
return '{"persona_md":"# Persona\\nLi is concise.","style_memories":["uses gentle imperative"],"semantic_memories":["prefers small checklists"],"episodic_summaries":["supported a demo preparation chat"],"confidence":"high"}'
|
||||
|
||||
draft = build_wechat_persona_draft(
|
||||
sample_export(),
|
||||
target_speaker_id="wxid_friend",
|
||||
llm_client=FakeLLM(),
|
||||
)
|
||||
|
||||
assert draft.persona_md == "# Persona\nLi is concise."
|
||||
assert {item.metadata["confidence"] for item in draft.memory_items} == {"high"}
|
||||
assert any(item.text == "uses gentle imperative" for item in draft.memory_items)
|
||||
187
tests/unit/test_weflow_parser.py
Normal file
187
tests/unit/test_weflow_parser.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from opentalking.persona.weflow_parser import WeFlowParseError, parse_weflow_export
|
||||
|
||||
|
||||
def write_json(path: Path, payload: object) -> Path:
|
||||
path.write_text(json.dumps(payload, ensure_ascii=False), encoding="utf-8")
|
||||
return path
|
||||
|
||||
|
||||
def test_parse_chatlab_json_export(tmp_path: Path) -> None:
|
||||
path = write_json(
|
||||
tmp_path / "chatlab.json",
|
||||
{
|
||||
"chatlab": {"version": "0.0.2", "generator": "WeFlow"},
|
||||
"meta": {"name": "Project room", "platform": "wechat", "type": "group", "groupId": "room@chatroom"},
|
||||
"members": [
|
||||
{
|
||||
"platformId": "wxid_li",
|
||||
"accountName": "Li Si",
|
||||
"groupNickname": "Product",
|
||||
"avatar": "https://example.test/avatar.jpg",
|
||||
}
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"sender": "wxid_li",
|
||||
"accountName": "Li Si",
|
||||
"groupNickname": "Product",
|
||||
"timestamp": 1738713600,
|
||||
"type": 0,
|
||||
"content": "Where are we today?",
|
||||
"platformMessageId": "12345678901234567890",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
result = parse_weflow_export(path, timezone="Asia/Shanghai")
|
||||
|
||||
assert result.detected_format == "chatlab_json"
|
||||
assert result.conversation_id == "room@chatroom"
|
||||
assert result.source_metadata["source_name"] == "chatlab.json"
|
||||
assert result.speakers[0].id == "wxid_li"
|
||||
assert result.speakers[0].name == "Product"
|
||||
assert result.turns[0].message_id == "12345678901234567890"
|
||||
assert result.turns[0].speaker_id == "wxid_li"
|
||||
assert result.turns[0].speaker_name == "Product"
|
||||
assert result.turns[0].timestamp == "2025-02-05T08:00:00+08:00"
|
||||
assert result.turns[0].content == "Where are we today?"
|
||||
|
||||
|
||||
def test_parse_raw_weflow_json_export(tmp_path: Path) -> None:
|
||||
path = write_json(
|
||||
tmp_path / "raw.json",
|
||||
{
|
||||
"success": True,
|
||||
"talker": "wxid_friend",
|
||||
"messages": [
|
||||
{
|
||||
"localId": 1,
|
||||
"serverId": "6116895530414915131",
|
||||
"createTime": 1738713600,
|
||||
"isSend": 0,
|
||||
"senderUsername": "wxid_friend",
|
||||
"parsedContent": "I like your short replies.",
|
||||
"content": "ignored when parsedContent exists",
|
||||
},
|
||||
{
|
||||
"localId": 2,
|
||||
"serverId": "6116895530414915132",
|
||||
"createTime": 1738713660,
|
||||
"isSend": 1,
|
||||
"senderUsername": "self_wxid",
|
||||
"content": "Noted.",
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
result = parse_weflow_export(path)
|
||||
|
||||
assert result.detected_format == "raw_json"
|
||||
assert result.conversation_id == "wxid_friend"
|
||||
assert [turn.is_self for turn in result.turns] == [False, True]
|
||||
assert result.turns[0].message_id == "6116895530414915131"
|
||||
assert result.turns[0].content == "I like your short replies."
|
||||
assert result.turns[1].speaker_id == "self_wxid"
|
||||
|
||||
|
||||
def test_parse_csv_export_preserves_long_ids(tmp_path: Path) -> None:
|
||||
path = tmp_path / "messages.csv"
|
||||
path.write_text(
|
||||
"serverId,createTime,isSend,senderUsername,accountName,content\n"
|
||||
"9223372036854775807123,2025-02-05 08:00:00,0,wxid_li,Li,morning\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = parse_weflow_export(path, conversation_id="wxid_li")
|
||||
|
||||
assert result.detected_format == "csv"
|
||||
assert result.conversation_id == "wxid_li"
|
||||
assert result.turns[0].message_id == "9223372036854775807123"
|
||||
assert result.turns[0].speaker_name == "Li"
|
||||
|
||||
|
||||
def test_parse_txt_export(tmp_path: Path) -> None:
|
||||
path = tmp_path / "messages.txt"
|
||||
path.write_text(
|
||||
"[2025-02-05 08:00:00] Li: morning\n"
|
||||
"[2025-02-05 08:01:00] me\uff1aI am good today\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = parse_weflow_export(path)
|
||||
|
||||
assert result.detected_format == "txt"
|
||||
assert [turn.speaker_name for turn in result.turns] == ["Li", "me"]
|
||||
assert result.turns[0].content == "morning"
|
||||
assert result.turns[1].is_self is True
|
||||
|
||||
|
||||
def test_parse_html_export(tmp_path: Path) -> None:
|
||||
path = tmp_path / "messages.html"
|
||||
path.write_text(
|
||||
"""
|
||||
<html><body><table>
|
||||
<tr class="message"><td class="time">2025-02-05 08:00:00</td><td class="sender">Li</td><td class="content">hello & haha</td></tr>
|
||||
</table></body></html>
|
||||
""".strip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = parse_weflow_export(path)
|
||||
|
||||
assert result.detected_format == "html"
|
||||
assert result.turns[0].speaker_name == "Li"
|
||||
assert result.turns[0].content == "hello & haha"
|
||||
|
||||
|
||||
def test_zip_export_prefers_json_member(tmp_path: Path) -> None:
|
||||
archive = tmp_path / "weflow-export.zip"
|
||||
with zipfile.ZipFile(archive, "w") as zf:
|
||||
zf.writestr("messages.html", "<html><body>not preferred</body></html>")
|
||||
zf.writestr(
|
||||
"nested/messages.json",
|
||||
json.dumps(
|
||||
{
|
||||
"chatlab": {"version": "0.0.2", "generator": "WeFlow"},
|
||||
"meta": {"name": "Friend", "platform": "wechat", "type": "private"},
|
||||
"members": [],
|
||||
"messages": [
|
||||
{
|
||||
"sender": "wxid_li",
|
||||
"accountName": "Li",
|
||||
"timestamp": 1738713600,
|
||||
"type": 0,
|
||||
"content": "zip json member",
|
||||
"platformMessageId": "abc-1",
|
||||
}
|
||||
],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
|
||||
result = parse_weflow_export(archive)
|
||||
|
||||
assert result.detected_format == "chatlab_json"
|
||||
assert result.source_metadata["archive_member"] == "nested/messages.json"
|
||||
assert result.turns[0].content == "zip json member"
|
||||
|
||||
|
||||
def test_rejects_api_url_or_unsupported_file(tmp_path: Path) -> None:
|
||||
with pytest.raises(WeFlowParseError, match="upload"):
|
||||
parse_weflow_export("http://127.0.0.1:5031/api/v1/messages?talker=wxid_xxx")
|
||||
|
||||
unsupported = tmp_path / "avatar.png"
|
||||
unsupported.write_bytes(b"not a chat export")
|
||||
with pytest.raises(WeFlowParseError, match="unsupported"):
|
||||
parse_weflow_export(unsupported)
|
||||
Reference in New Issue
Block a user