mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
feat: agent add context
This commit is contained in:
@ -17,6 +17,12 @@ from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryMode
|
||||
@ -527,6 +533,95 @@ class AgentNode(Node[AgentNodeData]):
|
||||
# Conversation-level memory doesn't need saving here
|
||||
return None
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_query: str,
|
||||
assistant_response: str,
|
||||
agent_logs: list[AgentLogEvent],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from user query, tool calls, and assistant response.
|
||||
Format: user -> assistant(with tool_calls) -> tool -> assistant
|
||||
|
||||
The context includes:
|
||||
- Current user query (always present, may be empty)
|
||||
- Assistant message with tool_calls (if tools were called)
|
||||
- Tool results
|
||||
- Assistant's final response
|
||||
"""
|
||||
context_messages: list[PromptMessage] = []
|
||||
|
||||
# Always add user query (even if empty, to maintain conversation structure)
|
||||
context_messages.append(UserPromptMessage(content=user_query or ""))
|
||||
|
||||
# Extract actual tool calls from agent logs
|
||||
# Only include logs with label starting with "CALL " - these are real tool invocations
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
|
||||
|
||||
for log in agent_logs:
|
||||
if log.status == "success" and log.label and log.label.startswith("CALL "):
|
||||
# Extract tool name from label (format: "CALL tool_name")
|
||||
tool_name = log.label[5:] # Remove "CALL " prefix
|
||||
tool_call_id = log.message_id
|
||||
|
||||
# Parse tool response from data
|
||||
data = log.data or {}
|
||||
tool_response = ""
|
||||
|
||||
# Try to extract the actual tool response
|
||||
if "tool_response" in data:
|
||||
tool_response = data["tool_response"]
|
||||
elif "output" in data:
|
||||
tool_response = data["output"]
|
||||
elif "result" in data:
|
||||
tool_response = data["result"]
|
||||
|
||||
if isinstance(tool_response, dict):
|
||||
tool_response = str(tool_response)
|
||||
|
||||
# Get tool input for arguments
|
||||
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
|
||||
if isinstance(tool_input, dict):
|
||||
import json
|
||||
|
||||
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
|
||||
else:
|
||||
tool_input_str = str(tool_input) if tool_input else ""
|
||||
|
||||
if tool_response:
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_name,
|
||||
arguments=tool_input_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
tool_results.append((tool_call_id, tool_name, str(tool_response)))
|
||||
|
||||
# Add assistant message with tool_calls if there were tool calls
|
||||
if tool_calls:
|
||||
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
|
||||
|
||||
# Add tool result messages
|
||||
for tool_call_id, tool_name, result in tool_results:
|
||||
context_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=result,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Add final assistant response
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
|
||||
return context_messages
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
@ -782,20 +877,11 @@ class AgentNode(Node[AgentNodeData]):
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
# Get user query from parameters for building context
|
||||
user_query = parameters_for_log.get("query", "")
|
||||
|
||||
# Get user query from sys.query
|
||||
user_query_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.QUERY])
|
||||
user_query = user_query_var.text if user_query_var else ""
|
||||
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
user_query=user_query,
|
||||
assistant_response=text,
|
||||
assistant_files=files,
|
||||
)
|
||||
# Build context from history, user query, tool calls and assistant response
|
||||
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
@ -805,6 +891,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
"context": context,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
|
||||
@ -285,7 +285,7 @@ class Node(Generic[NodeDataT]):
|
||||
extractor_configs.append(node_config)
|
||||
return extractor_configs
|
||||
|
||||
def _execute_extractor_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
def _execute_mention_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
"""
|
||||
Execute all extractor nodes associated with this node.
|
||||
|
||||
@ -349,7 +349,7 @@ class Node(Generic[NodeDataT]):
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Step 1: Execute associated extractor nodes before main node execution
|
||||
yield from self._execute_extractor_nodes()
|
||||
yield from self._execute_mention_nodes()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
|
||||
@ -12,6 +12,13 @@ from core.memory import NodeTokenBufferMemory, TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageRole,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
|
||||
@ -139,50 +146,6 @@ def fetch_memory(
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
||||
def save_node_memory(
|
||||
memory: BaseMemory | None,
|
||||
variable_pool: VariablePool,
|
||||
user_query: str,
|
||||
assistant_response: str,
|
||||
user_files: Sequence["File"] | None = None,
|
||||
assistant_files: Sequence["File"] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save dialogue turn to node memory if applicable.
|
||||
|
||||
This function handles the storage logic for NodeTokenBufferMemory.
|
||||
For TokenBufferMemory (conversation-level), no action is taken as it uses
|
||||
the Message table which is managed elsewhere.
|
||||
|
||||
:param memory: Memory instance (NodeTokenBufferMemory or TokenBufferMemory)
|
||||
:param variable_pool: Variable pool containing system variables
|
||||
:param user_query: User's input text
|
||||
:param assistant_response: Assistant's response text
|
||||
:param user_files: Files attached by user (optional)
|
||||
:param assistant_files: Files generated by assistant (optional)
|
||||
"""
|
||||
if not isinstance(memory, NodeTokenBufferMemory):
|
||||
return
|
||||
|
||||
# Get workflow_run_id as the key for this execution
|
||||
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
|
||||
if not isinstance(workflow_run_id_var, StringSegment):
|
||||
return
|
||||
|
||||
workflow_run_id = workflow_run_id_var.value
|
||||
if not workflow_run_id:
|
||||
return
|
||||
|
||||
memory.add_messages(
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_content=user_query,
|
||||
user_files=list(user_files) if user_files else None,
|
||||
assistant_content=assistant_response,
|
||||
assistant_files=list(assistant_files) if assistant_files else None,
|
||||
)
|
||||
memory.flush()
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
@ -246,3 +209,45 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
|
||||
def build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
|
||||
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [
|
||||
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
||||
]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return context_messages
|
||||
|
||||
|
||||
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
||||
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, truncating multi-modal base64 data
|
||||
new_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
|
||||
truncated_base64 = ""
|
||||
if item.base64_data:
|
||||
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
|
||||
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
|
||||
else:
|
||||
new_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
||||
@ -20,7 +20,6 @@ from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
@ -327,25 +326,13 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": self._build_context(prompt_messages, clean_text),
|
||||
"context": llm_utils.build_context(prompt_messages, clean_text),
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
# Write to Node Memory if in node memory mode
|
||||
# Resolve the query template to get actual user content
|
||||
actual_query = variable_pool.convert_template(query or "").text
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=actual_query,
|
||||
assistant_response=clean_text,
|
||||
user_files=files,
|
||||
assistant_files=self._file_outputs,
|
||||
)
|
||||
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
@ -607,48 +594,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Separated mode: always return clean text and reasoning_content
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
@staticmethod
|
||||
def _build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
|
||||
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [
|
||||
LLMNode._truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
||||
]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return context_messages
|
||||
|
||||
@staticmethod
|
||||
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
||||
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, truncating multi-modal base64 data
|
||||
new_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
|
||||
truncated_base64 = ""
|
||||
if item.base64_data:
|
||||
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
|
||||
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
|
||||
else:
|
||||
new_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
@ -716,54 +661,158 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"""
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# Build a map from context index to its messages
|
||||
context_messages_map: dict[int, list[PromptMessage]] = {}
|
||||
# Process messages in DSL order: iterate once and handle each type directly
|
||||
combined_messages: list[PromptMessage] = []
|
||||
context_idx = 0
|
||||
for idx, type_ in template_order:
|
||||
static_idx = 0
|
||||
|
||||
for _, type_ in template_order:
|
||||
if type_ == "context":
|
||||
# Handle context reference
|
||||
ctx_ref = context_refs[context_idx]
|
||||
ctx_var = variable_pool.get(ctx_ref.value_selector)
|
||||
if ctx_var is None:
|
||||
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
|
||||
if not isinstance(ctx_var, ArrayPromptMessageSegment):
|
||||
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
|
||||
context_messages_map[idx] = list(ctx_var.value)
|
||||
combined_messages.extend(ctx_var.value)
|
||||
context_idx += 1
|
||||
|
||||
# Process static messages
|
||||
static_prompt_messages: Sequence[PromptMessage] = []
|
||||
stop: Sequence[str] | None = None
|
||||
if static_messages:
|
||||
static_prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=cast(Sequence[LLMNodeChatModelMessage], self.node_data.prompt_template),
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# Combine messages according to original DSL order
|
||||
combined_messages: list[PromptMessage] = []
|
||||
static_msg_iter = iter(static_prompt_messages)
|
||||
for idx, type_ in template_order:
|
||||
if type_ == "context":
|
||||
combined_messages.extend(context_messages_map[idx])
|
||||
else:
|
||||
if msg := next(static_msg_iter, None):
|
||||
combined_messages.append(msg)
|
||||
# Append any remaining static messages (e.g., memory messages)
|
||||
combined_messages.extend(static_msg_iter)
|
||||
# Handle static message
|
||||
static_msg = static_messages[static_idx]
|
||||
processed_msgs = LLMNode.handle_list_messages(
|
||||
messages=[static_msg],
|
||||
context=context,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=self.node_data.vision.configs.detail,
|
||||
)
|
||||
combined_messages.extend(processed_msgs)
|
||||
static_idx += 1
|
||||
|
||||
# Append memory messages
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=self.node_data.memory,
|
||||
model_config=model_config,
|
||||
)
|
||||
combined_messages.extend(memory_messages)
|
||||
|
||||
# Append current query if provided
|
||||
if query:
|
||||
query_message = LLMNodeChatModelMessage(
|
||||
text=query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
query_msgs = LLMNode.handle_list_messages(
|
||||
messages=[query_message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=self.node_data.vision.configs.detail,
|
||||
)
|
||||
combined_messages.extend(query_msgs)
|
||||
|
||||
# Handle files (sys_files and context_files)
|
||||
combined_messages = self._append_files_to_messages(
|
||||
messages=combined_messages,
|
||||
sys_files=files,
|
||||
context_files=context_files,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# Filter empty messages and get stop sequences
|
||||
combined_messages = self._filter_messages(combined_messages, model_config)
|
||||
stop = self._get_stop_sequences(model_config)
|
||||
|
||||
return combined_messages, stop
|
||||
|
||||
def _append_files_to_messages(
|
||||
self,
|
||||
*,
|
||||
messages: list[PromptMessage],
|
||||
sys_files: Sequence[File],
|
||||
context_files: list[File],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> list[PromptMessage]:
|
||||
"""Append sys_files and context_files to messages."""
|
||||
vision_enabled = self.node_data.vision.enabled
|
||||
vision_detail = self.node_data.vision.configs.detail
|
||||
|
||||
# Handle sys_files (will be deprecated later)
|
||||
if vision_enabled and sys_files:
|
||||
file_prompts = [
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
|
||||
]
|
||||
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
|
||||
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
|
||||
else:
|
||||
messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Handle context_files
|
||||
if vision_enabled and context_files:
|
||||
file_prompts = [
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
for file in context_files
|
||||
]
|
||||
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
|
||||
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
|
||||
else:
|
||||
messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
return messages
|
||||
|
||||
def _filter_messages(
|
||||
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> list[PromptMessage]:
|
||||
"""Filter empty messages and unsupported content types."""
|
||||
filtered_messages: list[PromptMessage] = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
filtered_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in message.content:
|
||||
# Skip non-text content if features are not defined
|
||||
if not model_config.model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
filtered_content.append(content_item)
|
||||
continue
|
||||
|
||||
# Skip content if corresponding feature is not supported
|
||||
feature_map = {
|
||||
PromptMessageContentType.IMAGE: ModelFeature.VISION,
|
||||
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
|
||||
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
|
||||
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
|
||||
}
|
||||
required_feature = feature_map.get(content_item.type)
|
||||
if required_feature and required_feature not in model_config.model_schema.features:
|
||||
continue
|
||||
filtered_content.append(content_item)
|
||||
|
||||
# Simplify single text content
|
||||
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
|
||||
message.content = filtered_content[0].data
|
||||
else:
|
||||
message.content = filtered_content
|
||||
|
||||
if not message.is_empty():
|
||||
filtered_messages.append(message)
|
||||
|
||||
if not filtered_messages:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_messages
|
||||
|
||||
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
|
||||
"""Get stop sequences from model config."""
|
||||
return model_config.stop
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
||||
@ -246,13 +246,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
# transform result into standard format
|
||||
result = self._transform_result(data=node_data, result=result or {})
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=query,
|
||||
assistant_response=json.dumps(result, ensure_ascii=False),
|
||||
)
|
||||
# Build context from prompt messages and response
|
||||
assistant_response = json.dumps(result, ensure_ascii=False)
|
||||
context = llm_utils.build_context(prompt_messages, assistant_response)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -262,6 +258,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"__is_success": 1 if not error else 0,
|
||||
"__reason": error,
|
||||
"__usage": jsonable_encoder(usage),
|
||||
"context": context,
|
||||
**result,
|
||||
},
|
||||
metadata={
|
||||
|
||||
@ -199,20 +199,17 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
# Build context from prompt messages and response
|
||||
assistant_response = f"class_name: {category_name}, class_id: {category_id}"
|
||||
context = llm_utils.build_context(prompt_messages, assistant_response)
|
||||
|
||||
outputs = {
|
||||
"class_name": category_name,
|
||||
"class_id": category_id,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"context": context,
|
||||
}
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=query or "",
|
||||
assistant_response=f"class_name: {category_name}, class_id: {category_id}",
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
|
||||
Reference in New Issue
Block a user