mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
feat: basic app add thought field
This commit is contained in:
@ -183,7 +183,24 @@ class AgentAppRunner(BaseAgentRunner):
|
|||||||
|
|
||||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||||
pass
|
if current_agent_thought_id is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
thought_text = output.data.get("thought")
|
||||||
|
self.save_agent_thought(
|
||||||
|
agent_thought_id=current_agent_thought_id,
|
||||||
|
tool_name=None,
|
||||||
|
tool_input=None,
|
||||||
|
thought=thought_text,
|
||||||
|
observation=None,
|
||||||
|
tool_invoke_meta=None,
|
||||||
|
answer=None,
|
||||||
|
messages_ids=[],
|
||||||
|
)
|
||||||
|
self.queue_manager.publish(
|
||||||
|
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||||
|
PublishFrom.APPLICATION_MANAGER,
|
||||||
|
)
|
||||||
|
|
||||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||||
if current_agent_thought_id is None:
|
if current_agent_thought_id is None:
|
||||||
@ -269,15 +286,20 @@ class AgentAppRunner(BaseAgentRunner):
|
|||||||
"""
|
"""
|
||||||
Initialize system message
|
Initialize system message
|
||||||
"""
|
"""
|
||||||
if not prompt_messages and prompt_template:
|
if not prompt_template:
|
||||||
return [
|
return prompt_messages or []
|
||||||
SystemPromptMessage(content=prompt_template),
|
|
||||||
]
|
|
||||||
|
|
||||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
prompt_messages = prompt_messages or []
|
||||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
|
||||||
|
|
||||||
return prompt_messages or []
|
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||||
|
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
|
if not prompt_messages:
|
||||||
|
return [SystemPromptMessage(content=prompt_template)]
|
||||||
|
|
||||||
|
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||||
|
return prompt_messages
|
||||||
|
|
||||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,67 +1,55 @@
|
|||||||
# Agent Patterns
|
# Agent Patterns
|
||||||
|
|
||||||
A unified agent pattern module that provides common agent execution strategies for both Agent V2 nodes and Agent Applications in Dify.
|
A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability.
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
This module implements a strategy pattern for agent execution, automatically selecting the appropriate strategy based on model capabilities. It serves as the core engine for agent-based interactions across different components of the Dify platform.
|
The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently.
|
||||||
|
|
||||||
## Key Features
|
## Key Features
|
||||||
|
|
||||||
### 1. Multiple Agent Strategies
|
- **Dual strategies**
|
||||||
|
- `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`.
|
||||||
- **Function Call Strategy**: Leverages native function/tool calling capabilities of advanced LLMs (e.g., GPT-4, Claude)
|
- `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested.
|
||||||
- **ReAct Strategy**: Implements the ReAct (Reasoning + Acting) approach for models without native function calling support
|
- **Explicit or auto selection**
|
||||||
|
- `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT).
|
||||||
### 2. Automatic Strategy Selection
|
- Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not.
|
||||||
|
- **Unified execution contract**
|
||||||
The `StrategyFactory` intelligently selects the optimal strategy based on model features:
|
- `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`.
|
||||||
|
- Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools.
|
||||||
- Models with `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL` capabilities → Function Call Strategy
|
- **Tool handling and hooks**
|
||||||
- Other models → ReAct Strategy
|
- Tools convert to `PromptMessageTool` objects before invocation.
|
||||||
|
- Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`.
|
||||||
### 3. Unified Interface
|
- Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs.
|
||||||
|
- **File-aware arguments**
|
||||||
- Common base class (`AgentPattern`) ensures consistent behavior across strategies
|
- Tool args accept `[File: <id>]` or `[Files: <id1, id2>]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely.
|
||||||
- Seamless integration with both workflow nodes and standalone agent applications
|
- **ReAct prompt shaping**
|
||||||
- Standardized input/output formats for easy consumption
|
- System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders.
|
||||||
|
- Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history.
|
||||||
### 4. Advanced Capabilities
|
- **Observability and accounting**
|
||||||
|
- Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths.
|
||||||
- **Streaming Support**: Real-time response streaming for better user experience
|
|
||||||
- **File Handling**: Built-in support for processing and managing files during agent execution
|
|
||||||
- **Iteration Control**: Configurable maximum iterations with safety limits (capped at 99)
|
|
||||||
- **Tool Management**: Flexible tool integration supporting various tool types
|
|
||||||
- **Context Propagation**: Execution context for tracing, auditing, and debugging
|
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
```
|
```
|
||||||
agent/patterns/
|
agent/patterns/
|
||||||
├── base.py # Abstract base class defining the agent pattern interface
|
├── base.py # Shared utilities: logging, usage, tool invocation, file handling
|
||||||
├── function_call.py # Implementation using native LLM function calling
|
├── function_call.py # Native function-calling loop with tool execution
|
||||||
├── react.py # Implementation using ReAct prompting approach
|
├── react.py # ReAct loop with CoT parsing and scratchpad wiring
|
||||||
└── strategy_factory.py # Factory for automatic strategy selection
|
└── strategy_factory.py # Strategy selection by model features or explicit override
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
The module is designed to be used by:
|
- For auto-selection:
|
||||||
|
- Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params.
|
||||||
1. **Agent V2 Nodes**: In workflow orchestration for complex agent tasks
|
- For explicit behavior:
|
||||||
1. **Agent Applications**: For standalone conversational agents
|
- Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct.
|
||||||
1. **Custom Implementations**: As a foundation for building specialized agent behaviors
|
- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`.
|
||||||
|
|
||||||
## Integration Points
|
## Integration Points
|
||||||
|
|
||||||
- **Model Runtime**: Interfaces with Dify's model runtime for LLM interactions
|
- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls.
|
||||||
- **Tool System**: Integrates with the tool framework for external capabilities
|
- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers.
|
||||||
- **Memory Management**: Compatible with conversation memory systems
|
- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments.
|
||||||
- **File Management**: Handles file inputs/outputs during agent execution
|
- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging.
|
||||||
|
|
||||||
## Benefits
|
|
||||||
|
|
||||||
1. **Consistency**: Unified implementation reduces code duplication and maintenance overhead
|
|
||||||
1. **Flexibility**: Easy to extend with new strategies or customize existing ones
|
|
||||||
1. **Performance**: Optimized for each model's capabilities to ensure best performance
|
|
||||||
1. **Reliability**: Built-in safety mechanisms and error handling
|
|
||||||
|
|||||||
@ -457,6 +457,9 @@ class WorkflowBasedAppRunner:
|
|||||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||||
from core.app.entities.queue_entities import ChunkType as QueueChunkType
|
from core.app.entities.queue_entities import ChunkType as QueueChunkType
|
||||||
|
|
||||||
|
if event.is_final and not event.chunk:
|
||||||
|
return
|
||||||
|
|
||||||
self._publish_event(
|
self._publish_event(
|
||||||
QueueTextChunkEvent(
|
QueueTextChunkEvent(
|
||||||
text=event.chunk,
|
text=event.chunk,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -68,6 +69,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||||
|
|
||||||
_task_state: EasyUITaskState
|
_task_state: EasyUITaskState
|
||||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||||
|
|
||||||
@ -441,7 +444,13 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
for thought in agent_thoughts:
|
for thought in agent_thoughts:
|
||||||
# Add thought/reasoning
|
# Add thought/reasoning
|
||||||
if thought.thought:
|
if thought.thought:
|
||||||
reasoning_list.append(thought.thought)
|
reasoning_text = thought.thought
|
||||||
|
if "<think" in reasoning_text.lower():
|
||||||
|
clean_text, extracted_reasoning = self._split_reasoning_from_answer(reasoning_text)
|
||||||
|
if extracted_reasoning:
|
||||||
|
reasoning_text = extracted_reasoning
|
||||||
|
thought.thought = clean_text or extracted_reasoning
|
||||||
|
reasoning_list.append(reasoning_text)
|
||||||
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
|
||||||
|
|
||||||
# Add tool calls
|
# Add tool calls
|
||||||
@ -464,6 +473,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
else:
|
else:
|
||||||
# Completion/Chat mode: use reasoning_content from llm_result
|
# Completion/Chat mode: use reasoning_content from llm_result
|
||||||
reasoning_content = llm_result.reasoning_content
|
reasoning_content = llm_result.reasoning_content
|
||||||
|
if not reasoning_content and answer:
|
||||||
|
# Extract reasoning from <think> blocks and clean the final answer
|
||||||
|
clean_answer, reasoning_content = self._split_reasoning_from_answer(answer)
|
||||||
|
if reasoning_content:
|
||||||
|
answer = clean_answer
|
||||||
|
llm_result.message.content = clean_answer
|
||||||
|
llm_result.reasoning_content = reasoning_content
|
||||||
|
message.answer = clean_answer
|
||||||
if reasoning_content:
|
if reasoning_content:
|
||||||
reasoning_list = [reasoning_content]
|
reasoning_list = [reasoning_content]
|
||||||
# Content comes first, then reasoning
|
# Content comes first, then reasoning
|
||||||
@ -493,6 +510,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
)
|
)
|
||||||
session.add(generation_detail)
|
session.add(generation_detail)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Extract reasoning segments from <think> blocks and return (clean_text, reasoning).
|
||||||
|
"""
|
||||||
|
matches = cls._THINK_PATTERN.findall(text)
|
||||||
|
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||||
|
|
||||||
|
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||||
|
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||||
|
|
||||||
|
return clean_text, reasoning_content or ""
|
||||||
|
|
||||||
def _handle_stop(self, event: QueueStopEvent):
|
def _handle_stop(self, event: QueueStopEvent):
|
||||||
"""
|
"""
|
||||||
Handle stop.
|
Handle stop.
|
||||||
|
|||||||
@ -474,57 +474,67 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
outputs = execution.outputs or {}
|
outputs = execution.outputs or {}
|
||||||
metadata = execution.metadata or {}
|
metadata = execution.metadata or {}
|
||||||
|
|
||||||
# Extract reasoning_content from outputs
|
reasoning_list = self._extract_reasoning(outputs)
|
||||||
reasoning_content = outputs.get("reasoning_content")
|
tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG))
|
||||||
reasoning_list: list[str] = []
|
|
||||||
if reasoning_content:
|
|
||||||
# reasoning_content could be a string or already a list
|
|
||||||
if isinstance(reasoning_content, str):
|
|
||||||
reasoning_list = [reasoning_content] if reasoning_content.strip() else []
|
|
||||||
elif isinstance(reasoning_content, list):
|
|
||||||
# Filter out empty or whitespace-only strings
|
|
||||||
reasoning_list = [r.strip() for r in reasoning_content if isinstance(r, str) and r.strip()]
|
|
||||||
|
|
||||||
# Extract tool_calls from metadata.agent_log
|
if not reasoning_list and not tool_calls_list:
|
||||||
tool_calls_list: list[dict] = []
|
|
||||||
agent_log = metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG)
|
|
||||||
if agent_log and isinstance(agent_log, list):
|
|
||||||
for log in agent_log:
|
|
||||||
# Each log entry has label, data, status, etc.
|
|
||||||
log_data = log.data if hasattr(log, "data") else log.get("data", {})
|
|
||||||
tool_name = log_data.get("tool_name")
|
|
||||||
# Only include tool calls with valid tool_name
|
|
||||||
if tool_name and str(tool_name).strip():
|
|
||||||
tool_calls_list.append(
|
|
||||||
{
|
|
||||||
"id": log_data.get("tool_call_id", ""),
|
|
||||||
"name": tool_name,
|
|
||||||
"arguments": json.dumps(log_data.get("tool_args", {})),
|
|
||||||
"result": str(log_data.get("output", "")),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only save if there's meaningful generation detail (reasoning or tool calls)
|
|
||||||
has_valid_reasoning = bool(reasoning_list)
|
|
||||||
has_valid_tool_calls = bool(tool_calls_list)
|
|
||||||
|
|
||||||
if not has_valid_reasoning and not has_valid_tool_calls:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Build sequence based on content, reasoning, and tool_calls
|
sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list)
|
||||||
sequence: list[dict] = []
|
self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence)
|
||||||
text = outputs.get("text", "")
|
|
||||||
|
|
||||||
# For now, use a simple sequence: content -> reasoning -> tool_calls
|
def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]:
|
||||||
# This can be enhanced later to track actual streaming order
|
"""Extract reasoning_content as a clean list of non-empty strings."""
|
||||||
|
reasoning_content = outputs.get("reasoning_content")
|
||||||
|
if isinstance(reasoning_content, str):
|
||||||
|
trimmed = reasoning_content.strip()
|
||||||
|
return [trimmed] if trimmed else []
|
||||||
|
if isinstance(reasoning_content, list):
|
||||||
|
return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]:
|
||||||
|
"""Extract tool call records from agent logs."""
|
||||||
|
if not agent_log or not isinstance(agent_log, list):
|
||||||
|
return []
|
||||||
|
|
||||||
|
tool_calls: list[dict[str, str]] = []
|
||||||
|
for log in agent_log:
|
||||||
|
log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {})
|
||||||
|
tool_name = log_data.get("tool_name")
|
||||||
|
if tool_name and str(tool_name).strip():
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": log_data.get("tool_call_id", ""),
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": json.dumps(log_data.get("tool_args", {})),
|
||||||
|
"result": str(log_data.get("output", "")),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
def _build_generation_sequence(
|
||||||
|
self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]]
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Build a simple content/reasoning/tool_call sequence."""
|
||||||
|
sequence: list[dict[str, Any]] = []
|
||||||
if text:
|
if text:
|
||||||
sequence.append({"type": "content", "start": 0, "end": len(text)})
|
sequence.append({"type": "content", "start": 0, "end": len(text)})
|
||||||
for i, _ in enumerate(reasoning_list):
|
for index in range(len(reasoning_list)):
|
||||||
sequence.append({"type": "reasoning", "index": i})
|
sequence.append({"type": "reasoning", "index": index})
|
||||||
for i in range(len(tool_calls_list)):
|
for index in range(len(tool_calls_list)):
|
||||||
sequence.append({"type": "tool_call", "index": i})
|
sequence.append({"type": "tool_call", "index": index})
|
||||||
|
return sequence
|
||||||
|
|
||||||
# Check if generation detail already exists for this node execution
|
def _upsert_generation_detail(
|
||||||
|
self,
|
||||||
|
session,
|
||||||
|
execution: WorkflowNodeExecution,
|
||||||
|
reasoning_list: list[str],
|
||||||
|
tool_calls_list: list[dict[str, str]],
|
||||||
|
sequence: list[dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Insert or update LLMGenerationDetail with serialized fields."""
|
||||||
existing = (
|
existing = (
|
||||||
session.query(LLMGenerationDetail)
|
session.query(LLMGenerationDetail)
|
||||||
.filter_by(
|
.filter_by(
|
||||||
@ -534,23 +544,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
reasoning_json = json.dumps(reasoning_list) if reasoning_list else None
|
||||||
|
tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None
|
||||||
|
sequence_json = json.dumps(sequence) if sequence else None
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
# Update existing record
|
existing.reasoning_content = reasoning_json
|
||||||
existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None
|
existing.tool_calls = tool_calls_json
|
||||||
existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None
|
existing.sequence = sequence_json
|
||||||
existing.sequence = json.dumps(sequence) if sequence else None
|
return
|
||||||
else:
|
|
||||||
# Create new record
|
generation_detail = LLMGenerationDetail(
|
||||||
generation_detail = LLMGenerationDetail(
|
tenant_id=self._tenant_id,
|
||||||
tenant_id=self._tenant_id,
|
app_id=self._app_id,
|
||||||
app_id=self._app_id,
|
workflow_run_id=execution.workflow_execution_id,
|
||||||
workflow_run_id=execution.workflow_execution_id,
|
node_id=execution.node_id,
|
||||||
node_id=execution.node_id,
|
reasoning_content=reasoning_json,
|
||||||
reasoning_content=json.dumps(reasoning_list) if reasoning_list else None,
|
tool_calls=tool_calls_json,
|
||||||
tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None,
|
sequence=sequence_json,
|
||||||
sequence=json.dumps(sequence) if sequence else None,
|
)
|
||||||
)
|
session.add(generation_detail)
|
||||||
session.add(generation_detail)
|
|
||||||
|
|
||||||
def get_db_models_by_workflow_run(
|
def get_db_models_by_workflow_run(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -391,12 +391,9 @@ class ResponseStreamCoordinator:
|
|||||||
# Determine which node to attribute the output to
|
# Determine which node to attribute the output to
|
||||||
# For special selectors (sys, env, conversation), use the active response node
|
# For special selectors (sys, env, conversation), use the active response node
|
||||||
# For regular selectors, use the source node
|
# For regular selectors, use the source node
|
||||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
active_session = self._active_session
|
||||||
# Special selector - use active response node
|
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
|
||||||
output_node_id = self._active_session.node_id
|
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
|
||||||
else:
|
|
||||||
# Regular node selector
|
|
||||||
output_node_id = source_selector_prefix
|
|
||||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||||
|
|
||||||
# Check if there's a direct stream for this selector
|
# Check if there's a direct stream for this selector
|
||||||
@ -404,65 +401,27 @@ class ResponseStreamCoordinator:
|
|||||||
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
|
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_direct_stream:
|
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
|
||||||
# Stream all available chunks for direct stream
|
|
||||||
while self._has_unread_stream(segment.selector):
|
if stream_targets:
|
||||||
if event := self._pop_stream_chunk(segment.selector):
|
all_complete = True
|
||||||
# For special selectors, update the event to use active response node's information
|
|
||||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
for target_selector in stream_targets:
|
||||||
response_node = self._graph.nodes[self._active_session.node_id]
|
while self._has_unread_stream(target_selector):
|
||||||
updated_event = NodeRunStreamChunkEvent(
|
if event := self._pop_stream_chunk(target_selector):
|
||||||
id=execution_id,
|
events.append(
|
||||||
node_id=response_node.id,
|
self._rewrite_stream_event(
|
||||||
node_type=response_node.node_type,
|
event=event,
|
||||||
selector=event.selector,
|
output_node_id=output_node_id,
|
||||||
chunk=event.chunk,
|
execution_id=execution_id,
|
||||||
is_final=event.is_final,
|
special_selector=bool(special_selector),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
events.append(updated_event)
|
|
||||||
else:
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
# Check if stream is closed
|
if not self._is_stream_closed(target_selector):
|
||||||
if self._is_stream_closed(segment.selector):
|
all_complete = False
|
||||||
is_complete = True
|
|
||||||
|
|
||||||
else:
|
is_complete = all_complete
|
||||||
# No direct stream - check for child field streams (for object types)
|
|
||||||
child_streams = self._find_child_streams(segment.selector)
|
|
||||||
|
|
||||||
if child_streams:
|
|
||||||
# Process all child streams
|
|
||||||
all_children_complete = True
|
|
||||||
|
|
||||||
for child_selector in sorted(child_streams):
|
|
||||||
# Stream all available chunks from this child
|
|
||||||
while self._has_unread_stream(child_selector):
|
|
||||||
if event := self._pop_stream_chunk(child_selector):
|
|
||||||
# Forward child stream event
|
|
||||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
|
||||||
response_node = self._graph.nodes[self._active_session.node_id]
|
|
||||||
updated_event = NodeRunStreamChunkEvent(
|
|
||||||
id=execution_id,
|
|
||||||
node_id=response_node.id,
|
|
||||||
node_type=response_node.node_type,
|
|
||||||
selector=event.selector,
|
|
||||||
chunk=event.chunk,
|
|
||||||
is_final=event.is_final,
|
|
||||||
chunk_type=event.chunk_type,
|
|
||||||
tool_call=event.tool_call,
|
|
||||||
tool_result=event.tool_result,
|
|
||||||
)
|
|
||||||
events.append(updated_event)
|
|
||||||
else:
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
# Check if this child stream is complete
|
|
||||||
if not self._is_stream_closed(child_selector):
|
|
||||||
all_children_complete = False
|
|
||||||
|
|
||||||
# Object segment is complete only when all children are complete
|
|
||||||
is_complete = all_children_complete
|
|
||||||
|
|
||||||
# Fallback: check if scalar value exists in variable pool
|
# Fallback: check if scalar value exists in variable pool
|
||||||
if not is_complete and not has_direct_stream:
|
if not is_complete and not has_direct_stream:
|
||||||
@ -485,6 +444,28 @@ class ResponseStreamCoordinator:
|
|||||||
|
|
||||||
return events, is_complete
|
return events, is_complete
|
||||||
|
|
||||||
|
def _rewrite_stream_event(
|
||||||
|
self,
|
||||||
|
event: NodeRunStreamChunkEvent,
|
||||||
|
output_node_id: str,
|
||||||
|
execution_id: str,
|
||||||
|
special_selector: bool,
|
||||||
|
) -> NodeRunStreamChunkEvent:
|
||||||
|
"""Rewrite event to attribute to active response node when selector is special."""
|
||||||
|
if not special_selector:
|
||||||
|
return event
|
||||||
|
|
||||||
|
return self._create_stream_chunk_event(
|
||||||
|
node_id=output_node_id,
|
||||||
|
execution_id=execution_id,
|
||||||
|
selector=event.selector,
|
||||||
|
chunk=event.chunk,
|
||||||
|
is_final=event.is_final,
|
||||||
|
chunk_type=event.chunk_type,
|
||||||
|
tool_call=event.tool_call,
|
||||||
|
tool_result=event.tool_result,
|
||||||
|
)
|
||||||
|
|
||||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||||
"""Process a text segment. Returns (events, is_complete)."""
|
"""Process a text segment. Returns (events, is_complete)."""
|
||||||
assert self._active_session is not None
|
assert self._active_session is not None
|
||||||
|
|||||||
@ -1203,6 +1203,7 @@ class Message(Base):
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FIXME (Novice) -- It's easy to cause N+1 query problem here.
|
||||||
@property
|
@property
|
||||||
def generation_detail(self) -> dict[str, Any] | None:
|
def generation_detail(self) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -695,6 +695,10 @@ class WorkflowRun(Base):
|
|||||||
def workflow(self):
|
def workflow(self):
|
||||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs_as_generation(self):
|
||||||
|
return is_generation_outputs(self.outputs_dict)
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
@ -708,7 +712,7 @@ class WorkflowRun(Base):
|
|||||||
"inputs": self.inputs_dict,
|
"inputs": self.inputs_dict,
|
||||||
"status": self.status,
|
"status": self.status,
|
||||||
"outputs": self.outputs_dict,
|
"outputs": self.outputs_dict,
|
||||||
"outputs_as_generation": is_generation_outputs(self.outputs_dict),
|
"outputs_as_generation": self.outputs_as_generation,
|
||||||
"error": self.error,
|
"error": self.error,
|
||||||
"elapsed_time": self.elapsed_time,
|
"elapsed_time": self.elapsed_time,
|
||||||
"total_tokens": self.total_tokens,
|
"total_tokens": self.total_tokens,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
"""
|
"""
|
||||||
Mark agent test modules as a package to avoid import name collisions.
|
Mark agent test modules as a package to avoid import name collisions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,48 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||||
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
|
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||||
|
from core.workflow.nodes import NodeType
|
||||||
|
|
||||||
|
|
||||||
|
class DummyQueueManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.published = []
|
||||||
|
|
||||||
|
def publish(self, event, publish_from: PublishFrom) -> None:
|
||||||
|
self.published.append((event, publish_from))
|
||||||
|
|
||||||
|
|
||||||
|
def test_skip_empty_final_chunk() -> None:
|
||||||
|
queue_manager = DummyQueueManager()
|
||||||
|
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app")
|
||||||
|
|
||||||
|
empty_final_event = NodeRunStreamChunkEvent(
|
||||||
|
id="exec",
|
||||||
|
node_id="node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
selector=["node", "text"],
|
||||||
|
chunk="",
|
||||||
|
is_final=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner._handle_event(workflow_entry=MagicMock(), event=empty_final_event)
|
||||||
|
assert queue_manager.published == []
|
||||||
|
|
||||||
|
normal_event = NodeRunStreamChunkEvent(
|
||||||
|
id="exec",
|
||||||
|
node_id="node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
selector=["node", "text"],
|
||||||
|
chunk="hi",
|
||||||
|
is_final=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner._handle_event(workflow_entry=MagicMock(), event=normal_event)
|
||||||
|
|
||||||
|
assert len(queue_manager.published) == 1
|
||||||
|
published_event, publish_from = queue_manager.published[0]
|
||||||
|
assert publish_from == PublishFrom.APPLICATION_MANAGER
|
||||||
|
assert published_event.text == "hi"
|
||||||
|
|
||||||
@ -6,6 +6,7 @@ from core.workflow.entities.tool_entities import ToolResultStatus
|
|||||||
from core.workflow.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
from core.workflow.graph.graph import Graph
|
from core.workflow.graph.graph import Graph
|
||||||
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||||
|
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
ChunkType,
|
ChunkType,
|
||||||
NodeRunStreamChunkEvent,
|
NodeRunStreamChunkEvent,
|
||||||
@ -13,6 +14,7 @@ from core.workflow.graph_events import (
|
|||||||
ToolResult,
|
ToolResult,
|
||||||
)
|
)
|
||||||
from core.workflow.nodes.base.entities import BaseNodeData
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
|
from core.workflow.nodes.base.template import Template, VariableSegment
|
||||||
from core.workflow.runtime import VariablePool
|
from core.workflow.runtime import VariablePool
|
||||||
|
|
||||||
|
|
||||||
@ -186,3 +188,44 @@ class TestResponseCoordinatorObjectStreaming:
|
|||||||
assert ("node1", "generation", "content") in children
|
assert ("node1", "generation", "content") in children
|
||||||
assert ("node1", "generation", "tool_calls") in children
|
assert ("node1", "generation", "tool_calls") in children
|
||||||
assert ("node1", "generation", "thought") in children
|
assert ("node1", "generation", "thought") in children
|
||||||
|
|
||||||
|
def test_special_selector_rewrites_to_active_response_node(self):
|
||||||
|
"""Ensure special selectors attribute streams to the active response node."""
|
||||||
|
graph = MagicMock(spec=Graph)
|
||||||
|
variable_pool = MagicMock(spec=VariablePool)
|
||||||
|
|
||||||
|
response_node = MagicMock()
|
||||||
|
response_node.id = "response_node"
|
||||||
|
response_node.node_type = NodeType.ANSWER
|
||||||
|
graph.nodes = {"response_node": response_node}
|
||||||
|
graph.root_node = response_node
|
||||||
|
|
||||||
|
coordinator = ResponseStreamCoordinator(variable_pool, graph)
|
||||||
|
coordinator.track_node_execution("response_node", "exec_resp")
|
||||||
|
|
||||||
|
coordinator._active_session = ResponseSession(
|
||||||
|
node_id="response_node",
|
||||||
|
template=Template(segments=[VariableSegment(selector=["sys", "foo"])]),
|
||||||
|
)
|
||||||
|
|
||||||
|
event = NodeRunStreamChunkEvent(
|
||||||
|
id="stream_1",
|
||||||
|
node_id="llm_node",
|
||||||
|
node_type=NodeType.LLM,
|
||||||
|
selector=["sys", "foo"],
|
||||||
|
chunk="hi",
|
||||||
|
is_final=True,
|
||||||
|
chunk_type=ChunkType.TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
|
coordinator._stream_buffers[("sys", "foo")] = [event]
|
||||||
|
coordinator._stream_positions[("sys", "foo")] = 0
|
||||||
|
coordinator._closed_streams.add(("sys", "foo"))
|
||||||
|
|
||||||
|
events, is_complete = coordinator._process_variable_segment(VariableSegment(selector=["sys", "foo"]))
|
||||||
|
|
||||||
|
assert is_complete
|
||||||
|
assert len(events) == 1
|
||||||
|
rewritten = events[0]
|
||||||
|
assert rewritten.node_id == "response_node"
|
||||||
|
assert rewritten.id == "exec_resp"
|
||||||
|
|||||||
@ -146,3 +146,4 @@ def test_serialize_tool_call_strips_files_to_ids():
|
|||||||
assert serialized["name"] == "do"
|
assert serialized["name"] == "do"
|
||||||
assert serialized["arguments"] == '{"a":1}'
|
assert serialized["arguments"] == '{"a":1}'
|
||||||
assert serialized["output"] == "ok"
|
assert serialized["output"] == "ok"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user