merge main

This commit is contained in:
Joel
2024-10-28 10:51:02 +08:00
858 changed files with 16206 additions and 17932 deletions

View File

@ -2,7 +2,7 @@ from uuid import uuid4
import pytest
from core.app.segments import (
from core.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
@ -11,43 +11,43 @@ from core.app.segments import (
ObjectSegment,
SecretVariable,
StringVariable,
factory,
)
from core.app.segments.exc import VariableError
from core.variables.exc import VariableError
from factories import variable_factory
def test_string_variable():
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
result = factory.build_variable_from_mapping(test_data)
result = variable_factory.build_variable_from_mapping(test_data)
assert isinstance(result, StringVariable)
def test_integer_variable():
test_data = {"value_type": "number", "name": "test_int", "value": 42}
result = factory.build_variable_from_mapping(test_data)
result = variable_factory.build_variable_from_mapping(test_data)
assert isinstance(result, IntegerVariable)
def test_float_variable():
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
result = factory.build_variable_from_mapping(test_data)
result = variable_factory.build_variable_from_mapping(test_data)
assert isinstance(result, FloatVariable)
def test_secret_variable():
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
result = factory.build_variable_from_mapping(test_data)
result = variable_factory.build_variable_from_mapping(test_data)
assert isinstance(result, SecretVariable)
def test_invalid_value_type():
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
with pytest.raises(VariableError):
factory.build_variable_from_mapping(test_data)
variable_factory.build_variable_from_mapping(test_data)
def test_build_a_blank_string():
result = factory.build_variable_from_mapping(
result = variable_factory.build_variable_from_mapping(
{
"value_type": "string",
"name": "blank",
@ -59,7 +59,7 @@ def test_build_a_blank_string():
def test_build_a_object_variable_with_none_value():
var = factory.build_segment(
var = variable_factory.build_segment(
{
"key1": None,
}
@ -79,7 +79,7 @@ def test_object_variable():
"key2": 2,
},
}
variable = factory.build_variable_from_mapping(mapping)
variable = variable_factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value["key1"], str)
assert isinstance(variable.value["key2"], int)
@ -96,7 +96,7 @@ def test_array_string_variable():
"text",
],
}
variable = factory.build_variable_from_mapping(mapping)
variable = variable_factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayStringVariable)
assert isinstance(variable.value[0], str)
assert isinstance(variable.value[1], str)
@ -113,7 +113,7 @@ def test_array_number_variable():
2.0,
],
}
variable = factory.build_variable_from_mapping(mapping)
variable = variable_factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayNumberVariable)
assert isinstance(variable.value[0], int)
assert isinstance(variable.value[1], float)
@ -136,7 +136,7 @@ def test_array_object_variable():
},
],
}
variable = factory.build_variable_from_mapping(mapping)
variable = variable_factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
@ -146,13 +146,13 @@ def test_array_object_variable():
assert isinstance(variable.value[1]["key2"], int)
def test_variable_cannot_large_than_5_kb():
def test_variable_cannot_large_than_200_kb():
with pytest.raises(VariableError):
factory.build_variable_from_mapping(
variable_factory.build_variable_from_mapping(
{
"id": str(uuid4()),
"value_type": "string",
"name": "test_text",
"value": "a" * 1024 * 6,
"value": "a" * 1024 * 201,
}
)

View File

@ -1,5 +1,5 @@
from core.app.segments import SecretVariable, StringSegment, parser
from core.helper import encrypter
from core.variables import SecretVariable, StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
@ -13,12 +13,13 @@ def test_segment_group_to_text():
environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"),
],
conversation_variables=[],
)
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key."
assert segments_group.log == (
@ -32,9 +33,10 @@ def test_convert_constant_to_segment_group():
system_variables={},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
template = "Hello, world!"
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "Hello, world!"
assert segments_group.log == "Hello, world!"
@ -46,9 +48,10 @@ def test_convert_variable_to_segment_group():
},
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
template = "{{#sys.user_id#}}"
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
segments_group = variable_pool.convert_template(template)
assert segments_group.text == "fake-user-id"
assert segments_group.log == "fake-user-id"
assert segments_group.value == [StringSegment(value="fake-user-id")]

View File

@ -1,7 +1,7 @@
import pytest
from pydantic import ValidationError
from core.app.segments import (
from core.variables import (
FloatVariable,
IntegerVariable,
ObjectVariable,

View File

@ -6,7 +6,7 @@ import pytest
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
@patch("httpx.request")
@patch("httpx.Client.request")
def test_successful_request(mock_request):
mock_response = MagicMock()
mock_response.status_code = 200
@ -16,7 +16,7 @@ def test_successful_request(mock_request):
assert response.status_code == 200
@patch("httpx.request")
@patch("httpx.Client.request")
def test_retry_exceed_max_retries(mock_request):
mock_response = MagicMock()
mock_response.status_code = 500
@ -29,7 +29,7 @@ def test_retry_exceed_max_retries(mock_request):
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
@patch("httpx.request")
@patch("httpx.Client.request")
def test_retry_logic_success(mock_request):
side_effects = []

View File

@ -1,3 +1,5 @@
import string
import numpy as np
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
@ -31,7 +33,7 @@ def test_max_chunks():
max_chunks = embedding_model._get_max_chunks(model, credentials)
embedding_model._create_text_embedding = _create_text_embedding
texts = ["0123456789" for i in range(0, max_chunks * 2)]
texts = [string.digits for i in range(0, max_chunks * 2)]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test")
assert len(result.embeddings) == max_chunks * 2

View File

@ -1,11 +1,16 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
from core.app.app_config.entities import ModelConfigEntity
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessageRole,
UserPromptMessage,
)
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
@ -123,32 +128,30 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
model_config_mock, _, messages, inputs, context = get_chat_model_args
files = [
FileVar(
File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
url="https://example.com/image1.jpg",
extra_config=FileExtraConfig(
image_config={
"detail": "high",
}
),
remote_url="https://example.com/image1.jpg",
_extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)),
)
]
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
prompt_template=messages,
inputs=inputs,
query=None,
files=files,
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock,
)
assert len(prompt_messages) == 4
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
@ -157,7 +160,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
)
assert isinstance(prompt_messages[3].content, list)
assert len(prompt_messages[3].content) == 2
assert prompt_messages[3].content[1].data == files[0].url
assert prompt_messages[3].content[1].data == files[0].remote_url
@pytest.fixture

