Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -206,6 +206,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5
CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5

View File

@ -9,12 +9,13 @@ from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from app_factory import create_app
from models import Account, DifySetup, Tenant, TenantAccountJoin, db
from extensions.ext_database import db
from models import Account, DifySetup, Tenant, TenantAccountJoin
from services.account_service import AccountService, RegisterService
# Loading the .env file if it exists
def _load_env() -> None:
def _load_env():
current_file_path = pathlib.Path(__file__).absolute()
# Items later in the list have higher precedence.
files_to_load = [".env", "vdb.env"]

View File

@ -0,0 +1,206 @@
"""Integration tests for ChatMessageApi permission verification."""
import uuid
from types import SimpleNamespace
from unittest import mock
import pytest
from flask.testing import FlaskClient
from controllers.console.app import completion as completion_api
from controllers.console.app import message as message_api
from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import Account, App, Tenant
from models.account import TenantAccountRole
from models.model import AppMode
from services.app_generate_service import AppGenerateService
class TestChatMessageApiPermissions:
"""Test permission verification for ChatMessageApi endpoint."""
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model for testing."""
app = App()
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT.value
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
return app
@pytest.fixture
def mock_account(self):
"""Create a mock Account for testing."""
account = Account()
account.id = str(uuid.uuid4())
account.name = "Test User"
account.email = "test@example.com"
account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
# Create mock tenant
tenant = Tenant()
tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
account._current_tenant = tenant
return account
@pytest.mark.parametrize(
("role", "status"),
[
(TenantAccountRole.OWNER, 200),
(TenantAccountRole.ADMIN, 200),
(TenantAccountRole.EDITOR, 200),
(TenantAccountRole.NORMAL, 403),
(TenantAccountRole.DATASET_OPERATOR, 403),
],
)
def test_post_with_owner_role_succeeds(
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
mock_app_model,
mock_account,
role: TenantAccountRole,
status: int,
):
"""Test that OWNER role can access chat-messages endpoint."""
"""Setup common mocks for testing."""
# Mock app loading
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
# Mock current user
monkeypatch.setattr(completion_api, "current_user", mock_account)
mock_generate = mock.Mock(return_value={"message": "Test response"})
monkeypatch.setattr(AppGenerateService, "generate", mock_generate)
# Set user role to OWNER
mock_account.role = role
response = test_client.post(
f"/console/api/apps/{mock_app_model.id}/chat-messages",
headers=auth_header,
json={
"inputs": {},
"query": "Hello, world!",
"model_config": {
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}}
},
"response_mode": "blocking",
},
)
assert response.status_code == status
@pytest.mark.parametrize(
("role", "status"),
[
(TenantAccountRole.OWNER, 200),
(TenantAccountRole.ADMIN, 200),
(TenantAccountRole.EDITOR, 200),
(TenantAccountRole.NORMAL, 403),
(TenantAccountRole.DATASET_OPERATOR, 403),
],
)
def test_get_requires_edit_permission(
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
mock_app_model,
mock_account,
role: TenantAccountRole,
status: int,
):
"""Ensure GET chat-messages endpoint enforces edit permissions."""
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
conversation_id = uuid.uuid4()
created_at = naive_utc_now()
mock_conversation = SimpleNamespace(id=str(conversation_id), app_id=str(mock_app_model.id))
mock_message = SimpleNamespace(
id=str(uuid.uuid4()),
conversation_id=str(conversation_id),
inputs=[],
query="hello",
message=[{"text": "hello"}],
message_tokens=0,
re_sign_file_url_answer="",
answer_tokens=0,
provider_response_latency=0.0,
from_source="console",
from_end_user_id=None,
from_account_id=mock_account.id,
feedbacks=[],
workflow_run_id=None,
annotation=None,
annotation_hit_history=None,
created_at=created_at,
agent_thoughts=[],
message_files=[],
message_metadata_dict={},
status="success",
error="",
parent_message_id=None,
)
class MockQuery:
def __init__(self, model):
self.model = model
def where(self, *args, **kwargs):
return self
def first(self):
if getattr(self.model, "__name__", "") == "Conversation":
return mock_conversation
return None
def order_by(self, *args, **kwargs):
return self
def limit(self, *_):
return self
def all(self):
if getattr(self.model, "__name__", "") == "Message":
return [mock_message]
return []
mock_session = mock.Mock()
mock_session.query.side_effect = MockQuery
mock_session.scalar.return_value = False
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
monkeypatch.setattr(message_api, "current_user", mock_account)
class DummyPagination:
def __init__(self, data, limit, has_more):
self.data = data
self.limit = limit
self.has_more = has_more
monkeypatch.setattr(message_api, "InfiniteScrollPagination", DummyPagination)
mock_account.role = role
response = test_client.get(
f"/console/api/apps/{mock_app_model.id}/chat-messages",
headers=auth_header,
query_string={"conversation_id": str(conversation_id)},
)
assert response.status_code == status

View File

@ -0,0 +1,129 @@
"""Integration tests for ModelConfigResource permission verification."""
import uuid
from unittest import mock
import pytest
from flask.testing import FlaskClient
from controllers.console.app import model_config as model_config_api
from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import Account, App, Tenant
from models.account import TenantAccountRole
from models.model import AppMode
from services.app_model_config_service import AppModelConfigService
class TestModelConfigResourcePermissions:
"""Test permission verification for ModelConfigResource endpoint."""
@pytest.fixture
def mock_app_model(self):
"""Create a mock App model for testing."""
app = App()
app.id = str(uuid.uuid4())
app.mode = AppMode.CHAT.value
app.tenant_id = str(uuid.uuid4())
app.status = "normal"
app.app_model_config_id = str(uuid.uuid4())
return app
@pytest.fixture
def mock_account(self):
"""Create a mock Account for testing."""
account = Account()
account.id = str(uuid.uuid4())
account.name = "Test User"
account.email = "test@example.com"
account.last_active_at = naive_utc_now()
account.created_at = naive_utc_now()
account.updated_at = naive_utc_now()
# Create mock tenant
tenant = Tenant()
tenant.id = str(uuid.uuid4())
tenant.name = "Test Tenant"
account._current_tenant = tenant
return account
@pytest.mark.parametrize(
("role", "status"),
[
(TenantAccountRole.OWNER, 200),
(TenantAccountRole.ADMIN, 200),
(TenantAccountRole.EDITOR, 200),
(TenantAccountRole.NORMAL, 403),
(TenantAccountRole.DATASET_OPERATOR, 403),
],
)
def test_post_with_owner_role_succeeds(
self,
test_client: FlaskClient,
auth_header,
monkeypatch,
mock_app_model,
mock_account,
role: TenantAccountRole,
status: int,
):
"""Test that OWNER role can access model-config endpoint."""
# Set user role to OWNER
mock_account.role = role
# Mock app loading
mock_load_app_model = mock.Mock(return_value=mock_app_model)
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
# Mock current user
monkeypatch.setattr(model_config_api, "current_user", mock_account)
# Mock AccountService.load_user to prevent authentication issues
from services.account_service import AccountService
mock_load_user = mock.Mock(return_value=mock_account)
monkeypatch.setattr(AccountService, "load_user", mock_load_user)
mock_validate_config = mock.Mock(
return_value={
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}},
"pre_prompt": "You are a helpful assistant.",
"user_input_form": [],
"dataset_query_variable": "",
"agent_mode": {"enabled": False, "tools": []},
}
)
monkeypatch.setattr(AppModelConfigService, "validate_configuration", mock_validate_config)
# Mock database operations
mock_db_session = mock.Mock()
mock_db_session.add = mock.Mock()
mock_db_session.flush = mock.Mock()
mock_db_session.commit = mock.Mock()
monkeypatch.setattr(model_config_api.db, "session", mock_db_session)
# Mock app_model_config_was_updated event
mock_event = mock.Mock()
mock_event.send = mock.Mock()
monkeypatch.setattr(model_config_api, "app_model_config_was_updated", mock_event)
response = test_client.post(
f"/console/api/apps/{mock_app_model.id}/model-config",
headers=auth_header,
json={
"model": {
"provider": "openai",
"name": "gpt-4",
"mode": "chat",
"completion_params": {"temperature": 0.7, "max_tokens": 1000},
},
"user_input_form": [],
"dataset_query_variable": "",
"pre_prompt": "You are a helpful assistant.",
"agent_mode": {"enabled": False, "tools": []},
},
)
assert response.status_code == status

View File

@ -1,6 +1,5 @@
import unittest
from datetime import UTC, datetime
from typing import Optional
from unittest.mock import patch
from uuid import uuid4
@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase):
self.session.rollback()
def _create_upload_file(
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None
) -> UploadFile:
"""Helper method to create an UploadFile record for testing."""
if file_id is None:
@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase):
return upload_file
def _create_tool_file(
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None
) -> ToolFile:
"""Helper method to create a ToolFile record for testing."""
if file_id is None:
@ -84,26 +83,24 @@ class TestStorageKeyLoader(unittest.TestCase):
if tenant_id is None:
tenant_id = self.tenant_id
tool_file = ToolFile()
tool_file = ToolFile(
user_id=self.user_id,
tenant_id=tenant_id,
conversation_id=self.conversation_id,
file_key=file_key,
mimetype="text/plain",
original_url="http://example.com/file.txt",
name="test_tool_file.txt",
size=2048,
)
tool_file.id = file_id
tool_file.user_id = self.user_id
tool_file.tenant_id = tenant_id
tool_file.conversation_id = self.conversation_id
tool_file.file_key = file_key
tool_file.mimetype = "text/plain"
tool_file.original_url = "http://example.com/file.txt"
tool_file.name = "test_tool_file.txt"
tool_file.size = 2048
self.session.add(tool_file)
self.session.flush()
self.test_tool_files.append(tool_file)
return tool_file
def _create_file(
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
) -> File:
def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File:
"""Helper method to create a File object for testing."""
if tenant_id is None:
tenant_id = self.tenant_id

View File

@ -17,7 +17,7 @@ def mock_plugin_daemon(
:return: unpatch function
"""
def unpatch() -> None:
def unpatch():
monkeypatch.undo()
monkeypatch.setattr(PluginModelClient, "invoke_llm", MockModelClass.invoke_llm)

