feat: add mention type variable

This commit is contained in:
Novice
2026-01-12 17:39:36 +08:00
parent d65ae68668
commit bb190f9610
23 changed files with 457 additions and 439 deletions

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Annotated, Any, Literal, TypeAlias
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
@ -58,9 +58,28 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: str | None = None
class PromptMessageContext(BaseModel):
"""Context variable reference in prompt template.
YAML/JSON format: { "$context": ["node_id", "variable_name"] }
This will be expanded to list[PromptMessage] at runtime.
"""
model_config = ConfigDict(populate_by_name=True)
value_selector: Sequence[str] = Field(alias="$context")
# Union type for prompt template items (static message or context variable reference)
PromptTemplateItem: TypeAlias = Annotated[
LLMNodeChatModelMessage | PromptMessageContext,
Field(discriminator=None),
]
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: MemoryConfig | None = None
context: ContextConfig

View File

@ -7,7 +7,7 @@ import logging
import re
import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, cast
from sqlalchemy import select
@ -52,6 +52,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.variables import (
ArrayFileSegment,
ArrayPromptMessageSegment,
ArraySegment,
FileSegment,
NoneSegment,
@ -88,6 +89,7 @@ from .entities import (
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
PromptMessageContext,
)
from .exc import (
InvalidContextStructureError,
@ -160,8 +162,9 @@ class LLMNode(Node[LLMNodeData]):
variable_pool = self.graph_runtime_state.variable_pool
try:
# init messages template
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
# Parse prompt template to separate static messages and context references
prompt_template = self.node_data.prompt_template
static_messages, context_refs, template_order = self._parse_prompt_template()
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data)
@ -223,21 +226,40 @@ class LLMNode(Node[LLMNodeData]):
):
query = query_variable.text
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# Get prompt messages
prompt_messages: Sequence[PromptMessage]
stop: Sequence[str] | None
if isinstance(prompt_template, list) and context_refs:
prompt_messages, stop = self._build_prompt_messages_with_context(
context_refs=context_refs,
template_order=template_order,
static_messages=static_messages,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config,
context_files=context_files,
)
else:
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=cast(
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
self.node_data.prompt_template,
),
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# handle invoke result
generator = LLMNode.invoke_llm(
@ -304,7 +326,7 @@ class LLMNode(Node[LLMNodeData]):
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": self._build_context(prompt_messages, clean_text, model_config.mode),
"context": self._build_context(prompt_messages, clean_text),
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
@ -602,17 +624,15 @@ class LLMNode(Node[LLMNodeData]):
def _build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
model_mode: str,
) -> list[dict[str, Any]]:
) -> list[PromptMessage]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
"""
context_messages: list[PromptMessage] = [m for m in prompt_messages if m.role != PromptMessageRole.SYSTEM]
context_messages.append(AssistantPromptMessage(content=assistant_response))
return PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_mode, prompt_messages=context_messages
)
return context_messages
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
@ -629,6 +649,106 @@ class LLMNode(Node[LLMNodeData]):
return messages
def _parse_prompt_template(
self,
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
"""
Parse prompt_template to separate static messages and context references.
Returns:
Tuple of (static_messages, context_refs, template_order)
- static_messages: list of LLMNodeChatModelMessage
- context_refs: list of PromptMessageContext
- template_order: list of (index, type) tuples preserving original order
"""
prompt_template = self.node_data.prompt_template
static_messages: list[LLMNodeChatModelMessage] = []
context_refs: list[PromptMessageContext] = []
template_order: list[tuple[int, str]] = []
if isinstance(prompt_template, list):
for idx, item in enumerate(prompt_template):
if isinstance(item, PromptMessageContext):
context_refs.append(item)
template_order.append((idx, "context"))
else:
static_messages.append(item)
template_order.append((idx, "static"))
# Transform static messages for jinja2
if static_messages:
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
return static_messages, context_refs, template_order
def _build_prompt_messages_with_context(
self,
*,
context_refs: list[PromptMessageContext],
template_order: list[tuple[int, str]],
static_messages: list[LLMNodeChatModelMessage],
query: str | None,
files: Sequence[File],
context: str | None,
memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity,
context_files: list[File],
) -> tuple[list[PromptMessage], Sequence[str] | None]:
"""
Build prompt messages by combining static messages and context references in DSL order.
Returns:
Tuple of (prompt_messages, stop_sequences)
"""
variable_pool = self.graph_runtime_state.variable_pool
# Build a map from context index to its messages
context_messages_map: dict[int, list[PromptMessage]] = {}
context_idx = 0
for idx, type_ in template_order:
if type_ == "context":
ctx_ref = context_refs[context_idx]
ctx_var = variable_pool.get(ctx_ref.value_selector)
if ctx_var is None:
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
if not isinstance(ctx_var, ArrayPromptMessageSegment):
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
context_messages_map[idx] = list(ctx_var.value)
context_idx += 1
# Process static messages
static_prompt_messages: Sequence[PromptMessage] = []
stop: Sequence[str] | None = None
if static_messages:
static_prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=cast(Sequence[LLMNodeChatModelMessage], self.node_data.prompt_template),
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# Combine messages according to original DSL order
combined_messages: list[PromptMessage] = []
static_msg_iter = iter(static_prompt_messages)
for idx, type_ in template_order:
if type_ == "context":
combined_messages.extend(context_messages_map[idx])
else:
if msg := next(static_msg_iter, None):
combined_messages.append(msg)
# Append any remaining static messages (e.g., memory messages)
combined_messages.extend(static_msg_iter)
return combined_messages, stop
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables: dict[str, Any] = {}