mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
merge main
This commit is contained in:
@ -2,7 +2,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.segments import (
|
||||
from core.variables import (
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayStringVariable,
|
||||
@ -11,43 +11,43 @@ from core.app.segments import (
|
||||
ObjectSegment,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
factory,
|
||||
)
|
||||
from core.app.segments.exc import VariableError
|
||||
from core.variables.exc import VariableError
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
result = variable_factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, StringVariable)
|
||||
|
||||
|
||||
def test_integer_variable():
|
||||
test_data = {"value_type": "number", "name": "test_int", "value": 42}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
result = variable_factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, IntegerVariable)
|
||||
|
||||
|
||||
def test_float_variable():
|
||||
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
result = variable_factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, FloatVariable)
|
||||
|
||||
|
||||
def test_secret_variable():
|
||||
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
result = variable_factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, SecretVariable)
|
||||
|
||||
|
||||
def test_invalid_value_type():
|
||||
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(test_data)
|
||||
variable_factory.build_variable_from_mapping(test_data)
|
||||
|
||||
|
||||
def test_build_a_blank_string():
|
||||
result = factory.build_variable_from_mapping(
|
||||
result = variable_factory.build_variable_from_mapping(
|
||||
{
|
||||
"value_type": "string",
|
||||
"name": "blank",
|
||||
@ -59,7 +59,7 @@ def test_build_a_blank_string():
|
||||
|
||||
|
||||
def test_build_a_object_variable_with_none_value():
|
||||
var = factory.build_segment(
|
||||
var = variable_factory.build_segment(
|
||||
{
|
||||
"key1": None,
|
||||
}
|
||||
@ -79,7 +79,7 @@ def test_object_variable():
|
||||
"key2": 2,
|
||||
},
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
variable = variable_factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ObjectSegment)
|
||||
assert isinstance(variable.value["key1"], str)
|
||||
assert isinstance(variable.value["key2"], int)
|
||||
@ -96,7 +96,7 @@ def test_array_string_variable():
|
||||
"text",
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
variable = variable_factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayStringVariable)
|
||||
assert isinstance(variable.value[0], str)
|
||||
assert isinstance(variable.value[1], str)
|
||||
@ -113,7 +113,7 @@ def test_array_number_variable():
|
||||
2.0,
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
variable = variable_factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayNumberVariable)
|
||||
assert isinstance(variable.value[0], int)
|
||||
assert isinstance(variable.value[1], float)
|
||||
@ -136,7 +136,7 @@ def test_array_object_variable():
|
||||
},
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
variable = variable_factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ArrayObjectVariable)
|
||||
assert isinstance(variable.value[0], dict)
|
||||
assert isinstance(variable.value[1], dict)
|
||||
@ -146,13 +146,13 @@ def test_array_object_variable():
|
||||
assert isinstance(variable.value[1]["key2"], int)
|
||||
|
||||
|
||||
def test_variable_cannot_large_than_5_kb():
|
||||
def test_variable_cannot_large_than_200_kb():
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(
|
||||
variable_factory.build_variable_from_mapping(
|
||||
{
|
||||
"id": str(uuid4()),
|
||||
"value_type": "string",
|
||||
"name": "test_text",
|
||||
"value": "a" * 1024 * 6,
|
||||
"value": "a" * 1024 * 201,
|
||||
}
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from core.app.segments import SecretVariable, StringSegment, parser
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
@ -13,12 +13,13 @@ def test_segment_group_to_text():
|
||||
environment_variables=[
|
||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||
],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
|
||||
template = (
|
||||
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
||||
)
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
segments_group = variable_pool.convert_template(template)
|
||||
|
||||
assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key."
|
||||
assert segments_group.log == (
|
||||
@ -32,9 +33,10 @@ def test_convert_constant_to_segment_group():
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
template = "Hello, world!"
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
segments_group = variable_pool.convert_template(template)
|
||||
assert segments_group.text == "Hello, world!"
|
||||
assert segments_group.log == "Hello, world!"
|
||||
|
||||
@ -46,9 +48,10 @@ def test_convert_variable_to_segment_group():
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
template = "{{#sys.user_id#}}"
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
segments_group = variable_pool.convert_template(template)
|
||||
assert segments_group.text == "fake-user-id"
|
||||
assert segments_group.log == "fake-user-id"
|
||||
assert segments_group.value == [StringSegment(value="fake-user-id")]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.segments import (
|
||||
from core.variables import (
|
||||
FloatVariable,
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
|
||||
@ -6,7 +6,7 @@ import pytest
|
||||
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
||||
|
||||
|
||||
@patch("httpx.request")
|
||||
@patch("httpx.Client.request")
|
||||
def test_successful_request(mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
@ -16,7 +16,7 @@ def test_successful_request(mock_request):
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@patch("httpx.request")
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_exceed_max_retries(mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@ -29,7 +29,7 @@ def test_retry_exceed_max_retries(mock_request):
|
||||
assert str(e.value) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
||||
|
||||
|
||||
@patch("httpx.request")
|
||||
@patch("httpx.Client.request")
|
||||
def test_retry_logic_success(mock_request):
|
||||
side_effects = []
|
||||
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
@ -31,7 +33,7 @@ def test_max_chunks():
|
||||
max_chunks = embedding_model._get_max_chunks(model, credentials)
|
||||
embedding_model._create_text_embedding = _create_text_embedding
|
||||
|
||||
texts = ["0123456789" for i in range(0, max_chunks * 2)]
|
||||
texts = [string.digits for i in range(0, max_chunks * 2)]
|
||||
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test")
|
||||
assert len(result.embeddings) == max_chunks * 2
|
||||
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.file.file_obj import FileExtraConfig, FileTransferMethod, FileType, FileVar
|
||||
from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, UserPromptMessage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
@ -123,32 +128,30 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
model_config_mock, _, messages, inputs, context = get_chat_model_args
|
||||
|
||||
files = [
|
||||
FileVar(
|
||||
File(
|
||||
id="file1",
|
||||
tenant_id="tenant1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
url="https://example.com/image1.jpg",
|
||||
extra_config=FileExtraConfig(
|
||||
image_config={
|
||||
"detail": "high",
|
||||
}
|
||||
),
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
_extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)),
|
||||
)
|
||||
]
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
with patch("core.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
|
||||
mock_get_encoded_string.return_value = ImagePromptMessageContent(data=str(files[0].remote_url))
|
||||
prompt_messages = prompt_transform._get_chat_model_prompt_messages(
|
||||
prompt_template=messages,
|
||||
inputs=inputs,
|
||||
query=None,
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
@ -157,7 +160,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
)
|
||||
assert isinstance(prompt_messages[3].content, list)
|
||||
assert len(prompt_messages[3].content) == 2
|
||||
assert prompt_messages[3].content[1].data == files[0].url
|
||||
assert prompt_messages[3].content[1].data == files[0].remote_url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
|
||||
def test_get_parameter_type():
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number"
|
||||
with pytest.raises(ValueError):
|
||||
ToolParameterConverter.get_parameter_type("unsupported_type")
|
||||
|
||||
|
||||
def test_cast_parameter_by_type():
|
||||
# string
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ""
|
||||
|
||||
# secret input
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ""
|
||||
|
||||
# select
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ""
|
||||
|
||||
# boolean
|
||||
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
|
||||
for value in true_values:
|
||||
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True
|
||||
|
||||
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
|
||||
for value in false_values:
|
||||
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False
|
||||
|
||||
# number
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
|
||||
|
||||
# unknown
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
|
||||
@ -1,7 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@ -18,7 +18,8 @@ from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
@ -86,7 +87,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
{"role": "system", "text": "say hi"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
@ -105,7 +106,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
{"role": "system", "text": "say bye"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
@ -124,7 +125,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
{"role": "system", "text": "say good morning"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
"vision": {"configs": {"detail": "high", "variable_selector": []}, "enabled": False},
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
|
||||
@ -3,7 +3,6 @@ import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@ -11,6 +10,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@ import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@ -14,6 +13,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve
|
||||
|
||||
node_execution_id = str(uuid.uuid4())
|
||||
node_config = graph.node_id_config_mapping[next_node_id]
|
||||
node_type = NodeType.value_of(node_config.get("data", {}).get("type"))
|
||||
node_type = NodeType(node_config.get("data", {}).get("type"))
|
||||
mock_node_data = StartNodeData(**{"title": "demo", "variables": []})
|
||||
|
||||
yield NodeRunStartedEvent(
|
||||
|
||||
@ -3,7 +3,7 @@ import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult, UserFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@ -12,6 +12,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@ -11,6 +10,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
|
||||
@ -65,7 +65,8 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
|
||||
@pytest.mark.parametrize(
|
||||
("mime_type", "file_content", "expected_text", "transfer_method", "extension"),
|
||||
[
|
||||
("text/plain", b"Hello, world!", ["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
|
||||
("text/plain", b"Hello, world!",
|
||||
["Hello, world!"], FileTransferMethod.LOCAL_FILE, ".txt"),
|
||||
(
|
||||
"application/pdf",
|
||||
b"%PDF-1.5\n%Test PDF content",
|
||||
@ -80,7 +81,8 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s
|
||||
FileTransferMethod.REMOTE_URL,
|
||||
"",
|
||||
),
|
||||
("text/plain", b"Remote content", ["Remote content"], FileTransferMethod.REMOTE_URL, None),
|
||||
("text/plain", b"Remote content",
|
||||
["Remote content"], FileTransferMethod.REMOTE_URL, None),
|
||||
],
|
||||
)
|
||||
def test_run_extract_text(
|
||||
@ -117,20 +119,23 @@ def test_run_extract_text(
|
||||
|
||||
if mime_type == "application/pdf":
|
||||
mock_pdf_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract)
|
||||
elif mime_type.startswith("application/vnd.openxmlformats"):
|
||||
mock_docx_extract = Mock(return_value=expected_text[0])
|
||||
monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_doc", mock_docx_extract)
|
||||
monkeypatch.setattr(
|
||||
"core.workflow.nodes.document_extractor.node._extract_text_from_doc", mock_docx_extract)
|
||||
|
||||
result = document_extractor_node._run()
|
||||
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["text"] == expected_text
|
||||
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt")
|
||||
mock_ssrf_proxy_get.assert_called_once_with(
|
||||
"https://example.com/file.txt")
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
mock_download.assert_called_once_with(mock_file)
|
||||
|
||||
|
||||
@ -1,16 +1,20 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
@ -51,6 +55,7 @@ def test_execute_if_else_result_true():
|
||||
pool.add(["start", "less_than"], 21)
|
||||
pool.add(["start", "greater_than_or_equal"], 22)
|
||||
pool.add(["start", "less_than_or_equal"], 21)
|
||||
pool.add(["start", "null"], None)
|
||||
pool.add(["start", "not_null"], "1212")
|
||||
|
||||
node = IfElseNode(
|
||||
@ -111,6 +116,7 @@ def test_execute_if_else_result_true():
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
|
||||
@ -191,4 +197,63 @@ def test_execute_if_else_result_false():
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is False
|
||||
|
||||
|
||||
def test_array_file_contains_file_name():
|
||||
node_data = IfElseNodeData(
|
||||
title="123",
|
||||
logical_operator="and",
|
||||
cases=[
|
||||
IfElseNodeData.Case(
|
||||
case_id="true",
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
Condition(
|
||||
comparison_operator="contains",
|
||||
variable_selector=["start", "array_contains"],
|
||||
sub_variable_condition=SubVariableCondition(
|
||||
logical_operator="and",
|
||||
conditions=[
|
||||
SubCondition(
|
||||
key="name",
|
||||
comparison_operator="contains",
|
||||
value="ab",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
node = IfElseNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=Mock(),
|
||||
graph=Mock(),
|
||||
graph_runtime_state=Mock(),
|
||||
config={
|
||||
"id": "if-else",
|
||||
"data": node_data.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
|
||||
value=[
|
||||
File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="1",
|
||||
filename="ab",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
@ -4,14 +4,14 @@ from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.variables import ArrayStringVariable, StringVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
DEFAULT_NODE_ID = "node_id"
|
||||
|
||||
Reference in New Issue
Block a user