refactor: inject memory interface into LLMNode (#32754)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2026-03-01 04:05:18 +08:00
committed by GitHub
parent 1f0fca89a8
commit c034eb036c
4 changed files with 115 additions and 14 deletions

View File

@ -12,6 +12,7 @@ from core.helper.ssrf_proxy import ssrf_proxy
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
@ -26,9 +27,11 @@ from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import PromptMessageMemory
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
@ -177,6 +180,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.LLM:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return LLMNode(
id=node_id,
config=node_config,
@ -185,6 +189,7 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.DATASOURCE:
@ -278,3 +283,21 @@ class DifyNodeFactory(NodeFactory):
model_instance.stop = tuple(stop)
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance
def _build_memory_for_llm_node(
self,
*,
node_data: Mapping[str, Any],
model_instance: ModelInstance,
) -> PromptMessageMemory | None:
raw_memory_config = node_data.get("memory")
if raw_memory_config is None:
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
return llm_utils.fetch_memory(
variable_pool=self.graph_runtime_state.variable_pool,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,
)