Merge branch 'feat/memory-orchestration-be-dev-env' into deploy/dev

This commit is contained in:
Stream
2025-10-11 16:16:15 +08:00
156 changed files with 3100 additions and 806 deletions

View File

@ -1,4 +1,5 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
MEMORY_BLOCK_VARIABLE_NODE_ID = "memory_block"
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"

View File

@ -55,6 +55,10 @@ class VariablePool(BaseModel):
description="RAG pipeline variables.",
default_factory=list,
)
memory_blocks: Mapping[str, str] = Field(
description="Memory blocks.",
default_factory=dict,
)
def model_post_init(self, context: Any, /):
# Create a mapping from field names to SystemVariableKey enum values
@ -75,6 +79,9 @@ class VariablePool(BaseModel):
rag_pipeline_variables_map[node_id][key] = value
for key, value in rag_pipeline_variables_map.items():
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
# Add memory blocks to the variable pool
for memory_id, memory_value in self.memory_blocks.items():
self.add([CONVERSATION_VARIABLE_NODE_ID, memory_id], memory_value)
def add(self, selector: Sequence[str], value: Any, /):
"""

View File

@ -105,10 +105,10 @@ class RedisChannel:
command_type = CommandType(command_type_value)
if command_type == CommandType.ABORT:
return AbortCommand(**data)
return AbortCommand.model_validate(data)
else:
# For other command types, use base class
return GraphEngineCommand(**data)
return GraphEngineCommand.model_validate(data)
except (ValueError, TypeError):
return None

View File

@ -16,7 +16,7 @@ class EndNode(Node):
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData(**data)
self._node_data = EndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -342,10 +342,13 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists
flattened_outputs = self._flatten_outputs_if_needed(outputs)
yield IterationSucceededEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
@ -357,13 +360,39 @@ class IterationNode(Node):
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs},
outputs={"output": flattened_outputs},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
},
)
)
def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]:
"""
Flatten the outputs list if all elements are lists.
This maintains backward compatibility with version 1.8.1 behavior.
"""
if not outputs:
return outputs
# Check if all non-None outputs are lists
non_none_outputs = [output for output in outputs if output is not None]
if not non_none_outputs:
return outputs
if all(isinstance(output, list) for output in non_none_outputs):
# Flatten the list of lists
flattened: list[Any] = []
for output in outputs:
if isinstance(output, list):
flattened.extend(output)
elif output is not None:
# This shouldn't happen based on our check, but handle it gracefully
flattened.append(output)
return flattened
return outputs
def _handle_iteration_failure(
self,
started_at: datetime,
@ -373,10 +402,13 @@ class IterationNode(Node):
iter_run_map: dict[str, float],
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists (even in failure case)
flattened_outputs = self._flatten_outputs_if_needed(outputs)
yield IterationFailedEvent(
start_at=started_at,
inputs=inputs,
outputs={"output": outputs},
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,

View File

@ -18,7 +18,7 @@ class IterationStartNode(Node):
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData(**data)
self._node_data = IterationStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -41,7 +41,7 @@ class ListOperatorNode(Node):
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData(**data)
self._node_data = ListOperatorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -6,11 +6,15 @@ import re
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
@ -71,6 +75,8 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from models import UserFrom, Workflow
from models.engine import db
from . import llm_utils
from .entities import (
@ -315,6 +321,11 @@ class LLMNode(Node):
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
try:
self._handle_chatflow_memory(result_text, variable_pool)
except Exception as e:
logger.warning("Memory orchestration failed for node %s: %s", self.node_id, str(e))
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
@ -1184,6 +1195,79 @@ class LLMNode(Node):
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled
def _handle_chatflow_memory(self, llm_output: str, variable_pool: VariablePool):
if not self._node_data.memory or self._node_data.memory.mode != "block":
return
conversation_id_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.CONVERSATION_ID))
if not conversation_id_segment:
raise ValueError("Conversation ID not found in variable pool.")
conversation_id = conversation_id_segment.text
user_query_segment = variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
if not user_query_segment:
raise ValueError("User query not found in variable pool.")
user_query = user_query_segment.text
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
from services.chatflow_history_service import ChatflowHistoryService
ChatflowHistoryService.save_node_message(
prompt_message=(UserPromptMessage(content=user_query)),
node_id=self.node_id,
conversation_id=conversation_id,
app_id=self.app_id,
tenant_id=self.tenant_id
)
ChatflowHistoryService.save_node_message(
prompt_message=(AssistantPromptMessage(content=llm_output)),
node_id=self.node_id,
conversation_id=conversation_id,
app_id=self.app_id,
tenant_id=self.tenant_id
)
memory_config = self._node_data.memory
if not memory_config:
return
block_ids = memory_config.block_id
if not block_ids:
return
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
with Session(db.engine) as session:
stmt = select(Workflow).where(
Workflow.tenant_id == self.tenant_id,
Workflow.app_id == self.app_id
)
workflow = session.scalars(stmt).first()
if not workflow:
raise ValueError("Workflow not found.")
memory_blocks = workflow.memory_blocks
for block_id in block_ids:
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
from services.chatflow_memory_service import ChatflowMemoryService
ChatflowMemoryService.update_node_memory_if_needed(
tenant_id=self.tenant_id,
app_id=self.app_id,
node_id=self.id,
conversation_id=conversation_id,
memory_block_spec=memory_block_spec,
variable_pool=variable_pool,
is_draft=is_draft,
created_by=self._get_user_from_context()
)
def _get_user_from_context(self) -> MemoryCreatedBy:
if self.user_from == UserFrom.ACCOUNT:
return MemoryCreatedBy(account_id=self.user_id)
else:
return MemoryCreatedBy(end_user_id=self.user_id)
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole

View File

@ -18,7 +18,7 @@ class LoopEndNode(Node):
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData(**data)
self._node_data = LoopEndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -18,7 +18,7 @@ class LoopStartNode(Node):
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData(**data)
self._node_data = LoopStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -16,7 +16,7 @@ class StartNode(Node):
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData(**data)
self._node_data = StartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -15,7 +15,7 @@ class VariableAggregatorNode(Node):
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData(**data)
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy