Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

This commit is contained in:
-LAN-
2025-09-08 14:30:43 +08:00
828 changed files with 7240 additions and 2951 deletions

View File

@ -20,7 +20,7 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable") -> None:
def update(self, conversation_id: str, variable: "Variable"):
"""
Updates the value of the specified conversation variable in the underlying storage.

View File

@ -1,173 +0,0 @@
# GraphEngine Worker Pool Configuration
## Overview
The GraphEngine now supports **dynamic worker pool management** to optimize performance and resource usage. Instead of a fixed 10-worker pool, the engine can:
1. **Start with optimal worker count** based on graph complexity
1. **Scale up** when workload increases
1. **Scale down** when workers are idle
1. **Respect configurable min/max limits**
## Benefits
- **Resource Efficiency**: Uses fewer workers for simple sequential workflows
- **Better Performance**: Scales up for parallel-heavy workflows
- **Gevent Optimization**: Works efficiently with Gevent's greenlet model
- **Memory Savings**: Reduces memory footprint for simple workflows
## Configuration
### Configuration Variables (via dify_config)
| Variable | Default | Description |
|----------|---------|-------------|
| `GRAPH_ENGINE_MIN_WORKERS` | 1 | Minimum number of workers per engine |
| `GRAPH_ENGINE_MAX_WORKERS` | 10 | Maximum number of workers per engine |
| `GRAPH_ENGINE_SCALE_UP_THRESHOLD` | 3 | Queue depth that triggers scale up |
| `GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME` | 5.0 | Seconds of idle time before scaling down |
### Example Configurations
#### Low-Resource Environment
```bash
export GRAPH_ENGINE_MIN_WORKERS=1
export GRAPH_ENGINE_MAX_WORKERS=3
export GRAPH_ENGINE_SCALE_UP_THRESHOLD=2
export GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=3.0
```
#### High-Performance Environment
```bash
export GRAPH_ENGINE_MIN_WORKERS=2
export GRAPH_ENGINE_MAX_WORKERS=20
export GRAPH_ENGINE_SCALE_UP_THRESHOLD=5
export GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=10.0
```
#### Default (Balanced)
```bash
# Uses defaults: min=1, max=10, threshold=3, idle_time=5.0
```
## How It Works
### Initial Worker Calculation
The engine analyzes the graph structure at startup:
- **Sequential graphs** (no branches): 1 worker
- **Limited parallelism** (few branches): 2 workers
- **Moderate parallelism**: 3 workers
- **High parallelism** (many branches): 5 workers
### Dynamic Scaling
During execution:
1. **Scale Up** triggers when:
- Queue depth exceeds `SCALE_UP_THRESHOLD`
- All workers are busy and queue has items
- Not at `MAX_WORKERS` limit
1. **Scale Down** triggers when:
- Worker idle for more than `SCALE_DOWN_IDLE_TIME` seconds
- Above `MIN_WORKERS` limit
### Gevent Compatibility
Since Gevent patches threading to use greenlets:
- Workers are lightweight coroutines, not OS threads
- Dynamic scaling has minimal overhead
- Can efficiently handle many concurrent workers
## Migration Guide
### Before (Fixed 10 Workers)
```python
# Every GraphEngine instance created 10 workers
# Resource waste for simple workflows
# No adaptation to workload
```
### After (Dynamic Workers)
```python
# GraphEngine creates 1-5 initial workers based on graph
# Scales up/down based on workload
# Configurable via environment variables
```
### Backward Compatibility
The default configuration (`max=10`) maintains compatibility with existing deployments. To get the old behavior exactly:
```bash
export GRAPH_ENGINE_MIN_WORKERS=10
export GRAPH_ENGINE_MAX_WORKERS=10
```
## Performance Impact
### Memory Usage
- **Simple workflows**: ~80% reduction (1 vs 10 workers)
- **Complex workflows**: Similar or slightly better
### Execution Time
- **Sequential workflows**: No change
- **Parallel workflows**: Improved with proper scaling
- **Bursty workloads**: Better adaptation
### Example Metrics
| Workflow Type | Old (10 workers) | New (Dynamic) | Improvement |
|--------------|------------------|---------------|-------------|
| Sequential | 10 workers idle | 1 worker active | 90% fewer workers |
| 3-way parallel | 7 workers idle | 3 workers active | 70% fewer workers |
| Heavy parallel | 10 workers busy | 10+ workers (scales up) | Better throughput |
## Monitoring
Log messages indicate scaling activity:
```shell
INFO: GraphEngine initialized with 2 workers (min: 1, max: 10)
INFO: Scaled up workers: 2 -> 3 (queue_depth: 4)
INFO: Scaled down workers: 3 -> 2 (removed 1 idle workers)
```
## Best Practices
1. **Start with defaults** - They work well for most cases
1. **Monitor queue depth** - Adjust `SCALE_UP_THRESHOLD` if queues back up
1. **Consider workload patterns**:
- Bursty: Lower `SCALE_DOWN_IDLE_TIME`
- Steady: Higher `SCALE_DOWN_IDLE_TIME`
1. **Test with your workloads** - Measure and tune
## Troubleshooting
### Workers not scaling up
- Check `GRAPH_ENGINE_MAX_WORKERS` limit
- Verify queue depth exceeds threshold
- Check logs for scaling messages
### Workers scaling down too quickly
- Increase `GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME`
- Consider workload patterns
### Out of memory
- Reduce `GRAPH_ENGINE_MAX_WORKERS`
- Check for memory leaks in nodes

View File

@ -25,7 +25,7 @@ class GraphRuntimeState(BaseModel):
llm_usage: LLMUsage | None = None,
outputs: dict[str, Any] | None = None,
node_run_steps: int = 0,
**kwargs,
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
super().__init__(**kwargs)

View File

@ -19,7 +19,7 @@ from core.workflow.constants import (
from core.workflow.system_variable import SystemVariable
from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, File]
VariableValue = Union[str, int, float, dict[str, object], list[object], File]
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
@ -45,18 +45,18 @@ class VariablePool(BaseModel):
)
environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
default_factory=list,
default_factory=list[VariableUnion],
)
conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.",
default_factory=list,
default_factory=list[VariableUnion],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
default_factory=list,
)
def model_post_init(self, context: Any, /) -> None:
def model_post_init(self, context: Any, /):
# Create a mapping from field names to SystemVariableKey enum values
self._add_system_variables(self.system_variables)
# Add environment variables to the variable pool
@ -76,7 +76,7 @@ class VariablePool(BaseModel):
for key, value in rag_pipeline_variables_map.items():
self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value)
def add(self, selector: Sequence[str], value: Any, /) -> None:
def add(self, selector: Sequence[str], value: Any, /):
"""
Add a variable to the variable pool.
@ -180,11 +180,11 @@ class VariablePool(BaseModel):
# Return result as Segment
return result if isinstance(result, Segment) else variable_factory.build_segment(result)
def _extract_value(self, obj: Any) -> Any:
def _extract_value(self, obj: Any):
"""Extract the actual value from an ObjectSegment."""
return obj.value if isinstance(obj, ObjectSegment) else obj
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Any:
def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str):
"""Get a nested attribute from a dictionary-like object."""
if not isinstance(obj, dict):
return None
@ -210,7 +210,7 @@ class VariablePool(BaseModel):
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
segments = []
segments: list[Segment] = []
for part in filter(lambda x: x, parts):
if "." in part and (variable := self.get(part.split("."))):
segments.append(variable)

