refactor: inject workflow node memory via protocol (#32784)

This commit is contained in:
-LAN-
2026-03-02 01:55:49 +08:00
committed by GitHub
parent ef2b5d6107
commit 69b3e94630
7 changed files with 130 additions and 108 deletions

View File

@ -1,22 +1,21 @@
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import PromptMessageRole
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from extensions.ext_database import db
from models.model import Conversation
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
from .exc import InvalidVariableTypeError
from .protocols import PromptMessageMemory
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
@ -42,23 +41,51 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
def convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
return "\n".join(string_messages)
def fetch_memory_text(
*,
memory: PromptMessageMemory,
max_token_limit: int,
message_limit: int | None = None,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
) -> str:
history_messages = memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
return convert_history_messages_to_text(
history_messages=history_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
)

View File

@ -1338,48 +1338,16 @@ def _handle_memory_completion_mode(
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_messages = memory.get_history_prompt_messages(
memory_text = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
memory_text = _convert_history_messages_to_text(
history_messages=memory_messages,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,