View File

@ -5,8 +5,6 @@ from decimal import Decimal
from json import dumps
# import monkeypatch
from typing import Optional
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
@ -113,8 +111,8 @@ class MockModelClass(PluginModelClient):
@staticmethod
def generate_function_call(
tools: Optional[list[PromptMessageTool]],
) -> Optional[AssistantPromptMessage.ToolCall]:
tools: list[PromptMessageTool] | None,
) -> AssistantPromptMessage.ToolCall | None:
if not tools or len(tools) == 0:
return None
function: PromptMessageTool = tools[0]
@ -157,7 +155,7 @@ class MockModelClass(PluginModelClient):
def mocked_chat_create_sync(
model: str,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
tools: list[PromptMessageTool] | None = None,
) -> LLMResult:
tool_call = MockModelClass.generate_function_call(tools=tools)
@ -186,7 +184,7 @@ class MockModelClass(PluginModelClient):
def mocked_chat_create_stream(
model: str,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
tools: list[PromptMessageTool] | None = None,
) -> Generator[LLMResultChunk, None, None]:
tool_call = MockModelClass.generate_function_call(tools=tools)
@ -241,9 +239,9 @@ class MockModelClass(PluginModelClient):
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
model_parameters: dict | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
):
return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools)

View File

@ -1,8 +1,8 @@
import os
from typing import Literal
import httpx
import pytest
import requests
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
from core.tools.entities.common_entities import I18nObject
@ -27,13 +27,11 @@ class MockedHttp:
@classmethod
def requests_request(
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> requests.Response:
) -> httpx.Response:
"""
Mocked requests.request
Mocked httpx.request
"""
request = requests.PreparedRequest()
request.method = method
request.url = url
request = httpx.Request(method, url)
if url.endswith("/tools"):
content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
code=0, message="success", data=cls.list_tools()
@ -41,8 +39,7 @@ class MockedHttp:
else:
raise ValueError("")
response = requests.Response()
response.status_code = 200
response = httpx.Response(status_code=200)
response.request = request
response._content = content.encode("utf-8")
return response
@ -54,7 +51,7 @@ MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK_SWITCH:
monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
monkeypatch.setattr(httpx, "request", MockedHttp.requests_request)
def unpatch():
monkeypatch.undo()

