mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
refactor: inject workflow node memory via protocol (#32784)
This commit is contained in:
@ -1,6 +1,8 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, cast, final
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
@ -11,6 +13,7 @@ from core.helper.code_executor.code_executor import (
|
||||
CodeExecutor,
|
||||
)
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@ -18,7 +21,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.enums import NodeType, SystemVariableKey
|
||||
from core.workflow.file.file_manager import file_manager
|
||||
from core.workflow.graph.graph import NodeFactory
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@ -29,7 +32,6 @@ from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
@ -41,12 +43,34 @@ from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.variables.segments import StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
*,
|
||||
conversation_id: str | None,
|
||||
app_id: str,
|
||||
node_data_memory: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory or not conversation_id:
|
||||
return None
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
||||
class DefaultWorkflowCodeExecutor:
|
||||
def execute(
|
||||
self,
|
||||
@ -221,6 +245,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@ -229,10 +254,12 @@ class DifyNodeFactory(NodeFactory):
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@ -241,6 +268,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
@ -295,8 +323,14 @@ class DifyNodeFactory(NodeFactory):
|
||||
return None
|
||||
|
||||
node_memory = MemoryConfig.model_validate(raw_memory_config)
|
||||
return llm_utils.fetch_memory(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
conversation_id = (
|
||||
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
|
||||
)
|
||||
return fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id=self.graph_init_params.app_id,
|
||||
node_data_memory=node_memory,
|
||||
model_instance=model_instance,
|
||||
|
||||
Reference in New Issue
Block a user