Files
dify/api/dify_graph/nodes/llm/llm_utils.py
Novice 6b75188ddc fix: resolve import migrations and test failures after segment 3 merge
- Migrate core.model_runtime -> dify_graph.model_runtime across 20+ files
- Migrate core.workflow.file -> dify_graph.file across 15+ files
- Migrate core.workflow.enums -> dify_graph.enums in service files
- Fix SandboxContext phantom import in dify_graph/context/__init__.py
- Fix core.app.workflow.node_factory -> core.workflow.node_factory
- Fix toast import paths (useToastContext from toast/context)
- Fix app-info.tsx import paths for relocated app-operations
- Fix 15 frontend test files for API changes, missing QueryClientProvider,
  i18n key renames, and component behavior changes

Made-with: Cursor
2026-03-23 10:31:11 +08:00

288 lines
10 KiB
Python

from collections.abc import Sequence
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.memory import NodeTokenBufferMemory, TokenBufferMemory
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
from dify_graph.enums import SystemVariableKey
from dify_graph.file.models import File
from dify_graph.model_runtime.entities import PromptMessageRole
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContentUnionTypes,
TextPromptMessageContent,
ToolPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.llm.entities import LLMGenerationData
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from .exc import InvalidVariableTypeError
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
model_instance.model_name,
model_instance.credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_instance.model_name}")
return model_schema
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
variable = variable_pool.get(selector)
if variable is None:
return []
elif isinstance(variable, FileSegment):
return [variable.value]
elif isinstance(variable, ArrayFileSegment):
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def fetch_memory(
variable_pool: VariablePool,
app_id: str,
tenant_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
node_id: str = "",
) -> BaseMemory | None:
"""
Fetch memory based on configuration mode.
Returns TokenBufferMemory for conversation mode (default),
or NodeTokenBufferMemory for node mode (Chatflow only).
"""
if not node_data_memory:
return None
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
if node_data_memory.mode == MemoryMode.NODE:
if not node_id:
return None
return NodeTokenBufferMemory(
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
tenant_id=tenant_id,
model_instance=model_instance,
)
else:
from extensions.ext_database import db
from models.model import Conversation
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def fetch_memory_text(
*,
memory: PromptMessageMemory,
max_token_limit: int,
message_limit: int | None = None,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
) -> str:
history_messages = memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
return convert_history_messages_to_text(
history_messages=history_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
)
def build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
generation_data: LLMGenerationData | None = None,
files: Sequence[Any] | None = None,
) -> list[PromptMessage]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
"""
context_messages: list[PromptMessage] = [
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
]
file_suffix = ""
if files:
file_descriptions = _build_file_descriptions(files)
if file_descriptions:
file_suffix = f"\n\n{file_descriptions}"
if generation_data and generation_data.trace:
context_messages.extend(_build_messages_from_trace(generation_data, assistant_response, file_suffix))
else:
context_messages.append(AssistantPromptMessage(content=assistant_response + file_suffix))
return context_messages
def _build_file_descriptions(files: Sequence[Any]) -> str:
if not files:
return ""
descriptions: list[str] = ["[Generated Files]"]
for file in files:
file_id = getattr(file, "id", None) or getattr(file, "related_id", None)
filename = getattr(file, "filename", "unknown")
file_type = getattr(file, "type", "unknown")
if hasattr(file_type, "value"):
file_type = file_type.value
if file_id:
descriptions.append(f"- {filename} (id: {file_id}, type: {file_type})")
return "\n".join(descriptions)
def _build_messages_from_trace(
generation_data: LLMGenerationData,
assistant_response: str,
file_suffix: str = "",
) -> list[PromptMessage]:
from dify_graph.nodes.llm.entities import ModelTraceSegment, ToolTraceSegment
messages: list[PromptMessage] = []
covered_text_len = 0
for segment in generation_data.trace:
if segment.type == "model" and isinstance(segment.output, ModelTraceSegment):
model_output = segment.output
segment_content = model_output.text or ""
covered_text_len += len(segment_content)
if model_output.tool_calls:
tool_calls = [
AssistantPromptMessage.ToolCall(
id=tc.id or "",
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tc.name or "",
arguments=tc.arguments or "{}",
),
)
for tc in model_output.tool_calls
]
messages.append(AssistantPromptMessage(content=segment_content, tool_calls=tool_calls))
elif segment_content:
messages.append(AssistantPromptMessage(content=segment_content))
elif segment.type == "tool" and isinstance(segment.output, ToolTraceSegment):
tool_output = segment.output
messages.append(
ToolPromptMessage(
content=tool_output.output or "",
tool_call_id=tool_output.id or "",
name=tool_output.name or "",
)
)
remaining_text = assistant_response[covered_text_len:]
final_content = remaining_text + file_suffix
if final_content:
messages.append(AssistantPromptMessage(content=final_content))
return messages
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
content = message.content
if content is None or isinstance(content, str):
return message
new_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
if item.file_ref:
new_content.append(item.model_copy(update={"base64_data": "", "url": ""}))
else:
truncated_base64 = ""
if item.base64_data:
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
else:
new_content.append(item)
return message.model_copy(update={"content": new_content})
def restore_multimodal_content_in_messages(messages: Sequence[PromptMessage]) -> list[PromptMessage]:
return [_restore_message_content(msg) for msg in messages]
def _restore_message_content(message: PromptMessage) -> PromptMessage:
from dify_graph.file.file_manager import restore_multimodal_content
content = message.content
if content is None or isinstance(content, str):
return message
restored_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
restored_item = restore_multimodal_content(item)
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
else:
restored_content.append(item)
return message.model_copy(update={"content": restored_content})