feat: agent add context

This commit is contained in:
Novice
2026-01-16 11:28:49 +08:00
parent 2591615a3c
commit a7826d9ea4
10 changed files with 458 additions and 681 deletions

View File

@ -17,6 +17,12 @@ from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import MemoryMode
@ -527,6 +533,95 @@ class AgentNode(Node[AgentNodeData]):
# Conversation-level memory doesn't need saving here
return None
def _build_context(
self,
parameters_for_log: dict[str, Any],
user_query: str,
assistant_response: str,
agent_logs: list[AgentLogEvent],
) -> list[PromptMessage]:
"""
Build context from user query, tool calls, and assistant response.
Format: user -> assistant(with tool_calls) -> tool -> assistant
The context includes:
- Current user query (always present, may be empty)
- Assistant message with tool_calls (if tools were called)
- Tool results
- Assistant's final response
"""
context_messages: list[PromptMessage] = []
# Always add user query (even if empty, to maintain conversation structure)
context_messages.append(UserPromptMessage(content=user_query or ""))
# Extract actual tool calls from agent logs
# Only include logs with label starting with "CALL " - these are real tool invocations
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
for log in agent_logs:
if log.status == "success" and log.label and log.label.startswith("CALL "):
# Extract tool name from label (format: "CALL tool_name")
tool_name = log.label[5:] # Remove "CALL " prefix
tool_call_id = log.message_id
# Parse tool response from data
data = log.data or {}
tool_response = ""
# Try to extract the actual tool response
if "tool_response" in data:
tool_response = data["tool_response"]
elif "output" in data:
tool_response = data["output"]
elif "result" in data:
tool_response = data["result"]
if isinstance(tool_response, dict):
tool_response = str(tool_response)
# Get tool input for arguments
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
if isinstance(tool_input, dict):
import json
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
else:
tool_input_str = str(tool_input) if tool_input else ""
if tool_response:
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_name,
arguments=tool_input_str,
),
)
)
tool_results.append((tool_call_id, tool_name, str(tool_response)))
# Add assistant message with tool_calls if there were tool calls
if tool_calls:
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
# Add tool result messages
for tool_call_id, tool_name, result in tool_results:
context_messages.append(
ToolPromptMessage(
content=result,
tool_call_id=tool_call_id,
name=tool_name,
)
)
# Add final assistant response
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
@ -782,20 +877,11 @@ class AgentNode(Node[AgentNodeData]):
is_final=True,
)
# Save to node memory if in node memory mode
from core.workflow.nodes.llm import llm_utils
# Get user query from parameters for building context
user_query = parameters_for_log.get("query", "")
# Get user query from sys.query
user_query_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.QUERY])
user_query = user_query_var.text if user_query_var else ""
llm_utils.save_node_memory(
memory=memory,
variable_pool=self.graph_runtime_state.variable_pool,
user_query=user_query,
assistant_response=text,
assistant_files=files,
)
# Build context from history, user query, tool calls and assistant response
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
@ -805,6 +891,7 @@ class AgentNode(Node[AgentNodeData]):
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
"context": context,
**variables,
},
metadata={

View File

@ -285,7 +285,7 @@ class Node(Generic[NodeDataT]):
extractor_configs.append(node_config)
return extractor_configs
def _execute_extractor_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
def _execute_mention_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
"""
Execute all extractor nodes associated with this node.
@ -349,7 +349,7 @@ class Node(Generic[NodeDataT]):
self._start_at = naive_utc_now()
# Step 1: Execute associated extractor nodes before main node execution
yield from self._execute_extractor_nodes()
yield from self._execute_mention_nodes()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(

View File

@ -12,6 +12,13 @@ from core.memory import NodeTokenBufferMemory, TokenBufferMemory
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
)
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
@ -139,50 +146,6 @@ def fetch_memory(
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def save_node_memory(
memory: BaseMemory | None,
variable_pool: VariablePool,
user_query: str,
assistant_response: str,
user_files: Sequence["File"] | None = None,
assistant_files: Sequence["File"] | None = None,
) -> None:
"""
Save dialogue turn to node memory if applicable.
This function handles the storage logic for NodeTokenBufferMemory.
For TokenBufferMemory (conversation-level), no action is taken as it uses
the Message table which is managed elsewhere.
:param memory: Memory instance (NodeTokenBufferMemory or TokenBufferMemory)
:param variable_pool: Variable pool containing system variables
:param user_query: User's input text
:param assistant_response: Assistant's response text
:param user_files: Files attached by user (optional)
:param assistant_files: Files generated by assistant (optional)
"""
if not isinstance(memory, NodeTokenBufferMemory):
return
# Get workflow_run_id as the key for this execution
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
if not isinstance(workflow_run_id_var, StringSegment):
return
workflow_run_id = workflow_run_id_var.value
if not workflow_run_id:
return
memory.add_messages(
workflow_run_id=workflow_run_id,
user_content=user_query,
user_files=list(user_files) if user_files else None,
assistant_content=assistant_response,
assistant_files=list(assistant_files) if assistant_files else None,
)
memory.flush()
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
@ -246,3 +209,45 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
)
session.execute(stmt)
session.commit()
def build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
) -> list[PromptMessage]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
"""
context_messages: list[PromptMessage] = [
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
]
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
"""
Truncate multi-modal content base64 data in a message to avoid storing large data.
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, truncating multi-modal base64 data
new_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
truncated_base64 = ""
if item.base64_data:
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
else:
new_content.append(item)
return message.model_copy(update={"content": new_content})

View File

@ -20,7 +20,6 @@ from core.memory.base import BaseMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
ImagePromptMessageContent,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
@ -327,25 +326,13 @@ class LLMNode(Node[LLMNodeData]):
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": self._build_context(prompt_messages, clean_text),
"context": llm_utils.build_context(prompt_messages, clean_text),
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
# Write to Node Memory if in node memory mode
# Resolve the query template to get actual user content
actual_query = variable_pool.convert_template(query or "").text
llm_utils.save_node_memory(
memory=memory,
variable_pool=variable_pool,
user_query=actual_query,
assistant_response=clean_text,
user_files=files,
assistant_files=self._file_outputs,
)
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
@ -607,48 +594,6 @@ class LLMNode(Node[LLMNodeData]):
# Separated mode: always return clean text and reasoning_content
return clean_text, reasoning_content or ""
@staticmethod
def _build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
) -> list[PromptMessage]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
"""
context_messages: list[PromptMessage] = [
LLMNode._truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
]
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
@staticmethod
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
"""
Truncate multi-modal content base64 data in a message to avoid storing large data.
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, truncating multi-modal base64 data
new_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
truncated_base64 = ""
if item.base64_data:
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
else:
new_content.append(item)
return message.model_copy(update={"content": new_content})
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@ -716,54 +661,158 @@ class LLMNode(Node[LLMNodeData]):
"""
variable_pool = self.graph_runtime_state.variable_pool
# Build a map from context index to its messages
context_messages_map: dict[int, list[PromptMessage]] = {}
# Process messages in DSL order: iterate once and handle each type directly
combined_messages: list[PromptMessage] = []
context_idx = 0
for idx, type_ in template_order:
static_idx = 0
for _, type_ in template_order:
if type_ == "context":
# Handle context reference
ctx_ref = context_refs[context_idx]
ctx_var = variable_pool.get(ctx_ref.value_selector)
if ctx_var is None:
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
if not isinstance(ctx_var, ArrayPromptMessageSegment):
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
context_messages_map[idx] = list(ctx_var.value)
combined_messages.extend(ctx_var.value)
context_idx += 1
# Process static messages
static_prompt_messages: Sequence[PromptMessage] = []
stop: Sequence[str] | None = None
if static_messages:
static_prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=cast(Sequence[LLMNodeChatModelMessage], self.node_data.prompt_template),
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# Combine messages according to original DSL order
combined_messages: list[PromptMessage] = []
static_msg_iter = iter(static_prompt_messages)
for idx, type_ in template_order:
if type_ == "context":
combined_messages.extend(context_messages_map[idx])
else:
if msg := next(static_msg_iter, None):
combined_messages.append(msg)
# Append any remaining static messages (e.g., memory messages)
combined_messages.extend(static_msg_iter)
# Handle static message
static_msg = static_messages[static_idx]
processed_msgs = LLMNode.handle_list_messages(
messages=[static_msg],
context=context,
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
)
combined_messages.extend(processed_msgs)
static_idx += 1
# Append memory messages
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=self.node_data.memory,
model_config=model_config,
)
combined_messages.extend(memory_messages)
# Append current query if provided
if query:
query_message = LLMNodeChatModelMessage(
text=query,
role=PromptMessageRole.USER,
edition_type="basic",
)
query_msgs = LLMNode.handle_list_messages(
messages=[query_message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
)
combined_messages.extend(query_msgs)
# Handle files (sys_files and context_files)
combined_messages = self._append_files_to_messages(
messages=combined_messages,
sys_files=files,
context_files=context_files,
model_config=model_config,
)
# Filter empty messages and get stop sequences
combined_messages = self._filter_messages(combined_messages, model_config)
stop = self._get_stop_sequences(model_config)
return combined_messages, stop
def _append_files_to_messages(
self,
*,
messages: list[PromptMessage],
sys_files: Sequence[File],
context_files: list[File],
model_config: ModelConfigWithCredentialsEntity,
) -> list[PromptMessage]:
"""Append sys_files and context_files to messages."""
vision_enabled = self.node_data.vision.enabled
vision_detail = self.node_data.vision.configs.detail
# Handle sys_files (will be deprecated later)
if vision_enabled and sys_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
# Handle context_files
if vision_enabled and context_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
for file in context_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
return messages
def _filter_messages(
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> list[PromptMessage]:
"""Filter empty messages and unsupported content types."""
filtered_messages: list[PromptMessage] = []
for message in messages:
if isinstance(message.content, list):
filtered_content: list[PromptMessageContentUnionTypes] = []
for content_item in message.content:
# Skip non-text content if features are not defined
if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
filtered_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
feature_map = {
PromptMessageContentType.IMAGE: ModelFeature.VISION,
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
}
required_feature = feature_map.get(content_item.type)
if required_feature and required_feature not in model_config.model_schema.features:
continue
filtered_content.append(content_item)
# Simplify single text content
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
message.content = filtered_content[0].data
else:
message.content = filtered_content
if not message.is_empty():
filtered_messages.append(message)
if not filtered_messages:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_messages
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
"""Get stop sequences from model config."""
return model_config.stop
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables: dict[str, Any] = {}

View File

@ -246,13 +246,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
# transform result into standard format
result = self._transform_result(data=node_data, result=result or {})
# Save to node memory if in node memory mode
llm_utils.save_node_memory(
memory=memory,
variable_pool=variable_pool,
user_query=query,
assistant_response=json.dumps(result, ensure_ascii=False),
)
# Build context from prompt messages and response
assistant_response = json.dumps(result, ensure_ascii=False)
context = llm_utils.build_context(prompt_messages, assistant_response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -262,6 +258,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"__is_success": 1 if not error else 0,
"__reason": error,
"__usage": jsonable_encoder(usage),
"context": context,
**result,
},
metadata={

View File

@ -199,20 +199,17 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
"model_provider": model_config.provider,
"model_name": model_config.model,
}
# Build context from prompt messages and response
assistant_response = f"class_name: {category_name}, class_id: {category_id}"
context = llm_utils.build_context(prompt_messages, assistant_response)
outputs = {
"class_name": category_name,
"class_id": category_id,
"usage": jsonable_encoder(usage),
"context": context,
}
# Save to node memory if in node memory mode
llm_utils.save_node_memory(
memory=memory,
variable_pool=variable_pool,
user_query=query or "",
assistant_response=f"class_name: {category_name}, class_id: {category_id}",
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,