Merge branch 'feat/agent-node-v2' into feat/support-agent-sandbox

This commit is contained in:
Novice
2026-01-09 14:19:27 +08:00
294 changed files with 14621 additions and 3601 deletions

View File

@ -444,6 +444,78 @@ class TestAnnotationService:
assert total == 1
assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content
def test_get_annotation_list_by_app_id_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention.
This test verifies:
- Special characters (%, _, \) in keyword are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotations with special characters in content
annotation_with_percent = {
"question": "Question with 50% discount",
"answer": "Answer about 50% discount offer",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id)
annotation_with_underscore = {
"question": "Question with test_data",
"answer": "Answer about test_data value",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id)
annotation_with_backslash = {
"question": "Question with path\\to\\file",
"answer": "Answer about path\\to\\file location",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id)
# Create annotation that should NOT match (contains % but as part of different text)
annotation_no_match = {
"question": "Question with 100% different",
"answer": "Answer about 100% different content",
}
AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id)
# Test 1: Search with % character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="50%"
)
assert total == 1
assert len(annotation_list) == 1
assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content
# Test 2: Search with _ character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="test_data"
)
assert total == 1
assert len(annotation_list) == 1
assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content
# Test 3: Search with \ character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="path\\to\\file"
)
assert total == 1
assert len(annotation_list) == 1
assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content
# Test 4: Search with % should NOT match 100% (verifies escaping works)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="50%"
)
# Should only find the 50% annotation, not the 100% one
assert total == 1
assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list)
def test_get_annotation_list_by_app_id_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -7,7 +7,9 @@ from constants.model_template import default_app_templates
from models import Account
from models.model import App, Site
from services.account_service import AccountService, TenantService
from services.app_service import AppService
# Delay import of AppService to avoid circular dependency
# from services.app_service import AppService
class TestAppService:
@ -71,6 +73,9 @@ class TestAppService:
}
# Create app
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -109,6 +114,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Test different app modes
@ -159,6 +167,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
created_app = app_service.create_app(tenant.id, app_args, account)
@ -194,6 +205,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create multiple apps
@ -245,6 +259,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create apps with different modes
@ -315,6 +332,9 @@ class TestAppService:
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create an app
@ -392,6 +412,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -458,6 +481,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -508,6 +534,9 @@ class TestAppService:
"icon_background": "#45B7D1",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -562,6 +591,9 @@ class TestAppService:
"icon_background": "#74B9FF",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -617,6 +649,9 @@ class TestAppService:
"icon_background": "#A29BFE",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -672,6 +707,9 @@ class TestAppService:
"icon_background": "#FD79A8",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -720,6 +758,9 @@ class TestAppService:
"icon_background": "#E17055",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -768,6 +809,9 @@ class TestAppService:
"icon_background": "#00B894",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -826,6 +870,9 @@ class TestAppService:
"icon_background": "#6C5CE7",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -862,6 +909,9 @@ class TestAppService:
"icon_background": "#FDCB6E",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -899,6 +949,9 @@ class TestAppService:
"icon_background": "#E84393",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -947,8 +1000,132 @@ class TestAppService:
"icon_background": "#D63031",
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Attempt to create app with invalid mode
with pytest.raises(ValueError, match="invalid mode value"):
app_service.create_app(tenant.id, app_args, account)
def test_get_apps_with_special_characters_in_name(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test app retrieval with special characters in name search to verify SQL injection prevention.
This test verifies:
- Special characters (%, _, \) in name search are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
fake = Faker()
# Create account and tenant first
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
# Create apps with special characters in names
app_with_percent = app_service.create_app(
tenant.id,
{
"name": "App with 50% discount",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
app_with_underscore = app_service.create_app(
tenant.id,
{
"name": "test_data_app",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
app_with_backslash = app_service.create_app(
tenant.id,
{
"name": "path\\to\\app",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
# Create app that should NOT match
app_no_match = app_service.create_app(
tenant.id,
{
"name": "100% different",
"description": fake.text(max_nb_chars=100),
"mode": "chat",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#FF6B6B",
"api_rph": 100,
"api_rpm": 10,
},
account,
)
# Test 1: Search with % character
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "App with 50% discount"
# Test 2: Search with _ character
args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "test_data_app"
# Test 3: Search with \ character
args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert len(paginated_apps.items) == 1
assert paginated_apps.items[0].name == "path\\to\\app"
# Test 4: Search with % should NOT match 100% (verifies escaping works)
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
assert paginated_apps is not None
assert paginated_apps.total == 1
assert all("50%" in app.name for app in paginated_apps.items)

View File

@ -1,3 +1,4 @@
import uuid
from unittest.mock import create_autospec, patch
import pytest
@ -312,6 +313,85 @@ class TestTagService:
result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent")
assert len(result_no_match) == 0
def test_get_tags_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test tag retrieval with special characters in keyword to verify SQL injection prevention.
This test verifies:
- Special characters (%, _, \) in keyword are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
# Arrange: Create test data
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
from extensions.ext_database import db
# Create tags with special characters in names
tag_with_percent = Tag(
name="50% discount",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_with_percent.id = str(uuid.uuid4())
db.session.add(tag_with_percent)
tag_with_underscore = Tag(
name="test_data_tag",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_with_underscore.id = str(uuid.uuid4())
db.session.add(tag_with_underscore)
tag_with_backslash = Tag(
name="path\\to\\tag",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_with_backslash.id = str(uuid.uuid4())
db.session.add(tag_with_backslash)
# Create tag that should NOT match
tag_no_match = Tag(
name="100% different",
type="app",
tenant_id=tenant.id,
created_by=account.id,
)
tag_no_match.id = str(uuid.uuid4())
db.session.add(tag_no_match)
db.session.commit()
# Act & Assert: Test 1 - Search with % character
result = TagService.get_tags("app", tenant.id, keyword="50%")
assert len(result) == 1
assert result[0].name == "50% discount"
# Test 2 - Search with _ character
result = TagService.get_tags("app", tenant.id, keyword="test_data")
assert len(result) == 1
assert result[0].name == "test_data_tag"
# Test 3 - Search with \ character
result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag")
assert len(result) == 1
assert result[0].name == "path\\to\\tag"
# Test 4 - Search with % should NOT match 100% (verifies escaping works)
result = TagService.get_tags("app", tenant.id, keyword="50%")
assert len(result) == 1
assert all("50%" in item.name for item in result)
def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test tag retrieval when no tags exist.

View File

@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
from services.account_service import AccountService, TenantService
from services.app_service import AppService
# Delay import of AppService to avoid circular dependency
# from services.app_service import AppService
from services.workflow_app_service import WorkflowAppService
@ -86,6 +88,9 @@ class TestWorkflowAppService:
"api_rpm": 10,
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -147,6 +152,9 @@ class TestWorkflowAppService:
"api_rpm": 10,
}
# Import here to avoid circular dependency
from services.app_service import AppService
app_service = AppService()
app = app_service.create_app(tenant.id, app_args, account)
@ -308,6 +316,156 @@ class TestWorkflowAppService:
assert result_no_match["total"] == 0
assert len(result_no_match["data"]) == 0
def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention.
This test verifies:
- Special characters (%, _) in keyword are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
# Arrange: Create test data
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account)
from extensions.ext_database import db
service = WorkflowAppService()
# Test 1: Search with % character
workflow_run_1 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
type="workflow",
triggered_from="app-run",
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
status="succeeded",
inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}),
outputs=json.dumps({"result": "50% discount applied", "status": "success"}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_1)
db.session.flush()
workflow_app_log_1 = WorkflowAppLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_1.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
workflow_app_log_1.id = str(uuid.uuid4())
workflow_app_log_1.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_1)
db.session.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
)
# Should find the workflow_run_1 entry
assert result["total"] >= 1
assert len(result["data"]) >= 1
assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"])
# Test 2: Search with _ character
workflow_run_2 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
type="workflow",
triggered_from="app-run",
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
status="succeeded",
inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}),
outputs=json.dumps({"result": "test_data_value found", "status": "success"}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_2)
db.session.flush()
workflow_app_log_2 = WorkflowAppLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_2.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
workflow_app_log_2.id = str(uuid.uuid4())
workflow_app_log_2.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_2)
db.session.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20
)
# Should find the workflow_run_2 entry
assert result["total"] >= 1
assert len(result["data"]) >= 1
assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"])
# Test 3: Search with % should NOT match 100% (verifies escaping works correctly)
workflow_run_4 = WorkflowRun(
id=str(uuid.uuid4()),
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
type="workflow",
triggered_from="app-run",
version="1.0.0",
graph=json.dumps({"nodes": [], "edges": []}),
status="succeeded",
inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}),
outputs=json.dumps({"result": "100% different result", "status": "success"}),
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_4)
db.session.flush()
workflow_app_log_4 = WorkflowAppLog(
tenant_id=app.tenant_id,
app_id=app.id,
workflow_id=workflow.id,
workflow_run_id=workflow_run_4.id,
created_from="service-api",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
workflow_app_log_4.id = str(uuid.uuid4())
workflow_app_log_4.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_4)
db.session.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
)
# Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4)
# This verifies that escaping works correctly - 50% should not match 100%
assert result["total"] >= 1
assert len(result["data"]) >= 1
# Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different)
found_run_ids = [log.workflow_run_id for log in result["data"]]
assert workflow_run_1.id in found_run_ids
assert workflow_run_4.id not in found_run_ids
def test_get_paginate_workflow_app_logs_with_status_filter(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -165,7 +165,7 @@ class TestRagPipelineRunTasks:
"files": [],
"user_id": account.id,
"stream": False,
"invoke_from": "published",
"invoke_from": InvokeFrom.PUBLISHED_PIPELINE.value,
"workflow_execution_id": str(uuid.uuid4()),
"pipeline_config": {
"app_id": str(uuid.uuid4()),
@ -249,7 +249,7 @@ class TestRagPipelineRunTasks:
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
@ -294,7 +294,7 @@ class TestRagPipelineRunTasks:
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
@ -743,7 +743,7 @@ class TestRagPipelineRunTasks:
assert call_kwargs["pipeline"].id == pipeline.id
assert call_kwargs["workflow_id"] == workflow.id
assert call_kwargs["user"].id == account.id
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE
assert call_kwargs["streaming"] == False
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)

View File

@ -16,6 +16,7 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
@ -51,6 +52,7 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
# Set minimal required env vars
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
@ -75,6 +77,7 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
# Set environment variables using monkeypatch
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
@ -124,6 +127,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch):
# Set environment variables using monkeypatch
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
@ -140,6 +144,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch):
def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch):
"""Test that DB_EXTRAS options are properly merged with default timezone setting"""
# Set environment variables
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
@ -199,6 +204,7 @@ def test_celery_broker_url_with_special_chars_password(
# Set up basic required environment variables (following existing pattern)
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_TYPE", "postgresql")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")

View File