View File

@ -3,16 +3,27 @@ import unittest
import uuid
import pytest
from sqlalchemy import delete
from sqlalchemy.orm import Session
from core.variables.segments import StringSegment
from core.variables.types import SegmentType
from core.variables.variables import StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes import NodeType
from extensions.ext_database import db
from extensions.ext_storage import storage
from factories.variable_factory import build_segment
from libs import datetime_utils
from models import db
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService
from models.enums import CreatorUserRole
from models.model import UploadFile
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, WorkflowNodeExecutionModel
from services.workflow_draft_variable_service import (
DraftVariableSaver,
DraftVarLoader,
VariableResetError,
WorkflowDraftVariableService,
)
@pytest.mark.usefixtures("flask_req_ctx")
@ -175,6 +186,23 @@ class TestDraftVariableLoader(unittest.TestCase):
_node1_id = "test_loader_node_1"
_node_exec_id = str(uuid.uuid4())
# @pytest.fixture
# def test_app_id(self):
# return str(uuid.uuid4())
# @pytest.fixture
# def test_tenant_id(self):
# return str(uuid.uuid4())
# @pytest.fixture
# def session(self):
# with Session(bind=db.engine, expire_on_commit=False) as session:
# yield session
# @pytest.fixture
# def node_var(self, session):
# pass
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._test_tenant_id = str(uuid.uuid4())
@ -241,6 +269,246 @@ class TestDraftVariableLoader(unittest.TestCase):
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
assert node1_var.id == self._node_var_id
@pytest.mark.usefixtures("setup_account")
def test_load_offloaded_variable_string_type_integration(self, setup_account):
"""Test _load_offloaded_variable with string type using DraftVariableSaver for data creation."""
# Create a large string that will be offloaded
test_content = "x" * 15000 # Create a string larger than LARGE_VARIABLE_THRESHOLD (10KB)
large_string_segment = StringSegment(value=test_content)
node_execution_id = str(uuid.uuid4())
try:
with Session(bind=db.engine, expire_on_commit=False) as session:
# Use DraftVariableSaver to create offloaded variable (this mimics production)
saver = DraftVariableSaver(
session=session,
app_id=self._test_app_id,
node_id="test_offload_node",
node_type=NodeType.LLM, # Use a real node type
node_execution_id=node_execution_id,
user=setup_account,
)
# Save the variable - this will trigger offloading due to large size
saver.save(outputs={"offloaded_string_var": large_string_segment})
session.commit()
# Now test loading using DraftVarLoader
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
# Load the variable using the standard workflow
variables = var_loader.load_variables([["test_offload_node", "offloaded_string_var"]])
# Verify results
assert len(variables) == 1
loaded_variable = variables[0]
assert loaded_variable.name == "offloaded_string_var"
assert loaded_variable.selector == ["test_offload_node", "offloaded_string_var"]
assert isinstance(loaded_variable.value, StringSegment)
assert loaded_variable.value.value == test_content
finally:
# Clean up - delete all draft variables for this app
with Session(bind=db.engine) as session:
service = WorkflowDraftVariableService(session)
service.delete_workflow_variables(self._test_app_id)
session.commit()
def test_load_offloaded_variable_object_type_integration(self):
"""Test _load_offloaded_variable with object type using real storage and service."""
# Create a test object
test_object = {"key1": "value1", "key2": 42, "nested": {"inner": "data"}}
test_json = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
content_bytes = test_json.encode()
# Create an upload file record
upload_file = UploadFile(
tenant_id=self._test_tenant_id,
storage_type="local",
key=f"test_offload_{uuid.uuid4()}.json",
name="test_offload.json",
size=len(content_bytes),
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=datetime_utils.naive_utc_now(),
used=True,
used_by=str(uuid.uuid4()),
used_at=datetime_utils.naive_utc_now(),
)
# Store the content in storage
storage.save(upload_file.key, content_bytes)
# Create a variable file record
variable_file = WorkflowDraftVariableFile(
upload_file_id=upload_file.id,
value_type=SegmentType.OBJECT,
tenant_id=self._test_tenant_id,
app_id=self._test_app_id,
user_id=str(uuid.uuid4()),
size=len(content_bytes),
created_at=datetime_utils.naive_utc_now(),
)
try:
with Session(bind=db.engine, expire_on_commit=False) as session:
# Add upload file and variable file first to get their IDs
session.add_all([upload_file, variable_file])
session.flush() # This generates the IDs
# Now create the offloaded draft variable with the correct file_id
offloaded_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id="test_offload_node",
name="offloaded_object_var",
value=build_segment({"truncated": True}),
visible=True,
node_execution_id=str(uuid.uuid4()),
)
offloaded_var.file_id = variable_file.id
session.add(offloaded_var)
session.flush()
session.commit()
# Use the service method that properly preloads relationships
service = WorkflowDraftVariableService(session)
draft_vars = service.get_draft_variables_by_selectors(
self._test_app_id, [["test_offload_node", "offloaded_object_var"]]
)
assert len(draft_vars) == 1
loaded_var = draft_vars[0]
assert loaded_var.is_truncated()
# Create DraftVarLoader and test loading
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
# Test the _load_offloaded_variable method
selector_tuple, variable = var_loader._load_offloaded_variable(loaded_var)
# Verify the results
assert selector_tuple == ("test_offload_node", "offloaded_object_var")
assert variable.id == loaded_var.id
assert variable.name == "offloaded_object_var"
assert variable.value.value == test_object
finally:
# Clean up
with Session(bind=db.engine) as session:
# Query and delete by ID to ensure they're tracked in this session
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
session.query(UploadFile).filter_by(id=upload_file.id).delete()
session.commit()
# Clean up storage
try:
storage.delete(upload_file.key)
except Exception:
pass # Ignore cleanup failures
def test_load_variables_with_offloaded_variables_integration(self):
"""Test load_variables method with mix of regular and offloaded variables using real storage."""
# Create a regular variable (already exists from setUp)
# Create offloaded variable content
test_content = "This is offloaded content for integration test"
content_bytes = test_content.encode()
# Create upload file record
upload_file = UploadFile(
tenant_id=self._test_tenant_id,
storage_type="local",
key=f"test_integration_{uuid.uuid4()}.txt",
name="test_integration.txt",
size=len(content_bytes),
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=datetime_utils.naive_utc_now(),
used=True,
used_by=str(uuid.uuid4()),
used_at=datetime_utils.naive_utc_now(),
)
# Store the content
storage.save(upload_file.key, content_bytes)
# Create variable file
variable_file = WorkflowDraftVariableFile(
upload_file_id=upload_file.id,
value_type=SegmentType.STRING,
tenant_id=self._test_tenant_id,
app_id=self._test_app_id,
user_id=str(uuid.uuid4()),
size=len(content_bytes),
created_at=datetime_utils.naive_utc_now(),
)
try:
with Session(bind=db.engine, expire_on_commit=False) as session:
# Add upload file and variable file first to get their IDs
session.add_all([upload_file, variable_file])
session.flush() # This generates the IDs
# Now create the offloaded draft variable with the correct file_id
offloaded_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
node_id="test_integration_node",
name="offloaded_integration_var",
value=build_segment("truncated"),
visible=True,
node_execution_id=str(uuid.uuid4()),
)
offloaded_var.file_id = variable_file.id
session.add(offloaded_var)
session.flush()
session.commit()
# Test load_variables with both regular and offloaded variables
# This method should handle the relationship preloading internally
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
variables = var_loader.load_variables(
[
[SYSTEM_VARIABLE_NODE_ID, "sys_var"], # Regular variable from setUp
["test_integration_node", "offloaded_integration_var"], # Offloaded variable
]
)
# Verify results
assert len(variables) == 2
# Find regular variable
regular_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
assert regular_var.id == self._sys_var_id
assert regular_var.value == "sys_value"
# Find offloaded variable
offloaded_loaded_var = next(v for v in variables if v.selector[0] == "test_integration_node")
assert offloaded_loaded_var.id == offloaded_var.id
assert offloaded_loaded_var.value == test_content
finally:
# Clean up
with Session(bind=db.engine) as session:
# Query and delete by ID to ensure they're tracked in this session
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
session.query(UploadFile).filter_by(id=upload_file.id).delete()
session.commit()
# Clean up storage
try:
storage.delete(upload_file.key)
except Exception:
pass # Ignore cleanup failures
@pytest.mark.usefixtures("flask_req_ctx")
class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
@ -272,7 +540,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
triggered_from="workflow-run",
workflow_run_id=str(uuid.uuid4()),
index=1,
node_execution_id=self._node_exec_id,
node_execution_id=str(uuid.uuid4()),
node_id=self._node_id,
node_type=NodeType.LLM.value,
title="Test Node",
@ -281,7 +549,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
outputs='{"test_var": "output_value", "other_var": "other_output"}',
status="succeeded",
elapsed_time=1.5,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
)
@ -336,10 +604,14 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
)
self._conv_var.last_edited_at = datetime_utils.naive_utc_now()
with Session(db.engine, expire_on_commit=False) as persistent_session, persistent_session.begin():
persistent_session.add(
self._workflow_node_execution,
)
# Add all to database
db.session.add_all(
[
self._workflow_node_execution,
self._node_var_with_exec,
self._node_var_without_exec,
self._node_var_missing_exec,
@ -354,6 +626,14 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
self._node_var_missing_exec_id = self._node_var_missing_exec.id
self._conv_var_id = self._conv_var.id
def tearDown(self):
self._session.rollback()
with Session(db.engine) as session, session.begin():
stmt = delete(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.id == self._workflow_node_execution.id
)
session.execute(stmt)
def _get_test_srv(self) -> WorkflowDraftVariableService:
return WorkflowDraftVariableService(session=self._session)
@ -377,12 +657,10 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
created_by=str(uuid.uuid4()),
environment_variables=[],
conversation_variables=conversation_vars,
rag_pipeline_variables=[],
)
return workflow
def tearDown(self):
self._session.rollback()
def test_reset_node_variable_with_valid_execution_record(self):
"""Test resetting a node variable with valid execution record - should restore from execution"""
srv = self._get_test_srv()

View File

@ -3,6 +3,7 @@
import os
import tempfile
import unittest
from pathlib import Path
import pytest
@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
# Test download
with tempfile.NamedTemporaryFile() as temp_file:
storage.download(test_filename, temp_file.name)
with open(temp_file.name, "rb") as f:
downloaded_content = f.read()
downloaded_content = Path(temp_file.name).read_bytes()
assert downloaded_content == test_content
# Test scan

View File

@ -1,12 +1,15 @@
import uuid
from unittest.mock import patch
import pytest
from sqlalchemy import delete
from core.variables.segments import StringSegment
from models import Tenant, db
from models.model import App
from models.workflow import WorkflowDraftVariable
from extensions.ext_database import db
from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
@ -212,3 +215,256 @@ class TestDeleteDraftVariablesIntegration:
.execution_options(synchronize_session=False)
)
db.session.execute(query)
class TestDeleteDraftVariablesWithOffloadIntegration:
"""Integration tests for draft variable deletion with Offload data."""
@pytest.fixture
def setup_offload_test_data(self, app_and_tenant):
"""Create test data with draft variables that have associated Offload files."""
tenant, app = app_and_tenant
# Create UploadFile records
from libs.datetime_utils import naive_utc_now
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
db.session.add(upload_file1)
db.session.add(upload_file2)
db.session.flush()
# Create WorkflowDraftVariableFile records
from core.variables.types import SegmentType
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
db.session.add(var_file1)
db.session.add(var_file2)
db.session.flush()
# Create WorkflowDraftVariable records with file associations
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
# Create a regular variable without Offload data
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(draft_var1)
db.session.add(draft_var2)
db.session.add(draft_var3)
db.session.commit()
yield {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
# Cleanup
db.session.rollback()
# Clean up any remaining records
for table, ids in [
(WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]),
(WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]),
(UploadFile, [uf.id for uf in [upload_file1, upload_file2]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
db.session.execute(cleanup_query)
db.session.commit()
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
"""Test that deleting draft variables also cleans up associated Offload data."""
data = setup_offload_test_data
app_id = data["app"].id
# Mock storage deletion to succeed
mock_storage.delete.return_value = None
# Verify initial state
draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = db.session.query(WorkflowDraftVariableFile).count()
upload_files_before = db.session.query(UploadFile).count()
assert draft_vars_before == 3 # 2 with files + 1 regular
assert var_files_before == 2
assert upload_files_before == 2
# Delete draft variables
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
# Verify results
assert deleted_count == 3
# Check that all draft variables are deleted
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert draft_vars_after == 0
# Check that associated Offload data is cleaned up
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
upload_files_after = db.session.query(UploadFile).count()
assert var_files_after == 0 # All variable files should be deleted
assert upload_files_after == 0 # All upload files should be deleted
# Verify storage deletion was called for both files
assert mock_storage.delete.call_count == 2
storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list]
assert "test/file1.json" in storage_keys_deleted
assert "test/file2.json" in storage_keys_deleted
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
"""Test that database cleanup continues even when storage deletion fails."""
data = setup_offload_test_data
app_id = data["app"].id
# Mock storage deletion to fail for first file, succeed for second
mock_storage.delete.side_effect = [Exception("Storage error"), None]
# Delete draft variables
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
# Verify that all draft variables are still deleted
assert deleted_count == 3
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert draft_vars_after == 0
# Database cleanup should still succeed even with storage errors
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
upload_files_after = db.session.query(UploadFile).count()
assert var_files_after == 0
assert upload_files_after == 0
# Verify storage deletion was attempted for both files
assert mock_storage.delete.call_count == 2
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data):
"""Test deletion with mix of variables with and without Offload data."""
data = setup_offload_test_data
app_id = data["app"].id
# Create additional app with only regular variables (no offload data)
tenant = data["tenant"]
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app2)
db.session.flush()
# Add regular variables to app2
regular_vars = []
for i in range(3):
var = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var)
regular_vars.append(var)
db.session.commit()
try:
# Mock storage deletion
mock_storage.delete.return_value = None
# Delete variables for app2 (no offload data)
deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10)
assert deleted_count_app2 == 3
# Verify storage wasn't called for app2 (no offload files)
mock_storage.delete.assert_not_called()
# Delete variables for original app (with offload data)
deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10)
assert deleted_count_app1 == 3
# Now storage should be called for the offload files
assert mock_storage.delete.call_count == 2
finally:
# Cleanup app2 and its variables
cleanup_vars_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id == app2.id)
.execution_options(synchronize_session=False)
)
db.session.execute(cleanup_vars_query)
app2_obj = db.session.get(App, app2.id)
if app2_obj:
db.session.delete(app2_obj)
db.session.commit()