View File

@ -127,7 +127,7 @@ class WorkflowNodeExecution(BaseModel):
process_data: Optional[Mapping[str, Any]] = None,
outputs: Optional[Mapping[str, Any]] = None,
metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None,
) -> None:
):
"""
Update the model from mappings.

View File

@ -1,16 +1,18 @@
from collections.abc import Mapping
from typing import Any, Protocol
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
class ReadOnlyVariablePool(Protocol):
"""Read-only interface for VariablePool."""
def get(self, node_id: str, variable_key: str) -> Any:
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (read-only)."""
...
def get_all_by_node(self, node_id: str) -> dict[str, Any]:
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (read-only)."""
...

View File

@ -1,7 +1,9 @@
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
@ -12,19 +14,18 @@ class ReadOnlyVariablePoolWrapper:
def __init__(self, variable_pool: VariablePool):
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Any:
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (returns a defensive copy)."""
value = self._variable_pool.get(node_id, variable_key)
value = self._variable_pool.get([node_id, variable_key])
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> dict[str, Any]:
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (returns defensive copies)."""
variables = {}
variables: dict[str, object] = {}
if node_id in self._variable_pool.variable_dictionary:
for key, var in self._variable_pool.variable_dictionary[node_id].items():
# FIXME(-LAN-): Handle the actual Variable object structure
value = var.value if hasattr(var, "value") else var
variables[key] = deepcopy(value)
# Variables have a value property that contains the actual data
variables[key] = deepcopy(var.value)
return variables

View File

@ -113,17 +113,6 @@ class DebugLoggingLayer(GraphEngineLayer):
# Log initial state
self.logger.info("Initial State:")
# Log inputs if available
if self.graph_runtime_state.variable_pool:
initial_vars: dict[str, Any] = {}
# Access the variable dictionary directly
for node_id, variables in self.graph_runtime_state.variable_pool.variable_dictionary.items():
for var_key, var in variables.items():
initial_vars[f"{node_id}.{var_key}"] = str(var.value) if hasattr(var, "value") else str(var)
if initial_vars:
self.logger.info(" Initial variables: %s", self._format_dict(initial_vars))
@override
def on_event(self, event: GraphEngineEvent) -> None:
"""Log individual events based on their type."""

View File

@ -217,7 +217,7 @@ class WorkerPool:
return False
# Find and remove idle workers that have been idle long enough
workers_to_remove = []
workers_to_remove: list[tuple[Worker, int]] = []
for worker in self._workers:
# Check if worker is idle and has exceeded idle time threshold

View File

@ -13,6 +13,11 @@ class NodeEventBase(BaseModel):
pass
def _default_metadata():
v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
return v
class NodeRunResult(BaseModel):
"""
Node Run Result.
@ -23,7 +28,7 @@ class NodeRunResult(BaseModel):
inputs: Mapping[str, Any] = Field(default_factory=dict)
process_data: Mapping[str, Any] = Field(default_factory=dict)
outputs: Mapping[str, Any] = Field(default_factory=dict)
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=dict)
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata)
llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
edge_source_handle: str = "source" # source handle id of node with multiple branches