@ -0,0 +1,285 @@
from __future__ import annotations
import builtins
import sys
from datetime import datetime
from importlib import util
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
import pytest
from flask.views import MethodView
# kombu references MethodView as a global when importing celery/kombu pools.
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
def _load_app_module():
module_name = "controllers.console.app.app"
if module_name in sys.modules:
return sys.modules[module_name]
root = Path(__file__).resolve().parents[5]
module_path = root / "controllers" / "console" / "app" / "app.py"
class _StubNamespace:
def __init__(self):
self.models: dict[str, Any] = {}
self.payload = None
def schema_model(self, name, schema):
self.models[name] = schema
def _decorator(self, obj):
return obj
def doc(self, *args, **kwargs):
return self._decorator
def expect(self, *args, **kwargs):
return self._decorator
def response(self, *args, **kwargs):
return self._decorator
def route(self, *args, **kwargs):
def decorator(obj):
return obj
return decorator
stub_namespace = _StubNamespace()
original_console = sys.modules.get("controllers.console")
original_app_pkg = sys.modules.get("controllers.console.app")
stubbed_modules: list[tuple[str, ModuleType | None]] = []
console_module = ModuleType("controllers.console")
console_module.__path__ = [str(root / "controllers" / "console")]
console_module.console_ns = stub_namespace
console_module.api = None
console_module.bp = None
sys.modules["controllers.console"] = console_module
app_package = ModuleType("controllers.console.app")
app_package.__path__ = [str(root / "controllers" / "console" / "app")]
sys.modules["controllers.console.app"] = app_package
console_module.app = app_package
def _stub_module(name: str, attrs: dict[str, Any]):
original = sys.modules.get(name)
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
sys.modules[name] = module
stubbed_modules.append((name, original))
class _OpsTraceManager:
@staticmethod
def get_app_tracing_config(app_id: str) -> dict[str, Any]:
return {}
@staticmethod
def update_app_tracing_config(app_id: str, **kwargs) -> None:
return None
_stub_module(
"core.ops.ops_trace_manager",
{
"OpsTraceManager": _OpsTraceManager,
"TraceQueueManager": object,
"TraceTask": object,
},
)
spec = util.spec_from_file_location(module_name, module_path)
module = util.module_from_spec(spec)
sys.modules[module_name] = module
try:
assert spec.loader is not None
spec.loader.exec_module(module)
finally:
for name, original in reversed(stubbed_modules):
if original is not None:
sys.modules[name] = original
else:
sys.modules.pop(name, None)
if original_console is not None:
sys.modules["controllers.console"] = original_console
else:
sys.modules.pop("controllers.console", None)
if original_app_pkg is not None:
sys.modules["controllers.console.app"] = original_app_pkg
else:
sys.modules.pop("controllers.console.app", None)
return module
_app_module = _load_app_module()
AppDetailWithSite = _app_module.AppDetailWithSite
AppPagination = _app_module.AppPagination
AppPartial = _app_module.AppPartial
@pytest.fixture(autouse=True)
def patch_signed_url(monkeypatch):
"""Ensure icon URL generation uses a deterministic helper for tests."""
def _fake_signed_url(key: str | None) -> str | None:
if not key:
return None
return f"signed:{key}"
monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
def _ts(hour: int = 12) -> datetime:
return datetime(2024, 1, 1, hour, 0, 0)
def _dummy_model_config():
return SimpleNamespace(
model_dict={"provider": "openai", "name": "gpt-4o"},
pre_prompt="hello",
created_by="config-author",
created_at=_ts(9),
updated_by="config-editor",
updated_at=_ts(10),
)
def _dummy_workflow():
return SimpleNamespace(
id="wf-1",
created_by="workflow-author",
created_at=_ts(8),
updated_by="workflow-editor",
updated_at=_ts(9),
)
def test_app_partial_serialization_uses_aliases():
created_at = _ts()
app_obj = SimpleNamespace(
id="app-1",
name="My App",
desc_or_prompt="Prompt snippet",
mode_compatible_with_agent="chat",
icon_type="image",
icon="icon-key",
icon_background="#fff",
app_model_config=_dummy_model_config(),
workflow=_dummy_workflow(),
created_by="creator",
created_at=created_at,
updated_by="editor",
updated_at=created_at,
tags=[SimpleNamespace(id="tag-1", name="Utilities", type="app")],
access_mode="private",
create_user_name="Creator",
author_name="Author",
has_draft_trigger=True,
)
serialized = AppPartial.model_validate(app_obj, from_attributes=True).model_dump(mode="json")
assert serialized["description"] == "Prompt snippet"
assert serialized["mode"] == "chat"
assert serialized["icon_url"] == "signed:icon-key"
assert serialized["created_at"] == int(created_at.timestamp())
assert serialized["updated_at"] == int(created_at.timestamp())
assert serialized["model_config"]["model"] == {"provider": "openai", "name": "gpt-4o"}
assert serialized["workflow"]["id"] == "wf-1"
assert serialized["tags"][0]["name"] == "Utilities"
def test_app_detail_with_site_includes_nested_serialization():
timestamp = _ts(14)
site = SimpleNamespace(
code="site-code",
title="Public Site",
icon_type="image",
icon="site-icon",
created_at=timestamp,
updated_at=timestamp,
)
app_obj = SimpleNamespace(
id="app-2",
name="Detailed App",
description="Desc",
mode_compatible_with_agent="advanced-chat",
icon_type="image",
icon="detail-icon",
icon_background="#123456",
enable_site=True,
enable_api=True,
app_model_config={
"opening_statement": "hi",
"model": {"provider": "openai", "name": "gpt-4o"},
"retriever_resource": {"enabled": True},
},
workflow=_dummy_workflow(),
tracing={"enabled": True},
use_icon_as_answer_icon=True,
created_by="creator",
created_at=timestamp,
updated_by="editor",
updated_at=timestamp,
access_mode="public",
tags=[SimpleNamespace(id="tag-2", name="Prod", type="app")],
api_base_url="https://api.example.com/v1",
max_active_requests=5,
deleted_tools=[{"type": "api", "tool_name": "search", "provider_id": "prov"}],
site=site,
)
serialized = AppDetailWithSite.model_validate(app_obj, from_attributes=True).model_dump(mode="json")
assert serialized["icon_url"] == "signed:detail-icon"
assert serialized["model_config"]["retriever_resource"] == {"enabled": True}
assert serialized["deleted_tools"][0]["tool_name"] == "search"
assert serialized["site"]["icon_url"] == "signed:site-icon"
assert serialized["site"]["created_at"] == int(timestamp.timestamp())
def test_app_pagination_aliases_per_page_and_has_next():
item_one = SimpleNamespace(
id="app-10",
name="Paginated One",
desc_or_prompt="Summary",
mode_compatible_with_agent="chat",
icon_type="image",
icon="first-icon",
created_at=_ts(15),
updated_at=_ts(15),
)
item_two = SimpleNamespace(
id="app-11",
name="Paginated Two",
desc_or_prompt="Summary",
mode_compatible_with_agent="agent-chat",
icon_type="emoji",
icon="🙂",
created_at=_ts(16),
updated_at=_ts(16),
)
pagination = SimpleNamespace(
page=2,
per_page=10,
total=50,
has_next=True,
items=[item_one, item_two],
)
serialized = AppPagination.model_validate(pagination, from_attributes=True).model_dump(mode="json")
assert serialized["page"] == 2
assert serialized["limit"] == 10
assert serialized["has_more"] is True
assert len(serialized["data"]) == 2
assert serialized["data"][0]["icon_url"] == "signed:first-icon"
assert serialized["data"][1]["icon_url"] is None

View File

@ -1,7 +1,9 @@
import builtins
import io
from unittest.mock import patch
import pytest
from flask.views import MethodView
from werkzeug.exceptions import Forbidden
from controllers.common.errors import (
@ -14,6 +16,9 @@ from controllers.common.errors import (
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
class TestFileUploadSecurity:
"""Test file upload security logic without complex framework setup"""
@ -128,7 +133,7 @@ class TestFileUploadSecurity:
# Test passes if no exception is raised
# Test 4: Service error handling
@patch("services.file_service.FileService.upload_file")
@patch("controllers.console.files.FileService.upload_file")
def test_should_handle_file_too_large_error(self, mock_upload):
"""Test that service FileTooLargeError is properly converted"""
mock_upload.side_effect = ServiceFileTooLargeError("File too large")
@ -140,7 +145,7 @@ class TestFileUploadSecurity:
with pytest.raises(FileTooLargeError):
raise FileTooLargeError(e.description)
@patch("services.file_service.FileService.upload_file")
@patch("controllers.console.files.FileService.upload_file")
def test_should_handle_unsupported_file_type_error(self, mock_upload):
"""Test that service UnsupportedFileTypeError is properly converted"""
mock_upload.side_effect = ServiceUnsupportedFileTypeError()

View File

@ -0,0 +1,3 @@
"""
Mark agent test modules as a package to avoid import name collisions.
"""

View File

@ -0,0 +1,324 @@
"""Tests for AgentPattern base class."""
from decimal import Decimal
from unittest.mock import MagicMock
import pytest
from core.agent.entities import AgentLog, ExecutionContext
from core.agent.patterns.base import AgentPattern
from core.model_runtime.entities.llm_entities import LLMUsage
class ConcreteAgentPattern(AgentPattern):
"""Concrete implementation of AgentPattern for testing."""
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
"""Minimal implementation for testing."""
yield from []
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
@pytest.fixture
def agent_pattern(mock_model_instance, mock_context):
"""Create a concrete agent pattern for testing."""
return ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
max_iterations=10,
)
class TestAccumulateUsage:
"""Tests for _accumulate_usage method."""
def test_accumulate_usage_to_empty_dict(self, agent_pattern):
"""Test accumulating usage to an empty dict creates a copy."""
total_usage: dict = {"usage": None}
delta_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
agent_pattern._accumulate_usage(total_usage, delta_usage)
assert total_usage["usage"] is not None
assert total_usage["usage"].total_tokens == 150
assert total_usage["usage"].prompt_tokens == 100
assert total_usage["usage"].completion_tokens == 50
# Verify it's a copy, not a reference
assert total_usage["usage"] is not delta_usage
def test_accumulate_usage_adds_to_existing(self, agent_pattern):
"""Test accumulating usage adds to existing values."""
initial_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
total_usage: dict = {"usage": initial_usage}
delta_usage = LLMUsage(
prompt_tokens=200,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.2"),
completion_tokens=100,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.2"),
total_tokens=300,
total_price=Decimal("0.4"),
currency="USD",
latency=0.5,
)
agent_pattern._accumulate_usage(total_usage, delta_usage)
assert total_usage["usage"].total_tokens == 450 # 150 + 300
assert total_usage["usage"].prompt_tokens == 300 # 100 + 200
assert total_usage["usage"].completion_tokens == 150 # 50 + 100
def test_accumulate_usage_multiple_rounds(self, agent_pattern):
"""Test accumulating usage across multiple rounds."""
total_usage: dict = {"usage": None}
# Round 1: 100 tokens
round1_usage = LLMUsage(
prompt_tokens=70,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.07"),
completion_tokens=30,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.06"),
total_tokens=100,
total_price=Decimal("0.13"),
currency="USD",
latency=0.3,
)
agent_pattern._accumulate_usage(total_usage, round1_usage)
assert total_usage["usage"].total_tokens == 100
# Round 2: 150 tokens
round2_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.4,
)
agent_pattern._accumulate_usage(total_usage, round2_usage)
assert total_usage["usage"].total_tokens == 250 # 100 + 150
# Round 3: 200 tokens
round3_usage = LLMUsage(
prompt_tokens=130,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.13"),
completion_tokens=70,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.14"),
total_tokens=200,
total_price=Decimal("0.27"),
currency="USD",
latency=0.5,
)
agent_pattern._accumulate_usage(total_usage, round3_usage)
assert total_usage["usage"].total_tokens == 450 # 100 + 150 + 200
class TestCreateLog:
"""Tests for _create_log method."""
def test_create_log_with_label_and_status(self, agent_pattern):
"""Test creating a log with label and status."""
log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"key": "value"},
)
assert log.label == "ROUND 1"
assert log.log_type == AgentLog.LogType.ROUND
assert log.status == AgentLog.LogStatus.START
assert log.data == {"key": "value"}
assert log.parent_id is None
def test_create_log_with_parent_id(self, agent_pattern):
"""Test creating a log with parent_id."""
parent_log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
child_log = agent_pattern._create_log(
label="CALL tool",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={},
parent_id=parent_log.id,
)
assert child_log.parent_id == parent_log.id
assert child_log.log_type == AgentLog.LogType.TOOL_CALL
class TestFinishLog:
"""Tests for _finish_log method."""
def test_finish_log_updates_status(self, agent_pattern):
"""Test that finish_log updates status to SUCCESS."""
log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
finished_log = agent_pattern._finish_log(log, data={"result": "done"})
assert finished_log.status == AgentLog.LogStatus.SUCCESS
assert finished_log.data == {"result": "done"}
def test_finish_log_adds_usage_metadata(self, agent_pattern):
"""Test that finish_log adds usage to metadata."""
log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
finished_log = agent_pattern._finish_log(log, usage=usage)
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_TOKENS] == 150
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_PRICE] == Decimal("0.2")
assert finished_log.metadata[AgentLog.LogMetadata.CURRENCY] == "USD"
assert finished_log.metadata[AgentLog.LogMetadata.LLM_USAGE] == usage
class TestFindToolByName:
"""Tests for _find_tool_by_name method."""
def test_find_existing_tool(self, mock_model_instance, mock_context):
"""Test finding an existing tool by name."""
mock_tool = MagicMock()
mock_tool.entity.identity.name = "test_tool"
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
found_tool = pattern._find_tool_by_name("test_tool")
assert found_tool == mock_tool
def test_find_nonexistent_tool_returns_none(self, mock_model_instance, mock_context):
"""Test that finding a nonexistent tool returns None."""
mock_tool = MagicMock()
mock_tool.entity.identity.name = "test_tool"
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
found_tool = pattern._find_tool_by_name("nonexistent_tool")
assert found_tool is None
class TestMaxIterationsCapping:
"""Tests for max_iterations capping."""
def test_max_iterations_capped_at_99(self, mock_model_instance, mock_context):
"""Test that max_iterations is capped at 99."""
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
max_iterations=150,
)
assert pattern.max_iterations == 99
def test_max_iterations_not_capped_when_under_99(self, mock_model_instance, mock_context):
"""Test that max_iterations is not capped when under 99."""
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
max_iterations=50,
)
assert pattern.max_iterations == 50

View File