View File

@ -1,6 +1,5 @@
import os
from collections import UserDict
from typing import Optional
from unittest.mock import MagicMock
import pytest
@ -22,7 +21,7 @@ class MockBaiduVectorDBClass:
def mock_vector_db_client(
self,
config=None,
adapter: Optional[HTTPAdapter] = None,
adapter: HTTPAdapter | None = None,
):
self.conn = MagicMock()
self._config = MagicMock()
@ -101,8 +100,8 @@ class MockBaiduVectorDBClass:
"row": {
"id": primary_key.get("id"),
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": '{"doc_id": "doc_id_001"}',
"page_content": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"code": 0,
"msg": "Success",
@ -128,8 +127,8 @@ class MockBaiduVectorDBClass:
"row": {
"id": "doc_id_001",
"vector": [0.23432432, 0.8923744, 0.89238432],
"text": "text",
"metadata": '{"doc_id": "doc_id_001"}',
"page_content": "text",
"metadata": {"doc_id": "doc_id_001"},
},
"distance": 0.1,
"score": 0.5,

View File

@ -1,5 +1,5 @@
import os
from typing import Optional, Union
from typing import Union
import pytest
from _pytest.monkeypatch import MonkeyPatch
@ -23,16 +23,16 @@ class MockTcvectordbClass:
key="",
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=10,
adapter: Optional[HTTPAdapter] = None,
adapter: HTTPAdapter | None = None,
pool_size: int = 2,
proxies: Optional[dict] = None,
password: Optional[str] = None,
proxies: dict | None = None,
password: str | None = None,
**kwargs,
):
self._conn = None
self._read_consistency = read_consistency
def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase:
def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase:
return RPCDatabase(
name="dify",
read_consistency=self._read_consistency,
@ -42,7 +42,7 @@ class MockTcvectordbClass:
return True
def describe_collection(
self, database_name: str, collection_name: str, timeout: Optional[float] = None
self, database_name: str, collection_name: str, timeout: float | None = None
) -> RPCCollection:
index = Index(
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
@ -71,13 +71,13 @@ class MockTcvectordbClass:
collection_name: str,
shard: int,
replicas: int,
description: Optional[str] = None,
index: Optional[Index] = None,
embedding: Optional[Embedding] = None,
timeout: Optional[float] = None,
ttl_config: Optional[dict] = None,
filter_index_config: Optional[FilterIndexConfig] = None,
indexes: Optional[list[IndexField]] = None,
description: str | None = None,
index: Index | None = None,
embedding: Embedding | None = None,
timeout: float | None = None,
ttl_config: dict | None = None,
filter_index_config: FilterIndexConfig | None = None,
indexes: list[IndexField] | None = None,
) -> RPCCollection:
return RPCCollection(
RPCDatabase(
@ -102,7 +102,7 @@ class MockTcvectordbClass:
database_name: str,
collection_name: str,
documents: list[Union[Document, dict]],
timeout: Optional[float] = None,
timeout: float | None = None,
build_index: bool = True,
**kwargs,
):
@ -113,12 +113,12 @@ class MockTcvectordbClass:
database_name: str,
collection_name: str,
vectors: list[list[float]],
filter: Optional[Filter] = None,
filter: Filter | None = None,
params=None,
retrieve_vector: bool = False,
limit: int = 10,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
output_fields: list[str] | None = None,
timeout: float | None = None,
) -> list[list[dict]]:
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
@ -126,14 +126,14 @@ class MockTcvectordbClass:
self,
database_name: str,
collection_name: str,
ann: Optional[Union[list[AnnSearch], AnnSearch]] = None,
match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None,
filter: Optional[Union[Filter, str]] = None,
rerank: Optional[Rerank] = None,
retrieve_vector: Optional[bool] = None,
output_fields: Optional[list[str]] = None,
limit: Optional[int] = None,
timeout: Optional[float] = None,
ann: Union[list[AnnSearch], AnnSearch] | None = None,
match: Union[list[KeywordSearch], KeywordSearch] | None = None,
filter: Union[Filter, str] | None = None,
rerank: Rerank | None = None,
retrieve_vector: bool | None = None,
output_fields: list[str] | None = None,
limit: int | None = None,
timeout: float | None = None,
return_pd_object=False,
**kwargs,
) -> list[list[dict]]:
@ -143,27 +143,27 @@ class MockTcvectordbClass:
self,
database_name: str,
collection_name: str,
document_ids: Optional[list] = None,
document_ids: list | None = None,
retrieve_vector: bool = False,
limit: Optional[int] = None,
offset: Optional[int] = None,
filter: Optional[Filter] = None,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[dict]:
limit: int | None = None,
offset: int | None = None,
filter: Filter | None = None,
output_fields: list[str] | None = None,
timeout: float | None = None,
):
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
def collection_delete(
self,
database_name: str,
collection_name: str,
document_ids: Optional[list[str]] = None,
filter: Optional[Filter] = None,
timeout: Optional[float] = None,
document_ids: list[str] | None = None,
filter: Filter | None = None,
timeout: float | None = None,
):
return {"code": 0, "msg": "operation success"}
def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None) -> dict:
def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None):
return {"code": 0, "msg": "operation success"}