View File

@ -19,6 +19,7 @@ class ModelInvokeCompletedEvent(NodeEventBase):
text: str
usage: LLMUsage
finish_reason: str | None = None
reasoning_content: str | None = None
class RunRetryEvent(NodeEventBase):

View File

@ -68,7 +68,7 @@ class AgentNode(Node):
node_type = NodeType.AGENT
_node_data: AgentNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AgentNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -17,7 +17,7 @@ class AnswerNode(Node):
_node_data: AnswerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = AnswerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -50,7 +50,7 @@ class DefaultValue(BaseModel):
key: str
@staticmethod
def _parse_json(value: str) -> Any:
def _parse_json(value: str):
"""Unified JSON parsing handler"""
try:
return json.loads(value)

View File

@ -57,7 +57,7 @@ class VariableTemplateParser:
self.template = template
self.variable_keys = self.extract()
def extract(self) -> list:
def extract(self):
"""
Extracts all the template variable keys from the template string.

View File

@ -27,7 +27,7 @@ class CodeNode(Node):
_node_data: CodeNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = CodeNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -49,7 +49,7 @@ class CodeNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
"""
Get default config of node.
:param filters: filter by node config parameters.

View File

@ -46,7 +46,7 @@ class DocumentExtractorNode(Node):
_node_data: DocumentExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = DocumentExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -15,7 +15,7 @@ class EndNode(Node):
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -421,7 +421,10 @@ class Executor:
body_string += f"--{boundary}--\r\n"
elif self.node_data.body:
if self.content:
body_string = self.content.decode("utf-8", errors="replace")
if isinstance(self.content, bytes):
body_string = self.content.decode("utf-8", errors="replace")
else:
body_string = self.content
elif self.data and self.node_data.body.type == "x-www-form-urlencoded":
body_string = urlencode(self.data)
elif self.data and self.node_data.body.type == "form-data":

View File