@ -0,0 +1,332 @@
"""Tests for FunctionCallStrategy."""
from decimal import Decimal
from unittest.mock import MagicMock
import pytest
from core.agent.entities import AgentLog, ExecutionContext
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
@pytest.fixture
def mock_tool():
"""Create a mock tool."""
tool = MagicMock()
tool.entity.identity.name = "test_tool"
tool.to_prompt_message_tool.return_value = PromptMessageTool(
name="test_tool",
description="A test tool",
parameters={
"type": "object",
"properties": {"param1": {"type": "string", "description": "A parameter"}},
"required": ["param1"],
},
)
return tool
class TestFunctionCallStrategyInit:
"""Tests for FunctionCallStrategy initialization."""
def test_initialization(self, mock_model_instance, mock_context, mock_tool):
"""Test basic initialization."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
max_iterations=10,
)
assert strategy.model_instance == mock_model_instance
assert strategy.context == mock_context
assert strategy.max_iterations == 10
assert len(strategy.tools) == 1
def test_initialization_with_tool_invoke_hook(self, mock_model_instance, mock_context, mock_tool):
"""Test initialization with tool_invoke_hook."""
from core.agent.patterns.function_call import FunctionCallStrategy
mock_hook = MagicMock()
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
tool_invoke_hook=mock_hook,
)
assert strategy.tool_invoke_hook == mock_hook
class TestConvertToolsToPromptFormat:
"""Tests for _convert_tools_to_prompt_format method."""
def test_convert_tools_returns_prompt_message_tools(self, mock_model_instance, mock_context, mock_tool):
"""Test that _convert_tools_to_prompt_format returns PromptMessageTool list."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
tools = strategy._convert_tools_to_prompt_format()
assert len(tools) == 1
assert isinstance(tools[0], PromptMessageTool)
assert tools[0].name == "test_tool"
def test_convert_tools_empty_when_no_tools(self, mock_model_instance, mock_context):
"""Test that _convert_tools_to_prompt_format returns empty list when no tools."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
tools = strategy._convert_tools_to_prompt_format()
assert tools == []
class TestAgentLogGeneration:
"""Tests for AgentLog generation during run."""
def test_round_log_structure(self, mock_model_instance, mock_context, mock_tool):
"""Test that round logs have correct structure."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
max_iterations=1,
)
# Create a round log
round_log = strategy._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"inputs": {"query": "test"}},
)
assert round_log.label == "ROUND 1"
assert round_log.log_type == AgentLog.LogType.ROUND
assert round_log.status == AgentLog.LogStatus.START
assert "inputs" in round_log.data
def test_tool_call_log_structure(self, mock_model_instance, mock_context, mock_tool):
"""Test that tool call logs have correct structure."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
# Create a parent round log
round_log = strategy._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
# Create a tool call log
tool_log = strategy._create_log(
label="CALL test_tool",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={"tool_name": "test_tool", "tool_args": {"param1": "value1"}},
parent_id=round_log.id,
)
assert tool_log.label == "CALL test_tool"
assert tool_log.log_type == AgentLog.LogType.TOOL_CALL
assert tool_log.parent_id == round_log.id
assert tool_log.data["tool_name"] == "test_tool"
class TestToolInvocation:
"""Tests for tool invocation."""
def test_invoke_tool_with_hook(self, mock_model_instance, mock_context, mock_tool):
"""Test that tool invocation uses hook when provided."""
from core.agent.patterns.function_call import FunctionCallStrategy
from core.tools.entities.tool_entities import ToolInvokeMeta
mock_hook = MagicMock()
mock_meta = ToolInvokeMeta(
time_cost=0.5,
error=None,
tool_config={"tool_provider_type": "test", "tool_provider": "test_id"},
)
mock_hook.return_value = ("Tool result", ["file-1"], mock_meta)
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
tool_invoke_hook=mock_hook,
)
result, files, meta = strategy._invoke_tool(mock_tool, {"param1": "value"}, "test_tool")
mock_hook.assert_called_once()
assert result == "Tool result"
assert files == [] # Hook returns file IDs, but _invoke_tool returns empty File list
assert meta == mock_meta
def test_invoke_tool_without_hook_attribute_set(self, mock_model_instance, mock_context, mock_tool):
"""Test that tool_invoke_hook is None when not provided."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
tool_invoke_hook=None,
)
# Verify that tool_invoke_hook is None
assert strategy.tool_invoke_hook is None
class TestUsageTracking:
"""Tests for usage tracking across rounds."""
def test_round_usage_is_separate_from_total(self, mock_model_instance, mock_context):
"""Test that round usage is tracked separately from total."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
# Simulate two rounds of usage
total_usage: dict = {"usage": None}
round1_usage: dict = {"usage": None}
round2_usage: dict = {"usage": None}
# Round 1
usage1 = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
strategy._accumulate_usage(round1_usage, usage1)
strategy._accumulate_usage(total_usage, usage1)
# Round 2
usage2 = LLMUsage(
prompt_tokens=200,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.2"),
completion_tokens=100,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.2"),
total_tokens=300,
total_price=Decimal("0.4"),
currency="USD",
latency=0.5,
)
strategy._accumulate_usage(round2_usage, usage2)
strategy._accumulate_usage(total_usage, usage2)
# Verify round usage is separate
assert round1_usage["usage"].total_tokens == 150
assert round2_usage["usage"].total_tokens == 300
# Verify total is accumulated
assert total_usage["usage"].total_tokens == 450
class TestPromptMessageHandling:
"""Tests for prompt message handling."""
def test_messages_include_system_and_user(self, mock_model_instance, mock_context, mock_tool):
"""Test that messages include system and user prompts."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Hello"),
]
# Just verify the messages can be processed
assert len(messages) == 2
assert isinstance(messages[0], SystemPromptMessage)
assert isinstance(messages[1], UserPromptMessage)
def test_assistant_message_with_tool_calls(self, mock_model_instance, mock_context, mock_tool):
"""Test that assistant messages can contain tool calls."""
from core.model_runtime.entities.message_entities import AssistantPromptMessage
tool_call = AssistantPromptMessage.ToolCall(
id="call_123",
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name="test_tool",
arguments='{"param1": "value1"}',
),
)
assistant_message = AssistantPromptMessage(
content="I'll help you with that.",
tool_calls=[tool_call],
)
assert len(assistant_message.tool_calls) == 1
assert assistant_message.tool_calls[0].function.name == "test_tool"

View File

@ -0,0 +1,224 @@
"""Tests for ReActStrategy."""
from unittest.mock import MagicMock
import pytest
from core.agent.entities import ExecutionContext
from core.agent.patterns.react import ReActStrategy
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
@pytest.fixture
def mock_tool():
"""Create a mock tool."""
from core.model_runtime.entities.message_entities import PromptMessageTool
tool = MagicMock()
tool.entity.identity.name = "test_tool"
tool.entity.identity.provider = "test_provider"
# Use real PromptMessageTool for proper serialization
prompt_tool = PromptMessageTool(
name="test_tool",
description="A test tool",
parameters={"type": "object", "properties": {}},
)
tool.to_prompt_message_tool.return_value = prompt_tool
return tool
class TestReActStrategyInit:
"""Tests for ReActStrategy initialization."""
def test_init_with_instruction(self, mock_model_instance, mock_context):
"""Test that instruction is stored correctly."""
instruction = "You are a helpful assistant."
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
instruction=instruction,
)
assert strategy.instruction == instruction
def test_init_with_empty_instruction(self, mock_model_instance, mock_context):
"""Test that empty instruction is handled correctly."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
assert strategy.instruction == ""
class TestBuildPromptWithReactFormat:
"""Tests for _build_prompt_with_react_format method."""
def test_replace_tools_placeholder(self, mock_model_instance, mock_context, mock_tool):
"""Test that {{tools}} placeholder is replaced."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
system_content = "You have access to: {{tools}}"
messages = [
SystemPromptMessage(content=system_content),
UserPromptMessage(content="Hello"),
]
result = strategy._build_prompt_with_react_format(messages, [], True)
# The tools placeholder should be replaced with JSON
assert "{{tools}}" not in result[0].content
assert "test_tool" in result[0].content
def test_replace_tool_names_placeholder(self, mock_model_instance, mock_context, mock_tool):
"""Test that {{tool_names}} placeholder is replaced."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
system_content = "Valid actions: {{tool_names}}"
messages = [
SystemPromptMessage(content=system_content),
]
result = strategy._build_prompt_with_react_format(messages, [], True)
assert "{{tool_names}}" not in result[0].content
assert '"test_tool"' in result[0].content
def test_replace_instruction_placeholder(self, mock_model_instance, mock_context):
"""Test that {{instruction}} placeholder is replaced."""
instruction = "You are a helpful coding assistant."
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
instruction=instruction,
)
system_content = "{{instruction}}\n\nYou have access to: {{tools}}"
messages = [
SystemPromptMessage(content=system_content),
]
result = strategy._build_prompt_with_react_format(messages, [], True, instruction)
assert "{{instruction}}" not in result[0].content
assert instruction in result[0].content
def test_no_tools_available_message(self, mock_model_instance, mock_context):
"""Test that 'No tools available' is shown when include_tools is False."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
system_content = "You have access to: {{tools}}"
messages = [
SystemPromptMessage(content=system_content),
]
result = strategy._build_prompt_with_react_format(messages, [], False)
assert "No tools available" in result[0].content
def test_scratchpad_appended_as_assistant_message(self, mock_model_instance, mock_context):
"""Test that agent scratchpad is appended as AssistantPromptMessage."""
from core.agent.entities import AgentScratchpadUnit
from core.model_runtime.entities import AssistantPromptMessage
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
messages = [
SystemPromptMessage(content="System prompt"),
UserPromptMessage(content="User query"),
]
scratchpad = [
AgentScratchpadUnit(
thought="I need to search for information",
action_str='{"action": "search", "action_input": "query"}',
observation="Search results here",
)
]
result = strategy._build_prompt_with_react_format(messages, scratchpad, True)
# The last message should be an AssistantPromptMessage with scratchpad content
assert len(result) == 3
assert isinstance(result[-1], AssistantPromptMessage)
assert "I need to search for information" in result[-1].content
assert "Search results here" in result[-1].content
def test_empty_scratchpad_no_extra_message(self, mock_model_instance, mock_context):
"""Test that empty scratchpad doesn't add extra message."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
messages = [
SystemPromptMessage(content="System prompt"),
UserPromptMessage(content="User query"),
]
result = strategy._build_prompt_with_react_format(messages, [], True)
# Should only have the original 2 messages
assert len(result) == 2
def test_original_messages_not_modified(self, mock_model_instance, mock_context):
"""Test that original messages list is not modified."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
original_content = "Original system prompt {{tools}}"
messages = [
SystemPromptMessage(content=original_content),
]
strategy._build_prompt_with_react_format(messages, [], True)
# Original message should not be modified
assert messages[0].content == original_content

View File

@ -0,0 +1,203 @@
"""Tests for StrategyFactory."""
from unittest.mock import MagicMock
import pytest
from core.agent.entities import AgentEntity, ExecutionContext
from core.agent.patterns.function_call import FunctionCallStrategy
from core.agent.patterns.react import ReActStrategy
from core.agent.patterns.strategy_factory import StrategyFactory
from core.model_runtime.entities.model_entities import ModelFeature
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
class TestStrategyFactory:
"""Tests for StrategyFactory.create_strategy method."""
def test_create_function_call_strategy_with_tool_call_feature(self, mock_model_instance, mock_context):
"""Test that FunctionCallStrategy is created when model supports TOOL_CALL."""
model_features = [ModelFeature.TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, FunctionCallStrategy)
def test_create_function_call_strategy_with_multi_tool_call_feature(self, mock_model_instance, mock_context):
"""Test that FunctionCallStrategy is created when model supports MULTI_TOOL_CALL."""
model_features = [ModelFeature.MULTI_TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, FunctionCallStrategy)
def test_create_function_call_strategy_with_stream_tool_call_feature(self, mock_model_instance, mock_context):
"""Test that FunctionCallStrategy is created when model supports STREAM_TOOL_CALL."""
model_features = [ModelFeature.STREAM_TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, FunctionCallStrategy)
def test_create_react_strategy_without_tool_call_features(self, mock_model_instance, mock_context):
"""Test that ReActStrategy is created when model doesn't support tool calling."""
model_features = [ModelFeature.VISION] # Only vision, no tool calling
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, ReActStrategy)
def test_create_react_strategy_with_empty_features(self, mock_model_instance, mock_context):
"""Test that ReActStrategy is created when model has no features."""
model_features: list[ModelFeature] = []
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, ReActStrategy)
def test_explicit_function_calling_strategy_with_support(self, mock_model_instance, mock_context):
"""Test explicit FUNCTION_CALLING strategy selection with model support."""
model_features = [ModelFeature.TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
)
assert isinstance(strategy, FunctionCallStrategy)
def test_explicit_function_calling_strategy_without_support_falls_back_to_react(
self, mock_model_instance, mock_context
):
"""Test that explicit FUNCTION_CALLING falls back to ReAct when not supported."""
model_features: list[ModelFeature] = [] # No tool calling support
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
)
# Should fall back to ReAct since FC is not supported
assert isinstance(strategy, ReActStrategy)
def test_explicit_chain_of_thought_strategy(self, mock_model_instance, mock_context):
"""Test explicit CHAIN_OF_THOUGHT strategy selection."""
model_features = [ModelFeature.TOOL_CALL] # Even with tool call support
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
)
assert isinstance(strategy, ReActStrategy)
def test_react_strategy_with_instruction(self, mock_model_instance, mock_context):
"""Test that ReActStrategy receives instruction parameter."""
model_features: list[ModelFeature] = []
instruction = "You are a helpful assistant."
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
instruction=instruction,
)
assert isinstance(strategy, ReActStrategy)
assert strategy.instruction == instruction
def test_max_iterations_passed_to_strategy(self, mock_model_instance, mock_context):
"""Test that max_iterations is passed to the strategy."""
model_features = [ModelFeature.TOOL_CALL]
max_iterations = 5
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
max_iterations=max_iterations,
)
assert strategy.max_iterations == max_iterations
def test_tool_invoke_hook_passed_to_strategy(self, mock_model_instance, mock_context):
"""Test that tool_invoke_hook is passed to the strategy."""
model_features = [ModelFeature.TOOL_CALL]
mock_hook = MagicMock()
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
tool_invoke_hook=mock_hook,
)
assert strategy.tool_invoke_hook == mock_hook

