mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
fix: use thread local isolation the context (#31410)
This commit is contained in:
@ -1,6 +1,8 @@
|
||||
"""Tests for execution context module."""
|
||||
|
||||
import contextvars
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -149,6 +151,54 @@ class TestExecutionContext:
|
||||
|
||||
assert ctx.user == user
|
||||
|
||||
def test_thread_safe_context_manager(self):
|
||||
"""Test shared ExecutionContext works across threads without token mismatch."""
|
||||
test_var = contextvars.ContextVar("thread_safe_test_var")
|
||||
|
||||
class TrackingAppContext(AppContext):
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
return default
|
||||
|
||||
def get_extension(self, name: str) -> Any:
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def enter(self):
|
||||
token = test_var.set(threading.get_ident())
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
test_var.reset(token)
|
||||
|
||||
ctx = ExecutionContext(app_context=TrackingAppContext())
|
||||
errors: list[Exception] = []
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
def worker():
|
||||
try:
|
||||
for _ in range(20):
|
||||
with ctx:
|
||||
try:
|
||||
barrier.wait()
|
||||
barrier.wait()
|
||||
except threading.BrokenBarrierError:
|
||||
return
|
||||
except Exception as exc:
|
||||
errors.append(exc)
|
||||
try:
|
||||
barrier.abort()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t1 = threading.Thread(target=worker)
|
||||
t2 = threading.Thread(target=worker)
|
||||
t1.start()
|
||||
t2.start()
|
||||
t1.join(timeout=5)
|
||||
t2.join(timeout=5)
|
||||
|
||||
assert not errors
|
||||
|
||||
|
||||
class TestIExecutionContextProtocol:
|
||||
"""Test IExecutionContext protocol."""
|
||||
|
||||
Reference in New Issue
Block a user