mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
refactor: inject workflow node memory via protocol (#32784)
This commit is contained in:
@ -1,22 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import PromptMessageRole
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
|
||||
|
||||
from .exc import InvalidVariableTypeError
|
||||
from .protocols import PromptMessageMemory
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||
@ -42,23 +41,51 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
|
||||
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
def convert_history_messages_to_text(
|
||||
*,
|
||||
history_messages: Sequence[PromptMessage],
|
||||
human_prefix: str,
|
||||
ai_prefix: str,
|
||||
) -> str:
|
||||
string_messages: list[str] = []
|
||||
for message in history_messages:
|
||||
if message.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif message.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
# get conversation id
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
if isinstance(message.content, list):
|
||||
content_parts = []
|
||||
for content in message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
content_parts.append(content.data)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
content_parts.append("[image]")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
inner_msg = "\n".join(content_parts)
|
||||
string_messages.append(f"{role}: {inner_msg}")
|
||||
else:
|
||||
string_messages.append(f"{role}: {message.content}")
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
def fetch_memory_text(
|
||||
*,
|
||||
memory: PromptMessageMemory,
|
||||
max_token_limit: int,
|
||||
message_limit: int | None = None,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
) -> str:
|
||||
history_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
return convert_history_messages_to_text(
|
||||
history_messages=history_messages,
|
||||
human_prefix=human_prefix,
|
||||
ai_prefix=ai_prefix,
|
||||
)
|
||||
|
||||
@ -1338,48 +1338,16 @@ def _handle_memory_completion_mode(
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
memory_text = llm_utils.fetch_memory_text(
|
||||
memory=memory,
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
memory_text = _convert_history_messages_to_text(
|
||||
history_messages=memory_messages,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
|
||||
def _convert_history_messages_to_text(
|
||||
*,
|
||||
history_messages: Sequence[PromptMessage],
|
||||
human_prefix: str,
|
||||
ai_prefix: str,
|
||||
) -> str:
|
||||
string_messages: list[str] = []
|
||||
for message in history_messages:
|
||||
if message.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif message.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(message.content, list):
|
||||
content_parts = []
|
||||
for content in message.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
content_parts.append(content.data)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
content_parts.append("[image]")
|
||||
|
||||
inner_msg = "\n".join(content_parts)
|
||||
string_messages.append(f"{role}: {inner_msg}")
|
||||
else:
|
||||
string_messages.append(f"{role}: {message.content}")
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
|
||||
Reference in New Issue
Block a user