mirror of
https://github.com/langgenius/dify.git
synced 2026-05-01 16:08:04 +08:00
Merge branch 'feat/memory-orchestration-be-dev-env' into deploy/dev
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
MEMORY_BLOCK_VARIABLE_NODE_ID = "memory_block"
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||
|
||||
@ -55,6 +55,10 @@ class VariablePool(BaseModel):
|
||||
description="RAG pipeline variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
memory_blocks: Mapping[str, str] = Field(
|
||||
description="Memory blocks.",
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /):
|
||||
# Create a mapping from field names to SystemVariableKey enum values
|
||||
@ -75,6 +79,9 @@ class VariablePool(BaseModel):
|
||||
rag_pipeline_variables_map[node_id][key] = value
|
||||
for key, value in rag_pipeline_variables_map.items():
|
||||
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
||||
# Add memory blocks to the variable pool
|
||||
for memory_id, memory_value in self.memory_blocks.items():
|
||||
self.add([CONVERSATION_VARIABLE_NODE_ID, memory_id], memory_value)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /):
|
||||
"""
|
||||
|
||||
@ -105,10 +105,10 @@ class RedisChannel:
|
||||
command_type = CommandType(command_type_value)
|
||||
|
||||
if command_type == CommandType.ABORT:
|
||||
return AbortCommand(**data)
|
||||
return AbortCommand.model_validate(data)
|
||||
else:
|
||||
# For other command types, use base class
|
||||
return GraphEngineCommand(**data)
|
||||
return GraphEngineCommand.model_validate(data)
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
@ -16,7 +16,7 @@ class EndNode(Node):
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = EndNodeData(**data)
|
||||
self._node_data = EndNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -342,10 +342,13 @@ class IterationNode(Node):
|
||||
iterator_list_value: Sequence[object],
|
||||
iter_run_map: dict[str, float],
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists
|
||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||
|
||||
yield IterationSucceededEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
@ -357,13 +360,39 @@ class IterationNode(Node):
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": flattened_outputs},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]:
|
||||
"""
|
||||
Flatten the outputs list if all elements are lists.
|
||||
This maintains backward compatibility with version 1.8.1 behavior.
|
||||
"""
|
||||
if not outputs:
|
||||
return outputs
|
||||
|
||||
# Check if all non-None outputs are lists
|
||||
non_none_outputs = [output for output in outputs if output is not None]
|
||||
if not non_none_outputs:
|
||||
return outputs
|
||||
|
||||
if all(isinstance(output, list) for output in non_none_outputs):
|
||||
# Flatten the list of lists
|
||||
flattened: list[Any] = []
|
||||
for output in outputs:
|
||||
if isinstance(output, list):
|
||||
flattened.extend(output)
|
||||
elif output is not None:
|
||||
# This shouldn't happen based on our check, but handle it gracefully
|
||||
flattened.append(output)
|
||||
return flattened
|
||||
|
||||
return outputs
|
||||
|
||||
def _handle_iteration_failure(
|
||||
self,
|
||||
started_at: datetime,
|
||||
@ -373,10 +402,13 @@ class IterationNode(Node):
|
||||
iter_run_map: dict[str, float],
|
||||
error: IterationNodeError,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
# Flatten the list of lists if all outputs are lists (even in failure case)
|
||||
flattened_outputs = self._flatten_outputs_if_needed(outputs)
|
||||
|
||||
yield IterationFailedEvent(
|
||||
start_at=started_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
outputs={"output": flattened_outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
|
||||
|
||||
@ -18,7 +18,7 @@ class IterationStartNode(Node):
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationStartNodeData(**data)
|
||||
self._node_data = IterationStartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -41,7 +41,7 @@ class ListOperatorNode(Node):
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ListOperatorNodeData(**data)
|
||||
self._node_data = ListOperatorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -6,11 +6,15 @@ import re
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.entities import MemoryCreatedBy, MemoryScope
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
@ -71,6 +75,8 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from models import UserFrom, Workflow
|
||||
from models.engine import db
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
@ -315,6 +321,11 @@ class LLMNode(Node):
|
||||
if self._file_outputs:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
try:
|
||||
self._handle_chatflow_memory(result_text, variable_pool)
|
||||
except Exception as e:
|
||||
logger.warning("Memory orchestration failed for node %s: %s", self.node_id, str(e))
|
||||
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
@ -1184,6 +1195,79 @@ class LLMNode(Node):
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
def _handle_chatflow_memory(self, llm_output: str, variable_pool: VariablePool):
|
||||
if not self._node_data.memory or self._node_data.memory.mode != "block":
|
||||
return
|
||||
|
||||
conversation_id_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.CONVERSATION_ID))
|
||||
if not conversation_id_segment:
|
||||
raise ValueError("Conversation ID not found in variable pool.")
|
||||
conversation_id = conversation_id_segment.text
|
||||
|
||||
user_query_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
if not user_query_segment:
|
||||
raise ValueError("User query not found in variable pool.")
|
||||
user_query = user_query_segment.text
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from services.chatflow_history_service import ChatflowHistoryService
|
||||
|
||||
ChatflowHistoryService.save_node_message(
|
||||
prompt_message=(UserPromptMessage(content=user_query)),
|
||||
node_id=self.node_id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id
|
||||
)
|
||||
ChatflowHistoryService.save_node_message(
|
||||
prompt_message=(AssistantPromptMessage(content=llm_output)),
|
||||
node_id=self.node_id,
|
||||
conversation_id=conversation_id,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id
|
||||
)
|
||||
|
||||
memory_config = self._node_data.memory
|
||||
if not memory_config:
|
||||
return
|
||||
block_ids = memory_config.block_id
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == self.tenant_id,
|
||||
Workflow.app_id == self.app_id
|
||||
)
|
||||
workflow = session.scalars(stmt).first()
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found.")
|
||||
memory_blocks = workflow.memory_blocks
|
||||
|
||||
for block_id in block_ids:
|
||||
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
|
||||
|
||||
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
|
||||
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
ChatflowMemoryService.update_node_memory_if_needed(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
node_id=self.id,
|
||||
conversation_id=conversation_id,
|
||||
memory_block_spec=memory_block_spec,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft,
|
||||
created_by=self._get_user_from_context()
|
||||
)
|
||||
|
||||
def _get_user_from_context(self) -> MemoryCreatedBy:
|
||||
if self.user_from == UserFrom.ACCOUNT:
|
||||
return MemoryCreatedBy(account_id=self.user_id)
|
||||
else:
|
||||
return MemoryCreatedBy(end_user_id=self.user_id)
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
|
||||
@ -18,7 +18,7 @@ class LoopEndNode(Node):
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
self._node_data = LoopEndNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -18,7 +18,7 @@ class LoopStartNode(Node):
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
self._node_data = LoopStartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -16,7 +16,7 @@ class StartNode(Node):
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = StartNodeData(**data)
|
||||
self._node_data = StartNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
@ -15,7 +15,7 @@ class VariableAggregatorNode(Node):
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData(**data)
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
Reference in New Issue
Block a user