fix: release mem0 provider on runtime refresh

This commit is contained in:
charm-ch
2026-06-22 17:58:30 +08:00
committed by zyairehhh
parent 8a2dc83094
commit 2eb85ffe4d
5 changed files with 87 additions and 13 deletions

View File

@@ -11,7 +11,7 @@ from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from apps.api.core.config import get_settings
from opentalking.providers.memory.factory import build_memory_provider
from opentalking.providers.memory.factory import close_cached_memory_provider
from opentalking.providers.stt.factory import (
clear_stt_adapter_cache,
normalize_stt_provider,
@@ -531,12 +531,14 @@ def _build_updates(payload: RuntimeConfigPayload) -> dict[str, str]:
return updates
def _refresh_settings(request: Request) -> Any:
async def _refresh_settings(request: Request) -> Any:
get_settings.cache_clear()
settings = get_settings()
request.app.state.settings = settings
clear_stt_adapter_cache()
build_memory_provider.cache_clear()
await close_cached_memory_provider()
if hasattr(request.app.state, "wechat_import_registry"):
delattr(request.app.state, "wechat_import_registry")
return settings
@@ -580,7 +582,7 @@ async def apply_runtime_config(payload: RuntimeConfigPayload, request: Request)
for key in _RUNTIME_ENV_KEYS:
if key in values:
os.environ[key] = values[key]
settings = _refresh_settings(request)
settings = await _refresh_settings(request)
refreshed_runners = _refresh_live_runners(request, settings)
result = _current_payload(settings)
result["applied"] = True

View File

@@ -177,13 +177,13 @@ async def test_runtime_config_apply_persists_llm_stt_tts_and_keeps_blank_keys(mo
async def test_runtime_config_apply_updates_mem0_keys_and_refreshes_memory_provider(monkeypatch, tmp_path) -> None:
cleared = False
refreshed = False
def fake_cache_clear() -> None:
nonlocal cleared
cleared = True
async def fake_close_cached_memory_provider() -> None:
nonlocal refreshed
refreshed = True
monkeypatch.setattr(runtime_config, "build_memory_provider", SimpleNamespace(cache_clear=fake_cache_clear))
monkeypatch.setattr(runtime_config, "close_cached_memory_provider", fake_close_cached_memory_provider, raising=False)
payload = await runtime_config.apply_runtime_config(
runtime_config.RuntimeConfigPayload(
@@ -200,7 +200,7 @@ async def test_runtime_config_apply_updates_mem0_keys_and_refreshes_memory_provi
_request(monkeypatch, tmp_path),
)
assert cleared is True
assert refreshed is True
assert payload["mem0"]["llm"]["api_key_set"] is True
assert payload["mem0"]["embedder"]["api_key_set"] is True
assert "sk-new-mem0-llm" not in str(payload)
@@ -210,6 +210,22 @@ async def test_runtime_config_apply_updates_mem0_keys_and_refreshes_memory_provi
assert os.environ.get("DASHSCOPE_API_KEY") != "sk-new-mem0-llm"
async def test_runtime_config_apply_discards_stale_wechat_memory_registry(monkeypatch, tmp_path) -> None:
async def fake_close_cached_memory_provider() -> None:
return None
request = _request(monkeypatch, tmp_path)
request.app.state.wechat_import_registry = object()
monkeypatch.setattr(runtime_config, "close_cached_memory_provider", fake_close_cached_memory_provider, raising=False)
await runtime_config.apply_runtime_config(
runtime_config.RuntimeConfigPayload(mem0_llm_api_key="sk-new-mem0-llm"),
request,
)
assert not hasattr(request.app.state, "wechat_import_registry")
async def test_runtime_config_apply_rejects_unknown_provider(monkeypatch, tmp_path) -> None:
with pytest.raises(HTTPException) as exc_info:
await runtime_config.apply_runtime_config(

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import logging
from functools import lru_cache
from typing import Any
@@ -9,6 +10,8 @@ from opentalking.providers.memory.mem0_provider import InMemoryMemoryProvider, M
from opentalking.providers.memory.noop import NoopMemoryProvider
from opentalking.providers.memory.sqlite_provider import SQLiteMemoryProvider
log = logging.getLogger(__name__)
def _strip(value: Any) -> str:
return str(value or "").strip()
@@ -111,3 +114,15 @@ def build_memory_provider() -> MemoryProvider:
if provider in {"memory", "inmemory", "in-memory"}:
return InMemoryMemoryProvider()
raise ValueError(f"unsupported memory provider: {settings.memory_provider}")
async def close_cached_memory_provider() -> None:
if build_memory_provider.cache_info().currsize <= 0:
return
provider = build_memory_provider()
try:
await provider.close()
except Exception: # noqa: BLE001
log.warning("failed to close cached memory provider", exc_info=True)
finally:
build_memory_provider.cache_clear()

View File

@@ -387,9 +387,23 @@ class Mem0MemoryProvider(MemoryProvider):
return True
async def close(self) -> None:
close = getattr(self._client, "close", None)
if callable(close):
await _maybe_await(close())
closed: set[int] = set()
async def close_candidate(candidate: Any) -> None:
if candidate is None:
return
marker = id(candidate)
if marker in closed:
return
closed.add(marker)
close = getattr(candidate, "close", None)
if callable(close):
await _maybe_await(close())
vector_store = getattr(self._client, "vector_store", None)
await close_candidate(self._client)
await close_candidate(vector_store)
await close_candidate(getattr(vector_store, "client", None))
async def _add(
self,

View File

@@ -875,6 +875,33 @@ def test_mem0_provider_summary_write_uses_infer_false_metadata() -> None:
asyncio.run(run())
def test_mem0_provider_close_closes_nested_vector_store_client() -> None:
class FakeVectorClient:
def __init__(self) -> None:
self.closed = False
def close(self) -> None:
self.closed = True
class FakeVectorStore:
def __init__(self) -> None:
self.client = FakeVectorClient()
class FakeMem0:
def __init__(self) -> None:
self.vector_store = FakeVectorStore()
async def run() -> None:
fake = FakeMem0()
provider = Mem0MemoryProvider(client=fake)
await provider.close()
assert fake.vector_store.client.closed is True
asyncio.run(run())
def test_sqlite_memory_provider_roundtrip(tmp_path) -> None:
async def run() -> None:
provider = SQLiteMemoryProvider(tmp_path / "memory.sqlite3")