View File

@ -0,0 +1,388 @@
"""Tests for AgentAppRunner."""
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.llm_entities import LLMUsage
class TestOrganizePromptMessages:
"""Tests for _organize_prompt_messages method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
# We'll patch the class to avoid complex initialization
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
# Set up required attributes
runner.config = MagicMock(spec=AgentEntity)
runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
runner.config.prompt = None
runner.app_config = MagicMock()
runner.app_config.prompt_template = MagicMock()
runner.app_config.prompt_template.simple_prompt_template = "You are a helpful assistant."
runner.history_prompt_messages = []
runner.query = "Hello"
runner._current_thoughts = []
runner.files = []
runner.model_config = MagicMock()
runner.memory = None
runner.application_generate_entity = MagicMock()
runner.application_generate_entity.file_upload_config = None
return runner
def test_function_calling_uses_simple_prompt(self, mock_runner):
"""Test that function calling strategy uses simple_prompt_template."""
mock_runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
with patch.object(mock_runner, "_init_system_message") as mock_init:
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
with patch.object(mock_runner, "_organize_user_query") as mock_query:
mock_query.return_value = [UserPromptMessage(content="Hello")]
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
mock_transform.return_value.get_prompt.return_value = [
SystemPromptMessage(content="You are a helpful assistant.")
]
result = mock_runner._organize_prompt_messages()
# Verify _init_system_message was called with simple_prompt_template
mock_init.assert_called_once()
call_args = mock_init.call_args[0]
assert call_args[0] == "You are a helpful assistant."
def test_chain_of_thought_uses_agent_prompt(self, mock_runner):
"""Test that chain of thought strategy uses agent prompt template."""
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
mock_runner.config.prompt = AgentPromptEntity(
first_prompt="ReAct prompt template with {{tools}}",
next_iteration="Continue...",
)
with patch.object(mock_runner, "_init_system_message") as mock_init:
mock_init.return_value = [SystemPromptMessage(content="ReAct prompt template with {{tools}}")]
with patch.object(mock_runner, "_organize_user_query") as mock_query:
mock_query.return_value = [UserPromptMessage(content="Hello")]
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
mock_transform.return_value.get_prompt.return_value = [
SystemPromptMessage(content="ReAct prompt template with {{tools}}")
]
result = mock_runner._organize_prompt_messages()
# Verify _init_system_message was called with agent prompt
mock_init.assert_called_once()
call_args = mock_init.call_args[0]
assert call_args[0] == "ReAct prompt template with {{tools}}"
def test_chain_of_thought_without_prompt_falls_back(self, mock_runner):
"""Test that chain of thought without prompt falls back to simple_prompt_template."""
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
mock_runner.config.prompt = None
with patch.object(mock_runner, "_init_system_message") as mock_init:
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
with patch.object(mock_runner, "_organize_user_query") as mock_query:
mock_query.return_value = [UserPromptMessage(content="Hello")]
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
mock_transform.return_value.get_prompt.return_value = [
SystemPromptMessage(content="You are a helpful assistant.")
]
result = mock_runner._organize_prompt_messages()
# Verify _init_system_message was called with simple_prompt_template
mock_init.assert_called_once()
call_args = mock_init.call_args[0]
assert call_args[0] == "You are a helpful assistant."
class TestInitSystemMessage:
"""Tests for _init_system_message method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
return runner
def test_empty_messages_with_template(self, mock_runner):
"""Test that system message is created when messages are empty."""
result = mock_runner._init_system_message("System template", [])
assert len(result) == 1
assert isinstance(result[0], SystemPromptMessage)
assert result[0].content == "System template"
def test_empty_messages_without_template(self, mock_runner):
"""Test that empty list is returned when no template and no messages."""
result = mock_runner._init_system_message("", [])
assert result == []
def test_existing_system_message_not_duplicated(self, mock_runner):
"""Test that system message is not duplicated if already present."""
existing_messages = [
SystemPromptMessage(content="Existing system"),
UserPromptMessage(content="User message"),
]
result = mock_runner._init_system_message("New template", existing_messages)
# Should not insert new system message
assert len(result) == 2
assert result[0].content == "Existing system"
def test_system_message_inserted_when_missing(self, mock_runner):
"""Test that system message is inserted when first message is not system."""
existing_messages = [
UserPromptMessage(content="User message"),
]
result = mock_runner._init_system_message("System template", existing_messages)
assert len(result) == 2
assert isinstance(result[0], SystemPromptMessage)
assert result[0].content == "System template"
class TestClearUserPromptImageMessages:
"""Tests for _clear_user_prompt_image_messages method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
return runner
def test_text_content_unchanged(self, mock_runner):
"""Test that text content is unchanged."""
messages = [
UserPromptMessage(content="Plain text message"),
]
result = mock_runner._clear_user_prompt_image_messages(messages)
assert len(result) == 1
assert result[0].content == "Plain text message"
def test_original_messages_not_modified(self, mock_runner):
"""Test that original messages are not modified (deep copy)."""
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
TextPromptMessageContent,
)
messages = [
UserPromptMessage(
content=[
TextPromptMessageContent(data="Text part"),
ImagePromptMessageContent(
data="http://example.com/image.jpg",
format="url",
mime_type="image/jpeg",
),
]
),
]
result = mock_runner._clear_user_prompt_image_messages(messages)
# Original should still have list content
assert isinstance(messages[0].content, list)
# Result should have string content
assert isinstance(result[0].content, str)
class TestToolInvokeHook:
"""Tests for _create_tool_invoke_hook method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
runner.user_id = "test-user"
runner.tenant_id = "test-tenant"
runner.application_generate_entity = MagicMock()
runner.application_generate_entity.trace_manager = None
runner.application_generate_entity.invoke_from = "api"
runner.application_generate_entity.app_config = MagicMock()
runner.application_generate_entity.app_config.app_id = "test-app"
runner.agent_callback = MagicMock()
runner.conversation = MagicMock()
runner.conversation.id = "test-conversation"
runner.queue_manager = MagicMock()
runner._current_message_file_ids = []
return runner
def test_hook_calls_agent_invoke(self, mock_runner):
"""Test that the hook calls ToolEngine.agent_invoke."""
from core.tools.entities.tool_entities import ToolInvokeMeta
mock_message = MagicMock()
mock_message.id = "test-message"
mock_tool = MagicMock()
mock_tool_meta = ToolInvokeMeta(
time_cost=0.5,
error=None,
tool_config={
"tool_provider_type": "test_provider",
"tool_provider": "test_id",
},
)
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
hook = mock_runner._create_tool_invoke_hook(mock_message)
result_content, result_files, result_meta = hook(mock_tool, {"arg": "value"}, "test_tool")
# Verify ToolEngine.agent_invoke was called
mock_engine.agent_invoke.assert_called_once()
# Verify return values
assert result_content == "Tool result"
assert result_files == ["file-1", "file-2"]
assert result_meta == mock_tool_meta
def test_hook_publishes_file_events(self, mock_runner):
"""Test that the hook publishes QueueMessageFileEvent for files."""
from core.tools.entities.tool_entities import ToolInvokeMeta
mock_message = MagicMock()
mock_message.id = "test-message"
mock_tool = MagicMock()
mock_tool_meta = ToolInvokeMeta(
time_cost=0.5,
error=None,
tool_config={},
)
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
hook = mock_runner._create_tool_invoke_hook(mock_message)
hook(mock_tool, {}, "test_tool")
# Verify file events were published
assert mock_runner.queue_manager.publish.call_count == 2
assert mock_runner._current_message_file_ids == ["file-1", "file-2"]
class TestAgentLogProcessing:
"""Tests for AgentLog processing in run method."""
def test_agent_log_status_enum(self):
"""Test AgentLog status enum values."""
assert AgentLog.LogStatus.START == "start"
assert AgentLog.LogStatus.SUCCESS == "success"
assert AgentLog.LogStatus.ERROR == "error"
def test_agent_log_metadata_enum(self):
"""Test AgentLog metadata enum values."""
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
def test_agent_result_structure(self):
"""Test AgentResult structure."""
usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
result = AgentResult(
text="Final answer",
files=[],
usage=usage,
finish_reason="stop",
)
assert result.text == "Final answer"
assert result.files == []
assert result.usage == usage
assert result.finish_reason == "stop"
class TestOrganizeUserQuery:
"""Tests for _organize_user_query method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
runner.files = []
runner.application_generate_entity = MagicMock()
runner.application_generate_entity.file_upload_config = None
return runner
def test_simple_query_without_files(self, mock_runner):
"""Test organizing a simple query without files."""
result = mock_runner._organize_user_query("Hello world", [])
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert result[0].content == "Hello world"
def test_query_with_files(self, mock_runner):
"""Test organizing a query with files."""
from core.file.models import File
mock_file = MagicMock(spec=File)
mock_runner.files = [mock_file]
with patch("core.agent.agent_app_runner.file_manager") as mock_fm:
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
mock_fm.to_prompt_message_content.return_value = ImagePromptMessageContent(
data="http://example.com/image.jpg",
format="url",
mime_type="image/jpeg",
)
result = mock_runner._organize_user_query("Describe this image", [])
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert isinstance(result[0].content, list)
assert len(result[0].content) == 2 # Image + Text

View File

@ -0,0 +1,191 @@
"""Tests for agent entities."""
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentScratchpadUnit, ExecutionContext
class TestExecutionContext:
"""Tests for ExecutionContext entity."""
def test_create_with_all_fields(self):
"""Test creating ExecutionContext with all fields."""
context = ExecutionContext(
user_id="user-123",
app_id="app-456",
conversation_id="conv-789",
message_id="msg-012",
tenant_id="tenant-345",
)
assert context.user_id == "user-123"
assert context.app_id == "app-456"
assert context.conversation_id == "conv-789"
assert context.message_id == "msg-012"
assert context.tenant_id == "tenant-345"
def test_create_minimal(self):
"""Test creating minimal ExecutionContext."""
context = ExecutionContext.create_minimal(user_id="user-123")
assert context.user_id == "user-123"
assert context.app_id is None
assert context.conversation_id is None
assert context.message_id is None
assert context.tenant_id is None
def test_to_dict(self):
"""Test converting ExecutionContext to dictionary."""
context = ExecutionContext(
user_id="user-123",
app_id="app-456",
conversation_id="conv-789",
message_id="msg-012",
tenant_id="tenant-345",
)
result = context.to_dict()
assert result == {
"user_id": "user-123",
"app_id": "app-456",
"conversation_id": "conv-789",
"message_id": "msg-012",
"tenant_id": "tenant-345",
}
def test_with_updates(self):
"""Test creating new context with updates."""
original = ExecutionContext(
user_id="user-123",
app_id="app-456",
)
updated = original.with_updates(message_id="msg-789")
# Original should be unchanged
assert original.message_id is None
# Updated should have new value
assert updated.message_id == "msg-789"
assert updated.user_id == "user-123"
assert updated.app_id == "app-456"
class TestAgentLog:
"""Tests for AgentLog entity."""
def test_create_log_with_required_fields(self):
"""Test creating AgentLog with required fields."""
log = AgentLog(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"key": "value"},
)
assert log.label == "ROUND 1"
assert log.log_type == AgentLog.LogType.ROUND
assert log.status == AgentLog.LogStatus.START
assert log.data == {"key": "value"}
assert log.id is not None # Auto-generated
assert log.parent_id is None
assert log.error is None
def test_log_type_enum(self):
"""Test LogType enum values."""
assert AgentLog.LogType.ROUND == "round"
assert AgentLog.LogType.THOUGHT == "thought"
assert AgentLog.LogType.TOOL_CALL == "tool_call"
def test_log_status_enum(self):
"""Test LogStatus enum values."""
assert AgentLog.LogStatus.START == "start"
assert AgentLog.LogStatus.SUCCESS == "success"
assert AgentLog.LogStatus.ERROR == "error"
def test_log_metadata_enum(self):
"""Test LogMetadata enum values."""
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
class TestAgentScratchpadUnit:
"""Tests for AgentScratchpadUnit entity."""
def test_is_final_with_final_answer_action(self):
"""Test is_final returns True for Final Answer action."""
unit = AgentScratchpadUnit(
thought="I know the answer",
action=AgentScratchpadUnit.Action(
action_name="Final Answer",
action_input="The answer is 42",
),
)
assert unit.is_final() is True
def test_is_final_with_tool_action(self):
"""Test is_final returns False for tool action."""
unit = AgentScratchpadUnit(
thought="I need to search",
action=AgentScratchpadUnit.Action(
action_name="search",
action_input={"query": "test"},
),
)
assert unit.is_final() is False
def test_is_final_with_no_action(self):
"""Test is_final returns True when no action."""
unit = AgentScratchpadUnit(
thought="Just thinking",
)
assert unit.is_final() is True
def test_action_to_dict(self):
"""Test Action.to_dict method."""
action = AgentScratchpadUnit.Action(
action_name="search",
action_input={"query": "test"},
)
result = action.to_dict()
assert result == {
"action": "search",
"action_input": {"query": "test"},
}
class TestAgentEntity:
"""Tests for AgentEntity."""
def test_strategy_enum(self):
"""Test Strategy enum values."""
assert AgentEntity.Strategy.CHAIN_OF_THOUGHT == "chain-of-thought"
assert AgentEntity.Strategy.FUNCTION_CALLING == "function-calling"
def test_create_with_prompt(self):
"""Test creating AgentEntity with prompt."""
prompt = AgentPromptEntity(
first_prompt="You are a helpful assistant.",
next_iteration="Continue thinking...",
)
entity = AgentEntity(
provider="openai",
model="gpt-4",
strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
prompt=prompt,
max_iteration=5,
)
assert entity.provider == "openai"
assert entity.model == "gpt-4"
assert entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT
assert entity.prompt == prompt
assert entity.max_iteration == 5

View File

@ -0,0 +1,390 @@
"""
Tests for AdvancedChatAppGenerateTaskPipeline._handle_node_succeeded_event method,
specifically testing the ANSWER node message_replace logic.
"""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity
from core.app.entities.queue_entities import QueueNodeSucceededEvent
from core.workflow.enums import NodeType
from models import EndUser
from models.model import AppMode
class TestAnswerNodeMessageReplace:
"""Test cases for ANSWER node message_replace event logic."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock(spec=AdvancedChatAppGenerateEntity)
entity.task_id = "test-task-id"
entity.app_id = "test-app-id"
entity.workflow_run_id = "test-workflow-run-id"
# minimal app_config used by pipeline internals
entity.app_config = SimpleNamespace(
tenant_id="test-tenant-id",
app_id="test-app-id",
app_mode=AppMode.ADVANCED_CHAT,
app_model_config_dict={},
additional_features=None,
sensitive_word_avoidance=None,
)
entity.query = "test query"
entity.files = []
entity.extras = {}
entity.trace_manager = None
entity.inputs = {}
entity.invoke_from = "debugger"
return entity
@pytest.fixture
def mock_workflow(self):
"""Create a mock workflow."""
workflow = Mock()
workflow.id = "test-workflow-id"
workflow.features_dict = {}
return workflow
@pytest.fixture
def mock_queue_manager(self):
"""Create a mock queue manager."""
manager = Mock()
manager.listen.return_value = []
manager.graph_runtime_state = None
return manager
@pytest.fixture
def mock_conversation(self):
"""Create a mock conversation."""
conversation = Mock()
conversation.id = "test-conversation-id"
conversation.mode = "advanced_chat"
return conversation
@pytest.fixture
def mock_message(self):
"""Create a mock message."""
message = Mock()
message.id = "test-message-id"
message.query = "test query"
message.created_at = Mock()
message.created_at.timestamp.return_value = 1234567890
return message
@pytest.fixture
def mock_user(self):
"""Create a mock end user."""
user = MagicMock(spec=EndUser)
user.id = "test-user-id"
user.session_id = "test-session-id"
return user
@pytest.fixture
def mock_draft_var_saver_factory(self):
"""Create a mock draft variable saver factory."""
return Mock()
@pytest.fixture
def pipeline(
self,
mock_application_generate_entity,
mock_workflow,
mock_queue_manager,
mock_conversation,
mock_message,
mock_user,
mock_draft_var_saver_factory,
):
"""Create an AdvancedChatAppGenerateTaskPipeline instance with mocked dependencies."""
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
with patch("core.app.apps.advanced_chat.generate_task_pipeline.db"):
pipeline = AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=mock_application_generate_entity,
workflow=mock_workflow,
queue_manager=mock_queue_manager,
conversation=mock_conversation,
message=mock_message,
user=mock_user,
stream=True,
dialogue_count=1,
draft_var_saver_factory=mock_draft_var_saver_factory,
)
# Initialize workflow run id to avoid validation errors
pipeline._workflow_run_id = "test-workflow-run-id"
# Mock the message cycle manager methods we need to track
pipeline._message_cycle_manager.message_replace_to_stream_response = Mock()
return pipeline
def test_answer_node_with_different_output_sends_message_replace(self, pipeline, mock_application_generate_entity):
"""
Test that when an ANSWER node's final output differs from accumulated answer,
a message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "initial answer"
# Create ANSWER node succeeded event with different final output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": "updated final answer"},
)
# Mock the workflow response converter to avoid extra processing
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
responses = list(pipeline._handle_node_succeeded_event(event))
# Assert
assert pipeline._task_state.answer == "updated final answer"
# Verify message_replace was called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
answer="updated final answer", reason="variable_update"
)
def test_answer_node_with_same_output_does_not_send_message_replace(self, pipeline):
"""
Test that when an ANSWER node's final output is the same as accumulated answer,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "same answer"
# Create ANSWER node succeeded event with same output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": "same answer"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "same answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_none_output_does_not_send_message_replace(self, pipeline):
"""
Test that when an ANSWER node's output is None or missing 'answer' key,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event with None output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": None},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_empty_outputs_does_not_send_message_replace(self, pipeline):
"""
Test that when an ANSWER node has empty outputs dict,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event with empty outputs
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_no_answer_key_in_outputs(self, pipeline):
"""
Test that when an ANSWER node's outputs don't contain 'answer' key,
no message_replace event is sent.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event without 'answer' key in outputs
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"other_key": "some value"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_non_answer_node_does_not_send_message_replace(self, pipeline):
"""
Test that non-ANSWER nodes (e.g., LLM, END) don't trigger message_replace events.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Test with LLM node
llm_event = QueueNodeSucceededEvent(
node_execution_id="test-llm-execution-id",
node_id="test-llm-node",
node_type=NodeType.LLM,
start_at=datetime.now(),
outputs={"answer": "different answer"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(llm_event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_end_node_does_not_send_message_replace(self, pipeline):
"""
Test that END nodes don't trigger message_replace events even with 'answer' output.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "existing answer"
# Create END node succeeded event with answer output
event = QueueNodeSucceededEvent(
node_execution_id="test-end-execution-id",
node_id="test-end-node",
node_type=NodeType.END,
start_at=datetime.now(),
outputs={"answer": "different answer"},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should remain unchanged
assert pipeline._task_state.answer == "existing answer"
# Verify message_replace was NOT called
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
def test_answer_node_with_numeric_output_converts_to_string(self, pipeline):
"""
Test that when an ANSWER node's final output is numeric,
it gets converted to string properly.
"""
# Arrange: Set initial accumulated answer
pipeline._task_state.answer = "text answer"
# Create ANSWER node succeeded event with numeric output
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={"answer": 12345},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: answer should be converted to string
assert pipeline._task_state.answer == "12345"
# Verify message_replace was called with string
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
answer="12345", reason="variable_update"
)
def test_answer_node_files_are_recorded(self, pipeline):
"""
Test that ANSWER nodes properly record files from outputs.
"""
# Arrange
pipeline._task_state.answer = "existing answer"
# Create ANSWER node succeeded event with files
event = QueueNodeSucceededEvent(
node_execution_id="test-node-execution-id",
node_id="test-answer-node",
node_type=NodeType.ANSWER,
start_at=datetime.now(),
outputs={
"answer": "same answer",
"files": [
{"type": "image", "transfer_method": "remote_url", "remote_url": "http://example.com/img.png"}
],
},
)
# Mock the workflow response converter
pipeline._workflow_response_converter.fetch_files_from_node_outputs = Mock(return_value=event.outputs["files"])
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
pipeline._save_output_for_event = Mock()
# Act
list(pipeline._handle_node_succeeded_event(event))
# Assert: files should be recorded
assert len(pipeline._recorded_files) == 1
assert pipeline._recorded_files[0] == event.outputs["files"][0]

View File

@ -431,10 +431,10 @@ class TestWorkflowResponseConverterServiceApiTruncation:
description="Explore calls should have truncation enabled",
),
TestCase(
name="published_truncation_enabled",
invoke_from=InvokeFrom.PUBLISHED,
name="published_pipeline_truncation_enabled",
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
expected_truncation_enabled=True,
description="Published app calls should have truncation enabled",
description="Published pipeline calls should have truncation enabled",
),
],
ids=lambda x: x.name,

View File

@ -0,0 +1,47 @@
from unittest.mock import MagicMock
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.workflow.graph_events import NodeRunStreamChunkEvent
from core.workflow.nodes import NodeType
class DummyQueueManager:
def __init__(self) -> None:
self.published = []
def publish(self, event, publish_from: PublishFrom) -> None:
self.published.append((event, publish_from))
def test_skip_empty_final_chunk() -> None:
queue_manager = DummyQueueManager()
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app")
empty_final_event = NodeRunStreamChunkEvent(
id="exec",
node_id="node",
node_type=NodeType.LLM,
selector=["node", "text"],
chunk="",
is_final=True,
)
runner._handle_event(workflow_entry=MagicMock(), event=empty_final_event)
assert queue_manager.published == []
normal_event = NodeRunStreamChunkEvent(
id="exec",
node_id="node",
node_type=NodeType.LLM,
selector=["node", "text"],
chunk="hi",
is_final=False,
)
runner._handle_event(workflow_entry=MagicMock(), event=normal_event)
assert len(queue_manager.published) == 1
published_event, publish_from = queue_manager.published[0]
assert publish_from == PublishFrom.APPLICATION_MANAGER
assert published_event.text == "hi"

View File

@ -0,0 +1,144 @@
from collections.abc import Sequence
from datetime import datetime
from unittest.mock import Mock
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
from core.variables import StringVariable
from core.variables.segments import Segment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.system_variable import SystemVariable
class MockReadOnlyVariablePool:
def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None:
self._variables = variables or {}
def get(self, selector: Sequence[str]) -> Segment | None:
if len(selector) < 2:
return None
return self._variables.get((selector[0], selector[1]))
def get_all_by_node(self, node_id: str) -> dict[str, object]:
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
def get_by_prefix(self, prefix: str) -> dict[str, object]:
return {key: value for (nid, key), value in self._variables.items() if nid == prefix}
def _build_graph_runtime_state(
variable_pool: MockReadOnlyVariablePool,
conversation_id: str | None = None,
) -> ReadOnlyGraphRuntimeState:
graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState)
graph_runtime_state.variable_pool = variable_pool
graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view()
return graph_runtime_state
def _build_node_run_succeeded_event(
*,
node_type: NodeType,
outputs: dict[str, object] | None = None,
process_data: dict[str, object] | None = None,
) -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="node-exec-id",
node_id="assigner",
node_type=node_type,
start_at=datetime.utcnow(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs or {},
process_data=process_data or {},
),
)
def test_persists_conversation_variables_from_assigner_output():
conversation_id = "conv-123"
variable = StringVariable(
id="var-1",
name="name",
value="updated",
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
)
process_data = common_helpers.set_updated_variables(
{}, [common_helpers.variable_to_processed_data(variable.selector, variable)]
)
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
layer.on_event(event)
updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable)
updater.flush.assert_called_once()
def test_skips_when_outputs_missing():
conversation_id = "conv-456"
variable = StringVariable(
id="var-2",
name="name",
value="updated",
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
)
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER)
layer.on_event(event)
updater.update.assert_not_called()
updater.flush.assert_not_called()
def test_skips_non_assigner_nodes():
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.LLM)
layer.on_event(event)
updater.update.assert_not_called()
updater.flush.assert_not_called()
def test_skips_non_conversation_variables():
conversation_id = "conv-789"
non_conversation_variable = StringVariable(
id="var-3",
name="name",
value="updated",
selector=["environment", "name"],
)
process_data = common_helpers.set_updated_variables(
{}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)]
)
variable_pool = MockReadOnlyVariablePool()
updater = Mock()
layer = ConversationVariablePersistenceLayer(updater)
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
layer.on_event(event)
updater.update.assert_not_called()
updater.flush.assert_not_called()

View File

@ -1,4 +1,5 @@
import json
from collections.abc import Sequence
from time import time
from unittest.mock import Mock
@ -67,8 +68,10 @@ class MockReadOnlyVariablePool:
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
self._variables = variables or {}
def get(self, node_id: str, variable_key: str) -> Segment | None:
value = self._variables.get((node_id, variable_key))
def get(self, selector: Sequence[str]) -> Segment | None:
if len(selector) < 2:
return None
value = self._variables.get((selector[0], selector[1]))
if value is None:
return None
mock_segment = Mock(spec=Segment)

View File

@ -1,8 +1,12 @@
"""Primarily used for testing merged cell scenarios"""
import os
import tempfile
from types import SimpleNamespace
from docx import Document
from docx.oxml import OxmlElement
from docx.oxml.ns import qn
import core.rag.extractor.word_extractor as we
from core.rag.extractor.word_extractor import WordExtractor
@ -165,3 +169,110 @@ def test_extract_images_from_docx_uses_internal_files_url():
dify_config.FILES_URL = original_files_url
if original_internal_files_url is not None:
dify_config.INTERNAL_FILES_URL = original_internal_files_url
def test_extract_hyperlinks(monkeypatch):
# Mock db and storage to avoid issues during image extraction (even if no images are present)
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None))
monkeypatch.setattr(we, "db", db_stub)
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
doc = Document()
p = doc.add_paragraph("Visit ")
# Adding a hyperlink manually
r_id = "rId99"
hyperlink = OxmlElement("w:hyperlink")
hyperlink.set(qn("r:id"), r_id)
new_run = OxmlElement("w:r")
t = OxmlElement("w:t")
t.text = "Dify"
new_run.append(t)
hyperlink.append(new_run)
p._p.append(hyperlink)
# Add relationship to the part
doc.part.rels.add_relationship(
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink",
"https://dify.ai",
r_id,
is_external=True,
)
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
doc.save(tmp.name)
tmp_path = tmp.name
try:
extractor = WordExtractor(tmp_path, "tenant_id", "user_id")
docs = extractor.extract()
# Verify modern hyperlink extraction
assert "Visit[Dify](https://dify.ai)" in docs[0].page_content
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
def test_extract_legacy_hyperlinks(monkeypatch):
# Mock db and storage
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None))
monkeypatch.setattr(we, "db", db_stub)
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
doc = Document()
p = doc.add_paragraph()
# Construct a legacy HYPERLINK field:
# 1. w:fldChar (begin)
# 2. w:instrText (HYPERLINK "http://example.com")
# 3. w:fldChar (separate)
# 4. w:r (visible text "Example")
# 5. w:fldChar (end)
run1 = OxmlElement("w:r")
fldCharBegin = OxmlElement("w:fldChar")
fldCharBegin.set(qn("w:fldCharType"), "begin")
run1.append(fldCharBegin)
p._p.append(run1)
run2 = OxmlElement("w:r")
instrText = OxmlElement("w:instrText")
instrText.text = ' HYPERLINK "http://example.com" '
run2.append(instrText)
p._p.append(run2)
run3 = OxmlElement("w:r")
fldCharSep = OxmlElement("w:fldChar")
fldCharSep.set(qn("w:fldCharType"), "separate")
run3.append(fldCharSep)
p._p.append(run3)
run4 = OxmlElement("w:r")
t4 = OxmlElement("w:t")
t4.text = "Example"
run4.append(t4)
p._p.append(run4)
run5 = OxmlElement("w:r")
fldCharEnd = OxmlElement("w:fldChar")
fldCharEnd.set(qn("w:fldCharType"), "end")
run5.append(fldCharEnd)
p._p.append(run5)
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
doc.save(tmp.name)
tmp_path = tmp.name
try:
extractor = WordExtractor(tmp_path, "tenant_id", "user_id")
docs = extractor.extract()
# Verify legacy hyperlink extraction
assert "[Example](http://example.com)" in docs[0].page_content
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import queue
import threading
from unittest import mock
from core.workflow.entities.pause_reason import SchedulingPause
@ -36,6 +37,7 @@ def test_dispatcher_should_consume_remains_events_after_pause():
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=execution_coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
assert event_queue.empty()
@ -96,6 +98,7 @@ def _run_dispatcher_for_event(event) -> int:
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
@ -181,6 +184,7 @@ def test_dispatcher_drain_event_queue():
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()

View File

@ -5,6 +5,8 @@ This module provides a flexible configuration system for customizing
the behavior of mock nodes during testing.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
@ -95,67 +97,67 @@ class MockConfigBuilder:
def __init__(self) -> None:
self._config = MockConfig()
def with_auto_mock(self, enabled: bool = True) -> "MockConfigBuilder":
def with_auto_mock(self, enabled: bool = True) -> MockConfigBuilder:
"""Enable or disable auto-mocking."""
self._config.enable_auto_mock = enabled
return self
def with_delays(self, enabled: bool = True) -> "MockConfigBuilder":
def with_delays(self, enabled: bool = True) -> MockConfigBuilder:
"""Enable or disable simulated execution delays."""
self._config.simulate_delays = enabled
return self
def with_llm_response(self, response: str) -> "MockConfigBuilder":
def with_llm_response(self, response: str) -> MockConfigBuilder:
"""Set default LLM response."""
self._config.default_llm_response = response
return self
def with_agent_response(self, response: str) -> "MockConfigBuilder":
def with_agent_response(self, response: str) -> MockConfigBuilder:
"""Set default agent response."""
self._config.default_agent_response = response
return self
def with_tool_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
def with_tool_response(self, response: dict[str, Any]) -> MockConfigBuilder:
"""Set default tool response."""
self._config.default_tool_response = response
return self
def with_retrieval_response(self, response: str) -> "MockConfigBuilder":
def with_retrieval_response(self, response: str) -> MockConfigBuilder:
"""Set default retrieval response."""
self._config.default_retrieval_response = response
return self
def with_http_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
def with_http_response(self, response: dict[str, Any]) -> MockConfigBuilder:
"""Set default HTTP response."""
self._config.default_http_response = response
return self
def with_template_transform_response(self, response: str) -> "MockConfigBuilder":
def with_template_transform_response(self, response: str) -> MockConfigBuilder:
"""Set default template transform response."""
self._config.default_template_transform_response = response
return self
def with_code_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
def with_code_response(self, response: dict[str, Any]) -> MockConfigBuilder:
"""Set default code execution response."""
self._config.default_code_response = response
return self
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> "MockConfigBuilder":
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> MockConfigBuilder:
"""Set outputs for a specific node."""
self._config.set_node_outputs(node_id, outputs)
return self
def with_node_error(self, node_id: str, error: str) -> "MockConfigBuilder":
def with_node_error(self, node_id: str, error: str) -> MockConfigBuilder:
"""Set error for a specific node."""
self._config.set_node_error(node_id, error)
return self
def with_node_config(self, config: NodeMockConfig) -> "MockConfigBuilder":
def with_node_config(self, config: NodeMockConfig) -> MockConfigBuilder:
"""Add a node-specific configuration."""
self._config.set_node_config(config.node_id, config)
return self
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> "MockConfigBuilder":
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> MockConfigBuilder:
"""Set default configuration for a node type."""
self._config.set_default_config(node_type, config)
return self

View File

@ -0,0 +1,231 @@
"""Tests for ResponseStreamCoordinator object field streaming."""
from unittest.mock import MagicMock
from core.workflow.entities.tool_entities import ToolResultStatus
from core.workflow.enums import NodeType
from core.workflow.graph.graph import Graph
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
from core.workflow.graph_events import (
ChunkType,
NodeRunStreamChunkEvent,
ToolCall,
ToolResult,
)
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.template import Template, VariableSegment
from core.workflow.runtime import VariablePool
class TestResponseCoordinatorObjectStreaming:
"""Test streaming of object-type variables with child fields."""
def test_object_field_streaming(self):
"""Test that when selecting an object variable, all child field streams are forwarded."""
# Create mock graph and variable pool
graph = MagicMock(spec=Graph)
variable_pool = MagicMock(spec=VariablePool)
# Mock nodes
llm_node = MagicMock()
llm_node.id = "llm_node"
llm_node.node_type = NodeType.LLM
llm_node.execution_type = MagicMock()
llm_node.blocks_variable_output = MagicMock(return_value=False)
response_node = MagicMock()
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER
response_node.execution_type = MagicMock()
response_node.blocks_variable_output = MagicMock(return_value=False)
# Mock template for response node
response_node.node_data = MagicMock(spec=BaseNodeData)
response_node.node_data.answer = "{{#llm_node.generation#}}"
graph.nodes = {
"llm_node": llm_node,
"response_node": response_node,
}
graph.root_node = llm_node
graph.get_outgoing_edges = MagicMock(return_value=[])
# Create coordinator
coordinator = ResponseStreamCoordinator(variable_pool, graph)
# Track execution
coordinator.track_node_execution("llm_node", "exec_123")
coordinator.track_node_execution("response_node", "exec_456")
# Simulate streaming events for child fields of generation object
# 1. Content stream
content_event_1 = NodeRunStreamChunkEvent(
id="exec_123",
node_id="llm_node",
node_type=NodeType.LLM,
selector=["llm_node", "generation", "content"],
chunk="Hello",
is_final=False,
chunk_type=ChunkType.TEXT,
)
content_event_2 = NodeRunStreamChunkEvent(
id="exec_123",
node_id="llm_node",
node_type=NodeType.LLM,
selector=["llm_node", "generation", "content"],
chunk=" world",
is_final=True,
chunk_type=ChunkType.TEXT,
)
# 2. Tool call stream
tool_call_event = NodeRunStreamChunkEvent(
id="exec_123",
node_id="llm_node",
node_type=NodeType.LLM,
selector=["llm_node", "generation", "tool_calls"],
chunk='{"query": "test"}',
is_final=True,
chunk_type=ChunkType.TOOL_CALL,
tool_call=ToolCall(
id="call_123",
name="search",
arguments='{"query": "test"}',
),
)
# 3. Tool result stream
tool_result_event = NodeRunStreamChunkEvent(
id="exec_123",
node_id="llm_node",
node_type=NodeType.LLM,
selector=["llm_node", "generation", "tool_results"],
chunk="Found 10 results",
is_final=True,
chunk_type=ChunkType.TOOL_RESULT,
tool_result=ToolResult(
id="call_123",
name="search",
output="Found 10 results",
files=[],
status=ToolResultStatus.SUCCESS,
),
)
# Intercept these events
coordinator.intercept_event(content_event_1)
coordinator.intercept_event(tool_call_event)
coordinator.intercept_event(tool_result_event)
coordinator.intercept_event(content_event_2)
# Verify that all child streams are buffered
assert ("llm_node", "generation", "content") in coordinator._stream_buffers
assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers
assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers
# Verify payloads are preserved in buffered events
buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0]
assert buffered_call.tool_call is not None
assert buffered_call.tool_call.id == "call_123"
buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0]
assert buffered_result.tool_result is not None
assert buffered_result.tool_result.status == "success"
# Verify we can find child streams
child_streams = coordinator._find_child_streams(["llm_node", "generation"])
assert len(child_streams) == 3
assert ("llm_node", "generation", "content") in child_streams
assert ("llm_node", "generation", "tool_calls") in child_streams
assert ("llm_node", "generation", "tool_results") in child_streams
def test_find_child_streams(self):
"""Test the _find_child_streams method."""
graph = MagicMock(spec=Graph)
variable_pool = MagicMock(spec=VariablePool)
coordinator = ResponseStreamCoordinator(variable_pool, graph)
# Add some mock streams
coordinator._stream_buffers = {
("node1", "generation", "content"): [],
("node1", "generation", "tool_calls"): [],
("node1", "generation", "thought"): [],
("node1", "text"): [], # Not a child of generation
("node2", "generation", "content"): [], # Different node
}
# Find children of node1.generation
children = coordinator._find_child_streams(["node1", "generation"])
assert len(children) == 3
assert ("node1", "generation", "content") in children
assert ("node1", "generation", "tool_calls") in children
assert ("node1", "generation", "thought") in children
assert ("node1", "text") not in children
assert ("node2", "generation", "content") not in children
def test_find_child_streams_with_closed_streams(self):
"""Test that _find_child_streams also considers closed streams."""
graph = MagicMock(spec=Graph)
variable_pool = MagicMock(spec=VariablePool)
coordinator = ResponseStreamCoordinator(variable_pool, graph)
# Add some streams - some buffered, some closed
coordinator._stream_buffers = {
("node1", "generation", "content"): [],
}
coordinator._closed_streams = {
("node1", "generation", "tool_calls"),
("node1", "generation", "thought"),
}
# Should find all children regardless of whether they're in buffers or closed
children = coordinator._find_child_streams(["node1", "generation"])
assert len(children) == 3
assert ("node1", "generation", "content") in children
assert ("node1", "generation", "tool_calls") in children
assert ("node1", "generation", "thought") in children
def test_special_selector_rewrites_to_active_response_node(self):
"""Ensure special selectors attribute streams to the active response node."""
graph = MagicMock(spec=Graph)
variable_pool = MagicMock(spec=VariablePool)
response_node = MagicMock()
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER
graph.nodes = {"response_node": response_node}
graph.root_node = response_node
coordinator = ResponseStreamCoordinator(variable_pool, graph)
coordinator.track_node_execution("response_node", "exec_resp")
coordinator._active_session = ResponseSession(
node_id="response_node",
template=Template(segments=[VariableSegment(selector=["sys", "foo"])]),
)
event = NodeRunStreamChunkEvent(
id="stream_1",
node_id="llm_node",
node_type=NodeType.LLM,
selector=["sys", "foo"],
chunk="hi",
is_final=True,
chunk_type=ChunkType.TEXT,
)
coordinator._stream_buffers[("sys", "foo")] = [event]
coordinator._stream_positions[("sys", "foo")] = 0
coordinator._closed_streams.add(("sys", "foo"))
events, is_complete = coordinator._process_variable_segment(VariableSegment(selector=["sys", "foo"]))
assert is_complete
assert len(events) == 1
rewritten = events[0]
assert rewritten.node_id == "response_node"
assert rewritten.id == "exec_resp"

View File

@ -0,0 +1,539 @@
"""
Unit tests for stop_event functionality in GraphEngine.
Tests the unified stop_event management by GraphEngine and its propagation
to WorkerPool, Worker, Dispatcher, and Nodes.
"""
import threading
import time
from unittest.mock import MagicMock, Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
)
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
class TestStopEventPropagation:
"""Test suite for stop_event propagation through GraphEngine components."""
def test_graph_engine_creates_stop_event(self):
"""Test that GraphEngine creates a stop_event on initialization."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Verify stop_event was created
assert engine._stop_event is not None
assert isinstance(engine._stop_event, threading.Event)
# Verify it was set in graph_runtime_state
assert runtime_state.stop_event is not None
assert runtime_state.stop_event is engine._stop_event
def test_stop_event_cleared_on_start(self):
"""Test that stop_event is cleared when execution starts."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Set the stop_event before running
engine._stop_event.set()
assert engine._stop_event.is_set()
# Run the engine (should clear the stop_event)
events = list(engine.run())
# After running, stop_event should be set again (by _stop_execution)
# But during start it was cleared
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
def test_stop_event_set_on_stop(self):
"""Test that stop_event is set when execution stops."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Initially not set
assert not engine._stop_event.is_set()
# Run the engine
list(engine.run())
# After execution completes, stop_event should be set
assert engine._stop_event.is_set()
def test_stop_event_passed_to_worker_pool(self):
"""Test that stop_event is passed to WorkerPool."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Verify WorkerPool has the stop_event
assert engine._worker_pool._stop_event is not None
assert engine._worker_pool._stop_event is engine._stop_event
def test_stop_event_passed_to_dispatcher(self):
"""Test that stop_event is passed to Dispatcher."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Verify Dispatcher has the stop_event
assert engine._dispatcher._stop_event is not None
assert engine._dispatcher._stop_event is engine._stop_event
class TestNodeStopCheck:
"""Test suite for Node._should_stop() functionality."""
def test_node_should_stop_checks_runtime_state(self):
"""Test that Node._should_stop() checks GraphRuntimeState.stop_event."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
# Initially stop_event is not set
assert not answer_node._should_stop()
# Set the stop_event
runtime_state.stop_event.set()
# Now _should_stop should return True
assert answer_node._should_stop()
def test_node_run_checks_stop_event_between_yields(self):
"""Test that Node.run() checks stop_event between yielding events."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
# Create a simple node
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
# Set stop_event BEFORE running the node
runtime_state.stop_event.set()
# Run the node - should yield start event then detect stop
# The node should check stop_event before processing
assert answer_node._should_stop(), "stop_event should be set"
# Run and collect events
events = list(answer_node.run())
# Since stop_event is set at the start, we should get:
# 1. NodeRunStartedEvent (always yielded first)
# 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast)
assert len(events) >= 2
assert isinstance(events[0], NodeRunStartedEvent)
# Note: AnswerNode is very simple and might complete before stop check
# The important thing is that _should_stop() returns True when stop_event is set
assert answer_node._should_stop()
class TestStopEventIntegration:
"""Integration tests for stop_event in workflow execution."""
def test_simple_workflow_respects_stop_event(self):
"""Test that a simple workflow respects stop_event."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
# Create start and answer nodes
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.nodes["answer"] = answer_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Set stop_event before running
runtime_state.stop_event.set()
# Run the engine
events = list(engine.run())
# Should get started event but not succeeded (due to stop)
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
# The workflow should still complete (start node runs quickly)
# but answer node might be cancelled depending on timing
def test_stop_event_with_concurrent_nodes(self):
"""Test stop_event behavior with multiple concurrent nodes."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
# Create multiple nodes
for i in range(3):
answer_node = AnswerNode(
id=f"answer_{i}",
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes[f"answer_{i}"] = answer_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# All nodes should share the same stop_event
for node in mock_graph.nodes.values():
assert node.graph_runtime_state.stop_event is runtime_state.stop_event
assert node.graph_runtime_state.stop_event is engine._stop_event
class TestStopEventTimeoutBehavior:
"""Test stop_event behavior with join timeouts."""
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread")
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
"""Test that Dispatcher uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
dispatcher = engine._dispatcher
dispatcher.start() # This will create and start the mocked thread
mock_thread_instance = mock_thread_cls.return_value
mock_thread_instance.is_alive.return_value = True
dispatcher.stop()
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker")
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
"""Test that WorkerPool uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
worker_pool = engine._worker_pool
worker_pool.start(initial_count=1) # Start with one worker
mock_worker_instance = mock_worker_cls.return_value
mock_worker_instance.is_alive.return_value = True
worker_pool.stop()
mock_worker_instance.join.assert_called_once_with(timeout=2.0)
class TestStopEventResumeBehavior:
"""Test stop_event behavior during workflow resume."""
def test_stop_event_cleared_on_resume(self):
"""Test that stop_event is cleared when resuming a paused workflow."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Simulate a previous execution that set stop_event
engine._stop_event.set()
assert engine._stop_event.is_set()
# Run the engine (should clear stop_event in _start_execution)
events = list(engine.run())
# Execution should complete successfully
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
class TestWorkerStopBehavior:
"""Test Worker behavior with shared stop_event."""
def test_worker_uses_shared_stop_event(self):
"""Test that Worker uses shared stop_event from GraphEngine."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
# Get the worker pool and check workers
worker_pool = engine._worker_pool
# Start the worker pool to create workers
worker_pool.start()
# Check that at least one worker was created
assert len(worker_pool._workers) > 0
# Verify workers use the shared stop_event
for worker in worker_pool._workers:
assert worker._stop_event is engine._stop_event
# Clean up
worker_pool.stop()
def test_worker_stop_is_noop(self):
"""Test that Worker.stop() is now a no-op."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
# Create a mock worker
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.graph_engine.worker import Worker
ready_queue = InMemoryReadyQueue()
event_queue = MagicMock()
# Create a proper mock graph with real dict
mock_graph = Mock(spec=Graph)
mock_graph.nodes = {} # Use real dict
stop_event = threading.Event()
worker = Worker(
ready_queue=ready_queue,
event_queue=event_queue,
graph=mock_graph,
layers=[],
stop_event=stop_event,
)
# Calling stop() should do nothing (no-op)
# and should NOT set the stop_event
worker.stop()
assert not stop_event.is_set()

View File

@ -0,0 +1,328 @@
"""Tests for StreamChunkEvent and its subclasses."""
from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus
from core.workflow.node_events import (
ChunkType,
StreamChunkEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
class TestChunkType:
"""Tests for ChunkType enum."""
def test_chunk_type_values(self):
"""Test that ChunkType has expected values."""
assert ChunkType.TEXT == "text"
assert ChunkType.TOOL_CALL == "tool_call"
assert ChunkType.TOOL_RESULT == "tool_result"
assert ChunkType.THOUGHT == "thought"
def test_chunk_type_is_str_enum(self):
"""Test that ChunkType values are strings."""
for chunk_type in ChunkType:
assert isinstance(chunk_type.value, str)
class TestStreamChunkEvent:
"""Tests for base StreamChunkEvent."""
def test_create_with_required_fields(self):
"""Test creating StreamChunkEvent with required fields."""
event = StreamChunkEvent(
selector=["node1", "text"],
chunk="Hello",
)
assert event.selector == ["node1", "text"]
assert event.chunk == "Hello"
assert event.is_final is False
assert event.chunk_type == ChunkType.TEXT
def test_create_with_all_fields(self):
"""Test creating StreamChunkEvent with all fields."""
event = StreamChunkEvent(
selector=["node1", "output"],
chunk="World",
is_final=True,
chunk_type=ChunkType.TEXT,
)
assert event.selector == ["node1", "output"]
assert event.chunk == "World"
assert event.is_final is True
assert event.chunk_type == ChunkType.TEXT
def test_default_chunk_type_is_text(self):
"""Test that default chunk_type is TEXT."""
event = StreamChunkEvent(
selector=["node1", "text"],
chunk="test",
)
assert event.chunk_type == ChunkType.TEXT
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = StreamChunkEvent(
selector=["node1", "text"],
chunk="Hello",
is_final=True,
)
data = event.model_dump()
assert data["selector"] == ["node1", "text"]
assert data["chunk"] == "Hello"
assert data["is_final"] is True
assert data["chunk_type"] == "text"
class TestToolCallChunkEvent:
"""Tests for ToolCallChunkEvent."""
def test_create_with_required_fields(self):
"""Test creating ToolCallChunkEvent with required fields."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"city": "Beijing"}',
tool_call=ToolCall(id="call_123", name="weather", arguments=None),
)
assert event.selector == ["node1", "tool_calls"]
assert event.chunk == '{"city": "Beijing"}'
assert event.tool_call.id == "call_123"
assert event.tool_call.name == "weather"
assert event.chunk_type == ChunkType.TOOL_CALL
def test_chunk_type_is_tool_call(self):
"""Test that chunk_type is always TOOL_CALL."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk="",
tool_call=ToolCall(id="call_123", name="test_tool", arguments=None),
)
assert event.chunk_type == ChunkType.TOOL_CALL
def test_tool_arguments_field(self):
"""Test tool_arguments field."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"param": "value"}',
tool_call=ToolCall(
id="call_123",
name="test_tool",
arguments='{"param": "value"}',
),
)
assert event.tool_call.arguments == '{"param": "value"}'
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"city": "Beijing"}',
tool_call=ToolCall(
id="call_123",
name="weather",
arguments='{"city": "Beijing"}',
),
is_final=True,
)
data = event.model_dump()
assert data["chunk_type"] == "tool_call"
assert data["tool_call"]["id"] == "call_123"
assert data["tool_call"]["name"] == "weather"
assert data["tool_call"]["arguments"] == '{"city": "Beijing"}'
assert data["is_final"] is True
class TestToolResultChunkEvent:
"""Tests for ToolResultChunkEvent."""
def test_create_with_required_fields(self):
"""Test creating ToolResultChunkEvent with required fields."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="Weather: Sunny, 25°C",
tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"),
)
assert event.selector == ["node1", "tool_results"]
assert event.chunk == "Weather: Sunny, 25°C"
assert event.tool_result.id == "call_123"
assert event.tool_result.name == "weather"
assert event.chunk_type == ChunkType.TOOL_RESULT
def test_chunk_type_is_tool_result(self):
"""Test that chunk_type is always TOOL_RESULT."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_result=ToolResult(id="call_123", name="test_tool"),
)
assert event.chunk_type == ChunkType.TOOL_RESULT
def test_tool_files_default_empty(self):
"""Test that tool_files defaults to empty list."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_result=ToolResult(id="call_123", name="test_tool"),
)
assert event.tool_result.files == []
def test_tool_files_with_values(self):
"""Test tool_files with file IDs."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_result=ToolResult(
id="call_123",
name="test_tool",
files=["file_1", "file_2"],
),
)
assert event.tool_result.files == ["file_1", "file_2"]
def test_tool_error_output(self):
"""Test error output captured in tool_result."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="",
tool_result=ToolResult(
id="call_123",
name="test_tool",
output="Tool execution failed",
status=ToolResultStatus.ERROR,
),
)
assert event.tool_result.output == "Tool execution failed"
assert event.tool_result.status == ToolResultStatus.ERROR
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="Weather: Sunny",
tool_result=ToolResult(
id="call_123",
name="weather",
output="Weather: Sunny",
files=["file_1"],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
data = event.model_dump()
assert data["chunk_type"] == "tool_result"
assert data["tool_result"]["id"] == "call_123"
assert data["tool_result"]["name"] == "weather"
assert data["tool_result"]["files"] == ["file_1"]
assert data["is_final"] is True
class TestThoughtChunkEvent:
"""Tests for ThoughtChunkEvent."""
def test_create_with_required_fields(self):
"""Test creating ThoughtChunkEvent with required fields."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="I need to query the weather...",
)
assert event.selector == ["node1", "thought"]
assert event.chunk == "I need to query the weather..."
assert event.chunk_type == ChunkType.THOUGHT
def test_chunk_type_is_thought(self):
"""Test that chunk_type is always THOUGHT."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="thinking...",
)
assert event.chunk_type == ChunkType.THOUGHT
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="I need to analyze this...",
is_final=False,
)
data = event.model_dump()
assert data["chunk_type"] == "thought"
assert data["chunk"] == "I need to analyze this..."
assert data["is_final"] is False
class TestEventInheritance:
"""Tests for event inheritance relationships."""
def test_tool_call_is_stream_chunk(self):
"""Test that ToolCallChunkEvent is a StreamChunkEvent."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk="",
tool_call=ToolCall(id="call_123", name="test", arguments=None),
)
assert isinstance(event, StreamChunkEvent)
def test_tool_result_is_stream_chunk(self):
"""Test that ToolResultChunkEvent is a StreamChunkEvent."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_result=ToolResult(id="call_123", name="test"),
)
assert isinstance(event, StreamChunkEvent)
def test_thought_is_stream_chunk(self):
"""Test that ThoughtChunkEvent is a StreamChunkEvent."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="thinking...",
)
assert isinstance(event, StreamChunkEvent)
def test_all_events_have_common_fields(self):
"""Test that all events have common StreamChunkEvent fields."""
events = [
StreamChunkEvent(selector=["n", "t"], chunk="a"),
ToolCallChunkEvent(
selector=["n", "t"],
chunk="b",
tool_call=ToolCall(id="1", name="t", arguments=None),
),
ToolResultChunkEvent(
selector=["n", "t"],
chunk="c",
tool_result=ToolResult(id="1", name="t"),
),
ThoughtChunkEvent(selector=["n", "t"], chunk="d"),
]
for event in events:
assert hasattr(event, "selector")
assert hasattr(event, "chunk")
assert hasattr(event, "is_final")
assert hasattr(event, "chunk_type")

