From 9d80770dfc014684c1aa86c8f546f6fddfb9756f Mon Sep 17 00:00:00 2001 From: Harry Date: Thu, 22 Jan 2026 17:25:58 +0800 Subject: [PATCH] feat(sandbox): enhance sandbox management and tool artifact handling - Introduced SandboxManager.delete_storage method for improved storage management. - Refactored skill loading and tool artifact handling in DifyCliInitializer and SandboxBashSession. - Updated LLMNode to extract and compile tool artifacts, enhancing integration with skills. - Improved attribute management in AttrMap for better error handling and retrieval methods. --- .../console/app/workflow_draft_variable.py | 6 +- api/core/app_assets/__init__.py | 2 + api/core/app_assets/builder/skill_builder.py | 13 +-- api/core/app_assets/paths.py | 4 - api/core/sandbox/bash/session.py | 37 ++------- .../initializer/dify_cli_initializer.py | 6 +- api/core/sandbox/manager.py | 9 +++ api/core/skill/entities/tool_artifact.py | 22 +++++ api/core/skill/skill_manager.py | 62 -------------- api/core/workflow/nodes/llm/entities.py | 2 + api/core/workflow/nodes/llm/node.py | 81 +++++++++++++++---- api/libs/attr_map.py | 18 ++--- api/tests/unit_tests/libs/test_attr_map.py | 39 +++++---- 13 files changed, 147 insertions(+), 154 deletions(-) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 3ff388d330..fbaa2f8f0b 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -16,6 +16,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError from core.file import helpers as file_helpers +from core.sandbox.manager import SandboxManager from core.variables.segment_group import SegmentGroup from core.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment from core.variables.types import SegmentType @@ -23,7 +24,7 @@ from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIAB from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type -from libs.login import login_required +from libs.login import current_account_with_tenant, login_required from models import App, AppMode from models.workflow import WorkflowDraftVariable from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService @@ -249,6 +250,9 @@ class WorkflowVariableCollectionApi(Resource): @console_ns.response(204, "Workflow variables deleted successfully") @_api_prerequisite def delete(self, app_model: App): + # FIXME(Mairuis): move to SandboxArtifactService + current_user, _ = current_account_with_tenant() + SandboxManager.delete_storage(app_model.tenant_id, current_user.id) draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) diff --git a/api/core/app_assets/__init__.py b/api/core/app_assets/__init__.py index 3191c851ae..fa57145176 100644 --- a/api/core/app_assets/__init__.py +++ b/api/core/app_assets/__init__.py @@ -1,3 +1,4 @@ +from .constants import AppAssetsAttrs from .entities import ( AssetItem, FileAsset, @@ -8,6 +9,7 @@ from .parser import AssetItemParser, AssetParser, FileAssetParser, SkillAssetPar from .paths import AssetPaths __all__ = [ + "AppAssetsAttrs", "AssetItem", "AssetItemParser", "AssetPackager", diff --git a/api/core/app_assets/builder/skill_builder.py b/api/core/app_assets/builder/skill_builder.py index 9d46c71b72..c83539c8bf 100644 --- a/api/core/app_assets/builder/skill_builder.py +++ b/api/core/app_assets/builder/skill_builder.py @@ -51,16 +51,11 @@ class SkillBuilder: loaded = self._load_all(ctx) # 2. Compile all skills (CPU-bound, single thread) - documents = [ - SkillDocument(skill_id=s.node.id, content=s.content, metadata=s.metadata) - for s in loaded - ] + documents = [SkillDocument(skill_id=s.node.id, content=s.content, metadata=s.metadata) for s in loaded] artifact_set = SkillCompiler().compile_all(documents, tree, ctx.build_id) # 3. Save tool artifact - SkillManager.save_tool_artifact( - ctx.tenant_id, ctx.app_id, ctx.build_id, artifact_set.get_tool_artifact() - ) + SkillManager.save_artifact(ctx.tenant_id, ctx.app_id, ctx.build_id, artifact_set) # 4. Prepare compiled skills for upload to_upload: list[_CompiledSkill] = [] @@ -68,9 +63,7 @@ class SkillBuilder: artifact = artifact_set.get(skill.node.id) if artifact is None: continue - resolved_key = AssetPaths.build_resolved_file( - ctx.tenant_id, ctx.app_id, ctx.build_id, skill.node.id - ) + resolved_key = AssetPaths.build_resolved_file(ctx.tenant_id, ctx.app_id, ctx.build_id, skill.node.id) to_upload.append( _CompiledSkill( node=skill.node, diff --git a/api/core/app_assets/paths.py b/api/core/app_assets/paths.py index c30f73f7fc..fd900729ee 100644 --- a/api/core/app_assets/paths.py +++ b/api/core/app_assets/paths.py @@ -13,10 +13,6 @@ class AssetPaths: def build_resolved_file(tenant_id: str, app_id: str, assets_id: str, node_id: str) -> str: return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/resolved/{node_id}" - @staticmethod - def build_tool_artifact(tenant_id: str, app_id: str, assets_id: str) -> str: - return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/tool_artifact.json" - @staticmethod def build_skill_artifact_set(tenant_id: str, app_id: str, assets_id: str) -> str: return f"{AssetPaths._BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/skill_artifact_set.json" diff --git a/api/core/sandbox/bash/session.py b/api/core/sandbox/bash/session.py index 272d7a004c..b92c7f639e 100644 --- a/api/core/sandbox/bash/session.py +++ b/api/core/sandbox/bash/session.py @@ -8,7 +8,6 @@ from types import TracebackType from core.sandbox.sandbox import Sandbox from core.session.cli_api import CliApiSession, CliApiSessionManager from core.skill.entities.tool_artifact import ToolArtifact -from core.skill.skill_manager import SkillManager from core.virtual_environment.__base.helpers import pipeline from ..bash.dify_cli import DifyCliConfig @@ -19,17 +18,10 @@ logger = logging.getLogger(__name__) class SandboxBashSession: - def __init__( - self, - *, - sandbox: Sandbox, - node_id: str, - allow_tools: list[tuple[str, str]] | None, - ) -> None: + def __init__(self, *, sandbox: Sandbox, node_id: str, tools: ToolArtifact | None) -> None: self._sandbox = sandbox self._node_id = node_id - self._allow_tools = allow_tools - + self._tools = tools self._bash_tool: SandboxBashTool | None = None self._cli_api_session: CliApiSession | None = None self._tenant_id = sandbox.tenant_id @@ -42,8 +34,8 @@ class SandboxBashSession: tenant_id=self._tenant_id, user_id=self._user_id, ) - if self._allow_tools is not None: - tools_path = self._setup_node_tools_directory(self._node_id, self._allow_tools, self._cli_api_session) + if self._tools is not None and not self._tools.is_empty(): + tools_path = self._setup_node_tools_directory(self._node_id, self._tools, self._cli_api_session) else: tools_path = DifyCli.GLOBAL_TOOLS_PATH @@ -57,24 +49,9 @@ class SandboxBashSession: def _setup_node_tools_directory( self, node_id: str, - allow_tools: list[tuple[str, str]], + tools: ToolArtifact, cli_api_session: CliApiSession, ) -> str | None: - artifact: ToolArtifact | None = SkillManager.load_tool_artifact( - self._sandbox.tenant_id, - self._app_id, - self._assets_id, - ) - - if artifact is None or artifact.is_empty(): - logger.info("No tools found in artifact for assets_id=%s", self._assets_id) - return None - - artifact = artifact.filter(allow_tools) - if artifact.is_empty(): - logger.info("No tools found in artifact for assets_id=%s", self._assets_id) - return None - node_tools_path = f"{DifyCli.TOOLS_ROOT}/{node_id}" vm = self._sandbox.vm @@ -86,7 +63,7 @@ class SandboxBashSession: ) config_json = json.dumps( - DifyCliConfig.create(session=cli_api_session, tenant_id=self._tenant_id, artifact=artifact).model_dump( + DifyCliConfig.create(session=cli_api_session, tenant_id=self._tenant_id, artifact=tools).model_dump( mode="json" ), ensure_ascii=False, @@ -98,7 +75,7 @@ class SandboxBashSession: ).execute(raise_on_error=True) logger.info( - "Node %s tools initialized, path=%s, tool_count=%d", node_id, node_tools_path, len(artifact.references) + "Node %s tools initialized, path=%s, tool_count=%d", node_id, node_tools_path, len(tools.references) ) return node_tools_path diff --git a/api/core/sandbox/initializer/dify_cli_initializer.py b/api/core/sandbox/initializer/dify_cli_initializer.py index 0cc77ee0d8..0ac218bcb6 100644 --- a/api/core/sandbox/initializer/dify_cli_initializer.py +++ b/api/core/sandbox/initializer/dify_cli_initializer.py @@ -60,8 +60,8 @@ class DifyCliInitializer(SandboxInitializer): logger.info("Dify CLI uploaded to sandbox, path=%s", DifyCli.PATH) - artifact = SkillManager.load_tool_artifact(self._tenant_id, self._app_id, self._assets_id) - if artifact is None or not artifact.references: + artifact = SkillManager.load_artifact(self._tenant_id, self._app_id, self._assets_id) + if artifact is None or not artifact.get_tool_artifact().is_empty: logger.info("No tools found in artifact for assets_id=%s", self._assets_id) return @@ -72,7 +72,7 @@ class DifyCliInitializer(SandboxInitializer): ["mkdir", "-p", DifyCli.GLOBAL_TOOLS_PATH], error_message="Failed to create global tools dir" ).execute(raise_on_error=True) - config = DifyCliConfig.create(self._cli_api_session, self._tenant_id, artifact) + config = DifyCliConfig.create(self._cli_api_session, self._tenant_id, artifact.get_tool_artifact()) config_json = json.dumps(config.model_dump(mode="json"), ensure_ascii=False) config_path = f"{DifyCli.GLOBAL_TOOLS_PATH}/{DifyCli.CONFIG_FILENAME}" vm.upload_file(config_path, BytesIO(config_json.encode("utf-8"))) diff --git a/api/core/sandbox/manager.py b/api/core/sandbox/manager.py index ee50f9502f..5de620b79b 100644 --- a/api/core/sandbox/manager.py +++ b/api/core/sandbox/manager.py @@ -9,6 +9,7 @@ from core.sandbox.entities import AppAssets, SandboxType from core.sandbox.entities.providers import SandboxProviderEntity from core.sandbox.initializer.app_assets_initializer import AppAssetsInitializer from core.sandbox.initializer.dify_cli_initializer import DifyCliInitializer +from core.sandbox.initializer.skill_initializer import SkillInitializer from core.sandbox.sandbox import Sandbox from core.sandbox.storage.archive_storage import ArchiveSandboxStorage from core.virtual_environment.__base.virtual_environment import VirtualEnvironment @@ -123,6 +124,7 @@ class SandboxManager: .app(app_id) .initializer(AppAssetsInitializer(tenant_id, app_id, assets.id)) .initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id)) + .initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id)) .storage(storage, assets.id) .build() ) @@ -130,6 +132,11 @@ class SandboxManager: logger.info("Sandbox created: id=%s, assets=%s", sandbox.vm.metadata.id, sandbox.assets_id) return sandbox + @classmethod + def delete_storage(cls, tenant_id: str, user_id: str) -> None: + storage = ArchiveSandboxStorage(tenant_id, SandboxBuilder.draft_id(user_id)) + storage.delete() + @classmethod def create_draft( cls, @@ -153,6 +160,7 @@ class SandboxManager: .app(app_id) .initializer(AppAssetsInitializer(tenant_id, app_id, assets.id)) .initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id)) + .initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id)) .storage(storage, assets.id) .build() ) @@ -183,6 +191,7 @@ class SandboxManager: .app(app_id) .initializer(AppAssetsInitializer(tenant_id, app_id, assets.id)) .initializer(DifyCliInitializer(tenant_id, user_id, app_id, assets.id)) + .initializer(SkillInitializer(tenant_id, user_id, app_id, assets.id)) .storage(storage, assets.id) .build() ) diff --git a/api/core/skill/entities/tool_artifact.py b/api/core/skill/entities/tool_artifact.py index dcd6d682b0..3a3b424479 100644 --- a/api/core/skill/entities/tool_artifact.py +++ b/api/core/skill/entities/tool_artifact.py @@ -35,3 +35,25 @@ class ToolArtifact(BaseModel): if f"{reference.provider}.{reference.tool_name}" in tool_names ], ) + + def merge(self, other: "ToolArtifact") -> "ToolArtifact": + dep_map: dict[str, ToolDependency] = {} + for dep in self.dependencies: + key = f"{dep.provider}.{dep.tool_name}" + dep_map[key] = dep + for dep in other.dependencies: + key = f"{dep.provider}.{dep.tool_name}" + if key not in dep_map: + dep_map[key] = dep + + ref_map: dict[str, ToolReference] = {} + for ref in self.references: + ref_map[ref.uuid] = ref + for ref in other.references: + if ref.uuid not in ref_map: + ref_map[ref.uuid] = ref + + return ToolArtifact( + dependencies=list(dep_map.values()), + references=list(ref_map.values()), + ) \ No newline at end of file diff --git a/api/core/skill/skill_manager.py b/api/core/skill/skill_manager.py index 16021561f5..c222c4b996 100644 --- a/api/core/skill/skill_manager.py +++ b/api/core/skill/skill_manager.py @@ -1,71 +1,9 @@ -from core.app.entities.app_asset_entities import AppAssetFileTree -from core.app_assets.entities import SkillAsset from core.app_assets.paths import AssetPaths from core.skill.entities.skill_artifact_set import SkillArtifactSet -from core.skill.entities.skill_document import SkillDocument -from core.skill.skill_compiler import SkillCompiler from extensions.ext_storage import storage -from .entities import ToolArtifact - class SkillManager: - @staticmethod - def _load_content(storage_key: str) -> str: - import json - - try: - data = json.loads(storage.load_once(storage_key)) - return data.get("content", "") if isinstance(data, dict) else "" - except Exception: - return "" - - @staticmethod - def save_tool_artifact( - tenant_id: str, - app_id: str, - assets_id: str, - artifact: ToolArtifact, - ) -> None: - key = AssetPaths.build_tool_artifact(tenant_id, app_id, assets_id) - storage.save(key, artifact.model_dump_json(indent=2).encode("utf-8")) - - @staticmethod - def load_tool_artifact( - tenant_id: str, - app_id: str, - assets_id: str, - ) -> ToolArtifact | None: - key = AssetPaths.build_tool_artifact(tenant_id, app_id, assets_id) - try: - data = storage.load_once(key) - return ToolArtifact.model_validate_json(data) - except Exception: - return None - - @staticmethod - def compile_all( - documents: list[SkillDocument], - file_tree: AppAssetFileTree, - assets_id: str, - ) -> SkillArtifactSet: - compiler = SkillCompiler() - return compiler.compile_all(documents, file_tree, assets_id) - - @staticmethod - def assets_to_documents(assets: list[SkillAsset]) -> list[SkillDocument]: - documents: list[SkillDocument] = [] - for asset in assets: - content = SkillManager._load_content(asset.storage_key) - documents.append( - SkillDocument( - skill_id=asset.asset_id, - content=content, - metadata=asset.metadata, - ) - ) - return documents - @staticmethod def load_artifact( tenant_id: str, diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index e8972d5432..903f6acd88 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -59,6 +59,8 @@ class PromptConfig(BaseModel): class LLMNodeChatModelMessage(ChatModelMessage): text: str = "" jinja2_text: str | None = None + skill: bool = False + metadata: Mapping[str, Any] | None = None class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index e7c8036862..587e3af6b2 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -7,13 +7,16 @@ import logging import re import time from collections.abc import Generator, Mapping, Sequence +from functools import reduce from typing import TYPE_CHECKING, Any, Literal, cast from sqlalchemy import select from core.agent.entities import AgentEntity, AgentLog, AgentResult, AgentToolEntity, ExecutionContext from core.agent.patterns import StrategyFactory +from core.app.entities.app_asset_entities import AppAssetFileTree from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.app_assets.constants import AppAssetsAttrs from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError @@ -52,6 +55,11 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.sandbox import Sandbox from core.sandbox.bash.session import SandboxBashSession +from core.skill.constants import SkillAttrs +from core.skill.entities.skill_artifact_set import SkillArtifactSet +from core.skill.entities.skill_document import SkillDocument +from core.skill.entities.tool_artifact import ToolArtifact +from core.skill.skill_compiler import SkillCompiler from core.tools.__base.tool import Tool from core.tools.signature import sign_upload_file from core.tools.tool_manager import ToolManager @@ -281,6 +289,7 @@ class LLMNode(Node[LLMNodeData]): jinja2_variables=self.node_data.prompt_config.jinja2_variables, tenant_id=self.tenant_id, context_files=context_files, + sandbox=self.graph_runtime_state.sandbox, ) # Variables for outputs @@ -289,12 +298,14 @@ class LLMNode(Node[LLMNodeData]): sandbox = self.graph_runtime_state.sandbox if sandbox: + tool_artifact = self._extract_tool_artifact() generator = self._invoke_llm_with_sandbox( sandbox=sandbox, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, variable_pool=variable_pool, + tool_artifact=tool_artifact, ) elif self.tool_call_enabled: generator = self._invoke_llm_with_tools( @@ -847,6 +858,7 @@ class LLMNode(Node[LLMNodeData]): jinja2_variables=self.node_data.prompt_config.jinja2_variables or [], variable_pool=variable_pool, vision_detail_config=self.node_data.vision.configs.detail, + sandbox=self.graph_runtime_state.sandbox, ) combined_messages.extend(processed_msgs) static_idx += 1 @@ -1181,6 +1193,7 @@ class LLMNode(Node[LLMNodeData]): jinja2_variables: Sequence[VariableSelector], tenant_id: str, context_files: list[File] | None = None, + sandbox: Sandbox | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -1193,6 +1206,7 @@ class LLMNode(Node[LLMNodeData]): jinja2_variables=jinja2_variables, variable_pool=variable_pool, vision_detail_config=vision_detail, + sandbox=sandbox, ) ) @@ -1473,8 +1487,17 @@ class LLMNode(Node[LLMNodeData]): jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, + sandbox: Sandbox | None = None, ) -> Sequence[PromptMessage]: prompt_messages: list[PromptMessage] = [] + + # Extract skill compilation context from sandbox if available + artifact_set: SkillArtifactSet | None = None + file_tree: AppAssetFileTree | None = None + if sandbox: + artifact_set = sandbox.attrs.get(SkillAttrs.ARTIFACT_SET) + file_tree = sandbox.attrs.get(AppAssetsAttrs.FILE_TREE) + for message in messages: if message.edition_type == "jinja2": result_text = _render_jinja2_message( @@ -1482,6 +1505,16 @@ class LLMNode(Node[LLMNodeData]): jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) + + # Compile skill references after jinja2 rendering + if artifact_set is not None and file_tree is not None: + skill_artifact = SkillCompiler().compile_one( + artifact_set, + SkillDocument(skill_id="anonymous", content=result_text, metadata={}), + file_tree, + ) + result_text = skill_artifact.content + prompt_message = _combine_message_content_with_role( contents=[TextPromptMessageContent(data=result_text)], role=message.role ) @@ -1514,6 +1547,16 @@ class LLMNode(Node[LLMNodeData]): # Create message with text from all segments plain_text = segment_group.text + + # Compile skill references after context and variable substitution + if plain_text and artifact_set is not None and file_tree is not None: + skill_artifact = SkillCompiler().compile_one( + artifact_set, + SkillDocument(skill_id="anonymous", content=plain_text, metadata={}), + file_tree, + ) + plain_text = skill_artifact.content + if plain_text: prompt_message = _combine_message_content_with_role( contents=[TextPromptMessageContent(data=plain_text)], role=message.role @@ -1767,6 +1810,28 @@ class LLMNode(Node[LLMNodeData]): generation_data, ) + def _extract_tool_artifact(self) -> ToolArtifact | None: + """Extract tool artifact from prompt template.""" + + sandbox = self.graph_runtime_state.sandbox + if not sandbox: + raise LLMNodeError("Sandbox not found") + + artifact_set = sandbox.attrs.get(SkillAttrs.ARTIFACT_SET) + file_tree = sandbox.attrs.get(AppAssetsAttrs.FILE_TREE) + tool_artifacts: list[ToolArtifact] = [] + for prompt in self.node_data.prompt_template: + if isinstance(prompt, LLMNodeChatModelMessage): + skill_artifact = SkillCompiler().compile_one( + artifact_set, SkillDocument(skill_id="anonymous", content=prompt.text, metadata={}), file_tree + ) + tool_artifacts.append(skill_artifact.tools) + + if len(tool_artifacts) == 0: + return None + + return reduce(lambda x, y: x.merge(y), tool_artifacts) + def _invoke_llm_with_tools( self, model_instance: ModelInstance, @@ -1811,17 +1876,6 @@ class LLMNode(Node[LLMNodeData]): result = yield from self._process_tool_outputs(outputs) return result - def _get_allow_tools_list(self) -> list[tuple[str, str]] | None: - if not self._node_data.tools: - return None - - allow_tools = [] - for tool in self._node_data.tools: - if not tool.enabled: - continue - allow_tools.append((tool.provider_name, tool.tool_name)) - return allow_tools or None - def _invoke_llm_with_sandbox( self, sandbox: Sandbox, @@ -1829,12 +1883,11 @@ class LLMNode(Node[LLMNodeData]): prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None, variable_pool: VariablePool, + tool_artifact: ToolArtifact | None, ) -> Generator[NodeEventBase, None, LLMGenerationData]: - allow_tools = self._get_allow_tools_list() - result: LLMGenerationData | None = None - with SandboxBashSession(sandbox=sandbox, node_id=self.id, allow_tools=allow_tools) as session: + with SandboxBashSession(sandbox=sandbox, node_id=self.id, tools=tool_artifact) as session: prompt_files = self._extract_prompt_files(variable_pool) model_features = self._get_model_features(model_instance) diff --git a/api/libs/attr_map.py b/api/libs/attr_map.py index a200d72d42..cad546afce 100644 --- a/api/libs/attr_map.py +++ b/api/libs/attr_map.py @@ -7,8 +7,8 @@ to the same AttrKey instance can read/write the corresponding attribute. SESSION_KEY: AttrKey[Session] = AttrKey("session", Session) attrs = AttrMap() attrs.set(SESSION_KEY, session) - session = attrs.get(SESSION_KEY) # -> Session | None - session = attrs.require(SESSION_KEY) # -> Session (raises if not set) + session = attrs.get(SESSION_KEY) # -> Session (raises if not set) + session = attrs.get_or_none(SESSION_KEY) # -> Session | None Note: AttrMap is NOT thread-safe. Each instance should be confined to a single thread/context (e.g., one AttrMap per Sandbox/VirtualEnvironment instance). @@ -106,7 +106,13 @@ class AttrMap: raise AttrMapTypeError(key, key.type_, type(value)) self._data[key] = value - def get(self, key: AttrKey[T]) -> T | None: + def get(self, key: AttrKey[T]) -> T: + """Retrieve an attribute. Raises AttrMapKeyError if not set.""" + if key not in self._data: + raise AttrMapKeyError(key) + return cast(T, self._data[key]) + + def get_or_none(self, key: AttrKey[T]) -> T | None: """Retrieve an attribute, returning None if not set.""" return cast(T | None, self._data.get(key)) @@ -122,12 +128,6 @@ class AttrMap: return cast(T, self._data[key]) return default - def require(self, key: AttrKey[T]) -> T: - """Retrieve an attribute, raising AttrMapKeyError if not set.""" - if key not in self._data: - raise AttrMapKeyError(key) - return cast(T, self._data[key]) - def has(self, key: AttrKey[Any]) -> bool: """Check if an attribute is set.""" return key in self._data diff --git a/api/tests/unit_tests/libs/test_attr_map.py b/api/tests/unit_tests/libs/test_attr_map.py index 0e0bfcea1d..bbca2c5441 100644 --- a/api/tests/unit_tests/libs/test_attr_map.py +++ b/api/tests/unit_tests/libs/test_attr_map.py @@ -51,11 +51,27 @@ class TestAttrMap: assert result == "hello" - def test_get_returns_none_for_missing(self): + def test_get_raises_when_not_set(self): key: AttrKey[str] = AttrKey("session", str) attrs = AttrMap() - assert attrs.get(key) is None + with pytest.raises(AttrMapKeyError) as exc_info: + attrs.get(key) + + assert exc_info.value.key is key + + def test_get_or_none_returns_none_for_missing(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + + assert attrs.get_or_none(key) is None + + def test_get_or_none_returns_value_when_set(self): + key: AttrKey[str] = AttrKey("session", str) + attrs = AttrMap() + attrs.set(key, "hello") + + assert attrs.get_or_none(key) == "hello" def test_get_or_default_returns_value_when_set(self): key: AttrKey[str] = AttrKey("session", str) @@ -74,25 +90,6 @@ class TestAttrMap: assert result == "default" - def test_require_returns_value_when_set(self): - key: AttrKey[str] = AttrKey("session", str) - attrs = AttrMap() - attrs.set(key, "hello") - - result = attrs.require(key) - - assert result == "hello" - - def test_require_raises_when_not_set(self): - key: AttrKey[str] = AttrKey("session", str) - attrs = AttrMap() - - with pytest.raises(AttrMapKeyError) as exc_info: - attrs.require(key) - - assert exc_info.value.key is key - assert "session" in str(exc_info.value) - def test_has(self): key: AttrKey[str] = AttrKey("session", str) attrs = AttrMap()