@ -36,7 +36,7 @@ class HttpRequestNode(Node):
_node_data: HttpRequestNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = HttpRequestNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -58,7 +58,7 @@ class HttpRequestNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
def get_default_config(cls, filters: Optional[dict[str, Any]] = None):
return {
"type": "http-request",
"config": {

View File

@ -19,7 +19,7 @@ class IfElseNode(Node):
_node_data: IfElseNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IfElseNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -3,7 +3,7 @@ from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities import VariablePool
from core.workflow.enums import (
@ -55,7 +55,7 @@ class IterationNode(Node):
execution_type = NodeExecutionType.CONTAINER
_node_data: IterationNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -77,7 +77,7 @@ class IterationNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
return {
"type": "iteration",
"config": {
@ -97,10 +97,10 @@ class IterationNode(Node):
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
if isinstance(variable, NoneSegment) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})

View File

@ -17,7 +17,7 @@ class IterationStartNode(Node):
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -184,7 +184,6 @@ class KnowledgeIndexNode(Node):
def version(cls) -> str:
return "1"
def get_streaming_template(self) -> Template:
"""
Get the template for streaming.
@ -192,4 +191,4 @@ class KnowledgeIndexNode(Node):
Returns:
Template instance for this knowledge index node
"""
return Template(segments=[])
return Template(segments=[])

View File

@ -99,7 +99,7 @@ class KnowledgeRetrievalNode(Node):
graph_runtime_state: "GraphRuntimeState",
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
@ -116,7 +116,7 @@ class KnowledgeRetrievalNode(Node):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -40,7 +40,7 @@ class ListOperatorNode(Node):
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -171,6 +171,8 @@ class ListOperatorNode(Node):
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})
else:
if not isinstance(condition.value, bool):
raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}")
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
result = list(filter(filter_func, variable.value))
variable = variable.model_copy(update={"value": result})

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
@ -68,6 +68,23 @@ class LLMNodeData(BaseNodeData):
structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field(
# Keep tagged as default for backward compatibility
default="tagged",
description=(
"""
Strategy for handling model reasoning output.
separated: Return clean text (without <think> tags) + reasoning_content field.
Recommended for new workflows. Enables safe downstream parsing and
workflow variable access: {{#node_id.reasoning_content#}}
tagged : Return original text (with <think> tags) + reasoning_content field.
Maintains full backward compatibility while still providing reasoning_content
for workflow automation. Frontend thinking panels work as before.
"""
),
)
@field_validator("prompt_config", mode="before")
@classmethod

View File