View File

@ -1,56 +0,0 @@
import pytest
from core.tools.entities.tool_entities import ToolParameter
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
def test_get_parameter_type():
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number"
with pytest.raises(ValueError):
ToolParameterConverter.get_parameter_type("unsupported_type")
def test_cast_parameter_by_type():
# string
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ""
# secret input
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ""
# select
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ""
# boolean
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
for value in true_values:
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
for value in false_values:
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False
# number
assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1
assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0
assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0
assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
# unknown
assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1"
assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None

View File

@ -1,7 +1,7 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
@ -18,7 +18,8 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@ -86,7 +87,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
{"role": "system", "text": "say hi"},
{"role": "user", "text": "{{#start.query#}}"},
],
"vision": {"configs": {"detail": "high"}, "enabled": False},
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
},
"id": "llm1",
},
@ -105,7 +106,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
{"role": "system", "text": "say bye"},
{"role": "user", "text": "{{#start.query#}}"},
],
"vision": {"configs": {"detail": "high"}, "enabled": False},
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
},
"id": "llm2",
},
@ -124,7 +125,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
{"role": "system", "text": "say good morning"},
{"role": "user", "text": "{{#start.query#}}"},
],
"vision": {"configs": {"detail": "high"}, "enabled": False},
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
},
"id": "llm3",
},

View File

@ -3,7 +3,6 @@ import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
@ -11,6 +10,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType

View File

@ -2,7 +2,6 @@ import uuid
from collections.abc import Generator
from datetime import datetime, timezone
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
@ -14,6 +13,7 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.start.entities import StartNodeData
@ -39,7 +39,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
node_execution_id = str(uuid.uuid4())
node_config = graph.node_id_config_mapping[next_node_id]
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
node_type = NodeType(node_config.get("data", {}).get("type"))
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
yield NodeRunStartedEvent(

View File

@ -3,7 +3,7 @@ import uuid
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, UserFrom
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
@ -12,6 +12,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType

View File

@ -3,7 +3,6 @@ import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
@ -11,6 +10,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.answer.answer_node import AnswerNode
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType

View File

@ -65,7 +65,8 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
@pytest.mark.parametrize(
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
[
("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
("text/plain", b"Hello, world!",
["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
(
"application/pdf",
b"%PDF-1.5\n%Test PDF content",
@ -80,7 +81,8 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
FileTransferMethod.REMOTE_URL,
"",
),
("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None),
("text/plain", b"Remote content",
["Remote content"], FileTransferMethod.REMOTE_URL, None),
],
)
def test_run_extract_text(
@ -117,20 +119,23 @@ def test_run_extract_text(
if mime_type == "application/pdf":
mock_pdf_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
monkeypatch.setattr(
"core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
elif mime_type.startswith("application/vnd.openxmlformats"):
mock_docx_extract = Mock(return_value=expected_text[0])
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_doc", mock_docx_extract)
monkeypatch.setattr(
"core.workflow.nodes.document_extractor.node._extract_text_from_doc", mock_docx_extract)
result = document_extractor_node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
assert result.outputs is not None
assert result.outputs["text"] == expected_text
if transfer_method == FileTransferMethod.REMOTE_URL:
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")
mock_ssrf_proxy_get.assert_called_once_with(
"https://example.com/file.txt")
elif transfer_method == FileTransferMethod.LOCAL_FILE:
mock_download.assert_called_once_with(mock_file)

View File

@ -1,16 +1,20 @@
import time
import uuid
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import UserFrom
from core.file import File, FileTransferMethod, FileType
from core.variables import ArrayFileSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
@ -51,6 +55,7 @@ def test_execute_if_else_result_true():
pool.add(["start", "less_than"], 21)
pool.add(["start", "greater_than_or_equal"], 22)
pool.add(["start", "less_than_or_equal"], 21)
pool.add(["start", "null"], None)
pool.add(["start", "not_null"], "1212")
node = IfElseNode(
@ -111,6 +116,7 @@ def test_execute_if_else_result_true():
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True
@ -191,4 +197,63 @@ def test_execute_if_else_result_false():
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is False
def test_array_file_contains_file_name():
node_data = IfElseNodeData(
title="123",
logical_operator="and",
cases=[
IfElseNodeData.Case(
case_id="true",
logical_operator="and",
conditions=[
Condition(
comparison_operator="contains",
variable_selector=["start", "array_contains"],
sub_variable_condition=SubVariableCondition(
logical_operator="and",
conditions=[
SubCondition(
key="name",
comparison_operator="contains",
value="ab",
)
],
),
)
],
)
],
)
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
config={
"id": "if-else",
"data": node_data.model_dump(),
},
)
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
filename="ab",
),
],
)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] is True

View File

@ -4,14 +4,14 @@ from unittest import mock
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.segments import ArrayStringVariable, StringVariable
from core.workflow.entities.node_entities import UserFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
from models.enums import UserFrom
from models.workflow import WorkflowType
DEFAULT_NODE_ID = "node_id"