mirror of
https://github.com/langgenius/dify.git
synced 2026-02-23 03:17:57 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -206,6 +206,7 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
EMAIL_REGISTER_TOKEN_EXPIRY_MINUTES=5
|
||||
CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
|
||||
OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
|
||||
@ -9,12 +9,13 @@ from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app_factory import create_app
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin, db
|
||||
from extensions.ext_database import db
|
||||
from models import Account, DifySetup, Tenant, TenantAccountJoin
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
|
||||
# Loading the .env file if it exists
|
||||
def _load_env() -> None:
|
||||
def _load_env():
|
||||
current_file_path = pathlib.Path(__file__).absolute()
|
||||
# Items later in the list have higher precedence.
|
||||
files_to_load = [".env", "vdb.env"]
|
||||
|
||||
@ -0,0 +1,206 @@
|
||||
"""Integration tests for ChatMessageApi permission verification."""
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import completion as completion_api
|
||||
from controllers.console.app import message as message_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, Tenant
|
||||
from models.account import TenantAccountRole
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
class TestChatMessageApiPermissions:
|
||||
"""Test permission verification for ChatMessageApi endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model for testing."""
|
||||
app = App()
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create a mock Account for testing."""
|
||||
|
||||
account = Account()
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
account.email = "test@example.com"
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant()
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_post_with_owner_role_succeeds(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Test that OWNER role can access chat-messages endpoint."""
|
||||
|
||||
"""Setup common mocks for testing."""
|
||||
# Mock app loading
|
||||
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock current user
|
||||
monkeypatch.setattr(completion_api, "current_user", mock_account)
|
||||
|
||||
mock_generate = mock.Mock(return_value={"message": "Test response"})
|
||||
monkeypatch.setattr(AppGenerateService, "generate", mock_generate)
|
||||
|
||||
# Set user role to OWNER
|
||||
mock_account.role = role
|
||||
|
||||
response = test_client.post(
|
||||
f"/console/api/apps/{mock_app_model.id}/chat-messages",
|
||||
headers=auth_header,
|
||||
json={
|
||||
"inputs": {},
|
||||
"query": "Hello, world!",
|
||||
"model_config": {
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}}
|
||||
},
|
||||
"response_mode": "blocking",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_get_requires_edit_permission(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Ensure GET chat-messages endpoint enforces edit permissions."""
|
||||
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
conversation_id = uuid.uuid4()
|
||||
created_at = naive_utc_now()
|
||||
|
||||
mock_conversation = SimpleNamespace(id=str(conversation_id), app_id=str(mock_app_model.id))
|
||||
mock_message = SimpleNamespace(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=str(conversation_id),
|
||||
inputs=[],
|
||||
query="hello",
|
||||
message=[{"text": "hello"}],
|
||||
message_tokens=0,
|
||||
re_sign_file_url_answer="",
|
||||
answer_tokens=0,
|
||||
provider_response_latency=0.0,
|
||||
from_source="console",
|
||||
from_end_user_id=None,
|
||||
from_account_id=mock_account.id,
|
||||
feedbacks=[],
|
||||
workflow_run_id=None,
|
||||
annotation=None,
|
||||
annotation_hit_history=None,
|
||||
created_at=created_at,
|
||||
agent_thoughts=[],
|
||||
message_files=[],
|
||||
message_metadata_dict={},
|
||||
status="success",
|
||||
error="",
|
||||
parent_message_id=None,
|
||||
)
|
||||
|
||||
class MockQuery:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
if getattr(self.model, "__name__", "") == "Conversation":
|
||||
return mock_conversation
|
||||
return None
|
||||
|
||||
def order_by(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def limit(self, *_):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
if getattr(self.model, "__name__", "") == "Message":
|
||||
return [mock_message]
|
||||
return []
|
||||
|
||||
mock_session = mock.Mock()
|
||||
mock_session.query.side_effect = MockQuery
|
||||
mock_session.scalar.return_value = False
|
||||
|
||||
monkeypatch.setattr(message_api, "db", SimpleNamespace(session=mock_session))
|
||||
monkeypatch.setattr(message_api, "current_user", mock_account)
|
||||
|
||||
class DummyPagination:
|
||||
def __init__(self, data, limit, has_more):
|
||||
self.data = data
|
||||
self.limit = limit
|
||||
self.has_more = has_more
|
||||
|
||||
monkeypatch.setattr(message_api, "InfiniteScrollPagination", DummyPagination)
|
||||
|
||||
mock_account.role = role
|
||||
|
||||
response = test_client.get(
|
||||
f"/console/api/apps/{mock_app_model.id}/chat-messages",
|
||||
headers=auth_header,
|
||||
query_string={"conversation_id": str(conversation_id)},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
@ -0,0 +1,129 @@
|
||||
"""Integration tests for ModelConfigResource permission verification."""
|
||||
|
||||
import uuid
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from controllers.console.app import model_config as model_config_api
|
||||
from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, Tenant
|
||||
from models.account import TenantAccountRole
|
||||
from models.model import AppMode
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
class TestModelConfigResourcePermissions:
|
||||
"""Test permission verification for ModelConfigResource endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_model(self):
|
||||
"""Create a mock App model for testing."""
|
||||
app = App()
|
||||
app.id = str(uuid.uuid4())
|
||||
app.mode = AppMode.CHAT.value
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.status = "normal"
|
||||
app.app_model_config_id = str(uuid.uuid4())
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
"""Create a mock Account for testing."""
|
||||
|
||||
account = Account()
|
||||
account.id = str(uuid.uuid4())
|
||||
account.name = "Test User"
|
||||
account.email = "test@example.com"
|
||||
account.last_active_at = naive_utc_now()
|
||||
account.created_at = naive_utc_now()
|
||||
account.updated_at = naive_utc_now()
|
||||
|
||||
# Create mock tenant
|
||||
tenant = Tenant()
|
||||
tenant.id = str(uuid.uuid4())
|
||||
tenant.name = "Test Tenant"
|
||||
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "status"),
|
||||
[
|
||||
(TenantAccountRole.OWNER, 200),
|
||||
(TenantAccountRole.ADMIN, 200),
|
||||
(TenantAccountRole.EDITOR, 200),
|
||||
(TenantAccountRole.NORMAL, 403),
|
||||
(TenantAccountRole.DATASET_OPERATOR, 403),
|
||||
],
|
||||
)
|
||||
def test_post_with_owner_role_succeeds(
|
||||
self,
|
||||
test_client: FlaskClient,
|
||||
auth_header,
|
||||
monkeypatch,
|
||||
mock_app_model,
|
||||
mock_account,
|
||||
role: TenantAccountRole,
|
||||
status: int,
|
||||
):
|
||||
"""Test that OWNER role can access model-config endpoint."""
|
||||
# Set user role to OWNER
|
||||
mock_account.role = role
|
||||
|
||||
# Mock app loading
|
||||
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||
|
||||
# Mock current user
|
||||
monkeypatch.setattr(model_config_api, "current_user", mock_account)
|
||||
|
||||
# Mock AccountService.load_user to prevent authentication issues
|
||||
from services.account_service import AccountService
|
||||
|
||||
mock_load_user = mock.Mock(return_value=mock_account)
|
||||
monkeypatch.setattr(AccountService, "load_user", mock_load_user)
|
||||
|
||||
mock_validate_config = mock.Mock(
|
||||
return_value={
|
||||
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat", "completion_params": {}},
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"user_input_form": [],
|
||||
"dataset_query_variable": "",
|
||||
"agent_mode": {"enabled": False, "tools": []},
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(AppModelConfigService, "validate_configuration", mock_validate_config)
|
||||
|
||||
# Mock database operations
|
||||
mock_db_session = mock.Mock()
|
||||
mock_db_session.add = mock.Mock()
|
||||
mock_db_session.flush = mock.Mock()
|
||||
mock_db_session.commit = mock.Mock()
|
||||
monkeypatch.setattr(model_config_api.db, "session", mock_db_session)
|
||||
|
||||
# Mock app_model_config_was_updated event
|
||||
mock_event = mock.Mock()
|
||||
mock_event.send = mock.Mock()
|
||||
monkeypatch.setattr(model_config_api, "app_model_config_was_updated", mock_event)
|
||||
|
||||
response = test_client.post(
|
||||
f"/console/api/apps/{mock_app_model.id}/model-config",
|
||||
headers=auth_header,
|
||||
json={
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4",
|
||||
"mode": "chat",
|
||||
"completion_params": {"temperature": 0.7, "max_tokens": 1000},
|
||||
},
|
||||
"user_input_form": [],
|
||||
"dataset_query_variable": "",
|
||||
"pre_prompt": "You are a helpful assistant.",
|
||||
"agent_mode": {"enabled": False, "tools": []},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == status
|
||||
@ -1,6 +1,5 @@
|
||||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
@ -42,7 +41,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
self.session.rollback()
|
||||
|
||||
def _create_upload_file(
|
||||
self, file_id: Optional[str] = None, storage_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
self, file_id: str | None = None, storage_key: str | None = None, tenant_id: str | None = None
|
||||
) -> UploadFile:
|
||||
"""Helper method to create an UploadFile record for testing."""
|
||||
if file_id is None:
|
||||
@ -74,7 +73,7 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
return upload_file
|
||||
|
||||
def _create_tool_file(
|
||||
self, file_id: Optional[str] = None, file_key: Optional[str] = None, tenant_id: Optional[str] = None
|
||||
self, file_id: str | None = None, file_key: str | None = None, tenant_id: str | None = None
|
||||
) -> ToolFile:
|
||||
"""Helper method to create a ToolFile record for testing."""
|
||||
if file_id is None:
|
||||
@ -84,26 +83,24 @@ class TestStorageKeyLoader(unittest.TestCase):
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
tool_file = ToolFile()
|
||||
tool_file = ToolFile(
|
||||
user_id=self.user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=self.conversation_id,
|
||||
file_key=file_key,
|
||||
mimetype="text/plain",
|
||||
original_url="http://example.com/file.txt",
|
||||
name="test_tool_file.txt",
|
||||
size=2048,
|
||||
)
|
||||
tool_file.id = file_id
|
||||
tool_file.user_id = self.user_id
|
||||
tool_file.tenant_id = tenant_id
|
||||
tool_file.conversation_id = self.conversation_id
|
||||
tool_file.file_key = file_key
|
||||
tool_file.mimetype = "text/plain"
|
||||
tool_file.original_url = "http://example.com/file.txt"
|
||||
tool_file.name = "test_tool_file.txt"
|
||||
tool_file.size = 2048
|
||||
|
||||
self.session.add(tool_file)
|
||||
self.session.flush()
|
||||
self.test_tool_files.append(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
def _create_file(
|
||||
self, related_id: str, transfer_method: FileTransferMethod, tenant_id: Optional[str] = None
|
||||
) -> File:
|
||||
def _create_file(self, related_id: str, transfer_method: FileTransferMethod, tenant_id: str | None = None) -> File:
|
||||
"""Helper method to create a File object for testing."""
|
||||
if tenant_id is None:
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
@ -17,7 +17,7 @@ def mock_plugin_daemon(
|
||||
:return: unpatch function
|
||||
"""
|
||||
|
||||
def unpatch() -> None:
|
||||
def unpatch():
|
||||
monkeypatch.undo()
|
||||
|
||||
monkeypatch.setattr(PluginModelClient, "invoke_llm", MockModelClass.invoke_llm)
|
||||
|
||||
@ -5,8 +5,6 @@ from decimal import Decimal
|
||||
from json import dumps
|
||||
|
||||
# import monkeypatch
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
|
||||
@ -113,8 +111,8 @@ class MockModelClass(PluginModelClient):
|
||||
|
||||
@staticmethod
|
||||
def generate_function_call(
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
) -> Optional[AssistantPromptMessage.ToolCall]:
|
||||
tools: list[PromptMessageTool] | None,
|
||||
) -> AssistantPromptMessage.ToolCall | None:
|
||||
if not tools or len(tools) == 0:
|
||||
return None
|
||||
function: PromptMessageTool = tools[0]
|
||||
@ -157,7 +155,7 @@ class MockModelClass(PluginModelClient):
|
||||
def mocked_chat_create_sync(
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> LLMResult:
|
||||
tool_call = MockModelClass.generate_function_call(tools=tools)
|
||||
|
||||
@ -186,7 +184,7 @@ class MockModelClass(PluginModelClient):
|
||||
def mocked_chat_create_stream(
|
||||
model: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
tool_call = MockModelClass.generate_function_call(tools=tools)
|
||||
|
||||
@ -241,9 +239,9 @@ class MockModelClass(PluginModelClient):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
model_parameters: dict | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
):
|
||||
return MockModelClass.mocked_chat_create_stream(model=model, prompt_messages=prompt_messages, tools=tools)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -27,13 +27,11 @@ class MockedHttp:
|
||||
@classmethod
|
||||
def requests_request(
|
||||
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> requests.Response:
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Mocked requests.request
|
||||
Mocked httpx.request
|
||||
"""
|
||||
request = requests.PreparedRequest()
|
||||
request.method = method
|
||||
request.url = url
|
||||
request = httpx.Request(method, url)
|
||||
if url.endswith("/tools"):
|
||||
content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
|
||||
code=0, message="success", data=cls.list_tools()
|
||||
@ -41,8 +39,7 @@ class MockedHttp:
|
||||
else:
|
||||
raise ValueError("")
|
||||
|
||||
response = requests.Response()
|
||||
response.status_code = 200
|
||||
response = httpx.Response(status_code=200)
|
||||
response.request = request
|
||||
response._content = content.encode("utf-8")
|
||||
return response
|
||||
@ -54,7 +51,7 @@ MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
@pytest.fixture
|
||||
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK_SWITCH:
|
||||
monkeypatch.setattr(requests, "request", MockedHttp.requests_request)
|
||||
monkeypatch.setattr(httpx, "request", MockedHttp.requests_request)
|
||||
|
||||
def unpatch():
|
||||
monkeypatch.undo()
|
||||
|
||||
@ -3,16 +3,27 @@ import unittest
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from core.variables.types import SegmentType
|
||||
from core.variables.variables import StringVariable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes import NodeType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from factories.variable_factory import build_segment
|
||||
from libs import datetime_utils
|
||||
from models import db
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, VariableResetError, WorkflowDraftVariableService
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, WorkflowNodeExecutionModel
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
DraftVarLoader,
|
||||
VariableResetError,
|
||||
WorkflowDraftVariableService,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
@ -175,6 +186,23 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
_node1_id = "test_loader_node_1"
|
||||
_node_exec_id = str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_app_id(self):
|
||||
# return str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def test_tenant_id(self):
|
||||
# return str(uuid.uuid4())
|
||||
|
||||
# @pytest.fixture
|
||||
# def session(self):
|
||||
# with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# yield session
|
||||
|
||||
# @pytest.fixture
|
||||
# def node_var(self, session):
|
||||
# pass
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._test_tenant_id = str(uuid.uuid4())
|
||||
@ -241,6 +269,246 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
node1_var = next(v for v in variables if v.selector[0] == self._node1_id)
|
||||
assert node1_var.id == self._node_var_id
|
||||
|
||||
@pytest.mark.usefixtures("setup_account")
|
||||
def test_load_offloaded_variable_string_type_integration(self, setup_account):
|
||||
"""Test _load_offloaded_variable with string type using DraftVariableSaver for data creation."""
|
||||
|
||||
# Create a large string that will be offloaded
|
||||
test_content = "x" * 15000 # Create a string larger than LARGE_VARIABLE_THRESHOLD (10KB)
|
||||
large_string_segment = StringSegment(value=test_content)
|
||||
|
||||
node_execution_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Use DraftVariableSaver to create offloaded variable (this mimics production)
|
||||
saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_offload_node",
|
||||
node_type=NodeType.LLM, # Use a real node type
|
||||
node_execution_id=node_execution_id,
|
||||
user=setup_account,
|
||||
)
|
||||
|
||||
# Save the variable - this will trigger offloading due to large size
|
||||
saver.save(outputs={"offloaded_string_var": large_string_segment})
|
||||
session.commit()
|
||||
|
||||
# Now test loading using DraftVarLoader
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
# Load the variable using the standard workflow
|
||||
variables = var_loader.load_variables([["test_offload_node", "offloaded_string_var"]])
|
||||
|
||||
# Verify results
|
||||
assert len(variables) == 1
|
||||
loaded_variable = variables[0]
|
||||
assert loaded_variable.name == "offloaded_string_var"
|
||||
assert loaded_variable.selector == ["test_offload_node", "offloaded_string_var"]
|
||||
assert isinstance(loaded_variable.value, StringSegment)
|
||||
assert loaded_variable.value.value == test_content
|
||||
|
||||
finally:
|
||||
# Clean up - delete all draft variables for this app
|
||||
with Session(bind=db.engine) as session:
|
||||
service = WorkflowDraftVariableService(session)
|
||||
service.delete_workflow_variables(self._test_app_id)
|
||||
session.commit()
|
||||
|
||||
def test_load_offloaded_variable_object_type_integration(self):
|
||||
"""Test _load_offloaded_variable with object type using real storage and service."""
|
||||
|
||||
# Create a test object
|
||||
test_object = {"key1": "value1", "key2": 42, "nested": {"inner": "data"}}
|
||||
test_json = json.dumps(test_object, ensure_ascii=False, separators=(",", ":"))
|
||||
content_bytes = test_json.encode()
|
||||
|
||||
# Create an upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
key=f"test_offload_{uuid.uuid4()}.json",
|
||||
name="test_offload.json",
|
||||
size=len(content_bytes),
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
used=True,
|
||||
used_by=str(uuid.uuid4()),
|
||||
used_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store the content in storage
|
||||
storage.save(upload_file.key, content_bytes)
|
||||
|
||||
# Create a variable file record
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
upload_file_id=upload_file.id,
|
||||
value_type=SegmentType.OBJECT,
|
||||
tenant_id=self._test_tenant_id,
|
||||
app_id=self._test_app_id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
size=len(content_bytes),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Add upload file and variable file first to get their IDs
|
||||
session.add_all([upload_file, variable_file])
|
||||
session.flush() # This generates the IDs
|
||||
|
||||
# Now create the offloaded draft variable with the correct file_id
|
||||
offloaded_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_offload_node",
|
||||
name="offloaded_object_var",
|
||||
value=build_segment({"truncated": True}),
|
||||
visible=True,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
offloaded_var.file_id = variable_file.id
|
||||
|
||||
session.add(offloaded_var)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
# Use the service method that properly preloads relationships
|
||||
service = WorkflowDraftVariableService(session)
|
||||
draft_vars = service.get_draft_variables_by_selectors(
|
||||
self._test_app_id, [["test_offload_node", "offloaded_object_var"]]
|
||||
)
|
||||
|
||||
assert len(draft_vars) == 1
|
||||
loaded_var = draft_vars[0]
|
||||
assert loaded_var.is_truncated()
|
||||
|
||||
# Create DraftVarLoader and test loading
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
# Test the _load_offloaded_variable method
|
||||
selector_tuple, variable = var_loader._load_offloaded_variable(loaded_var)
|
||||
|
||||
# Verify the results
|
||||
assert selector_tuple == ("test_offload_node", "offloaded_object_var")
|
||||
assert variable.id == loaded_var.id
|
||||
assert variable.name == "offloaded_object_var"
|
||||
assert variable.value.value == test_object
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
storage.delete(upload_file.key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup failures
|
||||
|
||||
def test_load_variables_with_offloaded_variables_integration(self):
|
||||
"""Test load_variables method with mix of regular and offloaded variables using real storage."""
|
||||
# Create a regular variable (already exists from setUp)
|
||||
# Create offloaded variable content
|
||||
test_content = "This is offloaded content for integration test"
|
||||
content_bytes = test_content.encode()
|
||||
|
||||
# Create upload file record
|
||||
upload_file = UploadFile(
|
||||
tenant_id=self._test_tenant_id,
|
||||
storage_type="local",
|
||||
key=f"test_integration_{uuid.uuid4()}.txt",
|
||||
name="test_integration.txt",
|
||||
size=len(content_bytes),
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
used=True,
|
||||
used_by=str(uuid.uuid4()),
|
||||
used_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
# Store the content
|
||||
storage.save(upload_file.key, content_bytes)
|
||||
|
||||
# Create variable file
|
||||
variable_file = WorkflowDraftVariableFile(
|
||||
upload_file_id=upload_file.id,
|
||||
value_type=SegmentType.STRING,
|
||||
tenant_id=self._test_tenant_id,
|
||||
app_id=self._test_app_id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
size=len(content_bytes),
|
||||
created_at=datetime_utils.naive_utc_now(),
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
# Add upload file and variable file first to get their IDs
|
||||
session.add_all([upload_file, variable_file])
|
||||
session.flush() # This generates the IDs
|
||||
|
||||
# Now create the offloaded draft variable with the correct file_id
|
||||
offloaded_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
node_id="test_integration_node",
|
||||
name="offloaded_integration_var",
|
||||
value=build_segment("truncated"),
|
||||
visible=True,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
offloaded_var.file_id = variable_file.id
|
||||
|
||||
session.add(offloaded_var)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
# Test load_variables with both regular and offloaded variables
|
||||
# This method should handle the relationship preloading internally
|
||||
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
|
||||
|
||||
variables = var_loader.load_variables(
|
||||
[
|
||||
[SYSTEM_VARIABLE_NODE_ID, "sys_var"], # Regular variable from setUp
|
||||
["test_integration_node", "offloaded_integration_var"], # Offloaded variable
|
||||
]
|
||||
)
|
||||
|
||||
# Verify results
|
||||
assert len(variables) == 2
|
||||
|
||||
# Find regular variable
|
||||
regular_var = next(v for v in variables if v.selector[0] == SYSTEM_VARIABLE_NODE_ID)
|
||||
assert regular_var.id == self._sys_var_id
|
||||
assert regular_var.value == "sys_value"
|
||||
|
||||
# Find offloaded variable
|
||||
offloaded_loaded_var = next(v for v in variables if v.selector[0] == "test_integration_node")
|
||||
assert offloaded_loaded_var.id == offloaded_var.id
|
||||
assert offloaded_loaded_var.value == test_content
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
storage.delete(upload_file.key)
|
||||
except Exception:
|
||||
pass # Ignore cleanup failures
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx")
|
||||
class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
@ -272,7 +540,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
triggered_from="workflow-run",
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
index=1,
|
||||
node_execution_id=self._node_exec_id,
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
node_id=self._node_id,
|
||||
node_type=NodeType.LLM.value,
|
||||
title="Test Node",
|
||||
@ -281,7 +549,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
outputs='{"test_var": "output_value", "other_var": "other_output"}',
|
||||
status="succeeded",
|
||||
elapsed_time=1.5,
|
||||
created_by_role="account",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
@ -336,10 +604,14 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
)
|
||||
self._conv_var.last_edited_at = datetime_utils.naive_utc_now()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as persistent_session, persistent_session.begin():
|
||||
persistent_session.add(
|
||||
self._workflow_node_execution,
|
||||
)
|
||||
|
||||
# Add all to database
|
||||
db.session.add_all(
|
||||
[
|
||||
self._workflow_node_execution,
|
||||
self._node_var_with_exec,
|
||||
self._node_var_without_exec,
|
||||
self._node_var_missing_exec,
|
||||
@ -354,6 +626,14 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
self._node_var_missing_exec_id = self._node_var_missing_exec.id
|
||||
self._conv_var_id = self._conv_var.id
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
with Session(db.engine) as session, session.begin():
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.id == self._workflow_node_execution.id
|
||||
)
|
||||
session.execute(stmt)
|
||||
|
||||
def _get_test_srv(self) -> WorkflowDraftVariableService:
|
||||
return WorkflowDraftVariableService(session=self._session)
|
||||
|
||||
@ -377,12 +657,10 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase):
|
||||
created_by=str(uuid.uuid4()),
|
||||
environment_variables=[],
|
||||
conversation_variables=conversation_vars,
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
return workflow
|
||||
|
||||
def tearDown(self):
|
||||
self._session.rollback()
|
||||
|
||||
def test_reset_node_variable_with_valid_execution_record(self):
|
||||
"""Test resetting a node variable with valid execution record - should restore from execution"""
|
||||
srv = self._get_test_srv()
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
@ -60,8 +61,7 @@ class TestClickZettaVolumeStorage(unittest.TestCase):
|
||||
# Test download
|
||||
with tempfile.NamedTemporaryFile() as temp_file:
|
||||
storage.download(test_filename, temp_file.name)
|
||||
with open(temp_file.name, "rb") as f:
|
||||
downloaded_content = f.read()
|
||||
downloaded_content = Path(temp_file.name).read_bytes()
|
||||
assert downloaded_content == test_content
|
||||
|
||||
# Test scan
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.variables.segments import StringSegment
|
||||
from models import Tenant, db
|
||||
from models.model import App
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from extensions.ext_database import db
|
||||
from models import Tenant
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, UploadFile
|
||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
|
||||
|
||||
|
||||
@ -212,3 +215,256 @@ class TestDeleteDraftVariablesIntegration:
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(query)
|
||||
|
||||
|
||||
class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
"""Integration tests for draft variable deletion with Offload data."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_offload_test_data(self, app_and_tenant):
|
||||
"""Create test data with draft variables that have associated Offload files."""
|
||||
tenant, app = app_and_tenant
|
||||
|
||||
# Create UploadFile records
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
upload_file1 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file1.json",
|
||||
name="file1.json",
|
||||
size=1024,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
upload_file2 = UploadFile(
|
||||
tenant_id=tenant.id,
|
||||
storage_type="local",
|
||||
key="test/file2.json",
|
||||
name="file2.json",
|
||||
size=2048,
|
||||
extension="json",
|
||||
mime_type="application/json",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=str(uuid.uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
used=False,
|
||||
)
|
||||
db.session.add(upload_file1)
|
||||
db.session.add(upload_file2)
|
||||
db.session.flush()
|
||||
|
||||
# Create WorkflowDraftVariableFile records
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
var_file1 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file1.id,
|
||||
size=1024,
|
||||
length=10,
|
||||
value_type=SegmentType.STRING,
|
||||
)
|
||||
var_file2 = WorkflowDraftVariableFile(
|
||||
tenant_id=tenant.id,
|
||||
app_id=app.id,
|
||||
user_id=str(uuid.uuid4()),
|
||||
upload_file_id=upload_file2.id,
|
||||
size=2048,
|
||||
length=20,
|
||||
value_type=SegmentType.OBJECT,
|
||||
)
|
||||
db.session.add(var_file1)
|
||||
db.session.add(var_file2)
|
||||
db.session.flush()
|
||||
|
||||
# Create WorkflowDraftVariable records with file associations
|
||||
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_1",
|
||||
name="large_var_1",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file1.id,
|
||||
)
|
||||
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_2",
|
||||
name="large_var_2",
|
||||
value=StringSegment(value="truncated..."),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
file_id=var_file2.id,
|
||||
)
|
||||
# Create a regular variable without Offload data
|
||||
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app.id,
|
||||
node_id="node_3",
|
||||
name="regular_var",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
db.session.add(draft_var1)
|
||||
db.session.add(draft_var2)
|
||||
db.session.add(draft_var3)
|
||||
db.session.commit()
|
||||
|
||||
yield {
|
||||
"app": app,
|
||||
"tenant": tenant,
|
||||
"upload_files": [upload_file1, upload_file2],
|
||||
"variable_files": [var_file1, var_file2],
|
||||
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||
}
|
||||
|
||||
# Cleanup
|
||||
db.session.rollback()
|
||||
|
||||
# Clean up any remaining records
|
||||
for table, ids in [
|
||||
(WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]),
|
||||
(WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]),
|
||||
(UploadFile, [uf.id for uf in [upload_file1, upload_file2]]),
|
||||
]:
|
||||
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||
db.session.execute(cleanup_query)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that deleting draft variables also cleans up associated Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to succeed
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Verify initial state
|
||||
draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_before = db.session.query(UploadFile).count()
|
||||
|
||||
assert draft_vars_before == 3 # 2 with files + 1 regular
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify results
|
||||
assert deleted_count == 3
|
||||
|
||||
# Check that all draft variables are deleted
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Check that associated Offload data is cleaned up
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
|
||||
assert var_files_after == 0 # All variable files should be deleted
|
||||
assert upload_files_after == 0 # All upload files should be deleted
|
||||
|
||||
# Verify storage deletion was called for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list]
|
||||
assert "test/file1.json" in storage_keys_deleted
|
||||
assert "test/file2.json" in storage_keys_deleted
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
||||
"""Test that database cleanup continues even when storage deletion fails."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Mock storage deletion to fail for first file, succeed for second
|
||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||
|
||||
# Delete draft variables
|
||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
|
||||
# Verify that all draft variables are still deleted
|
||||
assert deleted_count == 3
|
||||
|
||||
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
assert draft_vars_after == 0
|
||||
|
||||
# Database cleanup should still succeed even with storage errors
|
||||
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
|
||||
upload_files_after = db.session.query(UploadFile).count()
|
||||
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
# Verify storage deletion was attempted for both files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
@patch("extensions.ext_storage.storage")
|
||||
def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data):
|
||||
"""Test deletion with mix of variables with and without Offload data."""
|
||||
data = setup_offload_test_data
|
||||
app_id = data["app"].id
|
||||
|
||||
# Create additional app with only regular variables (no offload data)
|
||||
tenant = data["tenant"]
|
||||
app2 = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App 2",
|
||||
mode="workflow",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app2)
|
||||
db.session.flush()
|
||||
|
||||
# Add regular variables to app2
|
||||
regular_vars = []
|
||||
for i in range(3):
|
||||
var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=app2.id,
|
||||
node_id=f"node_{i}",
|
||||
name=f"var_{i}",
|
||||
value=StringSegment(value="regular_value"),
|
||||
node_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(var)
|
||||
regular_vars.append(var)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
# Mock storage deletion
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
# Delete variables for app2 (no offload data)
|
||||
deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10)
|
||||
assert deleted_count_app2 == 3
|
||||
|
||||
# Verify storage wasn't called for app2 (no offload files)
|
||||
mock_storage.delete.assert_not_called()
|
||||
|
||||
# Delete variables for original app (with offload data)
|
||||
deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10)
|
||||
assert deleted_count_app1 == 3
|
||||
|
||||
# Now storage should be called for the offload files
|
||||
assert mock_storage.delete.call_count == 2
|
||||
|
||||
finally:
|
||||
# Cleanup app2 and its variables
|
||||
cleanup_vars_query = (
|
||||
delete(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.app_id == app2.id)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
db.session.execute(cleanup_vars_query)
|
||||
|
||||
app2_obj = db.session.get(App, app2.id)
|
||||
if app2_obj:
|
||||
db.session.delete(app2_obj)
|
||||
db.session.commit()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
from collections import UserDict
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
@ -22,7 +21,7 @@ class MockBaiduVectorDBClass:
|
||||
def mock_vector_db_client(
|
||||
self,
|
||||
config=None,
|
||||
adapter: Optional[HTTPAdapter] = None,
|
||||
adapter: HTTPAdapter | None = None,
|
||||
):
|
||||
self.conn = MagicMock()
|
||||
self._config = MagicMock()
|
||||
@ -101,8 +100,8 @@ class MockBaiduVectorDBClass:
|
||||
"row": {
|
||||
"id": primary_key.get("id"),
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"text": "text",
|
||||
"metadata": '{"doc_id": "doc_id_001"}',
|
||||
"page_content": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
},
|
||||
"code": 0,
|
||||
"msg": "Success",
|
||||
@ -128,8 +127,8 @@ class MockBaiduVectorDBClass:
|
||||
"row": {
|
||||
"id": "doc_id_001",
|
||||
"vector": [0.23432432, 0.8923744, 0.89238432],
|
||||
"text": "text",
|
||||
"metadata": '{"doc_id": "doc_id_001"}',
|
||||
"page_content": "text",
|
||||
"metadata": {"doc_id": "doc_id_001"},
|
||||
},
|
||||
"distance": 0.1,
|
||||
"score": 0.5,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
@ -23,16 +23,16 @@ class MockTcvectordbClass:
|
||||
key="",
|
||||
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
|
||||
timeout=10,
|
||||
adapter: Optional[HTTPAdapter] = None,
|
||||
adapter: HTTPAdapter | None = None,
|
||||
pool_size: int = 2,
|
||||
proxies: Optional[dict] = None,
|
||||
password: Optional[str] = None,
|
||||
proxies: dict | None = None,
|
||||
password: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._conn = None
|
||||
self._read_consistency = read_consistency
|
||||
|
||||
def create_database_if_not_exists(self, database_name: str, timeout: Optional[float] = None) -> RPCDatabase:
|
||||
def create_database_if_not_exists(self, database_name: str, timeout: float | None = None) -> RPCDatabase:
|
||||
return RPCDatabase(
|
||||
name="dify",
|
||||
read_consistency=self._read_consistency,
|
||||
@ -42,7 +42,7 @@ class MockTcvectordbClass:
|
||||
return True
|
||||
|
||||
def describe_collection(
|
||||
self, database_name: str, collection_name: str, timeout: Optional[float] = None
|
||||
self, database_name: str, collection_name: str, timeout: float | None = None
|
||||
) -> RPCCollection:
|
||||
index = Index(
|
||||
FilterIndex("id", enum.FieldType.String, enum.IndexType.PRIMARY_KEY),
|
||||
@ -71,13 +71,13 @@ class MockTcvectordbClass:
|
||||
collection_name: str,
|
||||
shard: int,
|
||||
replicas: int,
|
||||
description: Optional[str] = None,
|
||||
index: Optional[Index] = None,
|
||||
embedding: Optional[Embedding] = None,
|
||||
timeout: Optional[float] = None,
|
||||
ttl_config: Optional[dict] = None,
|
||||
filter_index_config: Optional[FilterIndexConfig] = None,
|
||||
indexes: Optional[list[IndexField]] = None,
|
||||
description: str | None = None,
|
||||
index: Index | None = None,
|
||||
embedding: Embedding | None = None,
|
||||
timeout: float | None = None,
|
||||
ttl_config: dict | None = None,
|
||||
filter_index_config: FilterIndexConfig | None = None,
|
||||
indexes: list[IndexField] | None = None,
|
||||
) -> RPCCollection:
|
||||
return RPCCollection(
|
||||
RPCDatabase(
|
||||
@ -102,7 +102,7 @@ class MockTcvectordbClass:
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
documents: list[Union[Document, dict]],
|
||||
timeout: Optional[float] = None,
|
||||
timeout: float | None = None,
|
||||
build_index: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@ -113,12 +113,12 @@ class MockTcvectordbClass:
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
vectors: list[list[float]],
|
||||
filter: Optional[Filter] = None,
|
||||
filter: Filter | None = None,
|
||||
params=None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
output_fields: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> list[list[dict]]:
|
||||
return [[{"metadata": {"doc_id": "foo1"}, "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
@ -126,14 +126,14 @@ class MockTcvectordbClass:
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
ann: Optional[Union[list[AnnSearch], AnnSearch]] = None,
|
||||
match: Optional[Union[list[KeywordSearch], KeywordSearch]] = None,
|
||||
filter: Optional[Union[Filter, str]] = None,
|
||||
rerank: Optional[Rerank] = None,
|
||||
retrieve_vector: Optional[bool] = None,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
limit: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
ann: Union[list[AnnSearch], AnnSearch] | None = None,
|
||||
match: Union[list[KeywordSearch], KeywordSearch] | None = None,
|
||||
filter: Union[Filter, str] | None = None,
|
||||
rerank: Rerank | None = None,
|
||||
retrieve_vector: bool | None = None,
|
||||
output_fields: list[str] | None = None,
|
||||
limit: int | None = None,
|
||||
timeout: float | None = None,
|
||||
return_pd_object=False,
|
||||
**kwargs,
|
||||
) -> list[list[dict]]:
|
||||
@ -143,27 +143,27 @@ class MockTcvectordbClass:
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
document_ids: Optional[list] = None,
|
||||
document_ids: list | None = None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
filter: Optional[Filter] = None,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> list[dict]:
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
filter: Filter | None = None,
|
||||
output_fields: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
|
||||
|
||||
def collection_delete(
|
||||
self,
|
||||
database_name: str,
|
||||
collection_name: str,
|
||||
document_ids: Optional[list[str]] = None,
|
||||
filter: Optional[Filter] = None,
|
||||
timeout: Optional[float] = None,
|
||||
document_ids: list[str] | None = None,
|
||||
filter: Filter | None = None,
|
||||
timeout: float | None = None,
|
||||
):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
def drop_collection(self, database_name: str, collection_name: str, timeout: Optional[float] = None) -> dict:
|
||||
def drop_collection(self, database_name: str, collection_name: str, timeout: float | None = None):
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
from collections import UserDict
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
@ -34,7 +33,7 @@ class MockIndex:
|
||||
include_vectors: bool = False,
|
||||
include_metadata: bool = False,
|
||||
filter: str = "",
|
||||
data: Optional[str] = None,
|
||||
data: str | None = None,
|
||||
namespace: str = "",
|
||||
include_data: bool = False,
|
||||
):
|
||||
|
||||
@ -6,7 +6,7 @@ Test Clickzetta integration in Docker environment
|
||||
import os
|
||||
import time
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from clickzetta import connect
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ def test_dify_api():
|
||||
max_retries = 30
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
response = requests.get(f"{base_url}/console/api/health")
|
||||
response = httpx.get(f"{base_url}/console/api/health")
|
||||
if response.status_code == 200:
|
||||
print("✓ Dify API is ready")
|
||||
break
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
import environs
|
||||
import os
|
||||
|
||||
from core.rag.datasource.vdb.lindorm.lindorm_vector import LindormVectorStore, LindormVectorStoreConfig
|
||||
from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis
|
||||
|
||||
env = environs.Env()
|
||||
|
||||
|
||||
class Config:
|
||||
SEARCH_ENDPOINT = env.str("SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070")
|
||||
SEARCH_USERNAME = env.str("SEARCH_USERNAME", "ADMIN")
|
||||
SEARCH_PWD = env.str("SEARCH_PWD", "ADMIN")
|
||||
USING_UGC = env.bool("USING_UGC", True)
|
||||
SEARCH_ENDPOINT = os.environ.get(
|
||||
"SEARCH_ENDPOINT", "http://ld-************-proxy-search-pub.lindorm.aliyuncs.com:30070"
|
||||
)
|
||||
SEARCH_USERNAME = os.environ.get("SEARCH_USERNAME", "ADMIN")
|
||||
SEARCH_PWD = os.environ.get("SEARCH_PWD", "ADMIN")
|
||||
USING_UGC = os.environ.get("USING_UGC", "True").lower() == "true"
|
||||
|
||||
|
||||
class TestLindormVectorStore(AbstractVectorTest):
|
||||
|
||||
@ -26,7 +26,7 @@ def get_example_document(doc_id: str) -> Document:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mock_redis() -> None:
|
||||
def setup_mock_redis():
|
||||
# get
|
||||
ext_redis.redis_client.get = MagicMock(return_value=None)
|
||||
|
||||
@ -48,7 +48,7 @@ class AbstractVectorTest:
|
||||
self.example_doc_id = str(uuid.uuid4())
|
||||
self.example_embedding = [1.001 * i for i in range(128)]
|
||||
|
||||
def create_vector(self) -> None:
|
||||
def create_vector(self):
|
||||
self.vector.create(
|
||||
texts=[get_example_document(doc_id=self.example_doc_id)],
|
||||
embeddings=[self.example_embedding],
|
||||
|
||||
@ -12,7 +12,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
class MockedCodeExecutor:
|
||||
@classmethod
|
||||
def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict) -> dict:
|
||||
def invoke(cls, language: Literal["python3", "javascript", "jinja2"], code: str, inputs: dict):
|
||||
# invoke directly
|
||||
match language:
|
||||
case CodeLanguage.PYTHON3:
|
||||
|
||||
@ -5,16 +5,14 @@ from os import getenv
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
|
||||
@ -29,15 +27,12 @@ def init_code_node(code_config: dict):
|
||||
"target": "code",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, code_config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -56,12 +51,21 @@ def init_code_node(code_config: dict):
|
||||
variable_pool.add(["code", "args1"], 1)
|
||||
variable_pool.add(["code", "args2"], 2)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = CodeNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=code_config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
@ -74,7 +78,7 @@ def init_code_node(code_config: dict):
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code(setup_code_executor_mock):
|
||||
code = """
|
||||
def main(args1: int, args2: int) -> dict:
|
||||
def main(args1: int, args2: int):
|
||||
return {
|
||||
"result": args1 + args2,
|
||||
}
|
||||
@ -85,6 +89,7 @@ def test_execute_code(setup_code_executor_mock):
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
@ -114,13 +119,13 @@ def test_execute_code(setup_code_executor_mock):
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] == 3
|
||||
assert result.error is None
|
||||
assert result.error == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code_output_validator(setup_code_executor_mock):
|
||||
code = """
|
||||
def main(args1: int, args2: int) -> dict:
|
||||
def main(args1: int, args2: int):
|
||||
return {
|
||||
"result": args1 + args2,
|
||||
}
|
||||
@ -131,6 +136,7 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string",
|
||||
@ -158,12 +164,12 @@ def test_execute_code_output_validator(setup_code_executor_mock):
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Output variable `result` must be a string"
|
||||
assert result.error == "Output result must be a string, got int instead"
|
||||
|
||||
|
||||
def test_execute_code_output_validator_depth():
|
||||
code = """
|
||||
def main(args1: int, args2: int) -> dict:
|
||||
def main(args1: int, args2: int):
|
||||
return {
|
||||
"result": {
|
||||
"result": args1 + args2,
|
||||
@ -176,6 +182,7 @@ def test_execute_code_output_validator_depth():
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"string_validator": {
|
||||
"type": "string",
|
||||
@ -281,7 +288,7 @@ def test_execute_code_output_validator_depth():
|
||||
|
||||
def test_execute_code_output_object_list():
|
||||
code = """
|
||||
def main(args1: int, args2: int) -> dict:
|
||||
def main(args1: int, args2: int):
|
||||
return {
|
||||
"result": {
|
||||
"result": args1 + args2,
|
||||
@ -294,6 +301,7 @@ def test_execute_code_output_object_list():
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"object_list": {
|
||||
"type": "array[object]",
|
||||
@ -354,9 +362,10 @@ def test_execute_code_output_object_list():
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
||||
def test_execute_code_scientific_notation():
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code_scientific_notation(setup_code_executor_mock):
|
||||
code = """
|
||||
def main() -> dict:
|
||||
def main():
|
||||
return {
|
||||
"result": -8.0E-5
|
||||
}
|
||||
@ -366,6 +375,7 @@ def test_execute_code_scientific_notation():
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"type": "code",
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
|
||||
@ -5,14 +5,12 @@ from urllib.parse import urlencode
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock
|
||||
|
||||
|
||||
@ -25,15 +23,12 @@ def init_http_node(config: dict):
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -52,12 +47,21 @@ def init_http_node(config: dict):
|
||||
variable_pool.add(["a", "args1"], 1)
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
@ -73,6 +77,7 @@ def test_get(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@ -106,6 +111,7 @@ def test_no_auth(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@ -135,6 +141,7 @@ def test_custom_authorization_header(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@ -227,6 +234,7 @@ def test_bearer_authorization_with_custom_header_ignored(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@ -267,6 +275,7 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@ -306,6 +315,7 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@ -339,6 +349,7 @@ def test_template(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@ -374,6 +385,7 @@ def test_json(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
@ -416,6 +428,7 @@ def test_x_www_form_urlencoded(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
@ -463,6 +476,7 @@ def test_form_data(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
@ -513,6 +527,7 @@ def test_none_data(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "post",
|
||||
@ -546,6 +561,7 @@ def test_mock_404(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@ -575,6 +591,7 @@ def test_multi_colons_parse(setup_http_mock):
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@ -627,10 +644,11 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{"data": {"type": "start", "title": "Start"}, "id": "start"},
|
||||
{
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
@ -651,12 +669,9 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -676,12 +691,21 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
variable_pool.add(["a", "args3"], {"nested": "nested_value"}) # Only for this test
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = HttpRequestNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=graph_config["nodes"][1],
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
|
||||
@ -6,17 +6,15 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
|
||||
@ -30,11 +28,9 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
# Use proper UUIDs for database compatibility
|
||||
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
|
||||
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
|
||||
@ -44,7 +40,6 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
@ -69,12 +64,21 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
)
|
||||
variable_pool.add(["abc", "output"], "sunny")
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = LLMNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
@ -173,15 +177,15 @@ def test_execute_llm():
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
print(f"Error: {item.run_result.error}")
|
||||
print(f"Error type: {item.run_result.error_type}")
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
if item.node_run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
print(f"Error: {item.node_run_result.error}")
|
||||
print(f"Error type: {item.node_run_result.error_type}")
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.process_data is not None
|
||||
assert item.node_run_result.outputs is not None
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
assert item.node_run_result.outputs.get("usage", {})["total_tokens"] > 0
|
||||
|
||||
|
||||
def test_execute_llm_with_jinja2():
|
||||
@ -284,11 +288,11 @@ def test_execute_llm_with_jinja2():
|
||||
result = node._run()
|
||||
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.process_data is not None
|
||||
assert "sunny" in json.dumps(item.run_result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(item.run_result.process_data)
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.process_data is not None
|
||||
assert "sunny" in json.dumps(item.node_run_result.process_data)
|
||||
assert "what's the weather today?" in json.dumps(item.node_run_result.process_data)
|
||||
|
||||
|
||||
def test_extract_json():
|
||||
|
||||
@ -1,16 +1,14 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
@ -18,7 +16,6 @@ from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||
|
||||
|
||||
@ -29,7 +26,7 @@ def get_mocked_fetch_memory(memory_text: str):
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: Optional[int] = None,
|
||||
message_limit: int | None = None,
|
||||
):
|
||||
return memory_text
|
||||
|
||||
@ -45,15 +42,12 @@ def init_parameter_extractor_node(config: dict):
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -74,12 +68,21 @@ def init_parameter_extractor_node(config: dict):
|
||||
variable_pool.add(["a", "args1"], 1)
|
||||
variable_pool.add(["a", "args2"], 2)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
@ -4,15 +4,13 @@ import uuid
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@ -22,6 +20,7 @@ def test_execute_code(setup_code_executor_mock):
|
||||
config = {
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
@ -42,15 +41,12 @@ def test_execute_code(setup_code_executor_mock):
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -69,12 +65,21 @@ def test_execute_code(setup_code_executor_mock):
|
||||
variable_pool.add(["1", "args1"], 1)
|
||||
variable_pool.add(["1", "args2"], 3)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
|
||||
|
||||
@ -4,16 +4,14 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def init_tool_node(config: dict):
|
||||
@ -25,15 +23,12 @@ def init_tool_node(config: dict):
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
@ -50,12 +45,21 @@ def init_tool_node(config: dict):
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = ToolNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
@ -66,6 +70,7 @@ def test_tool_variable_invoke():
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"title": "a",
|
||||
"desc": "a",
|
||||
"provider_id": "time",
|
||||
@ -86,10 +91,10 @@ def test_tool_variable_invoke():
|
||||
# execute node
|
||||
result = node._run()
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs is not None
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
|
||||
|
||||
def test_tool_mixed_invoke():
|
||||
@ -97,6 +102,7 @@ def test_tool_mixed_invoke():
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"title": "a",
|
||||
"desc": "a",
|
||||
"provider_id": "time",
|
||||
@ -117,7 +123,7 @@ def test_tool_mixed_invoke():
|
||||
# execute node
|
||||
result = node._run()
|
||||
for item in result:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs is not None
|
||||
assert item.run_result.outputs.get("text") is not None
|
||||
if isinstance(item, StreamCompletedEvent):
|
||||
assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.node_run_result.outputs is not None
|
||||
assert item.node_run_result.outputs.get("text") is not None
|
||||
|
||||
Reference in New Issue
Block a user