mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat: add mention type variable
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
@ -58,9 +58,28 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class PromptMessageContext(BaseModel):
|
||||
"""Context variable reference in prompt template.
|
||||
|
||||
YAML/JSON format: { "$context": ["node_id", "variable_name"] }
|
||||
This will be expanded to list[PromptMessage] at runtime.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
value_selector: Sequence[str] = Field(alias="$context")
|
||||
|
||||
|
||||
# Union type for prompt template items (static message or context variable reference)
|
||||
PromptTemplateItem: TypeAlias = Annotated[
|
||||
LLMNodeChatModelMessage | PromptMessageContext,
|
||||
Field(discriminator=None),
|
||||
]
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
|
||||
@ -7,7 +7,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -52,6 +52,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
@ -88,6 +89,7 @@ from .entities import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
PromptMessageContext,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@ -160,8 +162,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
# Parse prompt template to separate static messages and context references
|
||||
prompt_template = self.node_data.prompt_template
|
||||
static_messages, context_refs, template_order = self._parse_prompt_template()
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
@ -223,21 +226,40 @@ class LLMNode(Node[LLMNodeData]):
|
||||
):
|
||||
query = query_variable.text
|
||||
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
# Get prompt messages
|
||||
prompt_messages: Sequence[PromptMessage]
|
||||
stop: Sequence[str] | None
|
||||
if isinstance(prompt_template, list) and context_refs:
|
||||
prompt_messages, stop = self._build_prompt_messages_with_context(
|
||||
context_refs=context_refs,
|
||||
template_order=template_order,
|
||||
static_messages=static_messages,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
else:
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=cast(
|
||||
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
self.node_data.prompt_template,
|
||||
),
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
@ -304,7 +326,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": self._build_context(prompt_messages, clean_text, model_config.mode),
|
||||
"context": self._build_context(prompt_messages, clean_text),
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
@ -602,17 +624,15 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def _build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
model_mode: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [m for m in prompt_messages if m.role != PromptMessageRole.SYSTEM]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_mode, prompt_messages=context_messages
|
||||
)
|
||||
return context_messages
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
@ -629,6 +649,106 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return messages
|
||||
|
||||
def _parse_prompt_template(
|
||||
self,
|
||||
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
|
||||
"""
|
||||
Parse prompt_template to separate static messages and context references.
|
||||
|
||||
Returns:
|
||||
Tuple of (static_messages, context_refs, template_order)
|
||||
- static_messages: list of LLMNodeChatModelMessage
|
||||
- context_refs: list of PromptMessageContext
|
||||
- template_order: list of (index, type) tuples preserving original order
|
||||
"""
|
||||
prompt_template = self.node_data.prompt_template
|
||||
static_messages: list[LLMNodeChatModelMessage] = []
|
||||
context_refs: list[PromptMessageContext] = []
|
||||
template_order: list[tuple[int, str]] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for idx, item in enumerate(prompt_template):
|
||||
if isinstance(item, PromptMessageContext):
|
||||
context_refs.append(item)
|
||||
template_order.append((idx, "context"))
|
||||
else:
|
||||
static_messages.append(item)
|
||||
template_order.append((idx, "static"))
|
||||
# Transform static messages for jinja2
|
||||
if static_messages:
|
||||
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
|
||||
|
||||
return static_messages, context_refs, template_order
|
||||
|
||||
def _build_prompt_messages_with_context(
|
||||
self,
|
||||
*,
|
||||
context_refs: list[PromptMessageContext],
|
||||
template_order: list[tuple[int, str]],
|
||||
static_messages: list[LLMNodeChatModelMessage],
|
||||
query: str | None,
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context_files: list[File],
|
||||
) -> tuple[list[PromptMessage], Sequence[str] | None]:
|
||||
"""
|
||||
Build prompt messages by combining static messages and context references in DSL order.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_messages, stop_sequences)
|
||||
"""
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# Build a map from context index to its messages
|
||||
context_messages_map: dict[int, list[PromptMessage]] = {}
|
||||
context_idx = 0
|
||||
for idx, type_ in template_order:
|
||||
if type_ == "context":
|
||||
ctx_ref = context_refs[context_idx]
|
||||
ctx_var = variable_pool.get(ctx_ref.value_selector)
|
||||
if ctx_var is None:
|
||||
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
|
||||
if not isinstance(ctx_var, ArrayPromptMessageSegment):
|
||||
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
|
||||
context_messages_map[idx] = list(ctx_var.value)
|
||||
context_idx += 1
|
||||
|
||||
# Process static messages
|
||||
static_prompt_messages: Sequence[PromptMessage] = []
|
||||
stop: Sequence[str] | None = None
|
||||
if static_messages:
|
||||
static_prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=cast(Sequence[LLMNodeChatModelMessage], self.node_data.prompt_template),
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# Combine messages according to original DSL order
|
||||
combined_messages: list[PromptMessage] = []
|
||||
static_msg_iter = iter(static_prompt_messages)
|
||||
for idx, type_ in template_order:
|
||||
if type_ == "context":
|
||||
combined_messages.extend(context_messages_map[idx])
|
||||
else:
|
||||
if msg := next(static_msg_iter, None):
|
||||
combined_messages.append(msg)
|
||||
# Append any remaining static messages (e.g., memory messages)
|
||||
combined_messages.extend(static_msg_iter)
|
||||
|
||||
return combined_messages, stop
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user