View File

@ -78,7 +78,7 @@ class TestFileSaverImpl:
file_binary=_PNG_DATA,
mimetype=mime_type,
)
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True)
def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
_TEST_URL = "https://example.com/image.png"

View File

@ -0,0 +1,148 @@
import types
from collections.abc import Generator
from typing import Any
import pytest
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities import ToolCallResult
from core.workflow.entities.tool_entities import ToolResultStatus
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase
from core.workflow.nodes.llm.node import LLMNode
class _StubModelInstance:
"""Minimal stub to satisfy _stream_llm_events signature."""
provider_model_bundle = None
def _drain(generator: Generator[NodeEventBase, None, Any]):
events: list = []
try:
while True:
events.append(next(generator))
except StopIteration as exc:
return events, exc.value
@pytest.fixture(autouse=True)
def patch_deduct_llm_quota(monkeypatch):
# Avoid touching real quota logic during unit tests
monkeypatch.setattr("core.workflow.nodes.llm.node.llm_utils.deduct_llm_quota", lambda **_: None)
def _make_llm_node(reasoning_format: str) -> LLMNode:
node = LLMNode.__new__(LLMNode)
object.__setattr__(node, "_node_data", types.SimpleNamespace(reasoning_format=reasoning_format, tools=[]))
object.__setattr__(node, "tenant_id", "tenant")
return node
def test_stream_llm_events_extracts_reasoning_for_tagged():
node = _make_llm_node(reasoning_format="tagged")
tagged_text = "<think>Thought</think>Answer"
usage = LLMUsage.empty_usage()
def generator():
yield ModelInvokeCompletedEvent(
text=tagged_text,
usage=usage,
finish_reason="stop",
reasoning_content="",
structured_output=None,
)
events, returned = _drain(
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
)
assert events == []
clean_text, reasoning_content, gen_reasoning, gen_clean, ret_usage, finish_reason, structured, gen_data = returned
assert clean_text == tagged_text # original preserved for output
assert reasoning_content == "" # tagged mode keeps reasoning separate
assert gen_clean == "Answer" # stripped content for generation
assert gen_reasoning == "Thought" # reasoning extracted from <think> tag
assert ret_usage == usage
assert finish_reason == "stop"
assert structured is None
assert gen_data is None
# generation building should include reasoning and sequence
generation_content = gen_clean or clean_text
sequence = [
{"type": "reasoning", "index": 0},
{"type": "content", "start": 0, "end": len(generation_content)},
]
assert sequence == [
{"type": "reasoning", "index": 0},
{"type": "content", "start": 0, "end": len("Answer")},
]
def test_stream_llm_events_no_reasoning_results_in_empty_sequence():
node = _make_llm_node(reasoning_format="tagged")
plain_text = "Hello world"
usage = LLMUsage.empty_usage()
def generator():
yield ModelInvokeCompletedEvent(
text=plain_text,
usage=usage,
finish_reason=None,
reasoning_content="",
structured_output=None,
)
events, returned = _drain(
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
)
assert events == []
_, _, gen_reasoning, gen_clean, *_ = returned
generation_content = gen_clean or plain_text
assert gen_reasoning == ""
assert generation_content == plain_text
# Empty reasoning should imply empty sequence in generation construction
sequence = []
assert sequence == []
def test_serialize_tool_call_strips_files_to_ids():
file_cls = pytest.importorskip("core.file").File
file_type = pytest.importorskip("core.file.enums").FileType
transfer_method = pytest.importorskip("core.file.enums").FileTransferMethod
file_with_id = file_cls(
id="f1",
tenant_id="t",
type=file_type.IMAGE,
transfer_method=transfer_method.REMOTE_URL,
remote_url="http://example.com/f1",
storage_key="k1",
)
file_with_related = file_cls(
id=None,
tenant_id="t",
type=file_type.IMAGE,
transfer_method=transfer_method.REMOTE_URL,
related_id="rel2",
remote_url="http://example.com/f2",
storage_key="k2",
)
tool_call = ToolCallResult(
id="tc",
name="do",
arguments='{"a":1}',
output="ok",
files=[file_with_id, file_with_related],
status=ToolResultStatus.SUCCESS,
)
serialized = LLMNode._serialize_tool_call(tool_call)
assert serialized["files"] == ["f1", "rel2"]
assert serialized["id"] == "tc"
assert serialized["name"] == "do"
assert serialized["arguments"] == '{"a":1}'
assert serialized["output"] == "ok"

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import sys
import types
from collections.abc import Generator
@ -21,7 +23,7 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only
@pytest.fixture
def tool_node(monkeypatch) -> "ToolNode":
def tool_node(monkeypatch) -> ToolNode:
module_name = "core.ops.ops_trace_manager"
if module_name not in sys.modules:
ops_stub = types.ModuleType(module_name)
@ -85,7 +87,7 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
return events, stop.value
def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
def _identity_transform(messages, *_args, **_kwargs):
return messages
@ -103,7 +105,7 @@ def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[l
return _collect_events(generator)
def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
def test_link_messages_with_file_populate_files_output(tool_node: ToolNode):
file_obj = File(
tenant_id="tenant-id",
type=FileType.DOCUMENT,
@ -139,7 +141,7 @@ def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
assert files_segment.value == [file_obj]
def test_plain_link_messages_remain_links(tool_node: "ToolNode"):
def test_plain_link_messages_remain_links(tool_node: ToolNode):
message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=ToolInvokeMessage.TextMessage(text="https://dify.ai"),

View File

@ -1,14 +1,14 @@
import time
import uuid
from unittest import mock
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable, StringVariable
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
from core.workflow.runtime import GraphRuntimeState, VariablePool
@ -86,9 +86,6 @@ def test_overwrite_string_variable():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
@ -104,20 +101,14 @@ def test_overwrite_string_variable():
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=input_variable.value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
events = list(node.run())
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
assert updated_variables is not None
assert updated_variables[0].name == conversation_variable.name
assert updated_variables[0].new_value == input_variable.value
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@ -191,9 +182,6 @@ def test_append_variable_to_array():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
@ -209,22 +197,14 @@ def test_append_variable_to_array():
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=expected_value,
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
events = list(node.run())
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
assert updated_variables is not None
assert updated_variables[0].name == conversation_variable.name
assert updated_variables[0].new_value == ["the first value", "the second value"]
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
@ -287,9 +267,6 @@ def test_clear_array():
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
@ -305,20 +282,14 @@ def test_clear_array():
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,
name=conversation_variable.name,
description=conversation_variable.description,
selector=conversation_variable.selector,
value_type=conversation_variable.value_type,
value=[],
)
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
mock_conv_var_updater.flush.assert_called_once()
events = list(node.run())
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
assert updated_variables is not None
assert updated_variables[0].name == conversation_variable.name
assert updated_variables[0].new_value == []
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None

View File

@ -390,3 +390,42 @@ def test_remove_last_from_empty_array():
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == []
def test_node_factory_creates_variable_assigner_node():
graph_config = {
"edges": [],
"nodes": [
{
"data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []},
"id": "assigner",
},
],
}
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(conversation_id="conversation_id"),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(graph_config["nodes"][0])
assert isinstance(node, VariableAssignerNode)

View File

@ -0,0 +1,78 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
from fields.file_fields import FileResponse, FileWithSignedUrl, RemoteFileInfo, UploadConfig
def test_file_response_serializes_datetime() -> None:
created_at = datetime(2024, 1, 1, 12, 0, 0)
file_obj = SimpleNamespace(
id="file-1",
name="example.txt",
size=1024,
extension="txt",
mime_type="text/plain",
created_by="user-1",
created_at=created_at,
preview_url="https://preview",
source_url="https://source",
original_url="https://origin",
user_id="user-1",
tenant_id="tenant-1",
conversation_id="conv-1",
file_key="key-1",
)
serialized = FileResponse.model_validate(file_obj, from_attributes=True).model_dump(mode="json")
assert serialized["id"] == "file-1"
assert serialized["created_at"] == int(created_at.timestamp())
assert serialized["preview_url"] == "https://preview"
assert serialized["source_url"] == "https://source"
assert serialized["original_url"] == "https://origin"
assert serialized["user_id"] == "user-1"
assert serialized["tenant_id"] == "tenant-1"
assert serialized["conversation_id"] == "conv-1"
assert serialized["file_key"] == "key-1"
def test_file_with_signed_url_builds_payload() -> None:
payload = FileWithSignedUrl(
id="file-2",
name="remote.pdf",
size=2048,
extension="pdf",
url="https://signed",
mime_type="application/pdf",
created_by="user-2",
created_at=datetime(2024, 1, 2, 0, 0, 0),
)
dumped = payload.model_dump(mode="json")
assert dumped["url"] == "https://signed"
assert dumped["created_at"] == int(datetime(2024, 1, 2, 0, 0, 0).timestamp())
def test_remote_file_info_and_upload_config() -> None:
info = RemoteFileInfo(file_type="text/plain", file_length=123)
assert info.model_dump(mode="json") == {"file_type": "text/plain", "file_length": 123}
config = UploadConfig(
file_size_limit=1,
batch_count_limit=2,
file_upload_limit=3,
image_file_size_limit=4,
video_file_size_limit=5,
audio_file_size_limit=6,
workflow_file_upload_limit=7,
image_file_batch_limit=8,
single_chunk_attachment_limit=9,
attachment_image_file_size_limit=10,
)
dumped = config.model_dump(mode="json")
assert dumped["file_upload_limit"] == 3
assert dumped["attachment_image_file_size_limit"] == 10

View File

@ -1,6 +1,6 @@
import pytest
from libs.helper import extract_tenant_id
from libs.helper import escape_like_pattern, extract_tenant_id
from models.account import Account
from models.model import EndUser
@ -63,3 +63,51 @@ class TestExtractTenantId:
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
extract_tenant_id(dict_user)
class TestEscapeLikePattern:
"""Test cases for the escape_like_pattern utility function."""
def test_escape_percent_character(self):
"""Test escaping percent character."""
result = escape_like_pattern("50% discount")
assert result == "50\\% discount"
def test_escape_underscore_character(self):
"""Test escaping underscore character."""
result = escape_like_pattern("test_data")
assert result == "test\\_data"
def test_escape_backslash_character(self):
"""Test escaping backslash character."""
result = escape_like_pattern("path\\to\\file")
assert result == "path\\\\to\\\\file"
def test_escape_combined_special_characters(self):
"""Test escaping multiple special characters together."""
result = escape_like_pattern("file_50%\\path")
assert result == "file\\_50\\%\\\\path"
def test_escape_empty_string(self):
"""Test escaping empty string returns empty string."""
result = escape_like_pattern("")
assert result == ""
def test_escape_none_handling(self):
"""Test escaping None returns None (falsy check handles it)."""
# The function checks `if not pattern`, so None is falsy and returns as-is
result = escape_like_pattern(None)
assert result is None
def test_escape_normal_string_no_change(self):
"""Test that normal strings without special characters are unchanged."""
result = escape_like_pattern("normal text")
assert result == "normal text"
def test_escape_order_matters(self):
"""Test that backslash is escaped first to prevent double escaping."""
# If we escape % first, then escape \, we might get wrong results
# This test ensures the order is correct: \ first, then % and _
result = escape_like_pattern("test\\%_value")
# Should be: test\\\%\_value
assert result == "test\\\\\\%\\_value"

View File

@ -114,7 +114,7 @@ class TestAppModelValidation:
def test_icon_type_validation(self):
"""Test icon type enum values."""
# Assert
assert {t.value for t in IconType} == {"image", "emoji"}
assert {t.value for t in IconType} == {"image", "emoji", "link"}
def test_app_desc_or_prompt_with_description(self):
"""Test desc_or_prompt property when description exists."""