@ -41,5 +41,5 @@ class FileTypeNotSupportError(LLMNodeError):
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
def __init__(self, *, type_name: str):
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@ -107,7 +107,7 @@ def fetch_memory(
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration

View File

@ -2,8 +2,9 @@ import base64
import io
import json
import logging
import re
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
@ -101,6 +102,9 @@ class LLMNode(Node):
_node_data: LLMNodeData
# Compiled regex for extracting <think> blocks (with compatibility for attributes)
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@ -115,7 +119,7 @@ class LLMNode(Node):
graph_runtime_state: "GraphRuntimeState",
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
@ -132,7 +136,7 @@ class LLMNode(Node):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LLMNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -163,6 +167,7 @@ class LLMNode(Node):
result_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = None
variable_pool = self.graph_runtime_state.variable_pool
try:
@ -250,6 +255,7 @@ class LLMNode(Node):
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self._node_data.reasoning_format,
)
structured_output: LLMStructuredOutput | None = None
@ -258,9 +264,20 @@ class LLMNode(Node):
if isinstance(event, StreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompletedEvent):
# Raw text
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
reasoning_content = event.reasoning_content or ""
# For downstream nodes, determine clean text based on reasoning_format
if self._node_data.reasoning_format == "tagged":
# Keep <think> tags for backward compatibility
clean_text = result_text
else:
# Extract clean text from <think> tags
clean_text, _ = LLMNode._split_reasoning(result_text, self._node_data.reasoning_format)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
@ -278,7 +295,12 @@ class LLMNode(Node):
"model_name": model_config.model,
}
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
outputs = {
"text": clean_text,
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
if self._file_outputs:
@ -340,6 +362,7 @@ class LLMNode(Node):
file_outputs: list["File"],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
@ -377,6 +400,7 @@ class LLMNode(Node):
file_outputs=file_outputs,
node_id=node_id,
node_type=node_type,
reasoning_format=reasoning_format,
)
@staticmethod
@ -387,6 +411,7 @@ class LLMNode(Node):
file_outputs: list["File"],
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
@ -394,6 +419,7 @@ class LLMNode(Node):
invoke_result=invoke_result,
saver=file_saver,
file_outputs=file_outputs,
reasoning_format=reasoning_format,
)
yield event
return
@ -438,13 +464,66 @@ class LLMNode(Node):
except OutputParserError as e:
raise LLMNodeError(f"Failed to parse structured output: {e}")
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
# Extract reasoning content from <think> tags in the main text
full_text = full_text_buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
yield ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
usage=usage,
finish_reason=finish_reason,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
)
@staticmethod
def _image_file_to_markdown(file: "File", /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
@classmethod
def _split_reasoning(
cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged"
) -> tuple[str, str]:
"""
Split reasoning content from text based on reasoning_format strategy.
Args:
text: Full text that may contain <think> blocks
reasoning_format: Strategy for handling reasoning content
- "separated": Remove <think> tags and return clean text + reasoning_content field
- "tagged": Keep <think> tags in text, return empty reasoning_content
Returns:
tuple of (clean_text, reasoning_content)
"""
if reasoning_format == "tagged":
return text, ""
# Find all <think>...</think> blocks (case-insensitive)
matches = cls._THINK_PATTERN.findall(text)
# Extract reasoning content from all <think> blocks
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
# Remove all <think>...</think> blocks from original text
clean_text = cls._THINK_PATTERN.sub("", text)
# Clean up extra whitespace
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
# Separated mode: always return clean text and reasoning_content
return clean_text, reasoning_content or ""
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@ -880,7 +959,7 @@ class LLMNode(Node):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
return {
"type": "llm",
"config": {
@ -972,6 +1051,7 @@ class LLMNode(Node):
invoke_result: LLMResult,
saver: LLMFileSaver,
file_outputs: list["File"],
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
@ -981,10 +1061,24 @@ class LLMNode(Node):
):
buffer.write(text_part)
# Extract reasoning content from <think> tags in the main text
full_text = buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
return ModelInvokeCompletedEvent(
text=buffer.getvalue(),
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
usage=invoke_result.usage,
finish_reason=None,
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
)
@staticmethod

View File

@ -17,7 +17,7 @@ class LoopEndNode(Node):
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -49,7 +49,7 @@ class LoopNode(Node):
_node_data: LoopNodeData
execution_type = NodeExecutionType.CONTAINER
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -17,7 +17,7 @@ class LoopStartNode(Node):
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -96,7 +96,7 @@ class ParameterExtractorNodeData(BaseNodeData):
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def get_parameter_json_schema(self) -> dict:
def get_parameter_json_schema(self):
"""
Get parameter json schema.

View File

@ -63,7 +63,7 @@ class InvalidValueTypeError(ParameterExtractorNodeError):
expected_type: SegmentType,
actual_type: SegmentType | None,
value: Any,
) -> None:
):
message = (
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"

View File

@ -50,6 +50,7 @@ from .exc import (
)
from .prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_PROMPT,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
COMPLETION_GENERATE_JSON_PROMPT,
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
@ -92,7 +93,7 @@ class ParameterExtractorNode(Node):
_node_data: ParameterExtractorNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ParameterExtractorNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -117,7 +118,7 @@ class ParameterExtractorNode(Node):
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
return {
"model": {
"prompt_templates": {
@ -538,7 +539,7 @@ class ParameterExtractorNode(Node):
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
def _validate_result(self, data: ParameterExtractorNodeData, result: dict):
if len(data.parameters) != len(result):
raise InvalidNumberOfParametersError("Invalid number of parameters")
@ -591,7 +592,7 @@ class ParameterExtractorNode(Node):
else:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
"""
Transform result into standard format.
"""
@ -684,7 +685,7 @@ class ParameterExtractorNode(Node):
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
def _generate_default_result(self, data: ParameterExtractorNodeData):
"""
Generate default result.
"""
@ -746,7 +747,7 @@ class ParameterExtractorNode(Node):
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction),
text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str).replace("{{instructions}}", instruction),
)
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
return [system_prompt_messages, user_prompt_message]

View File

@ -60,7 +60,7 @@ class QuestionClassifierNode(Node):
graph_runtime_state: "GraphRuntimeState",
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
):
super().__init__(
id=id,
config=config,
@ -77,7 +77,7 @@ class QuestionClassifierNode(Node):
)
self._llm_file_saver = llm_file_saver
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = QuestionClassifierNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -271,7 +271,7 @@ class QuestionClassifierNode(Node):
return variable_mapping
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
"""
Get default config of node.
:param filters: filter by node config parameters (not used in this implementation).

View File

@ -15,7 +15,7 @@ class StartNode(Node):
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -17,7 +17,7 @@ class TemplateTransformNode(Node):
_node_data: TemplateTransformNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = TemplateTransformNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -39,7 +39,7 @@ class TemplateTransformNode(Node):
return self._node_data
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
def get_default_config(cls, filters: Optional[dict] = None):
"""
Get default config of node.
:param filters: filter by node config parameters.

View File

@ -48,7 +48,7 @@ class ToolNode(Node):
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ToolNodeData.model_validate(data)
@classmethod

View File

@ -14,7 +14,7 @@ class VariableAggregatorNode(Node):
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData(**data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -30,7 +30,7 @@ class VariableAssignerNode(Node):
_node_data: VariableAssignerData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
@ -58,7 +58,7 @@ class VariableAssignerNode(Node):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
) -> None:
):
super().__init__(
id=id,
config=config,

View File

@ -32,5 +32,5 @@ class ConversationIDNotFoundError(VariableOperatorNodeError):
class InvalidDataError(VariableOperatorNodeError):
def __init__(self, message: str) -> None:
def __init__(self, message: str):
super().__init__(message)

View File

@ -57,7 +57,7 @@ class VariableAssignerNode(Node):
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:

View File

@ -16,7 +16,7 @@ class WorkflowExecutionRepository(Protocol):
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowExecution) -> None:
def save(self, execution: WorkflowExecution):
"""
Save or update a WorkflowExecution instance.

View File

@ -26,7 +26,7 @@ class WorkflowNodeExecutionRepository(Protocol):
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowNodeExecution) -> None:
def save(self, execution: WorkflowNodeExecution):
"""
Save or update a NodeExecution instance.

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import Any, Literal, NamedTuple, Union
from typing import Literal, NamedTuple
from core.file import FileAttribute, file_manager
from core.variables import ArrayFileSegment
@ -10,7 +10,7 @@ from core.workflow.entities import VariablePool
from .entities import Condition, SubCondition, SupportedComparisonOperator
def _convert_to_bool(value: Any) -> bool:
def _convert_to_bool(value: object) -> bool:
if isinstance(value, int):
return bool(value)
@ -23,7 +23,7 @@ def _convert_to_bool(value: Any) -> bool:
class ConditionCheckResult(NamedTuple):
inputs: Sequence[Mapping[str, Any]]
inputs: Sequence[Mapping[str, object]]
group_results: Sequence[bool]
final_result: bool
@ -36,7 +36,7 @@ class ConditionProcessor:
conditions: Sequence[Condition],
operator: Literal["and", "or"],
) -> ConditionCheckResult:
input_conditions: list[Mapping[str, Any]] = []
input_conditions: list[Mapping[str, object]] = []
group_results: list[bool] = []
for condition in conditions:
@ -103,8 +103,8 @@ class ConditionProcessor:
def _evaluate_condition(
*,
operator: SupportedComparisonOperator,
value: Any,
expected: Union[str, Sequence[str], bool | Sequence[bool], None],
value: object,
expected: str | Sequence[str] | bool | Sequence[bool] | None,
) -> bool:
match operator:
case "contains":
@ -144,7 +144,17 @@ def _evaluate_condition(
case "not in":
return _assert_not_in(value=value, expected=expected)
case "all of" if isinstance(expected, list):
return _assert_all_of(value=value, expected=expected)
# Type narrowing: at this point expected is a list, could be list[str] or list[bool]
if all(isinstance(item, str) for item in expected):
# Create a new typed list to satisfy type checker
str_list: list[str] = [item for item in expected if isinstance(item, str)]
return _assert_all_of(value=value, expected=str_list)
elif all(isinstance(item, bool) for item in expected):
# Create a new typed list to satisfy type checker
bool_list: list[bool] = [item for item in expected if isinstance(item, bool)]
return _assert_all_of_bool(value=value, expected=bool_list)
else:
raise ValueError("all of operator expects homogeneous list of strings or booleans")
case "exists":
return _assert_exists(value=value)
case "not exists":
@ -153,55 +163,73 @@ def _evaluate_condition(
raise ValueError(f"Unsupported operator: {operator}")
def _assert_contains(*, value: Any, expected: Any) -> bool:
def _assert_contains(*, value: object, expected: object) -> bool:
if not value:
return False
if not isinstance(value, (str, list)):
raise ValueError("Invalid actual value type: string or array")
if expected not in value:
return False
# Type checking ensures value is str or list at this point
if isinstance(value, str):
if not isinstance(expected, str):
expected = str(expected)
if expected not in value:
return False
else: # value is list
if expected not in value:
return False
return True
def _assert_not_contains(*, value: Any, expected: Any) -> bool:
def _assert_not_contains(*, value: object, expected: object) -> bool:
if not value:
return True
if not isinstance(value, (str, list)):
raise ValueError("Invalid actual value type: string or array")
if expected in value:
return False
# Type checking ensures value is str or list at this point
if isinstance(value, str):
if not isinstance(expected, str):
expected = str(expected)
if expected in value:
return False
else: # value is list
if expected in value:
return False
return True
def _assert_start_with(*, value: Any, expected: Any) -> bool:
def _assert_start_with(*, value: object, expected: object) -> bool:
if not value:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if not isinstance(expected, str):
raise ValueError("Expected value must be a string for startswith")
if not value.startswith(expected):
return False
return True
def _assert_end_with(*, value: Any, expected: Any) -> bool:
def _assert_end_with(*, value: object, expected: object) -> bool:
if not value:
return False
if not isinstance(value, str):
raise ValueError("Invalid actual value type: string")
if not isinstance(expected, str):
raise ValueError("Expected value must be a string for endswith")
if not value.endswith(expected):
return False
return True
def _assert_is(*, value: Any, expected: Any) -> bool:
def _assert_is(*, value: object, expected: object) -> bool:
if value is None:
return False
@ -213,7 +241,7 @@ def _assert_is(*, value: Any, expected: Any) -> bool:
return True
def _assert_is_not(*, value: Any, expected: Any) -> bool:
def _assert_is_not(*, value: object, expected: object) -> bool:
if value is None:
return False
@ -225,19 +253,19 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool:
return True
def _assert_empty(*, value: Any) -> bool:
def _assert_empty(*, value: object) -> bool:
if not value:
return True
return False
def _assert_not_empty(*, value: Any) -> bool:
def _assert_not_empty(*, value: object) -> bool:
if value:
return True
return False
def _assert_equal(*, value: Any, expected: Any) -> bool:
def _assert_equal(*, value: object, expected: object) -> bool:
if value is None:
return False
@ -246,10 +274,16 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
# Handle boolean comparison
if isinstance(value, bool):
if not isinstance(expected, (bool, int, str)):
raise ValueError(f"Cannot convert {type(expected)} to bool")
expected = bool(expected)
elif isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)
if value != expected:
@ -257,7 +291,7 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
return True
def _assert_not_equal(*, value: Any, expected: Any) -> bool:
def _assert_not_equal(*, value: object, expected: object) -> bool:
if value is None:
return False
@ -266,10 +300,16 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
# Handle boolean comparison
if isinstance(value, bool):
if not isinstance(expected, (bool, int, str)):
raise ValueError(f"Cannot convert {type(expected)} to bool")
expected = bool(expected)
elif isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)
if value == expected:
@ -277,7 +317,7 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
return True
def _assert_greater_than(*, value: Any, expected: Any) -> bool:
def _assert_greater_than(*, value: object, expected: object) -> bool:
if value is None:
return False
@ -285,8 +325,12 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)
if value <= expected:
@ -294,7 +338,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
return True
def _assert_less_than(*, value: Any, expected: Any) -> bool:
def _assert_less_than(*, value: object, expected: object) -> bool:
if value is None:
return False
@ -302,8 +346,12 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)
if value >= expected:
@ -311,7 +359,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
return True
def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool:
if value is None:
return False
@ -319,8 +367,12 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)
if value < expected:
@ -328,7 +380,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
return True
def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
def _assert_less_than_or_equal(*, value: object, expected: object) -> bool:
if value is None:
return False
@ -336,8 +388,12 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
raise ValueError("Invalid actual value type: number")
if isinstance(value, int):
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to int")
expected = int(expected)
else:
if not isinstance(expected, (int, float, str)):
raise ValueError(f"Cannot convert {type(expected)} to float")
expected = float(expected)
if value > expected:
@ -345,19 +401,19 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
return True
def _assert_null(*, value: Any) -> bool:
def _assert_null(*, value: object) -> bool:
if value is None:
return True
return False
def _assert_not_null(*, value: Any) -> bool:
def _assert_not_null(*, value: object) -> bool:
if value is not None:
return True
return False
def _assert_in(*, value: Any, expected: Any) -> bool:
def _assert_in(*, value: object, expected: object) -> bool:
if not value:
return False
@ -369,7 +425,7 @@ def _assert_in(*, value: Any, expected: Any) -> bool:
return True
def _assert_not_in(*, value: Any, expected: Any) -> bool:
def _assert_not_in(*, value: object, expected: object) -> bool:
if not value:
return True
@ -381,20 +437,33 @@ def _assert_not_in(*, value: Any, expected: Any) -> bool:
return True
def _assert_all_of(*, value: Any, expected: Sequence[str]) -> bool:
def _assert_all_of(*, value: object, expected: Sequence[str]) -> bool:
if not value:
return False
if not all(item in value for item in expected):
# Ensure value is a container that supports 'in' operator
if not isinstance(value, (list, tuple, set, str)):
return False
return True
return all(item in value for item in expected)
def _assert_exists(*, value: Any) -> bool:
def _assert_all_of_bool(*, value: object, expected: Sequence[bool]) -> bool:
if not value:
return False
# Ensure value is a container that supports 'in' operator
if not isinstance(value, (list, tuple, set)):
return False
return all(item in value for item in expected)
def _assert_exists(*, value: object) -> bool:
return value is not None
def _assert_not_exists(*, value: Any) -> bool:
def _assert_not_exists(*, value: object) -> bool:
return value is None
@ -404,7 +473,7 @@ def _process_sub_conditions(
operator: Literal["and", "or"],
) -> bool:
files = variable.value
group_results = []
group_results: list[bool] = []
for condition in sub_conditions:
key = FileAttribute(condition.key)
values = [file_manager.get_attr(file=file, attr=key) for file in files]
@ -415,14 +484,14 @@ def _process_sub_conditions(
if expected_value and not expected_value.startswith("."):
expected_value = "." + expected_value
normalized_values = []
normalized_values: list[object] = []
for value in values:
if value and isinstance(value, str):
if not value.startswith("."):
value = "." + value
normalized_values.append(value)
values = normalized_values
sub_group_results = [
sub_group_results: list[bool] = [
_evaluate_condition(
value=value,
operator=condition.comparison_operator,

View File

@ -50,7 +50,7 @@ class WorkflowCycleManager:
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
):
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_info = workflow_info
@ -305,7 +305,7 @@ class WorkflowCycleManager:
error_message: Optional[str] = None,
exceptions_count: int = 0,
finished_at: Optional[datetime] = None,
) -> None:
):
"""Update workflow execution with completion data."""
execution.status = status
execution.outputs = outputs or {}
@ -322,7 +322,7 @@ class WorkflowCycleManager:
workflow_execution: WorkflowExecution,
conversation_id: Optional[str],
external_trace_id: Optional[str],
) -> None:
):
"""Add trace task if trace manager is provided."""
if trace_manager:
trace_manager.add_trace_task(
@ -340,7 +340,7 @@ class WorkflowCycleManager:
workflow_execution_id: str,
error_message: str,
now: datetime,
) -> None:
):
"""Fail all running node executions for a workflow."""
running_node_executions = [
node_exec
@ -410,7 +410,7 @@ class WorkflowCycleManager:
status: WorkflowNodeExecutionStatus,
error: Optional[str] = None,
handle_special_values: bool = False,
) -> None:
):
"""Update node execution with completion data."""
finished_at = naive_utc_now()
elapsed_time = (finished_at - event.start_at).total_seconds()

View File

@ -41,6 +41,7 @@ class WorkflowEntry:
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
variable_pool: VariablePool,
graph_runtime_state: GraphRuntimeState,
command_channel: Optional[CommandChannel] = None,
) -> None:
@ -351,7 +352,7 @@ class WorkflowEntry:
return result if isinstance(result, Mapping) or result is None else dict(result)
@staticmethod
def _handle_special_values(value: Any) -> Any:
def _handle_special_values(value: Any):
if value is None:
return value
if isinstance(value, dict):
@ -376,7 +377,7 @@ class WorkflowEntry:
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
tenant_id: str,
) -> None:
):
# NOTE(QuantumGhost): This logic should remain synchronized with
# the implementation of `load_into_variable_pool`, specifically the logic about
# variable existence checking.

View File

@ -18,7 +18,7 @@ class WorkflowRuntimeTypeConverter:
result = self._to_json_encodable_recursive(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
def _to_json_encodable_recursive(self, value: Any) -> Any:
def _to_json_encodable_recursive(self, value: Any):
if value is None:
return value
if isinstance(value, (bool, int, str, float)):