mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
feat: session management for InnerAPI&VM
This commit is contained in:
11
api/core/session/__init__.py
Normal file
11
api/core/session/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from .inner_api import InnerApiSession, InnerApiSessionManager
|
||||
from .session import BaseSession, RedisSessionStorage, SessionManager, SessionStorage
|
||||
|
||||
__all__ = [
|
||||
"BaseSession",
|
||||
"InnerApiSession",
|
||||
"InnerApiSessionManager",
|
||||
"RedisSessionStorage",
|
||||
"SessionManager",
|
||||
"SessionStorage",
|
||||
]
|
||||
19
api/core/session/inner_api.py
Normal file
19
api/core/session/inner_api.py
Normal file
@ -0,0 +1,19 @@
|
||||
from typing import Any
|
||||
|
||||
from .session import BaseSession, SessionManager
|
||||
|
||||
|
||||
class InnerApiSession(BaseSession):
|
||||
"""Inner API Session"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InnerApiSessionManager(SessionManager[InnerApiSession]):
|
||||
def __init__(self, ttl: int | None = None):
|
||||
super().__init__(key_prefix="inner_api_session", session_class=InnerApiSession, ttl=ttl)
|
||||
|
||||
def create(self, tenant_id: str, user_id: str, context: dict[str, Any] | None = None) -> InnerApiSession:
|
||||
session = InnerApiSession(tenant_id=tenant_id, user_id=user_id, context=context or {})
|
||||
self.save(session)
|
||||
return session
|
||||
106
api/core/session/session.py
Normal file
106
api/core/session/session.py
Normal file
@ -0,0 +1,106 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Generic, Protocol, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionStorage(Protocol):
|
||||
"""Session storage interface."""
|
||||
|
||||
def get(self, key: str) -> str | None: ...
|
||||
def set(self, key: str, value: str, ttl: int) -> None: ...
|
||||
def delete(self, key: str) -> bool: ...
|
||||
def exists(self, key: str) -> bool: ...
|
||||
def refresh_ttl(self, key: str, ttl: int) -> bool: ...
|
||||
|
||||
|
||||
class RedisSessionStorage:
|
||||
"""Redis storage implementation (default)."""
|
||||
|
||||
def get(self, key: str) -> str | None:
|
||||
result = redis_client.get(key)
|
||||
if result is None:
|
||||
return None
|
||||
return result.decode() if isinstance(result, bytes) else result
|
||||
|
||||
def set(self, key: str, value: str, ttl: int) -> None:
|
||||
redis_client.setex(key, ttl, value)
|
||||
|
||||
def delete(self, key: str) -> bool:
|
||||
return redis_client.delete(key) > 0
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return redis_client.exists(key) > 0
|
||||
|
||||
def refresh_ttl(self, key: str, ttl: int) -> bool:
|
||||
return bool(redis_client.expire(key, ttl))
|
||||
|
||||
|
||||
class BaseSession(BaseModel):
|
||||
"""Base session model."""
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
context: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def update_timestamp(self) -> None:
|
||||
self.updated_at = datetime.now(UTC)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseSession)
|
||||
|
||||
|
||||
class SessionManager(Generic[T]):
|
||||
"""Generic session manager."""
|
||||
|
||||
DEFAULT_TTL = 7200 # 2 hours
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key_prefix: str,
|
||||
session_class: type[T],
|
||||
storage: SessionStorage | None = None,
|
||||
ttl: int | None = None,
|
||||
):
|
||||
self._key_prefix = key_prefix
|
||||
self._session_class = session_class
|
||||
self._storage = storage or RedisSessionStorage()
|
||||
self._ttl = ttl or self.DEFAULT_TTL
|
||||
|
||||
def _get_key(self, session_id: str) -> str:
|
||||
return f"{self._key_prefix}:{session_id}"
|
||||
|
||||
def save(self, session: T) -> None:
|
||||
session.update_timestamp()
|
||||
key = self._get_key(session.id)
|
||||
self._storage.set(key, session.model_dump_json(), self._ttl)
|
||||
|
||||
def get(self, session_id: str) -> T | None:
|
||||
key = self._get_key(session_id)
|
||||
data = self._storage.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
try:
|
||||
return self._session_class.model_validate(json.loads(data))
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
logger.warning("Failed to deserialize session %s: %s", session_id, e)
|
||||
return None
|
||||
|
||||
def delete(self, session_id: str) -> bool:
|
||||
return self._storage.delete(self._get_key(session_id))
|
||||
|
||||
def exists(self, session_id: str) -> bool:
|
||||
return self._storage.exists(self._get_key(session_id))
|
||||
|
||||
def refresh_ttl(self, session_id: str) -> bool:
|
||||
return self._storage.refresh_ttl(self._get_key(session_id), self._ttl)
|
||||
3
api/core/virtual_environment/session/__init__.py
Normal file
3
api/core/virtual_environment/session/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .sandbox_session import SandboxProvider, SandboxSession, SandboxSessionManager
|
||||
|
||||
__all__ = ["SandboxProvider", "SandboxSession", "SandboxSessionManager"]
|
||||
47
api/core/virtual_environment/session/sandbox_session.py
Normal file
47
api/core/virtual_environment/session/sandbox_session.py
Normal file
@ -0,0 +1,47 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.session import BaseSession, SessionManager
|
||||
from core.virtual_environment.__base.entities import Arch
|
||||
|
||||
|
||||
class SandboxProvider(StrEnum):
|
||||
E2B = "e2b"
|
||||
DOCKER = "docker"
|
||||
LOCAL = "local"
|
||||
|
||||
|
||||
class SandboxSession(BaseSession):
|
||||
provider: SandboxProvider
|
||||
sandbox_id: str
|
||||
arch: Arch
|
||||
connection_config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SandboxSessionManager(SessionManager[SandboxSession]):
|
||||
def __init__(self, ttl: int | None = None):
|
||||
super().__init__(key_prefix="sandbox_session", session_class=SandboxSession, ttl=ttl)
|
||||
|
||||
def create(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: SandboxProvider,
|
||||
sandbox_id: str,
|
||||
arch: Arch,
|
||||
connection_config: dict[str, Any] | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> SandboxSession:
|
||||
session = SandboxSession(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
sandbox_id=sandbox_id,
|
||||
arch=arch,
|
||||
connection_config=connection_config or {},
|
||||
context=context or {},
|
||||
)
|
||||
self.save(session)
|
||||
return session
|
||||
Reference in New Issue
Block a user