View File

@ -1,6 +1,5 @@
import os
from collections import UserDict
from typing import Optional
import pytest
from _pytest.monkeypatch import MonkeyPatch
@ -34,7 +33,7 @@ class MockIndex:
include_vectors: bool = False,
include_metadata: bool = False,
filter: str = "",
data: Optional[str] = None,
data: str | None = None,
namespace: str = "",
include_data: bool = False,
):

View File

@ -6,7 +6,7 @@ Test Clickzetta integration in Docker environment
import os
import time
import requests
import httpx
from clickzetta import connect
@ -66,7 +66,7 @@ def test_dify_api():
max_retries = 30
for i in range(max_retries):
try:
response = requests.get(f"{base_url}/console/api/health")
response = httpx.get(f"{base_url}/console/api/health")
if response.status_code == 200:
print("✓ Dify API is ready")
break

View File

@ -1,16 +1,16 @@
import environs
import os
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
env = environs.Env()
class Config:
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070")
SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN")
USING_UGC = env.bool("USING_UGC", True)
SEARCH_ENDPOINT = os.environ.get(
"SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070"
)
SEARCH_USERNAME = os.environ.get("SEARCH_USERNAME", "ADMIN")
SEARCH_PWD = os.environ.get("SEARCH_PWD", "ADMIN")
USING_UGC = os.environ.get("USING_UGC", "True").lower() == "true"
class TestLindormVectorStore(AbstractVectorTest):

