mirror of
https://github.com/datascale-ai/opentalking.git
synced 2026-07-03 15:22:34 +08:00
fix: release mem0 provider on runtime refresh
This commit is contained in:
@@ -11,7 +11,7 @@ from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from apps.api.core.config import get_settings
|
||||
from opentalking.providers.memory.factory import build_memory_provider
|
||||
from opentalking.providers.memory.factory import close_cached_memory_provider
|
||||
from opentalking.providers.stt.factory import (
|
||||
clear_stt_adapter_cache,
|
||||
normalize_stt_provider,
|
||||
@@ -531,12 +531,14 @@ def _build_updates(payload: RuntimeConfigPayload) -> dict[str, str]:
|
||||
return updates
|
||||
|
||||
|
||||
def _refresh_settings(request: Request) -> Any:
|
||||
async def _refresh_settings(request: Request) -> Any:
|
||||
get_settings.cache_clear()
|
||||
settings = get_settings()
|
||||
request.app.state.settings = settings
|
||||
clear_stt_adapter_cache()
|
||||
build_memory_provider.cache_clear()
|
||||
await close_cached_memory_provider()
|
||||
if hasattr(request.app.state, "wechat_import_registry"):
|
||||
delattr(request.app.state, "wechat_import_registry")
|
||||
return settings
|
||||
|
||||
|
||||
@@ -580,7 +582,7 @@ async def apply_runtime_config(payload: RuntimeConfigPayload, request: Request)
|
||||
for key in _RUNTIME_ENV_KEYS:
|
||||
if key in values:
|
||||
os.environ[key] = values[key]
|
||||
settings = _refresh_settings(request)
|
||||
settings = await _refresh_settings(request)
|
||||
refreshed_runners = _refresh_live_runners(request, settings)
|
||||
result = _current_payload(settings)
|
||||
result["applied"] = True
|
||||
|
||||
@@ -177,13 +177,13 @@ async def test_runtime_config_apply_persists_llm_stt_tts_and_keeps_blank_keys(mo
|
||||
|
||||
|
||||
async def test_runtime_config_apply_updates_mem0_keys_and_refreshes_memory_provider(monkeypatch, tmp_path) -> None:
|
||||
cleared = False
|
||||
refreshed = False
|
||||
|
||||
def fake_cache_clear() -> None:
|
||||
nonlocal cleared
|
||||
cleared = True
|
||||
async def fake_close_cached_memory_provider() -> None:
|
||||
nonlocal refreshed
|
||||
refreshed = True
|
||||
|
||||
monkeypatch.setattr(runtime_config, "build_memory_provider", SimpleNamespace(cache_clear=fake_cache_clear))
|
||||
monkeypatch.setattr(runtime_config, "close_cached_memory_provider", fake_close_cached_memory_provider, raising=False)
|
||||
|
||||
payload = await runtime_config.apply_runtime_config(
|
||||
runtime_config.RuntimeConfigPayload(
|
||||
@@ -200,7 +200,7 @@ async def test_runtime_config_apply_updates_mem0_keys_and_refreshes_memory_provi
|
||||
_request(monkeypatch, tmp_path),
|
||||
)
|
||||
|
||||
assert cleared is True
|
||||
assert refreshed is True
|
||||
assert payload["mem0"]["llm"]["api_key_set"] is True
|
||||
assert payload["mem0"]["embedder"]["api_key_set"] is True
|
||||
assert "sk-new-mem0-llm" not in str(payload)
|
||||
@@ -210,6 +210,22 @@ async def test_runtime_config_apply_updates_mem0_keys_and_refreshes_memory_provi
|
||||
assert os.environ.get("DASHSCOPE_API_KEY") != "sk-new-mem0-llm"
|
||||
|
||||
|
||||
async def test_runtime_config_apply_discards_stale_wechat_memory_registry(monkeypatch, tmp_path) -> None:
|
||||
async def fake_close_cached_memory_provider() -> None:
|
||||
return None
|
||||
|
||||
request = _request(monkeypatch, tmp_path)
|
||||
request.app.state.wechat_import_registry = object()
|
||||
monkeypatch.setattr(runtime_config, "close_cached_memory_provider", fake_close_cached_memory_provider, raising=False)
|
||||
|
||||
await runtime_config.apply_runtime_config(
|
||||
runtime_config.RuntimeConfigPayload(mem0_llm_api_key="sk-new-mem0-llm"),
|
||||
request,
|
||||
)
|
||||
|
||||
assert not hasattr(request.app.state, "wechat_import_registry")
|
||||
|
||||
|
||||
async def test_runtime_config_apply_rejects_unknown_provider(monkeypatch, tmp_path) -> None:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await runtime_config.apply_runtime_config(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
@@ -9,6 +10,8 @@ from opentalking.providers.memory.mem0_provider import InMemoryMemoryProvider, M
|
||||
from opentalking.providers.memory.noop import NoopMemoryProvider
|
||||
from opentalking.providers.memory.sqlite_provider import SQLiteMemoryProvider
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _strip(value: Any) -> str:
|
||||
return str(value or "").strip()
|
||||
@@ -111,3 +114,15 @@ def build_memory_provider() -> MemoryProvider:
|
||||
if provider in {"memory", "inmemory", "in-memory"}:
|
||||
return InMemoryMemoryProvider()
|
||||
raise ValueError(f"unsupported memory provider: {settings.memory_provider}")
|
||||
|
||||
|
||||
async def close_cached_memory_provider() -> None:
|
||||
if build_memory_provider.cache_info().currsize <= 0:
|
||||
return
|
||||
provider = build_memory_provider()
|
||||
try:
|
||||
await provider.close()
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("failed to close cached memory provider", exc_info=True)
|
||||
finally:
|
||||
build_memory_provider.cache_clear()
|
||||
|
||||
@@ -387,9 +387,23 @@ class Mem0MemoryProvider(MemoryProvider):
|
||||
return True
|
||||
|
||||
async def close(self) -> None:
|
||||
close = getattr(self._client, "close", None)
|
||||
if callable(close):
|
||||
await _maybe_await(close())
|
||||
closed: set[int] = set()
|
||||
|
||||
async def close_candidate(candidate: Any) -> None:
|
||||
if candidate is None:
|
||||
return
|
||||
marker = id(candidate)
|
||||
if marker in closed:
|
||||
return
|
||||
closed.add(marker)
|
||||
close = getattr(candidate, "close", None)
|
||||
if callable(close):
|
||||
await _maybe_await(close())
|
||||
|
||||
vector_store = getattr(self._client, "vector_store", None)
|
||||
await close_candidate(self._client)
|
||||
await close_candidate(vector_store)
|
||||
await close_candidate(getattr(vector_store, "client", None))
|
||||
|
||||
async def _add(
|
||||
self,
|
||||
|
||||
@@ -875,6 +875,33 @@ def test_mem0_provider_summary_write_uses_infer_false_metadata() -> None:
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_mem0_provider_close_closes_nested_vector_store_client() -> None:
|
||||
class FakeVectorClient:
|
||||
def __init__(self) -> None:
|
||||
self.closed = False
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
class FakeVectorStore:
|
||||
def __init__(self) -> None:
|
||||
self.client = FakeVectorClient()
|
||||
|
||||
class FakeMem0:
|
||||
def __init__(self) -> None:
|
||||
self.vector_store = FakeVectorStore()
|
||||
|
||||
async def run() -> None:
|
||||
fake = FakeMem0()
|
||||
provider = Mem0MemoryProvider(client=fake)
|
||||
|
||||
await provider.close()
|
||||
|
||||
assert fake.vector_store.client.closed is True
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_sqlite_memory_provider_roundtrip(tmp_path) -> None:
|
||||
async def run() -> None:
|
||||
provider = SQLiteMemoryProvider(tmp_path / "memory.sqlite3")
|
||||
|
||||
Reference in New Issue
Block a user