mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
This commit is contained in:
@ -20,7 +20,7 @@ class ConversationVariableUpdater(Protocol):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, conversation_id: str, variable: "Variable") -> None:
|
||||
def update(self, conversation_id: str, variable: "Variable"):
|
||||
"""
|
||||
Updates the value of the specified conversation variable in the underlying storage.
|
||||
|
||||
|
||||
@ -1,173 +0,0 @@
|
||||
# GraphEngine Worker Pool Configuration
|
||||
|
||||
## Overview
|
||||
|
||||
The GraphEngine now supports **dynamic worker pool management** to optimize performance and resource usage. Instead of a fixed 10-worker pool, the engine can:
|
||||
|
||||
1. **Start with optimal worker count** based on graph complexity
|
||||
1. **Scale up** when workload increases
|
||||
1. **Scale down** when workers are idle
|
||||
1. **Respect configurable min/max limits**
|
||||
|
||||
## Benefits
|
||||
|
||||
- **Resource Efficiency**: Uses fewer workers for simple sequential workflows
|
||||
- **Better Performance**: Scales up for parallel-heavy workflows
|
||||
- **Gevent Optimization**: Works efficiently with Gevent's greenlet model
|
||||
- **Memory Savings**: Reduces memory footprint for simple workflows
|
||||
|
||||
## Configuration
|
||||
|
||||
### Configuration Variables (via dify_config)
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `GRAPH_ENGINE_MIN_WORKERS` | 1 | Minimum number of workers per engine |
|
||||
| `GRAPH_ENGINE_MAX_WORKERS` | 10 | Maximum number of workers per engine |
|
||||
| `GRAPH_ENGINE_SCALE_UP_THRESHOLD` | 3 | Queue depth that triggers scale up |
|
||||
| `GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME` | 5.0 | Seconds of idle time before scaling down |
|
||||
|
||||
### Example Configurations
|
||||
|
||||
#### Low-Resource Environment
|
||||
|
||||
```bash
|
||||
export GRAPH_ENGINE_MIN_WORKERS=1
|
||||
export GRAPH_ENGINE_MAX_WORKERS=3
|
||||
export GRAPH_ENGINE_SCALE_UP_THRESHOLD=2
|
||||
export GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=3.0
|
||||
```
|
||||
|
||||
#### High-Performance Environment
|
||||
|
||||
```bash
|
||||
export GRAPH_ENGINE_MIN_WORKERS=2
|
||||
export GRAPH_ENGINE_MAX_WORKERS=20
|
||||
export GRAPH_ENGINE_SCALE_UP_THRESHOLD=5
|
||||
export GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=10.0
|
||||
```
|
||||
|
||||
#### Default (Balanced)
|
||||
|
||||
```bash
|
||||
# Uses defaults: min=1, max=10, threshold=3, idle_time=5.0
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Initial Worker Calculation
|
||||
|
||||
The engine analyzes the graph structure at startup:
|
||||
|
||||
- **Sequential graphs** (no branches): 1 worker
|
||||
- **Limited parallelism** (few branches): 2 workers
|
||||
- **Moderate parallelism**: 3 workers
|
||||
- **High parallelism** (many branches): 5 workers
|
||||
|
||||
### Dynamic Scaling
|
||||
|
||||
During execution:
|
||||
|
||||
1. **Scale Up** triggers when:
|
||||
|
||||
- Queue depth exceeds `SCALE_UP_THRESHOLD`
|
||||
- All workers are busy and queue has items
|
||||
- Not at `MAX_WORKERS` limit
|
||||
|
||||
1. **Scale Down** triggers when:
|
||||
|
||||
- Worker idle for more than `SCALE_DOWN_IDLE_TIME` seconds
|
||||
- Above `MIN_WORKERS` limit
|
||||
|
||||
### Gevent Compatibility
|
||||
|
||||
Since Gevent patches threading to use greenlets:
|
||||
|
||||
- Workers are lightweight coroutines, not OS threads
|
||||
- Dynamic scaling has minimal overhead
|
||||
- Can efficiently handle many concurrent workers
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### Before (Fixed 10 Workers)
|
||||
|
||||
```python
|
||||
# Every GraphEngine instance created 10 workers
|
||||
# Resource waste for simple workflows
|
||||
# No adaptation to workload
|
||||
```
|
||||
|
||||
### After (Dynamic Workers)
|
||||
|
||||
```python
|
||||
# GraphEngine creates 1-5 initial workers based on graph
|
||||
# Scales up/down based on workload
|
||||
# Configurable via environment variables
|
||||
```
|
||||
|
||||
### Backward Compatibility
|
||||
|
||||
The default configuration (`max=10`) maintains compatibility with existing deployments. To get the old behavior exactly:
|
||||
|
||||
```bash
|
||||
export GRAPH_ENGINE_MIN_WORKERS=10
|
||||
export GRAPH_ENGINE_MAX_WORKERS=10
|
||||
```
|
||||
|
||||
## Performance Impact
|
||||
|
||||
### Memory Usage
|
||||
|
||||
- **Simple workflows**: ~80% reduction (1 vs 10 workers)
|
||||
- **Complex workflows**: Similar or slightly better
|
||||
|
||||
### Execution Time
|
||||
|
||||
- **Sequential workflows**: No change
|
||||
- **Parallel workflows**: Improved with proper scaling
|
||||
- **Bursty workloads**: Better adaptation
|
||||
|
||||
### Example Metrics
|
||||
|
||||
| Workflow Type | Old (10 workers) | New (Dynamic) | Improvement |
|
||||
|--------------|------------------|---------------|-------------|
|
||||
| Sequential | 10 workers idle | 1 worker active | 90% fewer workers |
|
||||
| 3-way parallel | 7 workers idle | 3 workers active | 70% fewer workers |
|
||||
| Heavy parallel | 10 workers busy | 10+ workers (scales up) | Better throughput |
|
||||
|
||||
## Monitoring
|
||||
|
||||
Log messages indicate scaling activity:
|
||||
|
||||
```shell
|
||||
INFO: GraphEngine initialized with 2 workers (min: 1, max: 10)
|
||||
INFO: Scaled up workers: 2 -> 3 (queue_depth: 4)
|
||||
INFO: Scaled down workers: 3 -> 2 (removed 1 idle workers)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Start with defaults** - They work well for most cases
|
||||
1. **Monitor queue depth** - Adjust `SCALE_UP_THRESHOLD` if queues back up
|
||||
1. **Consider workload patterns**:
|
||||
- Bursty: Lower `SCALE_DOWN_IDLE_TIME`
|
||||
- Steady: Higher `SCALE_DOWN_IDLE_TIME`
|
||||
1. **Test with your workloads** - Measure and tune
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Workers not scaling up
|
||||
|
||||
- Check `GRAPH_ENGINE_MAX_WORKERS` limit
|
||||
- Verify queue depth exceeds threshold
|
||||
- Check logs for scaling messages
|
||||
|
||||
### Workers scaling down too quickly
|
||||
|
||||
- Increase `GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME`
|
||||
- Consider workload patterns
|
||||
|
||||
### Out of memory
|
||||
|
||||
- Reduce `GRAPH_ENGINE_MAX_WORKERS`
|
||||
- Check for memory leaks in nodes
|
||||
@ -25,7 +25,7 @@ class GraphRuntimeState(BaseModel):
|
||||
llm_usage: LLMUsage | None = None,
|
||||
outputs: dict[str, Any] | None = None,
|
||||
node_run_steps: int = 0,
|
||||
**kwargs,
|
||||
**kwargs: object,
|
||||
):
|
||||
"""Initialize the GraphRuntimeState with validation."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@ -19,7 +19,7 @@ from core.workflow.constants import (
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from factories import variable_factory
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list, File]
|
||||
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
@ -45,18 +45,18 @@ class VariablePool(BaseModel):
|
||||
)
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list,
|
||||
default_factory=list[VariableUnion],
|
||||
)
|
||||
conversation_variables: Sequence[VariableUnion] = Field(
|
||||
description="Conversation variables.",
|
||||
default_factory=list,
|
||||
default_factory=list[VariableUnion],
|
||||
)
|
||||
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
|
||||
description="RAG pipeline variables.",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
def model_post_init(self, context: Any, /) -> None:
|
||||
def model_post_init(self, context: Any, /):
|
||||
# Create a mapping from field names to SystemVariableKey enum values
|
||||
self._add_system_variables(self.system_variables)
|
||||
# Add environment variables to the variable pool
|
||||
@ -76,7 +76,7 @@ class VariablePool(BaseModel):
|
||||
for key, value in rag_pipeline_variables_map.items():
|
||||
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
def add(self, selector: Sequence[str], value: Any, /):
|
||||
"""
|
||||
Add a variable to the variable pool.
|
||||
|
||||
@ -180,11 +180,11 @@ class VariablePool(BaseModel):
|
||||
# Return result as Segment
|
||||
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
|
||||
|
||||
def _extract_value(self, obj: Any) -> Any:
|
||||
def _extract_value(self, obj: Any):
|
||||
"""Extract the actual value from an ObjectSegment."""
|
||||
return obj.value if isinstance(obj, ObjectSegment) else obj
|
||||
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
|
||||
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str):
|
||||
"""Get a nested attribute from a dictionary-like object."""
|
||||
if not isinstance(obj, dict):
|
||||
return None
|
||||
@ -210,7 +210,7 @@ class VariablePool(BaseModel):
|
||||
|
||||
def convert_template(self, template: str, /):
|
||||
parts = VARIABLE_PATTERN.split(template)
|
||||
segments = []
|
||||
segments: list[Segment] = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and (variable := self.get(part.split("."))):
|
||||
segments.append(variable)
|
||||
|
||||
@ -127,7 +127,7 @@ class WorkflowNodeExecution(BaseModel):
|
||||
process_data: Optional[Mapping[str, Any]] = None,
|
||||
outputs: Optional[Mapping[str, Any]] = None,
|
||||
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Update the model from mappings.
|
||||
|
||||
|
||||
@ -1,16 +1,18 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
|
||||
|
||||
class ReadOnlyVariablePool(Protocol):
|
||||
"""Read-only interface for VariablePool."""
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Any:
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (read-only)."""
|
||||
...
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, Any]:
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (read-only)."""
|
||||
...
|
||||
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
@ -12,19 +14,18 @@ class ReadOnlyVariablePoolWrapper:
|
||||
def __init__(self, variable_pool: VariablePool):
|
||||
self._variable_pool = variable_pool
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Any:
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
"""Get a variable value (returns a defensive copy)."""
|
||||
value = self._variable_pool.get(node_id, variable_key)
|
||||
value = self._variable_pool.get([node_id, variable_key])
|
||||
return deepcopy(value) if value is not None else None
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, Any]:
|
||||
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
|
||||
"""Get all variables for a node (returns defensive copies)."""
|
||||
variables = {}
|
||||
variables: dict[str, object] = {}
|
||||
if node_id in self._variable_pool.variable_dictionary:
|
||||
for key, var in self._variable_pool.variable_dictionary[node_id].items():
|
||||
# FIXME(-LAN-): Handle the actual Variable object structure
|
||||
value = var.value if hasattr(var, "value") else var
|
||||
variables[key] = deepcopy(value)
|
||||
# Variables have a value property that contains the actual data
|
||||
variables[key] = deepcopy(var.value)
|
||||
return variables
|
||||
|
||||
|
||||
|
||||
@ -113,17 +113,6 @@ class DebugLoggingLayer(GraphEngineLayer):
|
||||
# Log initial state
|
||||
self.logger.info("Initial State:")
|
||||
|
||||
# Log inputs if available
|
||||
if self.graph_runtime_state.variable_pool:
|
||||
initial_vars: dict[str, Any] = {}
|
||||
# Access the variable dictionary directly
|
||||
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
|
||||
for var_key, var in variables.items():
|
||||
initial_vars[f"{node_id}.{var_key}"] = str(var.value) if hasattr(var, "value") else str(var)
|
||||
|
||||
if initial_vars:
|
||||
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""Log individual events based on their type."""
|
||||
|
||||
@ -217,7 +217,7 @@ class WorkerPool:
|
||||
return False
|
||||
|
||||
# Find and remove idle workers that have been idle long enough
|
||||
workers_to_remove = []
|
||||
workers_to_remove: list[tuple[Worker, int]] = []
|
||||
|
||||
for worker in self._workers:
|
||||
# Check if worker is idle and has exceeded idle time threshold
|
||||
|
||||
@ -13,6 +13,11 @@ class NodeEventBase(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
def _default_metadata():
|
||||
v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
return v
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
"""
|
||||
Node Run Result.
|
||||
@ -23,7 +28,7 @@ class NodeRunResult(BaseModel):
|
||||
inputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
process_data: Mapping[str, Any] = Field(default_factory=dict)
|
||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=dict)
|
||||
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata)
|
||||
llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
|
||||
edge_source_handle: str = "source" # source handle id of node with multiple branches
|
||||
|
||||
@ -19,6 +19,7 @@ class ModelInvokeCompletedEvent(NodeEventBase):
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(NodeEventBase):
|
||||
|
||||
@ -68,7 +68,7 @@ class AgentNode(Node):
|
||||
node_type = NodeType.AGENT
|
||||
_node_data: AgentNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AgentNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -17,7 +17,7 @@ class AnswerNode(Node):
|
||||
|
||||
_node_data: AnswerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = AnswerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -50,7 +50,7 @@ class DefaultValue(BaseModel):
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
def _parse_json(value: str) -> Any:
|
||||
def _parse_json(value: str):
|
||||
"""Unified JSON parsing handler"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
|
||||
@ -57,7 +57,7 @@ class VariableTemplateParser:
|
||||
self.template = template
|
||||
self.variable_keys = self.extract()
|
||||
|
||||
def extract(self) -> list:
|
||||
def extract(self):
|
||||
"""
|
||||
Extracts all the template variable keys from the template string.
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ class CodeNode(Node):
|
||||
|
||||
_node_data: CodeNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = CodeNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -49,7 +49,7 @@ class CodeNode(Node):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
|
||||
@ -46,7 +46,7 @@ class DocumentExtractorNode(Node):
|
||||
|
||||
_node_data: DocumentExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = DocumentExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -15,7 +15,7 @@ class EndNode(Node):
|
||||
|
||||
_node_data: EndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = EndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -421,7 +421,10 @@ class Executor:
|
||||
body_string += f"--{boundary}--\r\n"
|
||||
elif self.node_data.body:
|
||||
if self.content:
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
if isinstance(self.content, bytes):
|
||||
body_string = self.content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
body_string = self.content
|
||||
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
|
||||
body_string = urlencode(self.data)
|
||||
elif self.data and self.node_data.body.type == "form-data":
|
||||
|
||||
@ -36,7 +36,7 @@ class HttpRequestNode(Node):
|
||||
|
||||
_node_data: HttpRequestNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = HttpRequestNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -58,7 +58,7 @@ class HttpRequestNode(Node):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict[str, Any]] = None):
|
||||
return {
|
||||
"type": "http-request",
|
||||
"config": {
|
||||
|
||||
@ -19,7 +19,7 @@ class IfElseNode(Node):
|
||||
|
||||
_node_data: IfElseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IfElseNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -3,7 +3,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.enums import (
|
||||
@ -55,7 +55,7 @@ class IterationNode(Node):
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
_node_data: IterationNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -77,7 +77,7 @@ class IterationNode(Node):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
return {
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
@ -97,10 +97,10 @@ class IterationNode(Node):
|
||||
if not variable:
|
||||
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
|
||||
|
||||
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
|
||||
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
|
||||
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
|
||||
|
||||
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
|
||||
if isinstance(variable, NoneSegment) or len(variable.value) == 0:
|
||||
# Try our best to preserve the type informat.
|
||||
if isinstance(variable, ArraySegment):
|
||||
output = variable.model_copy(update={"value": []})
|
||||
|
||||
@ -17,7 +17,7 @@ class IterationStartNode(Node):
|
||||
|
||||
_node_data: IterationStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = IterationStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -184,7 +184,6 @@ class KnowledgeIndexNode(Node):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
|
||||
def get_streaming_template(self) -> Template:
|
||||
"""
|
||||
Get the template for streaming.
|
||||
@ -192,4 +191,4 @@ class KnowledgeIndexNode(Node):
|
||||
Returns:
|
||||
Template instance for this knowledge index node
|
||||
"""
|
||||
return Template(segments=[])
|
||||
return Template(segments=[])
|
||||
|
||||
@ -99,7 +99,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
@ -116,7 +116,7 @@ class KnowledgeRetrievalNode(Node):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -40,7 +40,7 @@ class ListOperatorNode(Node):
|
||||
|
||||
_node_data: ListOperatorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ListOperatorNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -171,6 +171,8 @@ class ListOperatorNode(Node):
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
if not isinstance(condition.value, bool):
|
||||
raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}")
|
||||
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@ -68,6 +68,23 @@ class LLMNodeData(BaseNodeData):
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = Field(
|
||||
# Keep tagged as default for backward compatibility
|
||||
default="tagged",
|
||||
description=(
|
||||
"""
|
||||
Strategy for handling model reasoning output.
|
||||
|
||||
separated: Return clean text (without <think> tags) + reasoning_content field.
|
||||
Recommended for new workflows. Enables safe downstream parsing and
|
||||
workflow variable access: {{#node_id.reasoning_content#}}
|
||||
|
||||
tagged : Return original text (with <think> tags) + reasoning_content field.
|
||||
Maintains full backward compatibility while still providing reasoning_content
|
||||
for workflow automation. Frontend thinking panels work as before.
|
||||
"""
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
|
||||
@ -41,5 +41,5 @@ class FileTypeNotSupportError(LLMNodeError):
|
||||
|
||||
|
||||
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str) -> None:
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||
|
||||
@ -107,7 +107,7 @@ def fetch_memory(
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
|
||||
@ -2,8 +2,9 @@ import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
@ -101,6 +102,9 @@ class LLMNode(Node):
|
||||
|
||||
_node_data: LLMNodeData
|
||||
|
||||
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
@ -115,7 +119,7 @@ class LLMNode(Node):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
@ -132,7 +136,7 @@ class LLMNode(Node):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LLMNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -163,6 +167,7 @@ class LLMNode(Node):
|
||||
result_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
reasoning_content = None
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
@ -250,6 +255,7 @@ class LLMNode(Node):
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self._node_data.reasoning_format,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@ -258,9 +264,20 @@ class LLMNode(Node):
|
||||
if isinstance(event, StreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompletedEvent):
|
||||
# Raw text
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
reasoning_content = event.reasoning_content or ""
|
||||
|
||||
# For downstream nodes, determine clean text based on reasoning_format
|
||||
if self._node_data.reasoning_format == "tagged":
|
||||
# Keep <think> tags for backward compatibility
|
||||
clean_text = result_text
|
||||
else:
|
||||
# Extract clean text from <think> tags
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
@ -278,7 +295,12 @@ class LLMNode(Node):
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
|
||||
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs:
|
||||
@ -340,6 +362,7 @@ class LLMNode(Node):
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
@ -377,6 +400,7 @@ class LLMNode(Node):
|
||||
file_outputs=file_outputs,
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
reasoning_format=reasoning_format,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -387,6 +411,7 @@ class LLMNode(Node):
|
||||
file_outputs: list["File"],
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
@ -394,6 +419,7 @@ class LLMNode(Node):
|
||||
invoke_result=invoke_result,
|
||||
saver=file_saver,
|
||||
file_outputs=file_outputs,
|
||||
reasoning_format=reasoning_format,
|
||||
)
|
||||
yield event
|
||||
return
|
||||
@ -438,13 +464,66 @@ class LLMNode(Node):
|
||||
except OutputParserError as e:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||
|
||||
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = full_text_buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
yield ModelInvokeCompletedEvent(
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
usage=usage,
|
||||
finish_reason=finish_reason,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _image_file_to_markdown(file: "File", /):
|
||||
text_chunk = f"})"
|
||||
return text_chunk
|
||||
|
||||
@classmethod
|
||||
def _split_reasoning(
|
||||
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Split reasoning content from text based on reasoning_format strategy.
|
||||
|
||||
Args:
|
||||
text: Full text that may contain <think> blocks
|
||||
reasoning_format: Strategy for handling reasoning content
|
||||
- "separated": Remove <think> tags and return clean text + reasoning_content field
|
||||
- "tagged": Keep <think> tags in text, return empty reasoning_content
|
||||
|
||||
Returns:
|
||||
tuple of (clean_text, reasoning_content)
|
||||
"""
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
return text, ""
|
||||
|
||||
# Find all <think>...</think> blocks (case-insensitive)
|
||||
matches = cls._THINK_PATTERN.findall(text)
|
||||
|
||||
# Extract reasoning content from all <think> blocks
|
||||
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
|
||||
|
||||
# Remove all <think>...</think> blocks from original text
|
||||
clean_text = cls._THINK_PATTERN.sub("", text)
|
||||
|
||||
# Clean up extra whitespace
|
||||
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
|
||||
|
||||
# Separated mode: always return clean text and reasoning_content
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
@ -880,7 +959,7 @@ class LLMNode(Node):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
return {
|
||||
"type": "llm",
|
||||
"config": {
|
||||
@ -972,6 +1051,7 @@ class LLMNode(Node):
|
||||
invoke_result: LLMResult,
|
||||
saver: LLMFileSaver,
|
||||
file_outputs: list["File"],
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> ModelInvokeCompletedEvent:
|
||||
buffer = io.StringIO()
|
||||
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
|
||||
@ -981,10 +1061,24 @@ class LLMNode(Node):
|
||||
):
|
||||
buffer.write(text_part)
|
||||
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
return ModelInvokeCompletedEvent(
|
||||
text=buffer.getvalue(),
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -17,7 +17,7 @@ class LoopEndNode(Node):
|
||||
|
||||
_node_data: LoopEndNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopEndNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -49,7 +49,7 @@ class LoopNode(Node):
|
||||
_node_data: LoopNodeData
|
||||
execution_type = NodeExecutionType.CONTAINER
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -17,7 +17,7 @@ class LoopStartNode(Node):
|
||||
|
||||
_node_data: LoopStartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = LoopStartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -96,7 +96,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or "function_call"
|
||||
|
||||
def get_parameter_json_schema(self) -> dict:
|
||||
def get_parameter_json_schema(self):
|
||||
"""
|
||||
Get parameter json schema.
|
||||
|
||||
|
||||
@ -63,7 +63,7 @@ class InvalidValueTypeError(ParameterExtractorNodeError):
|
||||
expected_type: SegmentType,
|
||||
actual_type: SegmentType | None,
|
||||
value: Any,
|
||||
) -> None:
|
||||
):
|
||||
message = (
|
||||
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
|
||||
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
|
||||
|
||||
@ -50,6 +50,7 @@ from .exc import (
|
||||
)
|
||||
from .prompts import (
|
||||
CHAT_EXAMPLE,
|
||||
CHAT_GENERATE_JSON_PROMPT,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
|
||||
COMPLETION_GENERATE_JSON_PROMPT,
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
|
||||
@ -92,7 +93,7 @@ class ParameterExtractorNode(Node):
|
||||
|
||||
_node_data: ParameterExtractorNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ParameterExtractorNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -117,7 +118,7 @@ class ParameterExtractorNode(Node):
|
||||
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
return {
|
||||
"model": {
|
||||
"prompt_templates": {
|
||||
@ -538,7 +539,7 @@ class ParameterExtractorNode(Node):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
if len(data.parameters) != len(result):
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
@ -591,7 +592,7 @@ class ParameterExtractorNode(Node):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
@ -684,7 +685,7 @@ class ParameterExtractorNode(Node):
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData):
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
@ -746,7 +747,7 @@ class ParameterExtractorNode(Node):
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
|
||||
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
|
||||
@ -60,7 +60,7 @@ class QuestionClassifierNode(Node):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
@ -77,7 +77,7 @@ class QuestionClassifierNode(Node):
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = QuestionClassifierNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -271,7 +271,7 @@ class QuestionClassifierNode(Node):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters (not used in this implementation).
|
||||
|
||||
@ -15,7 +15,7 @@ class StartNode(Node):
|
||||
|
||||
_node_data: StartNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = StartNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -17,7 +17,7 @@ class TemplateTransformNode(Node):
|
||||
|
||||
_node_data: TemplateTransformNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = TemplateTransformNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -39,7 +39,7 @@ class TemplateTransformNode(Node):
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
def get_default_config(cls, filters: Optional[dict] = None):
|
||||
"""
|
||||
Get default config of node.
|
||||
:param filters: filter by node config parameters.
|
||||
|
||||
@ -48,7 +48,7 @@ class ToolNode(Node):
|
||||
|
||||
_node_data: ToolNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = ToolNodeData.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -14,7 +14,7 @@ class VariableAggregatorNode(Node):
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData(**data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -30,7 +30,7 @@ class VariableAssignerNode(Node):
|
||||
|
||||
_node_data: VariableAssignerData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
@ -58,7 +58,7 @@ class VariableAssignerNode(Node):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
|
||||
@ -32,5 +32,5 @@ class ConversationIDNotFoundError(VariableOperatorNodeError):
|
||||
|
||||
|
||||
class InvalidDataError(VariableOperatorNodeError):
|
||||
def __init__(self, message: str) -> None:
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
@ -57,7 +57,7 @@ class VariableAssignerNode(Node):
|
||||
|
||||
_node_data: VariableAssignerNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = VariableAssignerNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
|
||||
@ -16,7 +16,7 @@ class WorkflowExecutionRepository(Protocol):
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, execution: WorkflowExecution) -> None:
|
||||
def save(self, execution: WorkflowExecution):
|
||||
"""
|
||||
Save or update a WorkflowExecution instance.
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ class WorkflowNodeExecutionRepository(Protocol):
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
def save(self, execution: WorkflowNodeExecution):
|
||||
"""
|
||||
Save or update a NodeExecution instance.
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal, NamedTuple, Union
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
from core.file import FileAttribute, file_manager
|
||||
from core.variables import ArrayFileSegment
|
||||
@ -10,7 +10,7 @@ from core.workflow.entities import VariablePool
|
||||
from .entities import Condition, SubCondition, SupportedComparisonOperator
|
||||
|
||||
|
||||
def _convert_to_bool(value: Any) -> bool:
|
||||
def _convert_to_bool(value: object) -> bool:
|
||||
if isinstance(value, int):
|
||||
return bool(value)
|
||||
|
||||
@ -23,7 +23,7 @@ def _convert_to_bool(value: Any) -> bool:
|
||||
|
||||
|
||||
class ConditionCheckResult(NamedTuple):
|
||||
inputs: Sequence[Mapping[str, Any]]
|
||||
inputs: Sequence[Mapping[str, object]]
|
||||
group_results: Sequence[bool]
|
||||
final_result: bool
|
||||
|
||||
@ -36,7 +36,7 @@ class ConditionProcessor:
|
||||
conditions: Sequence[Condition],
|
||||
operator: Literal["and", "or"],
|
||||
) -> ConditionCheckResult:
|
||||
input_conditions: list[Mapping[str, Any]] = []
|
||||
input_conditions: list[Mapping[str, object]] = []
|
||||
group_results: list[bool] = []
|
||||
|
||||
for condition in conditions:
|
||||
@ -103,8 +103,8 @@ class ConditionProcessor:
|
||||
def _evaluate_condition(
|
||||
*,
|
||||
operator: SupportedComparisonOperator,
|
||||
value: Any,
|
||||
expected: Union[str, Sequence[str], bool | Sequence[bool], None],
|
||||
value: object,
|
||||
expected: str | Sequence[str] | bool | Sequence[bool] | None,
|
||||
) -> bool:
|
||||
match operator:
|
||||
case "contains":
|
||||
@ -144,7 +144,17 @@ def _evaluate_condition(
|
||||
case "not in":
|
||||
return _assert_not_in(value=value, expected=expected)
|
||||
case "all of" if isinstance(expected, list):
|
||||
return _assert_all_of(value=value, expected=expected)
|
||||
# Type narrowing: at this point expected is a list, could be list[str] or list[bool]
|
||||
if all(isinstance(item, str) for item in expected):
|
||||
# Create a new typed list to satisfy type checker
|
||||
str_list: list[str] = [item for item in expected if isinstance(item, str)]
|
||||
return _assert_all_of(value=value, expected=str_list)
|
||||
elif all(isinstance(item, bool) for item in expected):
|
||||
# Create a new typed list to satisfy type checker
|
||||
bool_list: list[bool] = [item for item in expected if isinstance(item, bool)]
|
||||
return _assert_all_of_bool(value=value, expected=bool_list)
|
||||
else:
|
||||
raise ValueError("all of operator expects homogeneous list of strings or booleans")
|
||||
case "exists":
|
||||
return _assert_exists(value=value)
|
||||
case "not exists":
|
||||
@ -153,55 +163,73 @@ def _evaluate_condition(
|
||||
raise ValueError(f"Unsupported operator: {operator}")
|
||||
|
||||
|
||||
def _assert_contains(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_contains(*, value: object, expected: object) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
if not isinstance(value, (str, list)):
|
||||
raise ValueError("Invalid actual value type: string or array")
|
||||
|
||||
if expected not in value:
|
||||
return False
|
||||
# Type checking ensures value is str or list at this point
|
||||
if isinstance(value, str):
|
||||
if not isinstance(expected, str):
|
||||
expected = str(expected)
|
||||
if expected not in value:
|
||||
return False
|
||||
else: # value is list
|
||||
if expected not in value:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _assert_not_contains(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_not_contains(*, value: object, expected: object) -> bool:
|
||||
if not value:
|
||||
return True
|
||||
|
||||
if not isinstance(value, (str, list)):
|
||||
raise ValueError("Invalid actual value type: string or array")
|
||||
|
||||
if expected in value:
|
||||
return False
|
||||
# Type checking ensures value is str or list at this point
|
||||
if isinstance(value, str):
|
||||
if not isinstance(expected, str):
|
||||
expected = str(expected)
|
||||
if expected in value:
|
||||
return False
|
||||
else: # value is list
|
||||
if expected in value:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _assert_start_with(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_start_with(*, value: object, expected: object) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Invalid actual value type: string")
|
||||
|
||||
if not isinstance(expected, str):
|
||||
raise ValueError("Expected value must be a string for startswith")
|
||||
if not value.startswith(expected):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _assert_end_with(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_end_with(*, value: object, expected: object) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Invalid actual value type: string")
|
||||
|
||||
if not isinstance(expected, str):
|
||||
raise ValueError("Expected value must be a string for endswith")
|
||||
if not value.endswith(expected):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _assert_is(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_is(*, value: object, expected: object) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
@ -213,7 +241,7 @@ def _assert_is(*, value: Any, expected: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _assert_is_not(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_is_not(*, value: object, expected: object) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
@ -225,19 +253,19 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _assert_empty(*, value: Any) -> bool:
|
||||
def _assert_empty(*, value: object) -> bool:
|
||||
if not value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _assert_not_empty(*, value: Any) -> bool:
|
||||
def _assert_not_empty(*, value: object) -> bool:
|
||||
if value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _assert_equal(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_equal(*, value: object, expected: object) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
@ -246,10 +274,16 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
|
||||
|
||||
# Handle boolean comparison
|
||||
if isinstance(value, bool):
|
||||
if not isinstance(expected, (bool, int, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to bool")
|
||||
expected = bool(expected)
|
||||
elif isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value != expected:
|
||||
@ -257,7 +291,7 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _assert_not_equal(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_not_equal(*, value: object, expected: object) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
@ -266,10 +300,16 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
|
||||
|
||||
# Handle boolean comparison
|
||||
if isinstance(value, bool):
|
||||
if not isinstance(expected, (bool, int, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to bool")
|
||||
expected = bool(expected)
|
||||
elif isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value == expected:
|
||||
@ -277,7 +317,7 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _assert_greater_than(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_greater_than(*, value: object, expected: object) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
@ -285,8 +325,12 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value <= expected:
|
||||
@ -294,7 +338,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _assert_less_than(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_less_than(*, value: object, expected: object) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
@ -302,8 +346,12 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value >= expected:
|
||||
@ -311,7 +359,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
@ -319,8 +367,12 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value < expected:
|
||||
@ -328,7 +380,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_less_than_or_equal(*, value: object, expected: object) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
@ -336,8 +388,12 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to int")
|
||||
expected = int(expected)
|
||||
else:
|
||||
if not isinstance(expected, (int, float, str)):
|
||||
raise ValueError(f"Cannot convert {type(expected)} to float")
|
||||
expected = float(expected)
|
||||
|
||||
if value > expected:
|
||||
@ -345,19 +401,19 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _assert_null(*, value: Any) -> bool:
|
||||
def _assert_null(*, value: object) -> bool:
|
||||
if value is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _assert_not_null(*, value: Any) -> bool:
|
||||
def _assert_not_null(*, value: object) -> bool:
|
||||
if value is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _assert_in(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_in(*, value: object, expected: object) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
@ -369,7 +425,7 @@ def _assert_in(*, value: Any, expected: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _assert_not_in(*, value: Any, expected: Any) -> bool:
|
||||
def _assert_not_in(*, value: object, expected: object) -> bool:
|
||||
if not value:
|
||||
return True
|
||||
|
||||
@ -381,20 +437,33 @@ def _assert_not_in(*, value: Any, expected: Any) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _assert_all_of(*, value: Any, expected: Sequence[str]) -> bool:
|
||||
def _assert_all_of(*, value: object, expected: Sequence[str]) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
if not all(item in value for item in expected):
|
||||
# Ensure value is a container that supports 'in' operator
|
||||
if not isinstance(value, (list, tuple, set, str)):
|
||||
return False
|
||||
return True
|
||||
|
||||
return all(item in value for item in expected)
|
||||
|
||||
|
||||
def _assert_exists(*, value: Any) -> bool:
|
||||
def _assert_all_of_bool(*, value: object, expected: Sequence[bool]) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
# Ensure value is a container that supports 'in' operator
|
||||
if not isinstance(value, (list, tuple, set)):
|
||||
return False
|
||||
|
||||
return all(item in value for item in expected)
|
||||
|
||||
|
||||
def _assert_exists(*, value: object) -> bool:
|
||||
return value is not None
|
||||
|
||||
|
||||
def _assert_not_exists(*, value: Any) -> bool:
|
||||
def _assert_not_exists(*, value: object) -> bool:
|
||||
return value is None
|
||||
|
||||
|
||||
@ -404,7 +473,7 @@ def _process_sub_conditions(
|
||||
operator: Literal["and", "or"],
|
||||
) -> bool:
|
||||
files = variable.value
|
||||
group_results = []
|
||||
group_results: list[bool] = []
|
||||
for condition in sub_conditions:
|
||||
key = FileAttribute(condition.key)
|
||||
values = [file_manager.get_attr(file=file, attr=key) for file in files]
|
||||
@ -415,14 +484,14 @@ def _process_sub_conditions(
|
||||
if expected_value and not expected_value.startswith("."):
|
||||
expected_value = "." + expected_value
|
||||
|
||||
normalized_values = []
|
||||
normalized_values: list[object] = []
|
||||
for value in values:
|
||||
if value and isinstance(value, str):
|
||||
if not value.startswith("."):
|
||||
value = "." + value
|
||||
normalized_values.append(value)
|
||||
values = normalized_values
|
||||
sub_group_results = [
|
||||
sub_group_results: list[bool] = [
|
||||
_evaluate_condition(
|
||||
value=value,
|
||||
operator=condition.comparison_operator,
|
||||
|
||||
@ -50,7 +50,7 @@ class WorkflowCycleManager:
|
||||
workflow_info: CycleManagerWorkflowInfo,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
) -> None:
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
self._workflow_info = workflow_info
|
||||
@ -305,7 +305,7 @@ class WorkflowCycleManager:
|
||||
error_message: Optional[str] = None,
|
||||
exceptions_count: int = 0,
|
||||
finished_at: Optional[datetime] = None,
|
||||
) -> None:
|
||||
):
|
||||
"""Update workflow execution with completion data."""
|
||||
execution.status = status
|
||||
execution.outputs = outputs or {}
|
||||
@ -322,7 +322,7 @@ class WorkflowCycleManager:
|
||||
workflow_execution: WorkflowExecution,
|
||||
conversation_id: Optional[str],
|
||||
external_trace_id: Optional[str],
|
||||
) -> None:
|
||||
):
|
||||
"""Add trace task if trace manager is provided."""
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@ -340,7 +340,7 @@ class WorkflowCycleManager:
|
||||
workflow_execution_id: str,
|
||||
error_message: str,
|
||||
now: datetime,
|
||||
) -> None:
|
||||
):
|
||||
"""Fail all running node executions for a workflow."""
|
||||
running_node_executions = [
|
||||
node_exec
|
||||
@ -410,7 +410,7 @@ class WorkflowCycleManager:
|
||||
status: WorkflowNodeExecutionStatus,
|
||||
error: Optional[str] = None,
|
||||
handle_special_values: bool = False,
|
||||
) -> None:
|
||||
):
|
||||
"""Update node execution with completion data."""
|
||||
finished_at = naive_utc_now()
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
@ -41,6 +41,7 @@ class WorkflowEntry:
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
variable_pool: VariablePool,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
command_channel: Optional[CommandChannel] = None,
|
||||
) -> None:
|
||||
@ -351,7 +352,7 @@ class WorkflowEntry:
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
@staticmethod
|
||||
def _handle_special_values(value: Any) -> Any:
|
||||
def _handle_special_values(value: Any):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
@ -376,7 +377,7 @@ class WorkflowEntry:
|
||||
user_inputs: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
):
|
||||
# NOTE(QuantumGhost): This logic should remain synchronized with
|
||||
# the implementation of `load_into_variable_pool`, specifically the logic about
|
||||
# variable existence checking.
|
||||
|
||||
@ -18,7 +18,7 @@ class WorkflowRuntimeTypeConverter:
|
||||
result = self._to_json_encodable_recursive(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
def _to_json_encodable_recursive(self, value: Any) -> Any:
|
||||
def _to_json_encodable_recursive(self, value: Any):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, (bool, int, str, float)):
|
||||
|
||||
Reference in New Issue
Block a user