mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
merge main
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
@ -114,8 +115,10 @@ class CodeNode(BaseNode[CodeNodeData]):
|
||||
)
|
||||
|
||||
if isinstance(value, float):
|
||||
decimal_value = Decimal(str(value)).normalize()
|
||||
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
|
||||
# raise error if precision is too high
|
||||
if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION:
|
||||
if precision > dify_config.CODE_MAX_PRECISION:
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` has too high precision,"
|
||||
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
|
||||
|
||||
@ -521,18 +521,52 @@ class IterationNode(BaseNode[IterationNodeData]):
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": None},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
outputs[current_index] = None
|
||||
|
||||
# clean nodes resources
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
# iteration run failed
|
||||
if self.node_data.is_parallel:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
else:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": outputs},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
# stop the iterator
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
return
|
||||
yield metadata_event
|
||||
|
||||
current_output_segment = variable_pool.get(self.node_data.output_selector)
|
||||
|
||||
@ -144,6 +144,8 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
|
||||
available_datasets = []
|
||||
@ -171,6 +173,9 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
.all()
|
||||
)
|
||||
|
||||
# avoid blocking at retrieval
|
||||
db.session.close()
|
||||
|
||||
for dataset in results:
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
|
||||
@ -1,11 +1,29 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import AfterValidator, BaseModel, Field
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
|
||||
if seg_type not in _VALID_VAR_TYPE:
|
||||
raise ValueError(...)
|
||||
return seg_type
|
||||
|
||||
|
||||
class LoopVariableData(BaseModel):
|
||||
"""
|
||||
@ -13,7 +31,7 @@ class LoopVariableData(BaseModel):
|
||||
"""
|
||||
|
||||
label: str
|
||||
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Optional[Any | list[str]] = None
|
||||
|
||||
|
||||
@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import (
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayStringSegment,
|
||||
IntegerSegment,
|
||||
ObjectSegment,
|
||||
Segment,
|
||||
SegmentType,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
return variable_mapping
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
|
||||
"string": (StringSegment, SegmentType.STRING),
|
||||
"number": (IntegerSegment, SegmentType.NUMBER),
|
||||
"object": (ObjectSegment, SegmentType.OBJECT),
|
||||
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
|
||||
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
|
||||
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
|
||||
}
|
||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||
if value:
|
||||
if value and isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
else:
|
||||
value = []
|
||||
segment_info = segment_mapping.get(var_type)
|
||||
if not segment_info:
|
||||
raise ValueError(f"Invalid variable type: {var_type}")
|
||||
segment_class, value_type = segment_info
|
||||
return segment_class(value=value, value_type=value_type)
|
||||
try:
|
||||
return build_segment_with_type(var_type, value)
|
||||
except TypeMismatchError as type_exc:
|
||||
# Attempt to parse the value as a JSON-encoded string, if applicable.
|
||||
if not isinstance(value, str):
|
||||
raise
|
||||
try:
|
||||
value = json.loads(value)
|
||||
except ValueError:
|
||||
raise type_exc
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
||||
@ -16,7 +16,7 @@ class StartNode(BaseNode[StartNodeData]):
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables
|
||||
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
|
||||
|
||||
# TODO: System variables should be directly accessible, no need for special handling
|
||||
# Set system variables as node outputs.
|
||||
|
||||
@ -22,7 +22,7 @@ from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
@ -373,6 +373,12 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage)
|
||||
yield RunRetrieverResourceEvent(
|
||||
retriever_resources=message.message.retriever_resources,
|
||||
context=message.message.context,
|
||||
)
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
|
||||
@ -130,6 +130,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
|
||||
|
||||
|
||||
def get_zero_value(t: SegmentType):
|
||||
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return variable_factory.build_segment([])
|
||||
@ -137,6 +138,10 @@ def get_zero_value(t: SegmentType):
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
return variable_factory.build_segment("")
|
||||
case SegmentType.INTEGER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.FLOAT:
|
||||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case _:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from core.variables import SegmentType
|
||||
|
||||
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
|
||||
@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
case Operation.OVER_WRITE | Operation.CLEAR:
|
||||
return True
|
||||
case Operation.SET:
|
||||
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
|
||||
return variable_type in {
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
}
|
||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||
# Only number variable can be added, subtracted, multiplied or divided
|
||||
return variable_type == SegmentType.NUMBER
|
||||
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
|
||||
case Operation.APPEND | Operation.EXTEND:
|
||||
# Only array variable can be appended or extended
|
||||
return variable_type in {
|
||||
@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
|
||||
match variable_type:
|
||||
case SegmentType.STRING | SegmentType.OBJECT:
|
||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||
case SegmentType.NUMBER:
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return operation in {
|
||||
Operation.OVER_WRITE,
|
||||
Operation.SET,
|
||||
@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
case SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
|
||||
case SegmentType.NUMBER:
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
if not isinstance(value, int | float):
|
||||
return False
|
||||
if operation == Operation.DIVIDE and value == 0:
|
||||
|
||||
Reference in New Issue
Block a user