mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 08:28:03 +08:00
feat: introduce attribute management system for sandbox
- Added AttrMap and AttrKey classes for type-safe attribute storage. - Implemented AppAssetsAttrs and SkillAttrs for managing application and skill attributes. - Refactored Sandbox and initializers to utilize the new attribute management system, enhancing modularity and clarity in asset handling.
This commit is contained in:
7
api/core/app_assets/constants.py
Normal file
7
api/core/app_assets/constants.py
Normal file
@ -0,0 +1,7 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from libs.attr_map import AttrKey
|
||||
|
||||
|
||||
class AppAssetsAttrs:
|
||||
# Skill artifact set
|
||||
FILE_TREE = AttrKey("file_tree", AppAssetFileTree)
|
||||
@ -100,9 +100,6 @@ class SandboxBuilder:
|
||||
environments=self._environments,
|
||||
user_id=self._user_id,
|
||||
)
|
||||
for init in self._initializers:
|
||||
init.initialize(vm)
|
||||
|
||||
sandbox = Sandbox(
|
||||
vm=vm,
|
||||
storage=self._storage,
|
||||
@ -111,6 +108,9 @@ class SandboxBuilder:
|
||||
app_id=self._app_id,
|
||||
assets_id=self._assets_id,
|
||||
)
|
||||
for init in self._initializers:
|
||||
init.initialize(sandbox)
|
||||
|
||||
sandbox.mount()
|
||||
return sandbox
|
||||
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import logging
|
||||
|
||||
from core.app_assets.constants import AppAssetsAttrs
|
||||
from core.app_assets.paths import AssetPaths
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
from services.app_asset_service import AppAssetService
|
||||
|
||||
from ..entities import AppAssets
|
||||
from .base import SandboxInitializer
|
||||
@ -20,42 +22,17 @@ class AppAssetsInitializer(SandboxInitializer):
|
||||
self._app_id = app_id
|
||||
self._assets_id = assets_id
|
||||
|
||||
def initialize(self, env: VirtualEnvironment) -> None:
|
||||
def initialize(self, sandbox: Sandbox) -> None:
|
||||
vm = sandbox.vm
|
||||
# load app assets
|
||||
app_assets = AppAssetService.get_tenant_app_assets(self._tenant_id, self._assets_id)
|
||||
sandbox.attrs.set(AppAssetsAttrs.FILE_TREE, app_assets.asset_tree)
|
||||
|
||||
zip_key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id)
|
||||
download_url = FilePresignStorage(storage.storage_runner).get_download_url(zip_key)
|
||||
|
||||
(
|
||||
pipeline(env)
|
||||
.add(["wget", "-q", download_url, "-O", AppAssets.ZIP_PATH], error_message="Failed to download assets zip")
|
||||
# unzip with silent error and return 1 if the zip is empty
|
||||
# FIXME(Mairuis): should use a more robust way to check if the zip is empty
|
||||
.add(
|
||||
["sh", "-c", f"unzip {AppAssets.ZIP_PATH} -d {AppAssets.PATH} 2>/dev/null || [ $? -eq 1 ]"],
|
||||
error_message="Failed to unzip assets",
|
||||
)
|
||||
.execute(timeout=APP_ASSETS_DOWNLOAD_TIMEOUT, raise_on_error=True)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"App assets initialized for app_id=%s, published_id=%s",
|
||||
self._app_id,
|
||||
self._assets_id,
|
||||
)
|
||||
|
||||
|
||||
class DraftAppAssetsInitializer(SandboxInitializer):
|
||||
def __init__(self, tenant_id: str, app_id: str, assets_id: str) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._assets_id = assets_id
|
||||
|
||||
def initialize(self, env: VirtualEnvironment) -> None:
|
||||
zip_key = AssetPaths.build_zip(self._tenant_id, self._app_id, self._assets_id)
|
||||
download_url = FilePresignStorage(storage.storage_runner).get_download_url(zip_key)
|
||||
|
||||
(
|
||||
pipeline(env)
|
||||
.add(["rm", "-rf", AppAssets.PATH])
|
||||
pipeline(vm)
|
||||
.add(["wget", "-q", download_url, "-O", AppAssets.ZIP_PATH], error_message="Failed to download assets zip")
|
||||
# unzip with silent error and return 1 if the zip is empty
|
||||
# FIXME(Mairuis): should use a more robust way to check if the zip is empty
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
|
||||
|
||||
class SandboxInitializer(ABC):
|
||||
@abstractmethod
|
||||
def initialize(self, env: VirtualEnvironment) -> None: ...
|
||||
def initialize(self, env: Sandbox) -> None: ...
|
||||
|
||||
@ -5,10 +5,10 @@ import logging
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
from core.skill.skill_manager import SkillManager
|
||||
from core.virtual_environment.__base.helpers import pipeline
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
|
||||
from ..bash.dify_cli import DifyCliConfig, DifyCliLocator
|
||||
from ..entities import DifyCli
|
||||
@ -35,18 +35,19 @@ class DifyCliInitializer(SandboxInitializer):
|
||||
self._tools = []
|
||||
self._cli_api_session = None
|
||||
|
||||
def initialize(self, env: VirtualEnvironment) -> None:
|
||||
binary = self._locator.resolve(env.metadata.os, env.metadata.arch)
|
||||
def initialize(self, sandbox: Sandbox) -> None:
|
||||
vm = sandbox.vm
|
||||
binary = self._locator.resolve(vm.metadata.os, vm.metadata.arch)
|
||||
|
||||
pipeline(env).add(
|
||||
pipeline(vm).add(
|
||||
["mkdir", "-p", f"{DifyCli.ROOT}/bin"], error_message="Failed to create dify CLI directory"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
env.upload_file(DifyCli.PATH, BytesIO(binary.path.read_bytes()))
|
||||
vm.upload_file(DifyCli.PATH, BytesIO(binary.path.read_bytes()))
|
||||
|
||||
# Use 'cp' with mode preservation workaround: copy file to itself to claim ownership,
|
||||
# then use 'install' to set executable permission
|
||||
pipeline(env).add(
|
||||
pipeline(vm).add(
|
||||
[
|
||||
"sh",
|
||||
"-c",
|
||||
@ -67,16 +68,16 @@ class DifyCliInitializer(SandboxInitializer):
|
||||
# FIXME(Mairuis): store it in workflow context
|
||||
self._cli_api_session = CliApiSessionManager().create(tenant_id=self._tenant_id, user_id=self._user_id)
|
||||
|
||||
pipeline(env).add(
|
||||
pipeline(vm).add(
|
||||
["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
config = DifyCliConfig.create(self._cli_api_session, self._tenant_id, artifact)
|
||||
config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False)
|
||||
config_path = f"{DifyCli.GLOBAL_TOOLS_PATH}/{DifyCli.CONFIG_FILENAME}"
|
||||
env.upload_file(config_path, BytesIO(config_json.encode("utf-8")))
|
||||
vm.upload_file(config_path, BytesIO(config_json.encode("utf-8")))
|
||||
|
||||
pipeline(env, cwd=DifyCli.GLOBAL_TOOLS_PATH).add(
|
||||
pipeline(vm, cwd=DifyCli.GLOBAL_TOOLS_PATH).add(
|
||||
[DifyCli.PATH, "init"], error_message="Failed to initialize Dify CLI"
|
||||
).execute(raise_on_error=True)
|
||||
|
||||
|
||||
43
api/core/sandbox/initializer/skill_initializer.py
Normal file
43
api/core/sandbox/initializer/skill_initializer.py
Normal file
@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.skill import SkillAttrs
|
||||
from core.skill.skill_manager import SkillManager
|
||||
|
||||
from .base import SandboxInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillInitializer(SandboxInitializer):
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
app_id: str,
|
||||
assets_id: str,
|
||||
) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
self._user_id = user_id
|
||||
self._assets_id = assets_id
|
||||
|
||||
def initialize(self, sandbox: Sandbox) -> None:
|
||||
artifact_set = SkillManager.load_artifact(
|
||||
self._tenant_id,
|
||||
self._app_id,
|
||||
self._assets_id,
|
||||
)
|
||||
if artifact_set is None:
|
||||
raise ValueError(
|
||||
f"No skill artifact set found for tenant_id={self._tenant_id},"
|
||||
f"app_id={self._app_id}, "
|
||||
f"assets_id={self._assets_id} "
|
||||
)
|
||||
|
||||
sandbox.attrs.set(
|
||||
SkillAttrs.ARTIFACT_SET,
|
||||
artifact_set,
|
||||
)
|
||||
@ -7,7 +7,7 @@ from typing import Final
|
||||
from core.sandbox.builder import SandboxBuilder
|
||||
from core.sandbox.entities import AppAssets, SandboxType
|
||||
from core.sandbox.entities.providers import SandboxProviderEntity
|
||||
from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer, DraftAppAssetsInitializer
|
||||
from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer
|
||||
from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
from core.sandbox.storage.archive_storage import ArchiveSandboxStorage
|
||||
@ -151,7 +151,7 @@ class SandboxManager:
|
||||
.options(sandbox_provider.config)
|
||||
.user(user_id)
|
||||
.app(app_id)
|
||||
.initializer(DraftAppAssetsInitializer(tenant_id, app_id, assets.id))
|
||||
.initializer(AppAssetsInitializer(tenant_id, app_id, assets.id))
|
||||
.initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id))
|
||||
.storage(storage, assets.id)
|
||||
.build()
|
||||
|
||||
@ -3,6 +3,8 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from libs.attr_map import AttrMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.sandbox.storage.sandbox_storage import SandboxStorage
|
||||
from core.virtual_environment.__base.virtual_environment import VirtualEnvironment
|
||||
@ -27,6 +29,11 @@ class Sandbox:
|
||||
self._user_id = user_id
|
||||
self._app_id = app_id
|
||||
self._assets_id = assets_id
|
||||
self._attributes = AttrMap()
|
||||
|
||||
@property
|
||||
def attrs(self) -> AttrMap:
|
||||
return self._attributes
|
||||
|
||||
@property
|
||||
def vm(self) -> VirtualEnvironment:
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from .constants import SkillAttrs
|
||||
from .entities import ToolArtifact, ToolDependency, ToolReference
|
||||
from .skill_manager import SkillManager
|
||||
|
||||
__all__ = [
|
||||
"SkillAttrs",
|
||||
"SkillManager",
|
||||
"ToolArtifact",
|
||||
"ToolDependency",
|
||||
|
||||
7
api/core/skill/constants.py
Normal file
7
api/core/skill/constants.py
Normal file
@ -0,0 +1,7 @@
|
||||
from core.skill.entities.skill_artifact_set import SkillArtifactSet
|
||||
from libs.attr_map import AttrKey
|
||||
|
||||
|
||||
class SkillAttrs:
|
||||
# Skill artifact set
|
||||
ARTIFACT_SET = AttrKey("skill_artifact_set", SkillArtifactSet)
|
||||
164
api/libs/attr_map.py
Normal file
164
api/libs/attr_map.py
Normal file
@ -0,0 +1,164 @@
|
||||
"""
|
||||
Type-safe attribute storage inspired by Netty's AttributeKey/AttributeMap pattern.
|
||||
|
||||
Provides loosely-coupled typed attribute storage where only code with access
|
||||
to the same AttrKey instance can read/write the corresponding attribute.
|
||||
|
||||
SESSION_KEY: AttrKey[Session] = AttrKey("session", Session)
|
||||
attrs = AttrMap()
|
||||
attrs.set(SESSION_KEY, session)
|
||||
session = attrs.get(SESSION_KEY) # -> Session | None
|
||||
session = attrs.require(SESSION_KEY) # -> Session (raises if not set)
|
||||
|
||||
Note: AttrMap is NOT thread-safe. Each instance should be confined to a single
|
||||
thread/context (e.g., one AttrMap per Sandbox/VirtualEnvironment instance).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Generic, TypeVar, cast, final, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
D = TypeVar("D")
|
||||
|
||||
|
||||
@final
|
||||
class AttrKey(Generic[T]):
|
||||
"""
|
||||
A type-safe key for attribute storage.
|
||||
|
||||
Identity-based: different AttrKey instances with same name are distinct keys.
|
||||
This enables different modules to define keys independently without collision.
|
||||
"""
|
||||
|
||||
__slots__ = ("_name", "_type")
|
||||
|
||||
def __init__(self, name: str, type_: type[T]) -> None:
|
||||
self._name = name
|
||||
self._type = type_
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def type_(self) -> type[T]:
|
||||
return self._type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"AttrKey({self._name!r}, {self._type.__name__})"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return self is other
|
||||
|
||||
|
||||
class AttrMapKeyError(KeyError):
|
||||
"""Raised when a required attribute is not set."""
|
||||
|
||||
key: AttrKey[Any]
|
||||
|
||||
def __init__(self, key: AttrKey[Any]) -> None:
|
||||
self.key = key
|
||||
super().__init__(f"Required attribute '{key.name}' (type: {key.type_.__name__}) is not set")
|
||||
|
||||
|
||||
class AttrMapTypeError(TypeError):
|
||||
"""Raised when attribute value type doesn't match the key's declared type."""
|
||||
|
||||
key: AttrKey[Any]
|
||||
expected_type: type[Any]
|
||||
actual_type: type[Any]
|
||||
|
||||
def __init__(self, key: AttrKey[Any], expected_type: type[Any], actual_type: type[Any]) -> None:
|
||||
self.key = key
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(
|
||||
f"Attribute '{key.name}' expects type '{expected_type.__name__}', "
|
||||
f"got '{actual_type.__name__}'"
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class AttrMap:
|
||||
"""
|
||||
Thread-confined container for storing typed attributes using AttrKey instances.
|
||||
|
||||
NOT thread-safe. Each instance should be owned by a single context
|
||||
(e.g., one AttrMap per Sandbox/VirtualEnvironment instance).
|
||||
"""
|
||||
|
||||
__slots__ = ("_data",)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data: dict[AttrKey[Any], Any] = {}
|
||||
|
||||
def set(self, key: AttrKey[T], value: T, *, validate: bool = True) -> None:
|
||||
"""
|
||||
Store an attribute. Raises AttrMapTypeError if validate=True and type mismatches.
|
||||
|
||||
Note: Runtime validation only checks outer type (e.g., `list` not `list[str]`).
|
||||
"""
|
||||
if validate and not isinstance(value, key.type_):
|
||||
raise AttrMapTypeError(key, key.type_, type(value))
|
||||
self._data[key] = value
|
||||
|
||||
def get(self, key: AttrKey[T]) -> T | None:
|
||||
"""Retrieve an attribute, returning None if not set."""
|
||||
return cast(T | None, self._data.get(key))
|
||||
|
||||
@overload
|
||||
def get_or_default(self, key: AttrKey[T], default: T) -> T: ...
|
||||
|
||||
@overload
|
||||
def get_or_default(self, key: AttrKey[T], default: D) -> T | D: ...
|
||||
|
||||
def get_or_default(self, key: AttrKey[T], default: T | D) -> T | D:
|
||||
"""Retrieve an attribute, returning default if not set."""
|
||||
if key in self._data:
|
||||
return cast(T, self._data[key])
|
||||
return default
|
||||
|
||||
def require(self, key: AttrKey[T]) -> T:
|
||||
"""Retrieve an attribute, raising AttrMapKeyError if not set."""
|
||||
if key not in self._data:
|
||||
raise AttrMapKeyError(key)
|
||||
return cast(T, self._data[key])
|
||||
|
||||
def has(self, key: AttrKey[Any]) -> bool:
|
||||
"""Check if an attribute is set."""
|
||||
return key in self._data
|
||||
|
||||
def remove(self, key: AttrKey[Any]) -> bool:
|
||||
"""Remove an attribute. Returns True if it was present."""
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_if_absent(self, key: AttrKey[T], value: T, *, validate: bool = True) -> T:
|
||||
"""
|
||||
Set attribute only if not already set. Returns existing or newly set value.
|
||||
|
||||
Raises AttrMapTypeError if validate=True and type mismatches.
|
||||
"""
|
||||
if key in self._data:
|
||||
return cast(T, self._data[key])
|
||||
if validate and not isinstance(value, key.type_):
|
||||
raise AttrMapTypeError(key, key.type_, type(value))
|
||||
self._data[key] = value
|
||||
return value
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all attributes."""
|
||||
self._data.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
keys = [k.name for k in self._data]
|
||||
return f"AttrMap({keys})"
|
||||
@ -58,6 +58,22 @@ class AppAssetService:
|
||||
session.commit()
|
||||
return assets
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_app_assets(tenant_id: str, assets_id: str) -> AppAssets:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_assets = (
|
||||
session.query(AppAssets)
|
||||
.filter(
|
||||
AppAssets.tenant_id == tenant_id,
|
||||
AppAssets.id == assets_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not app_assets:
|
||||
raise ValueError(f"App assets not found for tenant_id={tenant_id}, assets_id={assets_id}")
|
||||
|
||||
return app_assets
|
||||
|
||||
@staticmethod
|
||||
def get_assets(tenant_id: str, app_id: str, user_id: str, *, is_draft: bool) -> AppAssets | None:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
|
||||
251
api/tests/unit_tests/libs/test_attr_map.py
Normal file
251
api/tests/unit_tests/libs/test_attr_map.py
Normal file
@ -0,0 +1,251 @@
|
||||
import pytest
|
||||
|
||||
from libs.attr_map import AttrKey, AttrMap, AttrMapKeyError, AttrMapTypeError
|
||||
|
||||
|
||||
class TestAttrKey:
|
||||
def test_identity_based_equality(self):
|
||||
key1 = AttrKey("session", str)
|
||||
key2 = AttrKey("session", str)
|
||||
|
||||
assert key1 != key2
|
||||
assert key1 == key1
|
||||
|
||||
def test_identity_based_hash(self):
|
||||
key1 = AttrKey("session", str)
|
||||
key2 = AttrKey("session", str)
|
||||
|
||||
assert hash(key1) != hash(key2)
|
||||
assert hash(key1) == hash(key1)
|
||||
|
||||
def test_can_be_used_as_dict_key(self):
|
||||
key1 = AttrKey("session", str)
|
||||
key2 = AttrKey("session", str)
|
||||
data: dict[AttrKey[str], str] = {}
|
||||
|
||||
data[key1] = "value1"
|
||||
data[key2] = "value2"
|
||||
|
||||
assert data[key1] == "value1"
|
||||
assert data[key2] == "value2"
|
||||
assert len(data) == 2
|
||||
|
||||
def test_properties(self):
|
||||
key = AttrKey("my_key", int)
|
||||
|
||||
assert key.name == "my_key"
|
||||
assert key.type_ is int
|
||||
|
||||
def test_repr(self):
|
||||
key = AttrKey("session", str)
|
||||
assert repr(key) == "AttrKey('session', str)"
|
||||
|
||||
|
||||
class TestAttrMap:
|
||||
def test_set_and_get(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
attrs.set(key, "hello")
|
||||
result = attrs.get(key)
|
||||
|
||||
assert result == "hello"
|
||||
|
||||
def test_get_returns_none_for_missing(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
assert attrs.get(key) is None
|
||||
|
||||
def test_get_or_default_returns_value_when_set(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
attrs.set(key, "hello")
|
||||
|
||||
result = attrs.get_or_default(key, "default")
|
||||
|
||||
assert result == "hello"
|
||||
|
||||
def test_get_or_default_returns_default_when_not_set(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
result = attrs.get_or_default(key, "default")
|
||||
|
||||
assert result == "default"
|
||||
|
||||
def test_require_returns_value_when_set(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
attrs.set(key, "hello")
|
||||
|
||||
result = attrs.require(key)
|
||||
|
||||
assert result == "hello"
|
||||
|
||||
def test_require_raises_when_not_set(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
with pytest.raises(AttrMapKeyError) as exc_info:
|
||||
attrs.require(key)
|
||||
|
||||
assert exc_info.value.key is key
|
||||
assert "session" in str(exc_info.value)
|
||||
|
||||
def test_has(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
assert not attrs.has(key)
|
||||
|
||||
attrs.set(key, "hello")
|
||||
|
||||
assert attrs.has(key)
|
||||
|
||||
def test_remove_existing(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
attrs.set(key, "hello")
|
||||
|
||||
result = attrs.remove(key)
|
||||
|
||||
assert result is True
|
||||
assert not attrs.has(key)
|
||||
|
||||
def test_remove_non_existing(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
result = attrs.remove(key)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_set_if_absent_when_absent(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
result = attrs.set_if_absent(key, "first")
|
||||
|
||||
assert result == "first"
|
||||
assert attrs.get(key) == "first"
|
||||
|
||||
def test_set_if_absent_when_present(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
attrs.set(key, "existing")
|
||||
|
||||
result = attrs.set_if_absent(key, "new")
|
||||
|
||||
assert result == "existing"
|
||||
assert attrs.get(key) == "existing"
|
||||
|
||||
def test_clear(self):
|
||||
key1: AttrKey[str] = AttrKey("key1", str)
|
||||
key2: AttrKey[int] = AttrKey("key2", int)
|
||||
attrs = AttrMap()
|
||||
attrs.set(key1, "hello")
|
||||
attrs.set(key2, 42)
|
||||
|
||||
attrs.clear()
|
||||
|
||||
assert len(attrs) == 0
|
||||
assert not attrs.has(key1)
|
||||
assert not attrs.has(key2)
|
||||
|
||||
def test_len(self):
|
||||
key1: AttrKey[str] = AttrKey("key1", str)
|
||||
key2: AttrKey[int] = AttrKey("key2", int)
|
||||
attrs = AttrMap()
|
||||
|
||||
assert len(attrs) == 0
|
||||
|
||||
attrs.set(key1, "hello")
|
||||
assert len(attrs) == 1
|
||||
|
||||
attrs.set(key2, 42)
|
||||
assert len(attrs) == 2
|
||||
|
||||
def test_repr(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
attrs.set(key, "hello")
|
||||
|
||||
result = repr(attrs)
|
||||
|
||||
assert "AttrMap" in result
|
||||
assert "session" in result
|
||||
|
||||
|
||||
class TestAttrMapTypeValidation:
|
||||
def test_set_with_wrong_type_raises(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
with pytest.raises(AttrMapTypeError) as exc_info:
|
||||
attrs.set(key, 123) # type: ignore[arg-type]
|
||||
|
||||
assert exc_info.value.key is key
|
||||
assert exc_info.value.expected_type is str
|
||||
assert exc_info.value.actual_type is int
|
||||
|
||||
def test_set_with_validate_false_allows_wrong_type(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
attrs.set(key, 123, validate=False) # type: ignore[arg-type]
|
||||
|
||||
assert attrs.get(key) == 123
|
||||
|
||||
def test_set_if_absent_with_wrong_type_raises(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
with pytest.raises(AttrMapTypeError):
|
||||
attrs.set_if_absent(key, 123) # type: ignore[arg-type]
|
||||
|
||||
def test_set_if_absent_with_validate_false_allows_wrong_type(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
attrs.set_if_absent(key, 123, validate=False) # type: ignore[arg-type]
|
||||
|
||||
assert attrs.get(key) == 123
|
||||
|
||||
def test_subclass_type_validation(self):
|
||||
class Animal:
|
||||
pass
|
||||
|
||||
class Dog(Animal):
|
||||
pass
|
||||
|
||||
key: AttrKey[Animal] = AttrKey("animal", Animal)
|
||||
attrs = AttrMap()
|
||||
|
||||
attrs.set(key, Dog())
|
||||
|
||||
assert isinstance(attrs.get(key), Dog)
|
||||
|
||||
|
||||
class TestAttrMapIsolation:
|
||||
def test_different_keys_with_same_name_are_isolated(self):
|
||||
key_in_module_a: AttrKey[str] = AttrKey("config", str)
|
||||
key_in_module_b: AttrKey[str] = AttrKey("config", str)
|
||||
attrs = AttrMap()
|
||||
|
||||
attrs.set(key_in_module_a, "value_a")
|
||||
attrs.set(key_in_module_b, "value_b")
|
||||
|
||||
assert attrs.get(key_in_module_a) == "value_a"
|
||||
assert attrs.get(key_in_module_b) == "value_b"
|
||||
|
||||
def test_multiple_attr_maps_are_independent(self):
|
||||
key: AttrKey[str] = AttrKey("session", str)
|
||||
attrs1 = AttrMap()
|
||||
attrs2 = AttrMap()
|
||||
|
||||
attrs1.set(key, "map1")
|
||||
attrs2.set(key, "map2")
|
||||
|
||||
assert attrs1.get(key) == "map1"
|
||||
assert attrs2.get(key) == "map2"
|
||||
Reference in New Issue
Block a user