View File

@ -26,7 +26,7 @@ def get_example_document(doc_id: str) -> Document:
@pytest.fixture
def setup_mock_redis() -> None:
def setup_mock_redis():
# get
ext_redis.redis_client.get = MagicMock(return_value=None)
@ -48,7 +48,7 @@ class AbstractVectorTest:
self.example_doc_id = str(uuid.uuid4())
self.example_embedding = [1.001 * i for i in range(128)]
def create_vector(self) -> None:
def create_vector(self):
self.vector.create(
texts=[get_example_document(doc_id=self.example_doc_id)],
embeddings=[self.example_embedding],

View File

@ -12,7 +12,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
class MockedCodeExecutor:
@classmethod
def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict) -> dict:
def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict):
# invoke directly
match language:
case CodeLanguage.PYTHON3:

View File

@ -5,16 +5,14 @@ from os import getenv
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
@ -29,15 +27,12 @@ def init_code_node(code_config: dict):
"target": "code",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, code_config],
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
@ -56,12 +51,21 @@ def init_code_node(code_config: dict):
variable_pool.add(["code", "args1"], 1)
variable_pool.add(["code", "args2"], 2)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = CodeNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=code_config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
@ -74,7 +78,7 @@ def init_code_node(code_config: dict):
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_code(setup_code_executor_mock):
code = """
def main(args1: int, args2: int) -> dict:
def main(args1: int, args2: int):
return {
"result": args1 + args2,
}
@ -85,6 +89,7 @@ def test_execute_code(setup_code_executor_mock):
code_config = {
"id": "code",
"data": {
"type": "code",
"outputs": {
"result": {
"type": "number",
@ -114,13 +119,13 @@ def test_execute_code(setup_code_executor_mock):
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] == 3
assert result.error is None
assert result.error == ""
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_code_output_validator(setup_code_executor_mock):
code = """
def main(args1: int, args2: int) -> dict:
def main(args1: int, args2: int):
return {
"result": args1 + args2,
}
@ -131,6 +136,7 @@ def test_execute_code_output_validator(setup_code_executor_mock):
code_config = {
"id": "code",
"data": {
"type": "code",
"outputs": {
"result": {
"type": "string",
@ -158,12 +164,12 @@ def test_execute_code_output_validator(setup_code_executor_mock):
result = node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == "Output variable `result` must be a string"
assert result.error == "Output result must be a string, got int instead"
def test_execute_code_output_validator_depth():
code = """
def main(args1: int, args2: int) -> dict:
def main(args1: int, args2: int):
return {
"result": {
"result": args1 + args2,
@ -176,6 +182,7 @@ def test_execute_code_output_validator_depth():
code_config = {
"id": "code",
"data": {
"type": "code",
"outputs": {
"string_validator": {
"type": "string",
@ -281,7 +288,7 @@ def test_execute_code_output_validator_depth():
def test_execute_code_output_object_list():
code = """
def main(args1: int, args2: int) -> dict:
def main(args1: int, args2: int):
return {
"result": {
"result": args1 + args2,
@ -294,6 +301,7 @@ def test_execute_code_output_object_list():
code_config = {
"id": "code",
"data": {
"type": "code",
"outputs": {
"object_list": {
"type": "array[object]",
@ -354,9 +362,10 @@ def test_execute_code_output_object_list():
node._transform_result(result, node._node_data.outputs)
def test_execute_code_scientific_notation():
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_code_scientific_notation(setup_code_executor_mock):
code = """
def main() -> dict:
def main():
return {
"result": -8.0E-5
}
@ -366,6 +375,7 @@ def test_execute_code_scientific_notation():
code_config = {
"id": "code",
"data": {
"type": "code",
"outputs": {
"result": {
"type": "number",

View File

@ -5,14 +5,12 @@ from urllib.parse import urlencode
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
@ -25,15 +23,12 @@ def init_http_node(config: dict):
"target": "1",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
@ -52,12 +47,21 @@ def init_http_node(config: dict):
variable_pool.add(["a", "args1"], 1)
variable_pool.add(["a", "args2"], 2)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = HttpRequestNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
@ -73,6 +77,7 @@ def test_get(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "get",
@ -106,6 +111,7 @@ def test_no_auth(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "get",
@ -135,6 +141,7 @@ def test_custom_authorization_header(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "get",
@ -227,6 +234,7 @@ def test_bearer_authorization_with_custom_header_ignored(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "get",
@ -267,6 +275,7 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "get",
@ -306,6 +315,7 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "get",
@ -339,6 +349,7 @@ def test_template(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "get",
@ -374,6 +385,7 @@ def test_json(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "post",
@ -416,6 +428,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "post",
@ -463,6 +476,7 @@ def test_form_data(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "post",
@ -513,6 +527,7 @@ def test_none_data(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "post",
@ -546,6 +561,7 @@ def test_mock_404(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "get",
@ -575,6 +591,7 @@ def test_multi_colons_parse(setup_http_mock):
config={
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "get",
@ -627,10 +644,11 @@ def test_nested_object_variable_selector(setup_http_mock):
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{"data": {"type": "start", "title": "Start"}, "id": "start"},
{
"id": "1",
"data": {
"type": "http-request",
"title": "http",
"desc": "",
"method": "get",
@ -651,12 +669,9 @@ def test_nested_object_variable_selector(setup_http_mock):
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
@ -676,12 +691,21 @@ def test_nested_object_variable_selector(setup_http_mock):
variable_pool.add(["a", "args2"], 2)
variable_pool.add(["a", "args3"], {"nested": "nested_value"}) # Only for this test
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = HttpRequestNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=graph_config["nodes"][1],
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
# Initialize node data

View File

@ -6,17 +6,15 @@ from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import StreamCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
@ -30,11 +28,9 @@ def init_llm_node(config: dict) -> LLMNode:
"target": "llm",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
}
graph = Graph.init(graph_config=graph_config)
# Use proper UUIDs for database compatibility
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
@ -44,7 +40,6 @@ def init_llm_node(config: dict) -> LLMNode:
init_params = GraphInitParams(
tenant_id=tenant_id,
app_id=app_id,
workflow_type=WorkflowType.WORKFLOW,
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
@ -69,12 +64,21 @@ def init_llm_node(config: dict) -> LLMNode:
)
variable_pool.add(["abc", "output"], "sunny")
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = LLMNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
@ -173,15 +177,15 @@ def test_execute_llm():
assert isinstance(result, Generator)
for item in result:
if isinstance(item, RunCompletedEvent):
if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
print(f"Error: {item.run_result.error}")
print(f"Error type: {item.run_result.error_type}")
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
if isinstance(item, StreamCompletedEvent):
if item.node_run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
print(f"Error: {item.node_run_result.error}")
print(f"Error type: {item.node_run_result.error_type}")
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.process_data is not None
assert item.node_run_result.outputs is not None
assert item.node_run_result.outputs.get("text") is not None
assert item.node_run_result.outputs.get("usage", {})["total_tokens"] > 0
def test_execute_llm_with_jinja2():
@ -284,11 +288,11 @@ def test_execute_llm_with_jinja2():
result = node._run()
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
if isinstance(item, StreamCompletedEvent):
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.process_data is not None
assert "sunny" in json.dumps(item.node_run_result.process_data)
assert "what's the weather today?" in json.dumps(item.node_run_result.process_data)
def test_extract_json():

View File

@ -1,16 +1,14 @@
import os
import time
import uuid
from typing import Optional
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
@ -18,7 +16,6 @@ from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from models.workflow import WorkflowType
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
@ -29,7 +26,7 @@ def get_mocked_fetch_memory(memory_text: str):
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: Optional[int] = None,
message_limit: int | None = None,
):
return memory_text
@ -45,15 +42,12 @@ def init_parameter_extractor_node(config: dict):
"target": "llm",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
@ -74,12 +68,21 @@ def init_parameter_extractor_node(config: dict):
variable_pool.add(["a", "args1"], 1)
variable_pool.add(["a", "args2"], 2)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = ParameterExtractorNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config.get("data", {}))
return node

View File

@ -4,15 +4,13 @@ import uuid
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@ -22,6 +20,7 @@ def test_execute_code(setup_code_executor_mock):
config = {
"id": "1",
"data": {
"type": "template-transform",
"title": "123",
"variables": [
{
@ -42,15 +41,12 @@ def test_execute_code(setup_code_executor_mock):
"target": "1",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
@ -69,12 +65,21 @@ def test_execute_code(setup_code_executor_mock):
variable_pool.add(["1", "args1"], 1)
variable_pool.add(["1", "args2"], 3)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = TemplateTransformNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config.get("data", {}))

View File

@ -4,16 +4,14 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event.event import RunCompletedEvent
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import StreamCompletedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
def init_tool_node(config: dict):
@ -25,15 +23,12 @@ def init_tool_node(config: dict):
"target": "1",
},
],
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
@ -50,12 +45,21 @@ def init_tool_node(config: dict):
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = ToolNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config.get("data", {}))
return node
@ -66,6 +70,7 @@ def test_tool_variable_invoke():
config={
"id": "1",
"data": {
"type": "tool",
"title": "a",
"desc": "a",
"provider_id": "time",
@ -86,10 +91,10 @@ def test_tool_variable_invoke():
# execute node
result = node._run()
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
if isinstance(item, StreamCompletedEvent):
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.outputs is not None
assert item.node_run_result.outputs.get("text") is not None
def test_tool_mixed_invoke():
@ -97,6 +102,7 @@ def test_tool_mixed_invoke():
config={
"id": "1",
"data": {
"type": "tool",
"title": "a",
"desc": "a",
"provider_id": "time",
@ -117,7 +123,7 @@ def test_tool_mixed_invoke():
# execute node
result = node._run()
for item in result:
if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None
if isinstance(item, StreamCompletedEvent):
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.node_run_result.outputs is not None
assert item.node_run_result.outputs.get("text") is not None