mirror of
https://github.com/langgenius/dify.git
synced 2026-04-22 19:57:40 +08:00
feat: fix i18n missing keys and merge upstream/main (#24615)
Signed-off-by: -LAN- <laipz8200@outlook.com> Signed-off-by: kenwoodjw <blackxin55+@gmail.com> Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com> Signed-off-by: yihong0618 <zouzou0208@gmail.com> Signed-off-by: zhanluxianshen <zhanluxianshen@163.com> Co-authored-by: -LAN- <laipz8200@outlook.com> Co-authored-by: GuanMu <ballmanjq@gmail.com> Co-authored-by: Davide Delbianco <davide.delbianco@outlook.com> Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: kenwoodjw <blackxin55+@gmail.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com> Co-authored-by: Qiang Lee <18018968632@163.com> Co-authored-by: 李强04 <liqiang04@gaotu.cn> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: Matri Qi <matrixdom@126.com> Co-authored-by: huayaoyue6 <huayaoyue@163.com> Co-authored-by: Bowen Liang <liangbowen@gf.com.cn> Co-authored-by: znn <jubinkumarsoni@gmail.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: yihong <zouzou0208@gmail.com> Co-authored-by: Muke Wang <shaodwaaron@gmail.com> Co-authored-by: wangmuke <wangmuke@kingsware.cn> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: quicksand <quicksandzn@gmail.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Eric Guo <eric.guocz@gmail.com> Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: jiangbo721 <jiangbo721@163.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: hjlarry <25834719+hjlarry@users.noreply.github.com> Co-authored-by: lxsummer <35754229+lxjustdoit@users.noreply.github.com> Co-authored-by: 湛露先生 <zhanluxianshen@163.com> Co-authored-by: Guangdong Liu <liugddx@gmail.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yessenia-d <yessenia.contact@gmail.com> Co-authored-by: huangzhuo1949 <167434202+huangzhuo1949@users.noreply.github.com> Co-authored-by: huangzhuo <huangzhuo1@xiaomi.com> Co-authored-by: 17hz <0x149527@gmail.com> Co-authored-by: Amy <1530140574@qq.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Nite Knite <nkCoding@gmail.com> Co-authored-by: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Co-authored-by: Petrus Han <petrus.hanks@gmail.com> Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com> Co-authored-by: Kalo Chin <frog.beepers.0n@icloud.com> Co-authored-by: Ujjwal Maurya <ujjwalsbx@gmail.com> Co-authored-by: Maries <xh001x@hotmail.com>
This commit is contained in:
@ -6,12 +6,14 @@ implementation details like tenant_id, app_id, etc.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
@ -60,7 +62,7 @@ class WorkflowExecution(BaseModel):
|
||||
Calculate elapsed time in seconds.
|
||||
If workflow is not finished, use current time.
|
||||
"""
|
||||
end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None)
|
||||
end_time = self.finished_at or naive_utc_now()
|
||||
return (end_time - self.started_at).total_seconds()
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
@ -71,7 +72,7 @@ class RouteNodeState(BaseModel):
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
self.node_run_result = run_result
|
||||
self.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
self.finished_at = naive_utc_now()
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
@ -89,7 +90,7 @@ class RuntimeRouteState(BaseModel):
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
|
||||
state = RouteNodeState(node_id=node_id, start_at=naive_utc_now())
|
||||
self.node_state_mapping[state.id] = state
|
||||
return state
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from copy import copy, deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
@ -51,6 +50,7 @@ from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
@ -640,7 +640,7 @@ class GraphEngine:
|
||||
while should_continue_retry and retries <= max_retries:
|
||||
try:
|
||||
# run node
|
||||
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
retry_start_at = naive_utc_now()
|
||||
# yield control to other threads
|
||||
time.sleep(0.001)
|
||||
event_stream = node.run()
|
||||
|
||||
@ -13,8 +13,9 @@ from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
@ -558,7 +559,7 @@ class AgentNode(BaseNode):
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == NodeType.AGENT:
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(msg_metadata)
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
@ -692,7 +693,13 @@ class AgentNode(BaseNode):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
|
||||
@ -8,6 +8,7 @@ from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
from core.variables.segments import ArrayFileSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
@ -119,6 +120,14 @@ class CodeNode(BaseNode):
|
||||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_boolean(self, value: bool | None, variable: str) -> bool | None:
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, bool):
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a boolean")
|
||||
|
||||
return value
|
||||
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
@ -173,6 +182,8 @@ class CodeNode(BaseNode):
|
||||
prefix=f"{prefix}.{output_name}" if prefix else output_name,
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif isinstance(output_value, bool):
|
||||
self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name)
|
||||
elif isinstance(output_value, int | float):
|
||||
self._check_number(
|
||||
value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name
|
||||
@ -232,7 +243,7 @@ class CodeNode(BaseNode):
|
||||
if output_name not in result:
|
||||
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
|
||||
if output_config.type == "object":
|
||||
if output_config.type == SegmentType.OBJECT:
|
||||
# check if output is object
|
||||
if not isinstance(result.get(output_name), dict):
|
||||
if result[output_name] is None:
|
||||
@ -249,18 +260,28 @@ class CodeNode(BaseNode):
|
||||
prefix=f"{prefix}.{output_name}",
|
||||
depth=depth + 1,
|
||||
)
|
||||
elif output_config.type == "number":
|
||||
elif output_config.type == SegmentType.NUMBER:
|
||||
# check if number available
|
||||
transformed_result[output_name] = self._check_number(
|
||||
value=result[output_name], variable=f"{prefix}{dot}{output_name}"
|
||||
)
|
||||
elif output_config.type == "string":
|
||||
checked = self._check_number(value=result[output_name], variable=f"{prefix}{dot}{output_name}")
|
||||
# If the output is a boolean and the output schema specifies a NUMBER type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
transformed_result[output_name] = self._convert_boolean_to_int(checked)
|
||||
|
||||
elif output_config.type == SegmentType.STRING:
|
||||
# check if string available
|
||||
transformed_result[output_name] = self._check_string(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == "array[number]":
|
||||
elif output_config.type == SegmentType.BOOLEAN:
|
||||
transformed_result[output_name] = self._check_boolean(
|
||||
value=result[output_name],
|
||||
variable=f"{prefix}{dot}{output_name}",
|
||||
)
|
||||
elif output_config.type == SegmentType.ARRAY_NUMBER:
|
||||
# check if array of number available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@ -278,10 +299,17 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
|
||||
transformed_result[output_name] = [
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
# If the element is a boolean and the output schema specifies a `array[number]` type,
|
||||
# convert the boolean value to an integer.
|
||||
#
|
||||
# This ensures compatibility with existing workflows that may use
|
||||
# `True` and `False` as values for NUMBER type outputs.
|
||||
self._convert_boolean_to_int(
|
||||
self._check_number(value=value, variable=f"{prefix}{dot}{output_name}[{i}]"),
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == "array[string]":
|
||||
elif output_config.type == SegmentType.ARRAY_STRING:
|
||||
# check if array of string available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@ -302,7 +330,7 @@ class CodeNode(BaseNode):
|
||||
self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == "array[object]":
|
||||
elif output_config.type == SegmentType.ARRAY_OBJECT:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
@ -340,6 +368,22 @@ class CodeNode(BaseNode):
|
||||
)
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
elif output_config.type == SegmentType.ARRAY_BOOLEAN:
|
||||
# check if array of object available
|
||||
if not isinstance(result[output_name], list):
|
||||
if result[output_name] is None:
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
transformed_result[output_name] = [
|
||||
self._check_boolean(value=value, variable=f"{prefix}{dot}{output_name}[{i}]")
|
||||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
|
||||
else:
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
@ -374,3 +418,16 @@ class CodeNode(BaseNode):
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None:
|
||||
"""This function convert boolean to integers when the output schema specifies a NUMBER type.
|
||||
|
||||
This ensures compatibility with existing workflows that may use
|
||||
`True` and `False` as values for NUMBER type outputs.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
return value
|
||||
|
||||
@ -1,11 +1,31 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import AfterValidator, BaseModel
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeLanguage
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_ALLOWED_OUTPUT_FROM_CODE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(segment_type: SegmentType) -> SegmentType:
|
||||
if segment_type not in _ALLOWED_OUTPUT_FROM_CODE:
|
||||
raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}")
|
||||
return segment_type
|
||||
|
||||
|
||||
class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
@ -13,7 +33,7 @@ class CodeNodeData(BaseNodeData):
|
||||
"""
|
||||
|
||||
class Output(BaseModel):
|
||||
type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
|
||||
type: Annotated[SegmentType, AfterValidator(_validate_type)]
|
||||
children: Optional[dict[str, "CodeNodeData.Output"]] = None
|
||||
|
||||
class Dependency(BaseModel):
|
||||
|
||||
@ -12,6 +12,7 @@ from json_repair import repair_json
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import file_manager
|
||||
from core.file.enums import FileTransferMethod
|
||||
from core.helper import ssrf_proxy
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -228,7 +229,9 @@ class Executor:
|
||||
files: dict[str, list[tuple[str | None, bytes, str]]] = {}
|
||||
for key, files_in_segment in files_list:
|
||||
for file in files_in_segment:
|
||||
if file.related_id is not None:
|
||||
if file.related_id is not None or (
|
||||
file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None
|
||||
):
|
||||
file_tuple = (
|
||||
file.filename,
|
||||
file_manager.download(file),
|
||||
@ -326,22 +329,16 @@ class Executor:
|
||||
"""
|
||||
do http request depending on api bundle
|
||||
"""
|
||||
if self.method not in {
|
||||
"get",
|
||||
"head",
|
||||
"post",
|
||||
"put",
|
||||
"delete",
|
||||
"patch",
|
||||
"options",
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"PATCH",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
"OPTIONS",
|
||||
}:
|
||||
_METHOD_MAP = {
|
||||
"get": ssrf_proxy.get,
|
||||
"head": ssrf_proxy.head,
|
||||
"post": ssrf_proxy.post,
|
||||
"put": ssrf_proxy.put,
|
||||
"delete": ssrf_proxy.delete,
|
||||
"patch": ssrf_proxy.patch,
|
||||
}
|
||||
method_lc = self.method.lower()
|
||||
if method_lc not in _METHOD_MAP:
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
@ -359,11 +356,11 @@ class Executor:
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
|
||||
response: httpx.Response = _METHOD_MAP[method_lc](**request_args)
|
||||
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
|
||||
raise HttpRequestNodeError(str(e)) from e
|
||||
# FIXME: fix type ignore, this maybe httpx type issue
|
||||
return response # type: ignore
|
||||
return response
|
||||
|
||||
def invoke(self) -> Response:
|
||||
# assemble headers
|
||||
|
||||
@ -4,7 +4,7 @@ import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, wait
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
from queue import Empty, Queue
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
@ -41,6 +41,7 @@ from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from factories.variable_factory import build_segment
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
@ -179,7 +180,7 @@ class IterationNode(BaseNode):
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
|
||||
start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
start_at = naive_utc_now()
|
||||
|
||||
yield IterationRunStartedEvent(
|
||||
iteration_id=self.id,
|
||||
@ -428,7 +429,7 @@ class IterationNode(BaseNode):
|
||||
"""
|
||||
run single iteration
|
||||
"""
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
iter_start_at = naive_utc_now()
|
||||
|
||||
try:
|
||||
rst = graph_engine.run()
|
||||
@ -505,7 +506,7 @@ class IterationNode(BaseNode):
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
duration = (naive_utc_now() - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
@ -526,7 +527,7 @@ class IterationNode(BaseNode):
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
duration = (naive_utc_now() - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
@ -602,7 +603,7 @@ class IterationNode(BaseNode):
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
duration = (naive_utc_now() - iter_start_at).total_seconds()
|
||||
iter_run_map[iteration_run_id] = duration
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
|
||||
@ -1,36 +1,43 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
_Condition = Literal[
|
||||
|
||||
class FilterOperator(StrEnum):
|
||||
# string conditions
|
||||
"contains",
|
||||
"start with",
|
||||
"end with",
|
||||
"is",
|
||||
"in",
|
||||
"empty",
|
||||
"not contains",
|
||||
"is not",
|
||||
"not in",
|
||||
"not empty",
|
||||
CONTAINS = "contains"
|
||||
START_WITH = "start with"
|
||||
END_WITH = "end with"
|
||||
IS = "is"
|
||||
IN = "in"
|
||||
EMPTY = "empty"
|
||||
NOT_CONTAINS = "not contains"
|
||||
IS_NOT = "is not"
|
||||
NOT_IN = "not in"
|
||||
NOT_EMPTY = "not empty"
|
||||
# number conditions
|
||||
"=",
|
||||
"≠",
|
||||
"<",
|
||||
">",
|
||||
"≥",
|
||||
"≤",
|
||||
]
|
||||
EQUAL = "="
|
||||
NOT_EQUAL = "≠"
|
||||
LESS_THAN = "<"
|
||||
GREATER_THAN = ">"
|
||||
GREATER_THAN_OR_EQUAL = "≥"
|
||||
LESS_THAN_OR_EQUAL = "≤"
|
||||
|
||||
|
||||
class Order(StrEnum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class FilterCondition(BaseModel):
|
||||
key: str = ""
|
||||
comparison_operator: _Condition = "contains"
|
||||
value: str | Sequence[str] = ""
|
||||
comparison_operator: FilterOperator = FilterOperator.CONTAINS
|
||||
# the value is bool if the filter operator is comparing with
|
||||
# a boolean constant.
|
||||
value: str | Sequence[str] | bool = ""
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
@ -38,10 +45,10 @@ class FilterBy(BaseModel):
|
||||
conditions: Sequence[FilterCondition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class OrderBy(BaseModel):
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
key: str = ""
|
||||
value: Literal["asc", "desc"] = "asc"
|
||||
value: Order = Order.ASC
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
@ -57,6 +64,6 @@ class ExtractConfig(BaseModel):
|
||||
class ListOperatorNodeData(BaseNodeData):
|
||||
variable: Sequence[str] = Field(default_factory=list)
|
||||
filter_by: FilterBy
|
||||
order_by: OrderBy
|
||||
order_by: OrderByConfig
|
||||
limit: Limit
|
||||
extract_by: ExtractConfig = Field(default_factory=ExtractConfig)
|
||||
|
||||
@ -1,18 +1,40 @@
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Optional, TypeAlias, TypeVar
|
||||
|
||||
from core.file import File
|
||||
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
|
||||
from .entities import ListOperatorNodeData
|
||||
from .entities import FilterOperator, ListOperatorNodeData, Order
|
||||
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
|
||||
|
||||
_SUPPORTED_TYPES_TUPLE = (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayStringSegment,
|
||||
ArrayBooleanSegment,
|
||||
)
|
||||
_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]:
|
||||
"""Returns the negation of a given filter function. If the original filter
|
||||
returns `True` for a value, the negated filter will return `False`, and vice versa.
|
||||
"""
|
||||
|
||||
def wrapper(value: _T) -> bool:
|
||||
return not filter_(value)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ListOperatorNode(BaseNode):
|
||||
_node_type = NodeType.LIST_OPERATOR
|
||||
@ -69,11 +91,8 @@ class ListOperatorNode(BaseNode):
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
)
|
||||
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
|
||||
error_message = (
|
||||
f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
|
||||
"or ArrayStringSegment"
|
||||
)
|
||||
if not isinstance(variable, _SUPPORTED_TYPES_TUPLE):
|
||||
error_message = f"Variable {self._node_data.variable} is not an array type, actual type: {type(variable)}"
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
|
||||
)
|
||||
@ -122,9 +141,7 @@ class ListOperatorNode(BaseNode):
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def _apply_filter(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
filter_func: Callable[[Any], bool]
|
||||
result: list[Any] = []
|
||||
for condition in self._node_data.filter_by.conditions:
|
||||
@ -154,33 +171,35 @@ class ListOperatorNode(BaseNode):
|
||||
)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayBooleanSegment):
|
||||
if not isinstance(condition.value, bool):
|
||||
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
|
||||
filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value)
|
||||
result = list(filter(filter_func, variable.value))
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statment should be unreachable.")
|
||||
return variable
|
||||
|
||||
def _apply_order(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
if isinstance(variable, ArrayStringSegment):
|
||||
result = _order_string(order=self._node_data.order_by.value, array=variable.value)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayNumberSegment):
|
||||
result = _order_number(order=self._node_data.order_by.value, array=variable.value)
|
||||
def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)):
|
||||
result = sorted(variable.value, reverse=self._node_data.order_by == Order.DESC)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
result = _order_file(
|
||||
order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
|
||||
)
|
||||
variable = variable.model_copy(update={"value": result})
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable")
|
||||
|
||||
return variable
|
||||
|
||||
def _apply_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
result = variable.value[: self._node_data.limit.size]
|
||||
return variable.model_copy(update={"value": result})
|
||||
|
||||
def _extract_slice(
|
||||
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
|
||||
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
|
||||
def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS:
|
||||
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
|
||||
if value < 1:
|
||||
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
|
||||
@ -232,11 +251,11 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo
|
||||
case "empty":
|
||||
return lambda x: x == ""
|
||||
case "not contains":
|
||||
return lambda x: not _contains(value)(x)
|
||||
return _negation(_contains(value))
|
||||
case "is not":
|
||||
return lambda x: not _is(value)(x)
|
||||
return _negation(_is(value))
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
return _negation(_in(value))
|
||||
case "not empty":
|
||||
return lambda x: x != ""
|
||||
case _:
|
||||
@ -248,7 +267,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab
|
||||
case "in":
|
||||
return _in(value)
|
||||
case "not in":
|
||||
return lambda x: not _in(value)(x)
|
||||
return _negation(_in(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
@ -271,6 +290,16 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]:
|
||||
match condition:
|
||||
case FilterOperator.IS:
|
||||
return _is(value)
|
||||
case FilterOperator.IS_NOT:
|
||||
return _negation(_is(value))
|
||||
case _:
|
||||
raise InvalidConditionError(f"Invalid condition: {condition}")
|
||||
|
||||
|
||||
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
|
||||
extract_func: Callable[[File], Any]
|
||||
if key in {"name", "extension", "mime_type", "url"} and isinstance(value, str):
|
||||
@ -298,7 +327,7 @@ def _endswith(value: str) -> Callable[[str], bool]:
|
||||
return lambda x: x.endswith(value)
|
||||
|
||||
|
||||
def _is(value: str) -> Callable[[str], bool]:
|
||||
def _is(value: _T) -> Callable[[_T], bool]:
|
||||
return lambda x: x == value
|
||||
|
||||
|
||||
@ -330,21 +359,13 @@ def _ge(value: int | float) -> Callable[[int | float], bool]:
|
||||
return lambda x: x >= value
|
||||
|
||||
|
||||
def _order_number(*, order: Literal["asc", "desc"], array: Sequence[int | float]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_string(*, order: Literal["asc", "desc"], array: Sequence[str]):
|
||||
return sorted(array, key=lambda x: x, reverse=order == "desc")
|
||||
|
||||
|
||||
def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Sequence[File]):
|
||||
def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]):
|
||||
extract_func: Callable[[File], Any]
|
||||
if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url"}:
|
||||
extract_func = _get_file_extract_string_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
elif order_by == "size":
|
||||
extract_func = _get_file_extract_number_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC)
|
||||
else:
|
||||
raise InvalidKeyError(f"Invalid order key: {order_by}")
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
@ -20,6 +19,7 @@ from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegme
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
@ -149,7 +149,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
last_used=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
|
||||
@ -3,7 +3,7 @@ import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import FileType, file_manager
|
||||
@ -55,7 +55,6 @@ from core.workflow.entities.variable_entities import VariableSelector
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.enums import ErrorStrategy, NodeType
|
||||
@ -90,6 +89,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.event import InNodeEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -161,7 +161,7 @@ class LLMNode(BaseNode):
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def _run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
|
||||
node_inputs: Optional[dict[str, Any]] = None
|
||||
process_data = None
|
||||
result_text = ""
|
||||
@ -737,7 +737,7 @@ class LLMNode(BaseNode):
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
|
||||
@ -12,9 +12,11 @@ _VALID_VAR_TYPE = frozenset(
|
||||
SegmentType.STRING,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
@ -36,6 +36,7 @@ from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -143,7 +144,7 @@ class LoopNode(BaseNode):
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
|
||||
start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
start_at = naive_utc_now()
|
||||
condition_processor = ConditionProcessor()
|
||||
|
||||
# Start Loop event
|
||||
@ -171,7 +172,7 @@ class LoopNode(BaseNode):
|
||||
try:
|
||||
check_break_result = False
|
||||
for i in range(loop_count):
|
||||
loop_start_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
loop_start_time = naive_utc_now()
|
||||
# run single loop
|
||||
loop_result = yield from self._run_single_loop(
|
||||
graph_engine=graph_engine,
|
||||
@ -185,7 +186,7 @@ class LoopNode(BaseNode):
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
)
|
||||
loop_end_time = datetime.now(UTC).replace(tzinfo=None)
|
||||
loop_end_time = naive_utc_now()
|
||||
|
||||
single_loop_variable = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
@ -403,11 +404,11 @@ class LoopNode(BaseNode):
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
_outputs = {}
|
||||
_outputs: dict[str, Segment | int | None] = {}
|
||||
for loop_variable_key, loop_variable_selector in loop_variable_selectors.items():
|
||||
_loop_variable_segment = variable_pool.get(loop_variable_selector)
|
||||
if _loop_variable_segment:
|
||||
_outputs[loop_variable_key] = _loop_variable_segment.value
|
||||
_outputs[loop_variable_key] = _loop_variable_segment
|
||||
else:
|
||||
_outputs[loop_variable_key] = None
|
||||
|
||||
@ -521,21 +522,30 @@ class LoopNode(BaseNode):
|
||||
return variable_mapping
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
if var_type in ["array[string]", "array[number]", "array[object]"]:
|
||||
if value and isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
if not var_type.is_array_type() or var_type == SegmentType.BOOLEAN:
|
||||
value = original_value
|
||||
elif var_type in [
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
]:
|
||||
if original_value and isinstance(original_value, str):
|
||||
value = json.loads(original_value)
|
||||
else:
|
||||
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
|
||||
value = []
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
try:
|
||||
return build_segment_with_type(var_type, value)
|
||||
return build_segment_with_type(var_type, value=value)
|
||||
except TypeMismatchError as type_exc:
|
||||
# Attempt to parse the value as a JSON-encoded string, if applicable.
|
||||
if not isinstance(value, str):
|
||||
if not isinstance(original_value, str):
|
||||
raise
|
||||
try:
|
||||
value = json.loads(value)
|
||||
value = json.loads(original_value)
|
||||
except ValueError:
|
||||
raise type_exc
|
||||
return build_segment_with_type(var_type, value)
|
||||
|
||||
@ -1,10 +1,46 @@
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Annotated, Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
|
||||
|
||||
_OLD_BOOL_TYPE_NAME = "bool"
|
||||
_OLD_SELECT_TYPE_NAME = "select"
|
||||
|
||||
_VALID_PARAMETER_TYPES = frozenset(
|
||||
[
|
||||
SegmentType.STRING, # "string",
|
||||
SegmentType.NUMBER, # "number",
|
||||
SegmentType.BOOLEAN,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_BOOLEAN,
|
||||
_OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node
|
||||
_OLD_SELECT_TYPE_NAME, # string type with enumeration choices.
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _validate_type(parameter_type: str) -> SegmentType:
|
||||
if not isinstance(parameter_type, str):
|
||||
raise TypeError(f"type should be str, got {type(parameter_type)}, value={parameter_type}")
|
||||
if parameter_type not in _VALID_PARAMETER_TYPES:
|
||||
raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.")
|
||||
|
||||
if parameter_type == _OLD_BOOL_TYPE_NAME:
|
||||
return SegmentType.BOOLEAN
|
||||
elif parameter_type == _OLD_SELECT_TYPE_NAME:
|
||||
return SegmentType.STRING
|
||||
return SegmentType(parameter_type)
|
||||
|
||||
|
||||
class _ParameterConfigError(Exception):
|
||||
@ -17,7 +53,7 @@ class ParameterConfig(BaseModel):
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: Literal["string", "number", "bool", "select", "array[string]", "array[number]", "array[object]"]
|
||||
type: Annotated[SegmentType, BeforeValidator(_validate_type)]
|
||||
options: Optional[list[str]] = None
|
||||
description: str
|
||||
required: bool
|
||||
@ -32,17 +68,20 @@ class ParameterConfig(BaseModel):
|
||||
return str(value)
|
||||
|
||||
def is_array_type(self) -> bool:
|
||||
return self.type in ("array[string]", "array[number]", "array[object]")
|
||||
return self.type.is_array_type()
|
||||
|
||||
def element_type(self) -> Literal["string", "number", "object"]:
|
||||
if self.type == "array[number]":
|
||||
return "number"
|
||||
elif self.type == "array[string]":
|
||||
return "string"
|
||||
elif self.type == "array[object]":
|
||||
return "object"
|
||||
else:
|
||||
raise _ParameterConfigError(f"{self.type} is not array type.")
|
||||
def element_type(self) -> SegmentType:
|
||||
"""Return the element type of the parameter.
|
||||
|
||||
Raises a ValueError if the parameter's type is not an array type.
|
||||
"""
|
||||
element_type = self.type.element_type()
|
||||
# At this point, self.type is guaranteed to be one of `ARRAY_STRING`,
|
||||
# `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`.
|
||||
#
|
||||
# See: _VALID_PARAMETER_TYPES for reference.
|
||||
assert element_type is not None, f"the element type should not be None, {self.type=}"
|
||||
return element_type
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
@ -74,16 +113,18 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
|
||||
if parameter.type in {"string", "select"}:
|
||||
if parameter.type == SegmentType.STRING:
|
||||
parameter_schema["type"] = "string"
|
||||
elif parameter.type.startswith("array"):
|
||||
elif parameter.type.is_array_type():
|
||||
parameter_schema["type"] = "array"
|
||||
nested_type = parameter.type[6:-1]
|
||||
parameter_schema["items"] = {"type": nested_type}
|
||||
element_type = parameter.type.element_type()
|
||||
if element_type is None:
|
||||
raise AssertionError("element type should not be None.")
|
||||
parameter_schema["items"] = {"type": element_type.value}
|
||||
else:
|
||||
parameter_schema["type"] = parameter.type
|
||||
|
||||
if parameter.type == "select":
|
||||
if parameter.options:
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
parameters["properties"][parameter.name] = parameter_schema
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class ParameterExtractorNodeError(ValueError):
|
||||
"""Base error for ParameterExtractorNode."""
|
||||
|
||||
@ -48,3 +53,23 @@ class InvalidArrayValueError(ParameterExtractorNodeError):
|
||||
|
||||
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model mode is invalid."""
|
||||
|
||||
|
||||
class InvalidValueTypeError(ParameterExtractorNodeError):
|
||||
def __init__(
|
||||
self,
|
||||
/,
|
||||
parameter_name: str,
|
||||
expected_type: SegmentType,
|
||||
actual_type: SegmentType | None,
|
||||
value: Any,
|
||||
) -> None:
|
||||
message = (
|
||||
f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, "
|
||||
f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}"
|
||||
)
|
||||
super().__init__(message)
|
||||
self.parameter_name = parameter_name
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
self.value = value
|
||||
|
||||
@ -26,7 +26,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.types import ArrayValidation, SegmentType
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@ -39,16 +39,13 @@ from factories.variable_factory import build_segment_with_type
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
InvalidArrayValueError,
|
||||
InvalidBoolValueError,
|
||||
InvalidInvokeResultError,
|
||||
InvalidModelModeError,
|
||||
InvalidModelTypeError,
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidNumberValueError,
|
||||
InvalidSelectValueError,
|
||||
InvalidStringValueError,
|
||||
InvalidTextContentTypeError,
|
||||
InvalidValueTypeError,
|
||||
ModelSchemaNotFoundError,
|
||||
ParameterExtractorNodeError,
|
||||
RequiredParameterMissingError,
|
||||
@ -549,9 +546,6 @@ class ParameterExtractorNode(BaseNode):
|
||||
return prompt_messages
|
||||
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Validate result.
|
||||
"""
|
||||
if len(data.parameters) != len(result):
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
@ -559,101 +553,106 @@ class ParameterExtractorNode(BaseNode):
|
||||
if parameter.required and parameter.name not in result:
|
||||
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
|
||||
|
||||
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
|
||||
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
|
||||
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
|
||||
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type.startswith("array"):
|
||||
parameters = result.get(parameter.name)
|
||||
if not isinstance(parameters, list):
|
||||
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
nested_type = parameter.type[6:-1]
|
||||
for item in parameters:
|
||||
if nested_type == "number" and not isinstance(item, int | float):
|
||||
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
if nested_type == "string" and not isinstance(item, str):
|
||||
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
if nested_type == "object" and not isinstance(item, dict):
|
||||
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
param_value = result.get(parameter.name)
|
||||
if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL):
|
||||
inferred_type = SegmentType.infer_segment_type(param_value)
|
||||
raise InvalidValueTypeError(
|
||||
parameter_name=parameter.name,
|
||||
expected_type=parameter.type,
|
||||
actual_type=inferred_type,
|
||||
value=param_value,
|
||||
)
|
||||
if parameter.type == SegmentType.STRING and parameter.options:
|
||||
if param_value not in parameter.options:
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _transform_number(value: int | float | str | bool) -> int | float | None:
|
||||
"""
|
||||
Attempts to transform the input into an integer or float.
|
||||
|
||||
Returns:
|
||||
int or float: The transformed number if the conversion is successful.
|
||||
None: If the transformation fails.
|
||||
|
||||
Note:
|
||||
Boolean values `True` and `False` are converted to integers `1` and `0`, respectively.
|
||||
This behavior ensures compatibility with existing workflows that may use boolean types as integers.
|
||||
"""
|
||||
if isinstance(value, bool):
|
||||
return int(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
return None
|
||||
if "." in value:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result = {}
|
||||
transformed_result: dict[str, Any] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
param_value = result[parameter.name]
|
||||
# transform value
|
||||
if parameter.type == "number":
|
||||
if isinstance(result[parameter.name], int | float):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif isinstance(result[parameter.name], str):
|
||||
try:
|
||||
if "." in result[parameter.name]:
|
||||
result[parameter.name] = float(result[parameter.name])
|
||||
else:
|
||||
result[parameter.name] = int(result[parameter.name])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
# TODO: bool is not supported in the current version
|
||||
# elif parameter.type == 'bool':
|
||||
# if isinstance(result[parameter.name], bool):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ['true', 'false']:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
|
||||
# elif isinstance(result[parameter.name], int):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
elif parameter.type in {"string", "select"}:
|
||||
if isinstance(result[parameter.name], str):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
if parameter.type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(param_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
if isinstance(result[parameter.name], (bool, int)):
|
||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ["true", "false"]:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true")
|
||||
elif parameter.type == SegmentType.STRING:
|
||||
if isinstance(param_value, str):
|
||||
transformed_result[parameter.name] = param_value
|
||||
elif parameter.is_array_type():
|
||||
if isinstance(result[parameter.name], list):
|
||||
if isinstance(param_value, list):
|
||||
nested_type = parameter.element_type()
|
||||
assert nested_type is not None
|
||||
segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
|
||||
transformed_result[parameter.name] = segment_value
|
||||
for item in result[parameter.name]:
|
||||
if nested_type == "number":
|
||||
if isinstance(item, int | float):
|
||||
segment_value.value.append(item)
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
if "." in item:
|
||||
segment_value.value.append(float(item))
|
||||
else:
|
||||
segment_value.value.append(int(item))
|
||||
except ValueError:
|
||||
pass
|
||||
elif nested_type == "string":
|
||||
for item in param_value:
|
||||
if nested_type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(item)
|
||||
if transformed is not None:
|
||||
segment_value.value.append(transformed)
|
||||
elif nested_type == SegmentType.STRING:
|
||||
if isinstance(item, str):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == "object":
|
||||
elif nested_type == SegmentType.OBJECT:
|
||||
if isinstance(item, dict):
|
||||
segment_value.value.append(item)
|
||||
elif nested_type == SegmentType.BOOLEAN:
|
||||
if isinstance(item, bool):
|
||||
segment_value.value.append(item)
|
||||
|
||||
if parameter.name not in transformed_result:
|
||||
if parameter.type == "number":
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == "bool":
|
||||
transformed_result[parameter.name] = False
|
||||
elif parameter.type in {"string", "select"}:
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type.startswith("array"):
|
||||
if parameter.type.is_array_type():
|
||||
transformed_result[parameter.name] = build_segment_with_type(
|
||||
segment_type=SegmentType(parameter.type), value=[]
|
||||
)
|
||||
elif parameter.type in (SegmentType.STRING, SegmentType.SECRET):
|
||||
transformed_result[parameter.name] = ""
|
||||
elif parameter.type == SegmentType.NUMBER:
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
transformed_result[parameter.name] = False
|
||||
else:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeAlias
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables.segments import BooleanSegment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
@ -158,8 +159,8 @@ class VariableAssignerNode(BaseNode):
|
||||
def get_zero_value(t: SegmentType):
|
||||
# TODO(QuantumGhost): this should be a method of `SegmentType`.
|
||||
match t:
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
|
||||
return variable_factory.build_segment([])
|
||||
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN:
|
||||
return variable_factory.build_segment_with_type(t, [])
|
||||
case SegmentType.OBJECT:
|
||||
return variable_factory.build_segment({})
|
||||
case SegmentType.STRING:
|
||||
@ -170,5 +171,7 @@ def get_zero_value(t: SegmentType):
|
||||
return variable_factory.build_segment(0.0)
|
||||
case SegmentType.NUMBER:
|
||||
return variable_factory.build_segment(0)
|
||||
case SegmentType.BOOLEAN:
|
||||
return BooleanSegment(value=False)
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported variable type: {t}")
|
||||
|
||||
@ -4,9 +4,11 @@ from core.variables import SegmentType
|
||||
EMPTY_VALUE_MAPPING = {
|
||||
SegmentType.STRING: "",
|
||||
SegmentType.NUMBER: 0,
|
||||
SegmentType.BOOLEAN: False,
|
||||
SegmentType.OBJECT: {},
|
||||
SegmentType.ARRAY_ANY: [],
|
||||
SegmentType.ARRAY_STRING: [],
|
||||
SegmentType.ARRAY_NUMBER: [],
|
||||
SegmentType.ARRAY_OBJECT: [],
|
||||
SegmentType.ARRAY_BOOLEAN: [],
|
||||
}
|
||||
|
||||
@ -16,28 +16,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
SegmentType.BOOLEAN,
|
||||
}
|
||||
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
|
||||
# Only number variable can be added, subtracted, multiplied or divided
|
||||
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
|
||||
case Operation.APPEND | Operation.EXTEND:
|
||||
case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can be appended or extended
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
|
||||
# Only array variable can have elements removed
|
||||
return variable_type in {
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_FILE,
|
||||
}
|
||||
return variable_type.is_array_type()
|
||||
case _:
|
||||
return False
|
||||
|
||||
@ -50,7 +37,7 @@ def is_variable_input_supported(*, operation: Operation):
|
||||
|
||||
def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation):
|
||||
match variable_type:
|
||||
case SegmentType.STRING | SegmentType.OBJECT:
|
||||
case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN:
|
||||
return operation in {Operation.OVER_WRITE, Operation.SET}
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
return operation in {
|
||||
@ -72,6 +59,9 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
case SegmentType.STRING:
|
||||
return isinstance(value, str)
|
||||
|
||||
case SegmentType.BOOLEAN:
|
||||
return isinstance(value, bool)
|
||||
|
||||
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
|
||||
if not isinstance(value, int | float):
|
||||
return False
|
||||
@ -91,6 +81,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
return isinstance(value, int | float)
|
||||
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
|
||||
return isinstance(value, dict)
|
||||
case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND:
|
||||
return isinstance(value, bool)
|
||||
|
||||
# Array & Extend / Overwrite
|
||||
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
@ -101,6 +93,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
|
||||
return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
|
||||
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, dict) for item in value)
|
||||
case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
|
||||
return isinstance(value, list) and all(isinstance(item, bool) for item in value)
|
||||
|
||||
case _:
|
||||
return False
|
||||
|
||||
@ -45,5 +45,5 @@ class SubVariableCondition(BaseModel):
|
||||
class Condition(BaseModel):
|
||||
variable_selector: list[str]
|
||||
comparison_operator: SupportedComparisonOperator
|
||||
value: str | Sequence[str] | None = None
|
||||
value: str | Sequence[str] | bool | None = None
|
||||
sub_variable_condition: SubVariableCondition | None = None
|
||||
|
||||
@ -1,13 +1,27 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from core.file import FileAttribute, file_manager
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from .entities import Condition, SubCondition, SupportedComparisonOperator
|
||||
|
||||
|
||||
def _convert_to_bool(value: Any) -> bool:
|
||||
if isinstance(value, int):
|
||||
return bool(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
loaded = json.loads(value)
|
||||
if isinstance(loaded, (int, bool)):
|
||||
return bool(loaded)
|
||||
|
||||
raise TypeError(f"unexpected value: type={type(value)}, value={value}")
|
||||
|
||||
|
||||
class ConditionProcessor:
|
||||
def process_conditions(
|
||||
self,
|
||||
@ -48,9 +62,16 @@ class ConditionProcessor:
|
||||
)
|
||||
else:
|
||||
actual_value = variable.value if variable else None
|
||||
expected_value = condition.value
|
||||
expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = variable_pool.convert_template(expected_value).text
|
||||
# Here we need to explicit convet the input string to boolean.
|
||||
if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None:
|
||||
# The following two lines is for compatibility with existing workflows.
|
||||
if isinstance(expected_value, list):
|
||||
expected_value = [_convert_to_bool(i) for i in expected_value]
|
||||
else:
|
||||
expected_value = _convert_to_bool(expected_value)
|
||||
input_conditions.append(
|
||||
{
|
||||
"actual_value": actual_value,
|
||||
@ -77,7 +98,7 @@ def _evaluate_condition(
|
||||
*,
|
||||
operator: SupportedComparisonOperator,
|
||||
value: Any,
|
||||
expected: str | Sequence[str] | None,
|
||||
expected: Union[str, Sequence[str], bool | Sequence[bool], None],
|
||||
) -> bool:
|
||||
match operator:
|
||||
case "contains":
|
||||
@ -130,7 +151,7 @@ def _assert_contains(*, value: Any, expected: Any) -> bool:
|
||||
if not value:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str | list):
|
||||
if not isinstance(value, (str, list)):
|
||||
raise ValueError("Invalid actual value type: string or array")
|
||||
|
||||
if expected not in value:
|
||||
@ -142,7 +163,7 @@ def _assert_not_contains(*, value: Any, expected: Any) -> bool:
|
||||
if not value:
|
||||
return True
|
||||
|
||||
if not isinstance(value, str | list):
|
||||
if not isinstance(value, (str, list)):
|
||||
raise ValueError("Invalid actual value type: string or array")
|
||||
|
||||
if expected in value:
|
||||
@ -178,8 +199,8 @@ def _assert_is(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Invalid actual value type: string")
|
||||
if not isinstance(value, (str, bool)):
|
||||
raise ValueError("Invalid actual value type: string or boolean")
|
||||
|
||||
if value != expected:
|
||||
return False
|
||||
@ -190,8 +211,8 @@ def _assert_is_not(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("Invalid actual value type: string")
|
||||
if not isinstance(value, (str, bool)):
|
||||
raise ValueError("Invalid actual value type: string or boolean")
|
||||
|
||||
if value == expected:
|
||||
return False
|
||||
@ -214,10 +235,13 @@ def _assert_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
if not isinstance(value, (int, float, bool)):
|
||||
raise ValueError("Invalid actual value type: number or boolean")
|
||||
|
||||
if isinstance(value, int):
|
||||
# Handle boolean comparison
|
||||
if isinstance(value, bool):
|
||||
expected = bool(expected)
|
||||
elif isinstance(value, int):
|
||||
expected = int(expected)
|
||||
else:
|
||||
expected = float(expected)
|
||||
@ -231,10 +255,13 @@ def _assert_not_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
if not isinstance(value, (int, float, bool)):
|
||||
raise ValueError("Invalid actual value type: number or boolean")
|
||||
|
||||
if isinstance(value, int):
|
||||
# Handle boolean comparison
|
||||
if isinstance(value, bool):
|
||||
expected = bool(expected)
|
||||
elif isinstance(value, int):
|
||||
expected = int(expected)
|
||||
else:
|
||||
expected = float(expected)
|
||||
@ -248,7 +275,7 @@ def _assert_greater_than(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
@ -265,7 +292,7 @@ def _assert_less_than(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
@ -282,7 +309,7 @@ def _assert_greater_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
@ -299,7 +326,7 @@ def _assert_less_than_or_equal(*, value: Any, expected: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(value, int | float):
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("Invalid actual value type: number")
|
||||
|
||||
if isinstance(value, int):
|
||||
|
||||
Reference in New Issue
Block a user