mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
test(api): fix broken tests (#25846)
This commit is contained in:
@ -397,14 +397,11 @@ class DatasetService:
|
|||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset not found")
|
raise ValueError("Dataset not found")
|
||||||
# check if dataset name is exists
|
# check if dataset name is exists
|
||||||
if (
|
|
||||||
db.session.query(Dataset)
|
if DatasetService._has_dataset_same_name(
|
||||||
.filter(
|
tenant_id=dataset.tenant_id,
|
||||||
Dataset.id != dataset_id,
|
dataset_id=dataset_id,
|
||||||
Dataset.name == data.get("name", dataset.name),
|
name=data.get("name", dataset.name),
|
||||||
Dataset.tenant_id == dataset.tenant_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
):
|
):
|
||||||
raise ValueError("Dataset name already exists")
|
raise ValueError("Dataset name already exists")
|
||||||
|
|
||||||
@ -417,6 +414,19 @@ class DatasetService:
|
|||||||
else:
|
else:
|
||||||
return DatasetService._update_internal_dataset(dataset, data, user)
|
return DatasetService._update_internal_dataset(dataset, data, user)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
|
||||||
|
dataset = (
|
||||||
|
db.session.query(Dataset)
|
||||||
|
.filter(
|
||||||
|
Dataset.id != dataset_id,
|
||||||
|
Dataset.name == name,
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return dataset is not None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_external_dataset(dataset, data, user):
|
def _update_external_dataset(dataset, data, user):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -178,7 +178,7 @@ class TestWorkflowDraftVariableFields:
|
|||||||
)
|
)
|
||||||
|
|
||||||
node_var.id = str(uuid.uuid4())
|
node_var.id = str(uuid.uuid4())
|
||||||
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
node_var.last_edited_at = naive_utc_now()
|
||||||
variable_file = WorkflowDraftVariableFile(
|
variable_file = WorkflowDraftVariableFile(
|
||||||
id=str(uuidv7()),
|
id=str(uuidv7()),
|
||||||
upload_file_id=str(uuid.uuid4()),
|
upload_file_id=str(uuid.uuid4()),
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import json
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from sqlalchemy import Engine
|
from sqlalchemy import Engine
|
||||||
|
|
||||||
@ -25,8 +25,6 @@ from models import Account, WorkflowNodeExecutionTriggeredFrom
|
|||||||
from models.enums import ExecutionOffLoadType
|
from models.enums import ExecutionOffLoadType
|
||||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||||
|
|
||||||
TRUNCATION_SIZE_THRESHOLD = 500
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TruncationTestCase:
|
class TruncationTestCase:
|
||||||
@ -166,35 +164,6 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation:
|
|||||||
assert domain_model.get_truncated_inputs() is None
|
assert domain_model.get_truncated_inputs() is None
|
||||||
assert domain_model.get_truncated_outputs() is None
|
assert domain_model.get_truncated_outputs() is None
|
||||||
|
|
||||||
@patch("core.repositories.sqlalchemy_workflow_node_execution_repository.FileService")
|
|
||||||
def test_save_with_truncation(self, mock_file_service_class):
|
|
||||||
"""Test the save method handles truncation and offload record creation."""
|
|
||||||
# Setup mock file service
|
|
||||||
mock_file_service = MagicMock()
|
|
||||||
mock_upload_file = MagicMock()
|
|
||||||
mock_upload_file.id = "mock-file-id"
|
|
||||||
mock_file_service.upload_file.return_value = mock_upload_file
|
|
||||||
mock_file_service_class.return_value = mock_file_service
|
|
||||||
|
|
||||||
large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1)}
|
|
||||||
|
|
||||||
repo = self.create_repository()
|
|
||||||
execution = create_workflow_node_execution(
|
|
||||||
inputs=large_data,
|
|
||||||
outputs=large_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock the session and database operations
|
|
||||||
with patch.object(repo, "_session_factory") as mock_session_factory:
|
|
||||||
mock_session = MagicMock()
|
|
||||||
mock_session_factory.return_value.__enter__.return_value = mock_session
|
|
||||||
|
|
||||||
repo.save(execution)
|
|
||||||
|
|
||||||
# Check that both merge operations were called (db_model and offload_record)
|
|
||||||
assert mock_session.merge.call_count == 1
|
|
||||||
mock_session.commit.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
class TestWorkflowNodeExecutionModelTruncatedProperties:
|
class TestWorkflowNodeExecutionModelTruncatedProperties:
|
||||||
"""Test the truncated properties on WorkflowNodeExecutionModel."""
|
"""Test the truncated properties on WorkflowNodeExecutionModel."""
|
||||||
|
|||||||
@ -1,243 +0,0 @@
|
|||||||
"""
|
|
||||||
Test context preservation in GraphEngine workers.
|
|
||||||
|
|
||||||
This module tests that Flask app context and context variables are properly
|
|
||||||
preserved when executing nodes in worker threads.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import contextvars
|
|
||||||
import queue
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
from flask import Flask, g
|
|
||||||
|
|
||||||
from core.workflow.enums import NodeType
|
|
||||||
from core.workflow.graph import Graph
|
|
||||||
from core.workflow.graph_engine.worker import Worker
|
|
||||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunSucceededEvent
|
|
||||||
from core.workflow.nodes.base.node import Node
|
|
||||||
from libs.flask_utils import preserve_flask_contexts
|
|
||||||
|
|
||||||
|
|
||||||
class TestContextPreservation:
|
|
||||||
"""Test suite for context preservation in workers."""
|
|
||||||
|
|
||||||
def test_preserve_flask_contexts_with_flask_app(self) -> None:
|
|
||||||
"""Test that Flask app context is preserved in worker context."""
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
# Variable to check if context was available
|
|
||||||
context_available = False
|
|
||||||
|
|
||||||
def worker_task() -> None:
|
|
||||||
nonlocal context_available
|
|
||||||
with preserve_flask_contexts(flask_app=app, context_vars=contextvars.Context()):
|
|
||||||
# Check if we're in app context
|
|
||||||
from flask import has_app_context
|
|
||||||
|
|
||||||
context_available = has_app_context()
|
|
||||||
|
|
||||||
# Run worker task in thread
|
|
||||||
thread = threading.Thread(target=worker_task)
|
|
||||||
thread.start()
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
assert context_available, "Flask app context should be available in worker"
|
|
||||||
|
|
||||||
def test_preserve_flask_contexts_with_context_vars(self) -> None:
|
|
||||||
"""Test that context variables are preserved in worker context."""
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
# Create a context variable
|
|
||||||
test_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_var")
|
|
||||||
test_var.set("test_value")
|
|
||||||
|
|
||||||
# Capture context
|
|
||||||
context = contextvars.copy_context()
|
|
||||||
|
|
||||||
# Variable to store value from worker
|
|
||||||
worker_value: str | None = None
|
|
||||||
|
|
||||||
def worker_task() -> None:
|
|
||||||
nonlocal worker_value
|
|
||||||
with preserve_flask_contexts(flask_app=app, context_vars=context):
|
|
||||||
# Try to get the context variable
|
|
||||||
try:
|
|
||||||
worker_value = test_var.get()
|
|
||||||
except LookupError:
|
|
||||||
worker_value = None
|
|
||||||
|
|
||||||
# Run worker task in thread
|
|
||||||
thread = threading.Thread(target=worker_task)
|
|
||||||
thread.start()
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
assert worker_value == "test_value", "Context variable should be preserved in worker"
|
|
||||||
|
|
||||||
def test_preserve_flask_contexts_with_user(self) -> None:
|
|
||||||
"""Test that Flask app context allows user storage in worker context.
|
|
||||||
|
|
||||||
Note: The existing preserve_flask_contexts preserves user from request context,
|
|
||||||
not from context vars. In worker threads without request context, we can still
|
|
||||||
set user data in g within the app context.
|
|
||||||
"""
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
# Variable to store user from worker
|
|
||||||
worker_can_set_user = False
|
|
||||||
|
|
||||||
def worker_task() -> None:
|
|
||||||
nonlocal worker_can_set_user
|
|
||||||
with preserve_flask_contexts(flask_app=app, context_vars=contextvars.Context()):
|
|
||||||
# Set and verify user in the app context
|
|
||||||
g._login_user = "test_user"
|
|
||||||
worker_can_set_user = hasattr(g, "_login_user") and g._login_user == "test_user"
|
|
||||||
|
|
||||||
# Run worker task in thread
|
|
||||||
thread = threading.Thread(target=worker_task)
|
|
||||||
thread.start()
|
|
||||||
thread.join()
|
|
||||||
|
|
||||||
assert worker_can_set_user, "Should be able to set user in Flask app context within worker"
|
|
||||||
|
|
||||||
def test_worker_with_context(self) -> None:
|
|
||||||
"""Test that Worker class properly uses context preservation."""
|
|
||||||
# Setup Flask app and context
|
|
||||||
app = Flask(__name__)
|
|
||||||
test_var: contextvars.ContextVar[str] = contextvars.ContextVar("test_var")
|
|
||||||
test_var.set("worker_test_value")
|
|
||||||
context = contextvars.copy_context()
|
|
||||||
|
|
||||||
# Create queues
|
|
||||||
ready_queue: queue.Queue[str] = queue.Queue()
|
|
||||||
event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
|
||||||
|
|
||||||
# Create a mock graph with a test node
|
|
||||||
graph = MagicMock(spec=Graph)
|
|
||||||
test_node = MagicMock(spec=Node)
|
|
||||||
|
|
||||||
# Variable to capture context inside node execution
|
|
||||||
captured_value: str | None = None
|
|
||||||
context_available_in_node = False
|
|
||||||
|
|
||||||
def mock_run() -> list[GraphNodeEventBase]:
|
|
||||||
"""Mock node run that checks context."""
|
|
||||||
nonlocal captured_value, context_available_in_node
|
|
||||||
try:
|
|
||||||
captured_value = test_var.get()
|
|
||||||
except LookupError:
|
|
||||||
captured_value = None
|
|
||||||
|
|
||||||
from flask import has_app_context
|
|
||||||
|
|
||||||
context_available_in_node = has_app_context()
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
return [
|
|
||||||
NodeRunSucceededEvent(
|
|
||||||
id="test",
|
|
||||||
node_id="test_node",
|
|
||||||
node_type=NodeType.CODE,
|
|
||||||
in_iteration_id=None,
|
|
||||||
outputs={},
|
|
||||||
start_at=datetime.now(),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
test_node.run = mock_run
|
|
||||||
graph.nodes = {"test_node": test_node}
|
|
||||||
|
|
||||||
# Create worker with context
|
|
||||||
worker = Worker(
|
|
||||||
ready_queue=ready_queue,
|
|
||||||
event_queue=event_queue,
|
|
||||||
graph=graph,
|
|
||||||
worker_id=0,
|
|
||||||
flask_app=app,
|
|
||||||
context_vars=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start worker
|
|
||||||
worker.start()
|
|
||||||
|
|
||||||
# Queue a node for execution
|
|
||||||
ready_queue.put("test_node")
|
|
||||||
|
|
||||||
# Wait for execution
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
# Stop worker
|
|
||||||
worker.stop()
|
|
||||||
worker.join(timeout=1)
|
|
||||||
|
|
||||||
# Check results
|
|
||||||
assert captured_value == "worker_test_value", "Context variable should be available in node execution"
|
|
||||||
assert context_available_in_node, "Flask app context should be available in node execution"
|
|
||||||
|
|
||||||
# Check that event was pushed
|
|
||||||
assert not event_queue.empty(), "Event should be pushed to event queue"
|
|
||||||
event = event_queue.get()
|
|
||||||
assert isinstance(event, NodeRunSucceededEvent), "Should receive NodeRunSucceededEvent"
|
|
||||||
|
|
||||||
def test_worker_without_context(self) -> None:
|
|
||||||
"""Test that Worker still works without context."""
|
|
||||||
# Create queues
|
|
||||||
ready_queue: queue.Queue[str] = queue.Queue()
|
|
||||||
event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
|
||||||
|
|
||||||
# Create a mock graph with a test node
|
|
||||||
graph = MagicMock(spec=Graph)
|
|
||||||
test_node = MagicMock(spec=Node)
|
|
||||||
|
|
||||||
# Flag to check if node was executed
|
|
||||||
node_executed = False
|
|
||||||
|
|
||||||
def mock_run() -> list[GraphNodeEventBase]:
|
|
||||||
"""Mock node run."""
|
|
||||||
nonlocal node_executed
|
|
||||||
node_executed = True
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
return [
|
|
||||||
NodeRunSucceededEvent(
|
|
||||||
id="test",
|
|
||||||
node_id="test_node",
|
|
||||||
node_type=NodeType.CODE,
|
|
||||||
in_iteration_id=None,
|
|
||||||
outputs={},
|
|
||||||
start_at=datetime.now(),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
test_node.run = mock_run
|
|
||||||
graph.nodes = {"test_node": test_node}
|
|
||||||
|
|
||||||
# Create worker without context
|
|
||||||
worker = Worker(
|
|
||||||
ready_queue=ready_queue,
|
|
||||||
event_queue=event_queue,
|
|
||||||
graph=graph,
|
|
||||||
worker_id=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Start worker
|
|
||||||
worker.start()
|
|
||||||
|
|
||||||
# Queue a node for execution
|
|
||||||
ready_queue.put("test_node")
|
|
||||||
|
|
||||||
# Wait for execution
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
# Stop worker
|
|
||||||
worker.stop()
|
|
||||||
worker.join(timeout=1)
|
|
||||||
|
|
||||||
# Check that node was executed
|
|
||||||
assert node_executed, "Node should be executed even without context"
|
|
||||||
|
|
||||||
# Check that event was pushed
|
|
||||||
assert not event_queue.empty(), "Event should be pushed to event queue"
|
|
||||||
@ -3,6 +3,7 @@ Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from unittest.mock import MagicMock, PropertyMock
|
from unittest.mock import MagicMock, PropertyMock
|
||||||
@ -87,7 +88,7 @@ def test_save(repository, session):
|
|||||||
"""Test save method."""
|
"""Test save method."""
|
||||||
session_obj, _ = session
|
session_obj, _ = session
|
||||||
# Create a mock execution
|
# Create a mock execution
|
||||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
execution.id = "test-id"
|
execution.id = "test-id"
|
||||||
execution.node_execution_id = "test-node-execution-id"
|
execution.node_execution_id = "test-node-execution-id"
|
||||||
execution.tenant_id = None
|
execution.tenant_id = None
|
||||||
@ -96,13 +97,14 @@ def test_save(repository, session):
|
|||||||
execution.process_data = None
|
execution.process_data = None
|
||||||
execution.outputs = None
|
execution.outputs = None
|
||||||
execution.metadata = None
|
execution.metadata = None
|
||||||
|
execution.workflow_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# Mock the to_db_model method to return the execution itself
|
# Mock the to_db_model method to return the execution itself
|
||||||
# This simulates the behavior of setting tenant_id and app_id
|
# This simulates the behavior of setting tenant_id and app_id
|
||||||
db_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
db_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||||
db_model.id = "test-id"
|
db_model.id = "test-id"
|
||||||
db_model.node_execution_id = "test-node-execution-id"
|
db_model.node_execution_id = "test-node-execution-id"
|
||||||
repository.to_db_model = MagicMock(return_value=db_model)
|
repository._to_db_model = MagicMock(return_value=db_model)
|
||||||
|
|
||||||
# Mock session.get to return None (no existing record)
|
# Mock session.get to return None (no existing record)
|
||||||
session_obj.get.return_value = None
|
session_obj.get.return_value = None
|
||||||
@ -111,7 +113,7 @@ def test_save(repository, session):
|
|||||||
repository.save(execution)
|
repository.save(execution)
|
||||||
|
|
||||||
# Assert to_db_model was called with the execution
|
# Assert to_db_model was called with the execution
|
||||||
repository.to_db_model.assert_called_once_with(execution)
|
repository._to_db_model.assert_called_once_with(execution)
|
||||||
|
|
||||||
# Assert session.get was called to check for existing record
|
# Assert session.get was called to check for existing record
|
||||||
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, db_model.id)
|
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, db_model.id)
|
||||||
@ -152,7 +154,7 @@ def test_save_with_existing_tenant_id(repository, session):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Mock the to_db_model method to return the modified execution
|
# Mock the to_db_model method to return the modified execution
|
||||||
repository.to_db_model = MagicMock(return_value=modified_execution)
|
repository._to_db_model = MagicMock(return_value=modified_execution)
|
||||||
|
|
||||||
# Mock session.get to return an existing record
|
# Mock session.get to return an existing record
|
||||||
existing_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
existing_model = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||||
@ -162,7 +164,7 @@ def test_save_with_existing_tenant_id(repository, session):
|
|||||||
repository.save(execution)
|
repository.save(execution)
|
||||||
|
|
||||||
# Assert to_db_model was called with the execution
|
# Assert to_db_model was called with the execution
|
||||||
repository.to_db_model.assert_called_once_with(execution)
|
repository._to_db_model.assert_called_once_with(execution)
|
||||||
|
|
||||||
# Assert session.get was called to check for existing record
|
# Assert session.get was called to check for existing record
|
||||||
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, modified_execution.id)
|
session_obj.get.assert_called_once_with(WorkflowNodeExecutionModel, modified_execution.id)
|
||||||
@ -179,10 +181,19 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
|||||||
session_obj, _ = session
|
session_obj, _ = session
|
||||||
# Set up mock
|
# Set up mock
|
||||||
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
mock_select = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.select")
|
||||||
|
mock_asc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.asc")
|
||||||
|
mock_desc = mocker.patch("core.repositories.sqlalchemy_workflow_node_execution_repository.desc")
|
||||||
|
|
||||||
|
mock_WorkflowNodeExecutionModel = mocker.patch(
|
||||||
|
"core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel"
|
||||||
|
)
|
||||||
mock_stmt = mocker.MagicMock()
|
mock_stmt = mocker.MagicMock()
|
||||||
mock_select.return_value = mock_stmt
|
mock_select.return_value = mock_stmt
|
||||||
mock_stmt.where.return_value = mock_stmt
|
mock_stmt.where.return_value = mock_stmt
|
||||||
mock_stmt.order_by.return_value = mock_stmt
|
mock_stmt.order_by.return_value = mock_stmt
|
||||||
|
mock_asc.return_value = mock_stmt
|
||||||
|
mock_desc.return_value = mock_stmt
|
||||||
|
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.return_value = mock_stmt
|
||||||
|
|
||||||
# Create a properly configured mock execution
|
# Create a properly configured mock execution
|
||||||
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
mock_execution = mocker.MagicMock(spec=WorkflowNodeExecutionModel)
|
||||||
@ -201,6 +212,7 @@ def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
|||||||
# Assert select was called with correct parameters
|
# Assert select was called with correct parameters
|
||||||
mock_select.assert_called_once()
|
mock_select.assert_called_once()
|
||||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||||
|
mock_WorkflowNodeExecutionModel.preload_offload_data_and_files.assert_called_once_with(mock_stmt)
|
||||||
# Assert _to_domain_model was called with the mock execution
|
# Assert _to_domain_model was called with the mock execution
|
||||||
repository._to_domain_model.assert_called_once_with(mock_execution)
|
repository._to_domain_model.assert_called_once_with(mock_execution)
|
||||||
# Assert the result contains our mock domain model
|
# Assert the result contains our mock domain model
|
||||||
@ -236,7 +248,7 @@ def test_to_db_model(repository):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Convert to DB model
|
# Convert to DB model
|
||||||
db_model = repository.to_db_model(domain_model)
|
db_model = repository._to_db_model(domain_model)
|
||||||
|
|
||||||
# Assert DB model has correct values
|
# Assert DB model has correct values
|
||||||
assert isinstance(db_model, WorkflowNodeExecutionModel)
|
assert isinstance(db_model, WorkflowNodeExecutionModel)
|
||||||
|
|||||||
@ -2,24 +2,18 @@
|
|||||||
Unit tests for SQLAlchemyWorkflowNodeExecutionRepository, focusing on process_data truncation functionality.
|
Unit tests for SQLAlchemyWorkflowNodeExecutionRepository, focusing on process_data truncation functionality.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import MagicMock, Mock
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
from core.repositories.sqlalchemy_workflow_node_execution_repository import (
|
||||||
SQLAlchemyWorkflowNodeExecutionRepository,
|
SQLAlchemyWorkflowNodeExecutionRepository,
|
||||||
_InputsOutputsTruncationResult,
|
|
||||||
)
|
)
|
||||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
||||||
from core.workflow.enums import NodeType
|
from core.workflow.enums import NodeType
|
||||||
from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom
|
||||||
from models.model import UploadFile
|
|
||||||
from models.workflow import WorkflowNodeExecutionOffload
|
|
||||||
|
|
||||||
|
|
||||||
class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
||||||
@ -74,154 +68,6 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
|||||||
created_at=datetime.now(),
|
created_at=datetime.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("core.repositories.sqlalchemy_workflow_node_execution_repository.dify_config")
|
|
||||||
def test_to_db_model_with_small_process_data(self, mock_config):
|
|
||||||
"""Test _to_db_model with small process_data that doesn't need truncation."""
|
|
||||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
|
|
||||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
|
|
||||||
mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
|
|
||||||
|
|
||||||
repository = self.create_repository()
|
|
||||||
small_process_data = {"small": "data", "count": 5}
|
|
||||||
|
|
||||||
execution = self.create_workflow_node_execution(process_data=small_process_data)
|
|
||||||
|
|
||||||
with patch.object(repository, "_truncate_and_upload", return_value=None) as mock_truncate:
|
|
||||||
db_model = repository._to_db_model(execution)
|
|
||||||
|
|
||||||
# Should try to truncate but return None (no truncation needed)
|
|
||||||
mock_truncate.assert_called_once_with(small_process_data, execution.id, "_process_data")
|
|
||||||
|
|
||||||
# Process data should be stored directly in database
|
|
||||||
assert db_model.process_data is not None
|
|
||||||
stored_data = json.loads(db_model.process_data)
|
|
||||||
assert stored_data == small_process_data
|
|
||||||
|
|
||||||
# No offload data should be created for process_data
|
|
||||||
assert db_model.offload_data is None
|
|
||||||
|
|
||||||
def test_to_db_model_with_large_process_data(self):
|
|
||||||
"""Test _to_db_model with large process_data that needs truncation."""
|
|
||||||
repository = self.create_repository()
|
|
||||||
|
|
||||||
# Create large process_data that would need truncation
|
|
||||||
large_process_data = {
|
|
||||||
"large_field": "x" * 10000, # Very large string
|
|
||||||
"metadata": {"type": "processing", "timestamp": 1234567890},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mock truncation result
|
|
||||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": {"type": "processing", "timestamp": 1234567890}}
|
|
||||||
|
|
||||||
mock_upload_file = Mock(spec=UploadFile)
|
|
||||||
mock_upload_file.id = "mock-file-id"
|
|
||||||
|
|
||||||
mock_offload = Mock(spec=WorkflowNodeExecutionOffload)
|
|
||||||
truncation_result = _InputsOutputsTruncationResult(
|
|
||||||
truncated_value=truncated_data, file=mock_upload_file, offload=mock_offload
|
|
||||||
)
|
|
||||||
|
|
||||||
execution = self.create_workflow_node_execution(process_data=large_process_data)
|
|
||||||
|
|
||||||
with patch.object(repository, "_truncate_and_upload", return_value=truncation_result) as mock_truncate:
|
|
||||||
db_model = repository._to_db_model(execution)
|
|
||||||
|
|
||||||
# Should call truncate with correct parameters
|
|
||||||
mock_truncate.assert_called_once_with(large_process_data, execution.id, "_process_data")
|
|
||||||
|
|
||||||
# Truncated data should be stored in database
|
|
||||||
assert db_model.process_data is not None
|
|
||||||
stored_data = json.loads(db_model.process_data)
|
|
||||||
assert stored_data == truncated_data
|
|
||||||
|
|
||||||
# Domain model should have truncated data set
|
|
||||||
assert execution.process_data_truncated is True
|
|
||||||
assert execution.get_truncated_process_data() == truncated_data
|
|
||||||
|
|
||||||
# Offload data should be created
|
|
||||||
assert db_model.offload_data is not None
|
|
||||||
assert len(db_model.offload_data) > 0
|
|
||||||
# Find the process_data offload entry
|
|
||||||
process_data_offload = next(
|
|
||||||
(item for item in db_model.offload_data if hasattr(item, "file_id") and item.file_id == "mock-file-id"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert process_data_offload is not None
|
|
||||||
|
|
||||||
def test_to_db_model_with_none_process_data(self):
|
|
||||||
"""Test _to_db_model with None process_data."""
|
|
||||||
repository = self.create_repository()
|
|
||||||
execution = self.create_workflow_node_execution(process_data=None)
|
|
||||||
|
|
||||||
with patch.object(repository, "_truncate_and_upload") as mock_truncate:
|
|
||||||
db_model = repository._to_db_model(execution)
|
|
||||||
|
|
||||||
# Should not call truncate for None data
|
|
||||||
mock_truncate.assert_not_called()
|
|
||||||
|
|
||||||
# Process data should be None
|
|
||||||
assert db_model.process_data is None
|
|
||||||
|
|
||||||
# No offload data should be created
|
|
||||||
assert db_model.offload_data == []
|
|
||||||
|
|
||||||
def test_to_domain_model_with_offloaded_process_data(self):
|
|
||||||
"""Test _to_domain_model with offloaded process_data."""
|
|
||||||
repository = self.create_repository()
|
|
||||||
|
|
||||||
# Create mock database model with offload data
|
|
||||||
db_model = Mock(spec=WorkflowNodeExecutionModel)
|
|
||||||
db_model.id = "test-execution-id"
|
|
||||||
db_model.node_execution_id = "test-node-execution-id"
|
|
||||||
db_model.workflow_id = "test-workflow-id"
|
|
||||||
db_model.workflow_run_id = None
|
|
||||||
db_model.index = 1
|
|
||||||
db_model.predecessor_node_id = None
|
|
||||||
db_model.node_id = "test-node-id"
|
|
||||||
db_model.node_type = "llm"
|
|
||||||
db_model.title = "Test Node"
|
|
||||||
db_model.status = "succeeded"
|
|
||||||
db_model.error = None
|
|
||||||
db_model.elapsed_time = 1.5
|
|
||||||
db_model.created_at = datetime.now()
|
|
||||||
db_model.finished_at = None
|
|
||||||
|
|
||||||
# Mock truncated process_data from database
|
|
||||||
truncated_process_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
|
||||||
db_model.process_data_dict = truncated_process_data
|
|
||||||
db_model.inputs_dict = None
|
|
||||||
db_model.outputs_dict = None
|
|
||||||
db_model.execution_metadata_dict = {}
|
|
||||||
|
|
||||||
# Mock offload data with process_data file
|
|
||||||
mock_offload_data = Mock(spec=WorkflowNodeExecutionOffload)
|
|
||||||
mock_offload_data.inputs_file_id = None
|
|
||||||
mock_offload_data.inputs_file = None
|
|
||||||
mock_offload_data.outputs_file_id = None
|
|
||||||
mock_offload_data.outputs_file = None
|
|
||||||
mock_offload_data.process_data_file_id = "process-data-file-id"
|
|
||||||
|
|
||||||
mock_process_data_file = Mock(spec=UploadFile)
|
|
||||||
mock_offload_data.process_data_file = mock_process_data_file
|
|
||||||
|
|
||||||
db_model.offload_data = [mock_offload_data]
|
|
||||||
|
|
||||||
# Mock the file loading
|
|
||||||
original_process_data = {"large_field": "x" * 10000, "metadata": "info"}
|
|
||||||
|
|
||||||
with patch.object(repository, "_load_file", return_value=original_process_data) as mock_load:
|
|
||||||
domain_model = repository._to_domain_model(db_model)
|
|
||||||
|
|
||||||
# Should load the file
|
|
||||||
mock_load.assert_called_once_with(mock_process_data_file)
|
|
||||||
|
|
||||||
# Domain model should have original data
|
|
||||||
assert domain_model.process_data == original_process_data
|
|
||||||
|
|
||||||
# Domain model should have truncated data set
|
|
||||||
assert domain_model.process_data_truncated is True
|
|
||||||
assert domain_model.get_truncated_process_data() == truncated_process_data
|
|
||||||
|
|
||||||
def test_to_domain_model_without_offload_data(self):
|
def test_to_domain_model_without_offload_data(self):
|
||||||
"""Test _to_domain_model without offload data."""
|
"""Test _to_domain_model without offload data."""
|
||||||
repository = self.create_repository()
|
repository = self.create_repository()
|
||||||
@ -258,116 +104,3 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData:
|
|||||||
# Should not be truncated
|
# Should not be truncated
|
||||||
assert domain_model.process_data_truncated is False
|
assert domain_model.process_data_truncated is False
|
||||||
assert domain_model.get_truncated_process_data() is None
|
assert domain_model.get_truncated_process_data() is None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TruncationScenario:
|
|
||||||
"""Test scenario for truncation functionality."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
process_data: dict[str, Any] | None
|
|
||||||
should_truncate: bool
|
|
||||||
expected_truncated: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class TestProcessDataTruncationScenarios:
|
|
||||||
"""Test various scenarios for process_data truncation."""
|
|
||||||
|
|
||||||
def get_truncation_scenarios(self) -> list[TruncationScenario]:
|
|
||||||
"""Create test scenarios for truncation."""
|
|
||||||
return [
|
|
||||||
TruncationScenario(
|
|
||||||
name="none_data",
|
|
||||||
process_data=None,
|
|
||||||
should_truncate=False,
|
|
||||||
),
|
|
||||||
TruncationScenario(
|
|
||||||
name="small_data",
|
|
||||||
process_data={"key": "value"},
|
|
||||||
should_truncate=False,
|
|
||||||
),
|
|
||||||
TruncationScenario(
|
|
||||||
name="large_data",
|
|
||||||
process_data={"large": "x" * 10000},
|
|
||||||
should_truncate=True,
|
|
||||||
expected_truncated=True,
|
|
||||||
),
|
|
||||||
TruncationScenario(
|
|
||||||
name="empty_data",
|
|
||||||
process_data={},
|
|
||||||
should_truncate=False,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"scenario",
|
|
||||||
[
|
|
||||||
TruncationScenario("none_data", None, False, False),
|
|
||||||
TruncationScenario("small_data", {"small": "data"}, False, False),
|
|
||||||
TruncationScenario("large_data", {"large": "x" * 10000}, True, True),
|
|
||||||
TruncationScenario("empty_data", {}, False, False),
|
|
||||||
],
|
|
||||||
ids=["none_data", "small_data", "large_data", "empty_data"],
|
|
||||||
)
|
|
||||||
def test_process_data_truncation_scenarios(self, scenario: TruncationScenario):
|
|
||||||
"""Test various process_data truncation scenarios."""
|
|
||||||
repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
|
||||||
session_factory=MagicMock(spec=sessionmaker),
|
|
||||||
user=Mock(spec=Account, id="test-user", tenant_id="test-tenant"),
|
|
||||||
app_id="test-app",
|
|
||||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
|
||||||
)
|
|
||||||
|
|
||||||
execution = WorkflowNodeExecution(
|
|
||||||
id="test-execution-id",
|
|
||||||
workflow_id="test-workflow-id",
|
|
||||||
index=1,
|
|
||||||
node_id="test-node-id",
|
|
||||||
node_type=NodeType.LLM,
|
|
||||||
title="Test Node",
|
|
||||||
process_data=scenario.process_data,
|
|
||||||
created_at=datetime.now(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock truncation behavior
|
|
||||||
if scenario.should_truncate:
|
|
||||||
truncated_data = {"truncated": True}
|
|
||||||
mock_file = Mock(spec=UploadFile, id="file-id")
|
|
||||||
mock_offload = Mock(spec=WorkflowNodeExecutionOffload)
|
|
||||||
truncation_result = _InputsOutputsTruncationResult(
|
|
||||||
truncated_value=truncated_data, file=mock_file, offload=mock_offload
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(repository, "_truncate_and_upload", return_value=truncation_result):
|
|
||||||
db_model = repository._to_db_model(execution)
|
|
||||||
|
|
||||||
# Should create offload data
|
|
||||||
assert db_model.offload_data is not None
|
|
||||||
assert len(db_model.offload_data) > 0
|
|
||||||
# Find the process_data offload entry
|
|
||||||
process_data_offload = next(
|
|
||||||
(item for item in db_model.offload_data if hasattr(item, "file_id") and item.file_id == "file-id"),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert process_data_offload is not None
|
|
||||||
assert execution.process_data_truncated == scenario.expected_truncated
|
|
||||||
else:
|
|
||||||
with patch.object(repository, "_truncate_and_upload", return_value=None):
|
|
||||||
db_model = repository._to_db_model(execution)
|
|
||||||
|
|
||||||
# Should not create offload data or set truncation
|
|
||||||
if scenario.process_data is None:
|
|
||||||
assert db_model.offload_data == []
|
|
||||||
assert db_model.process_data is None
|
|
||||||
else:
|
|
||||||
# For small data, might have offload_data from other fields but not process_data
|
|
||||||
if db_model.offload_data:
|
|
||||||
# Check that no process_data offload entries exist
|
|
||||||
process_data_offloads = [
|
|
||||||
item
|
|
||||||
for item in db_model.offload_data
|
|
||||||
if hasattr(item, "type_") and item.type_.value == "process_data"
|
|
||||||
]
|
|
||||||
assert len(process_data_offloads) == 0
|
|
||||||
|
|
||||||
assert execution.process_data_truncated is False
|
|
||||||
|
|||||||
@ -104,6 +104,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||||
patch("extensions.ext_database.db.session") as mock_db,
|
patch("extensions.ext_database.db.session") as mock_db,
|
||||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||||
|
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
|
||||||
):
|
):
|
||||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||||
mock_naive_utc_now.return_value = current_time
|
mock_naive_utc_now.return_value = current_time
|
||||||
@ -114,6 +115,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
"db_session": mock_db,
|
"db_session": mock_db,
|
||||||
"naive_utc_now": mock_naive_utc_now,
|
"naive_utc_now": mock_naive_utc_now,
|
||||||
"current_time": current_time,
|
"current_time": current_time,
|
||||||
|
"has_dataset_same_name": has_dataset_same_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -190,9 +192,9 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
"external_knowledge_api_id": "new_api_id",
|
"external_knowledge_api_id": "new_api_id",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
|
|
||||||
# Verify permission check was called
|
|
||||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||||
|
|
||||||
# Verify dataset and binding updates
|
# Verify dataset and binding updates
|
||||||
@ -214,6 +216,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||||
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
|
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
with pytest.raises(ValueError) as context:
|
with pytest.raises(ValueError) as context:
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
@ -227,6 +230,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||||
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
with pytest.raises(ValueError) as context:
|
with pytest.raises(ValueError) as context:
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
@ -250,6 +254,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
"external_knowledge_id": "knowledge_id",
|
"external_knowledge_id": "knowledge_id",
|
||||||
"external_knowledge_api_id": "api_id",
|
"external_knowledge_api_id": "api_id",
|
||||||
}
|
}
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
with pytest.raises(ValueError) as context:
|
with pytest.raises(ValueError) as context:
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
@ -280,6 +285,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
"embedding_model": "text-embedding-ada-002",
|
"embedding_model": "text-embedding-ada-002",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
|
|
||||||
# Verify permission check was called
|
# Verify permission check was called
|
||||||
@ -320,6 +326,8 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
"embedding_model": None, # Should be filtered out
|
"embedding_model": None, # Should be filtered out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
|
|
||||||
# Verify database update was called with filtered data
|
# Verify database update was called with filtered data
|
||||||
@ -356,6 +364,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||||
|
|
||||||
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
|
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
|
|
||||||
@ -402,6 +411,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
"embedding_model": "text-embedding-ada-002",
|
"embedding_model": "text-embedding-ada-002",
|
||||||
"retrieval_model": "new_model",
|
"retrieval_model": "new_model",
|
||||||
}
|
}
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
|
|
||||||
@ -453,6 +463,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||||
|
|
||||||
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
|
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
|
|
||||||
@ -505,6 +516,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
"embedding_model": "text-embedding-3-small",
|
"embedding_model": "text-embedding-3-small",
|
||||||
"retrieval_model": "new_model",
|
"retrieval_model": "new_model",
|
||||||
}
|
}
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
|
|
||||||
@ -558,6 +570,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
"indexing_technique": "high_quality", # Same as current
|
"indexing_technique": "high_quality", # Same as current
|
||||||
"retrieval_model": "new_model",
|
"retrieval_model": "new_model",
|
||||||
}
|
}
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
|
|
||||||
@ -588,6 +601,7 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
|
|
||||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||||
update_data = {"name": "new_name"}
|
update_data = {"name": "new_name"}
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
with pytest.raises(ValueError) as context:
|
with pytest.raises(ValueError) as context:
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
@ -604,6 +618,8 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
|
|
||||||
update_data = {"name": "new_name"}
|
update_data = {"name": "new_name"}
|
||||||
|
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
with pytest.raises(NoPermissionError):
|
with pytest.raises(NoPermissionError):
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
|
|
||||||
@ -628,6 +644,8 @@ class TestDatasetServiceUpdateDataset:
|
|||||||
"retrieval_model": "new_model",
|
"retrieval_model": "new_model",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||||
|
|
||||||
with pytest.raises(Exception) as context:
|
with pytest.raises(Exception) as context:
|
||||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||||
|
|
||||||
|
|||||||
@ -310,7 +310,7 @@ class TestWorkflowDraftVariableService:
|
|||||||
|
|
||||||
# Create mock execution record
|
# Create mock execution record
|
||||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||||
mock_execution.outputs_dict = {"test_var": "output_value"}
|
mock_execution.load_full_outputs.return_value = {"test_var": "output_value"}
|
||||||
|
|
||||||
# Mock the repository to return the execution record
|
# Mock the repository to return the execution record
|
||||||
service._api_node_execution_repo = Mock()
|
service._api_node_execution_repo = Mock()
|
||||||
@ -383,7 +383,7 @@ class TestWorkflowDraftVariableService:
|
|||||||
|
|
||||||
# Create mock execution record
|
# Create mock execution record
|
||||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||||
mock_execution.outputs_dict = {"sys.files": "[]"}
|
mock_execution.load_full_outputs.return_value = {"sys.files": "[]"}
|
||||||
|
|
||||||
# Mock the repository to return the execution record
|
# Mock the repository to return the execution record
|
||||||
service._api_node_execution_repo = Mock()
|
service._api_node_execution_repo = Mock()
|
||||||
@ -415,7 +415,7 @@ class TestWorkflowDraftVariableService:
|
|||||||
|
|
||||||
# Create mock execution record
|
# Create mock execution record
|
||||||
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
mock_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||||
mock_execution.outputs_dict = {"sys.query": "reset query"}
|
mock_execution.load_full_outputs.return_value = {"sys.query": "reset query"}
|
||||||
|
|
||||||
# Mock the repository to return the execution record
|
# Mock the repository to return the execution record
|
||||||
service._api_node_execution_repo = Mock()
|
service._api_node_execution_repo = Mock()
|
||||||
|
|||||||
@ -313,7 +313,7 @@ class TestDeleteDraftVariableOffloadData:
|
|||||||
assert result == 1 # Only one storage deletion succeeded
|
assert result == 1 # Only one storage deletion succeeded
|
||||||
|
|
||||||
# Verify warning was logged
|
# Verify warning was logged
|
||||||
mock_logging.warning.assert_called_once_with("Failed to delete storage object storage/key/1: Storage error")
|
mock_logging.exception.assert_called_once_with("Failed to delete storage object %s", "storage/key/1")
|
||||||
|
|
||||||
# Verify both database cleanup calls still happened
|
# Verify both database cleanup calls still happened
|
||||||
assert mock_conn.execute.call_count == 3
|
assert mock_conn.execute.call_count == 3
|
||||||
@ -334,4 +334,4 @@ class TestDeleteDraftVariableOffloadData:
|
|||||||
assert result == 0
|
assert result == 0
|
||||||
|
|
||||||
# Verify error was logged
|
# Verify error was logged
|
||||||
mock_logging.error.assert_called_once_with("Error deleting draft variable offload data: Database error")
|
mock_logging.exception.assert_called_once_with("Error deleting draft variable offload data:")
|
||||||
|
|||||||
Reference in New Issue
Block a user