refactor: inject workflow node memory via protocol (#32784)

This commit is contained in:
-LAN-
2026-03-02 01:55:49 +08:00
committed by GitHub
parent ef2b5d6107
commit 69b3e94630
7 changed files with 130 additions and 108 deletions

View File

@ -1,6 +1,8 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
@ -11,6 +13,7 @@ from core.helper.code_executor.code_executor import (
CodeExecutor,
)
from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@ -18,7 +21,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.enums import NodeType, SystemVariableKey
from core.workflow.file.file_manager import file_manager
from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node
@ -29,7 +32,6 @@ from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
@ -41,12 +43,34 @@ from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def fetch_memory(
*,
conversation_id: str | None,
app_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
if not node_data_memory or not conversation_id:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
class DefaultWorkflowCodeExecutor:
def execute(
self,
@ -221,6 +245,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return QuestionClassifierNode(
id=node_id,
config=node_config,
@ -229,10 +254,12 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return ParameterExtractorNode(
id=node_id,
config=node_config,
@ -241,6 +268,7 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
return node_class(
@ -295,8 +323,14 @@ class DifyNodeFactory(NodeFactory):
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
return llm_utils.fetch_memory(
variable_pool=self.graph_runtime_state.variable_pool,
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
conversation_id = (
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
)
return fetch_memory(
conversation_id=conversation_id,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,