mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
Merge branch 'feat/agent-node-v2' into feat/support-agent-sandbox
This commit is contained in:
@ -444,6 +444,78 @@ class TestAnnotationService:
|
||||
assert total == 1
|
||||
assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content
|
||||
|
||||
def test_get_annotation_list_by_app_id_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _, \) in keyword are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create annotations with special characters in content
|
||||
annotation_with_percent = {
|
||||
"question": "Question with 50% discount",
|
||||
"answer": "Answer about 50% discount offer",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id)
|
||||
|
||||
annotation_with_underscore = {
|
||||
"question": "Question with test_data",
|
||||
"answer": "Answer about test_data value",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id)
|
||||
|
||||
annotation_with_backslash = {
|
||||
"question": "Question with path\\to\\file",
|
||||
"answer": "Answer about path\\to\\file location",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id)
|
||||
|
||||
# Create annotation that should NOT match (contains % but as part of different text)
|
||||
annotation_no_match = {
|
||||
"question": "Question with 100% different",
|
||||
"answer": "Answer about 100% different content",
|
||||
}
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id)
|
||||
|
||||
# Test 1: Search with % character - should find exact match only
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="50%"
|
||||
)
|
||||
assert total == 1
|
||||
assert len(annotation_list) == 1
|
||||
assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content
|
||||
|
||||
# Test 2: Search with _ character - should find exact match only
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="test_data"
|
||||
)
|
||||
assert total == 1
|
||||
assert len(annotation_list) == 1
|
||||
assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content
|
||||
|
||||
# Test 3: Search with \ character - should find exact match only
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="path\\to\\file"
|
||||
)
|
||||
assert total == 1
|
||||
assert len(annotation_list) == 1
|
||||
assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content
|
||||
|
||||
# Test 4: Search with % should NOT match 100% (verifies escaping works)
|
||||
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
|
||||
app.id, page=1, limit=10, keyword="50%"
|
||||
)
|
||||
# Should only find the 50% annotation, not the 100% one
|
||||
assert total == 1
|
||||
assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list)
|
||||
|
||||
def test_get_annotation_list_by_app_id_app_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
@ -7,7 +7,9 @@ from constants.model_template import default_app_templates
|
||||
from models import Account
|
||||
from models.model import App, Site
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
|
||||
|
||||
class TestAppService:
|
||||
@ -71,6 +73,9 @@ class TestAppService:
|
||||
}
|
||||
|
||||
# Create app
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -109,6 +114,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Test different app modes
|
||||
@ -159,6 +167,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
created_app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -194,6 +205,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create multiple apps
|
||||
@ -245,6 +259,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with different modes
|
||||
@ -315,6 +332,9 @@ class TestAppService:
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create an app
|
||||
@ -392,6 +412,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -458,6 +481,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -508,6 +534,9 @@ class TestAppService:
|
||||
"icon_background": "#45B7D1",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -562,6 +591,9 @@ class TestAppService:
|
||||
"icon_background": "#74B9FF",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -617,6 +649,9 @@ class TestAppService:
|
||||
"icon_background": "#A29BFE",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -672,6 +707,9 @@ class TestAppService:
|
||||
"icon_background": "#FD79A8",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -720,6 +758,9 @@ class TestAppService:
|
||||
"icon_background": "#E17055",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -768,6 +809,9 @@ class TestAppService:
|
||||
"icon_background": "#00B894",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -826,6 +870,9 @@ class TestAppService:
|
||||
"icon_background": "#6C5CE7",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -862,6 +909,9 @@ class TestAppService:
|
||||
"icon_background": "#FDCB6E",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -899,6 +949,9 @@ class TestAppService:
|
||||
"icon_background": "#E84393",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -947,8 +1000,132 @@ class TestAppService:
|
||||
"icon_background": "#D63031",
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Attempt to create app with invalid mode
|
||||
with pytest.raises(ValueError, match="invalid mode value"):
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
def test_get_apps_with_special_characters_in_name(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test app retrieval with special characters in name search to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _, \) in name search are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant first
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
|
||||
# Create apps with special characters in names
|
||||
app_with_percent = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "App with 50% discount",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
app_with_underscore = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "test_data_app",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
app_with_backslash = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "path\\to\\app",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
# Create app that should NOT match
|
||||
app_no_match = app_service.create_app(
|
||||
tenant.id,
|
||||
{
|
||||
"name": "100% different",
|
||||
"description": fake.text(max_nb_chars=100),
|
||||
"mode": "chat",
|
||||
"icon_type": "emoji",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FF6B6B",
|
||||
"api_rph": 100,
|
||||
"api_rpm": 10,
|
||||
},
|
||||
account,
|
||||
)
|
||||
|
||||
# Test 1: Search with % character
|
||||
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "App with 50% discount"
|
||||
|
||||
# Test 2: Search with _ character
|
||||
args = {"name": "test_data", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "test_data_app"
|
||||
|
||||
# Test 3: Search with \ character
|
||||
args = {"name": "path\\to\\app", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert len(paginated_apps.items) == 1
|
||||
assert paginated_apps.items[0].name == "path\\to\\app"
|
||||
|
||||
# Test 4: Search with % should NOT match 100% (verifies escaping works)
|
||||
args = {"name": "50%", "mode": "chat", "page": 1, "limit": 10}
|
||||
paginated_apps = app_service.get_paginate_apps(account.id, tenant.id, args)
|
||||
assert paginated_apps is not None
|
||||
assert paginated_apps.total == 1
|
||||
assert all("50%" in app.name for app in paginated_apps.items)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from unittest.mock import create_autospec, patch
|
||||
|
||||
import pytest
|
||||
@ -312,6 +313,85 @@ class TestTagService:
|
||||
result_no_match = TagService.get_tags("app", tenant.id, keyword="nonexistent")
|
||||
assert len(result_no_match) == 0
|
||||
|
||||
def test_get_tags_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test tag retrieval with special characters in keyword to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _, \) in keyword are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create tags with special characters in names
|
||||
tag_with_percent = Tag(
|
||||
name="50% discount",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_percent.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_percent)
|
||||
|
||||
tag_with_underscore = Tag(
|
||||
name="test_data_tag",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_underscore.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_underscore)
|
||||
|
||||
tag_with_backslash = Tag(
|
||||
name="path\\to\\tag",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_backslash.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_backslash)
|
||||
|
||||
# Create tag that should NOT match
|
||||
tag_no_match = Tag(
|
||||
name="100% different",
|
||||
type="app",
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_no_match.id = str(uuid.uuid4())
|
||||
db.session.add(tag_no_match)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Act & Assert: Test 1 - Search with % character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="50%")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "50% discount"
|
||||
|
||||
# Test 2 - Search with _ character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="test_data")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "test_data_tag"
|
||||
|
||||
# Test 3 - Search with \ character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="path\\to\\tag")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "path\\to\\tag"
|
||||
|
||||
# Test 4 - Search with % should NOT match 100% (verifies escaping works)
|
||||
result = TagService.get_tags("app", tenant.id, keyword="50%")
|
||||
assert len(result) == 1
|
||||
assert all("50%" in item.name for item in result)
|
||||
|
||||
def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test tag retrieval when no tags exist.
|
||||
|
||||
@ -10,7 +10,9 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
|
||||
from models.enums import CreatorUserRole
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.app_service import AppService
|
||||
|
||||
# Delay import of AppService to avoid circular dependency
|
||||
# from services.app_service import AppService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
@ -86,6 +88,9 @@ class TestWorkflowAppService:
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -147,6 +152,9 @@ class TestWorkflowAppService:
|
||||
"api_rpm": 10,
|
||||
}
|
||||
|
||||
# Import here to avoid circular dependency
|
||||
from services.app_service import AppService
|
||||
|
||||
app_service = AppService()
|
||||
app = app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
@ -308,6 +316,156 @@ class TestWorkflowAppService:
|
||||
assert result_no_match["total"] == 0
|
||||
assert len(result_no_match["data"]) == 0
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention.
|
||||
|
||||
This test verifies:
|
||||
- Special characters (%, _) in keyword are properly escaped
|
||||
- Search treats special characters as literal characters, not wildcards
|
||||
- SQL injection via LIKE wildcards is prevented
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = WorkflowAppService()
|
||||
|
||||
# Test 1: Search with % character
|
||||
workflow_run_1 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
status="succeeded",
|
||||
inputs=json.dumps({"search_term": "50% discount", "input2": "other_value"}),
|
||||
outputs=json.dumps({"result": "50% discount applied", "status": "success"}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_1)
|
||||
db.session.flush()
|
||||
|
||||
workflow_app_log_1 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_1.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
workflow_app_log_1.id = str(uuid.uuid4())
|
||||
workflow_app_log_1.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_1)
|
||||
db.session.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
|
||||
)
|
||||
# Should find the workflow_run_1 entry
|
||||
assert result["total"] >= 1
|
||||
assert len(result["data"]) >= 1
|
||||
assert any(log.workflow_run_id == workflow_run_1.id for log in result["data"])
|
||||
|
||||
# Test 2: Search with _ character
|
||||
workflow_run_2 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
status="succeeded",
|
||||
inputs=json.dumps({"search_term": "test_data_value", "input2": "other_value"}),
|
||||
outputs=json.dumps({"result": "test_data_value found", "status": "success"}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_2)
|
||||
db.session.flush()
|
||||
|
||||
workflow_app_log_2 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_2.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
workflow_app_log_2.id = str(uuid.uuid4())
|
||||
workflow_app_log_2.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_2)
|
||||
db.session.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20
|
||||
)
|
||||
# Should find the workflow_run_2 entry
|
||||
assert result["total"] >= 1
|
||||
assert len(result["data"]) >= 1
|
||||
assert any(log.workflow_run_id == workflow_run_2.id for log in result["data"])
|
||||
|
||||
# Test 3: Search with % should NOT match 100% (verifies escaping works correctly)
|
||||
workflow_run_4 = WorkflowRun(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="1.0.0",
|
||||
graph=json.dumps({"nodes": [], "edges": []}),
|
||||
status="succeeded",
|
||||
inputs=json.dumps({"search_term": "100% different", "input2": "other_value"}),
|
||||
outputs=json.dumps({"result": "100% different result", "status": "success"}),
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_4)
|
||||
db.session.flush()
|
||||
|
||||
workflow_app_log_4 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_run_id=workflow_run_4.id,
|
||||
created_from="service-api",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=account.id,
|
||||
)
|
||||
workflow_app_log_4.id = str(uuid.uuid4())
|
||||
workflow_app_log_4.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_4)
|
||||
db.session.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
|
||||
)
|
||||
# Should only find the 50% entry (workflow_run_1), not the 100% entry (workflow_run_4)
|
||||
# This verifies that escaping works correctly - 50% should not match 100%
|
||||
assert result["total"] >= 1
|
||||
assert len(result["data"]) >= 1
|
||||
# Verify that we found workflow_run_1 (50% discount) but not workflow_run_4 (100% different)
|
||||
found_run_ids = [log.workflow_run_id for log in result["data"]]
|
||||
assert workflow_run_1.id in found_run_ids
|
||||
assert workflow_run_4.id not in found_run_ids
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_status_filter(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
@ -165,7 +165,7 @@ class TestRagPipelineRunTasks:
|
||||
"files": [],
|
||||
"user_id": account.id,
|
||||
"stream": False,
|
||||
"invoke_from": "published",
|
||||
"invoke_from": InvokeFrom.PUBLISHED_PIPELINE.value,
|
||||
"workflow_execution_id": str(uuid.uuid4()),
|
||||
"pipeline_config": {
|
||||
"app_id": str(uuid.uuid4()),
|
||||
@ -249,7 +249,7 @@ class TestRagPipelineRunTasks:
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
@ -294,7 +294,7 @@ class TestRagPipelineRunTasks:
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
@ -743,7 +743,7 @@ class TestRagPipelineRunTasks:
|
||||
assert call_kwargs["pipeline"].id == pipeline.id
|
||||
assert call_kwargs["workflow_id"] == workflow.id
|
||||
assert call_kwargs["user"].id == account.id
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED
|
||||
assert call_kwargs["invoke_from"] == InvokeFrom.PUBLISHED_PIPELINE
|
||||
assert call_kwargs["streaming"] == False
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30") # Custom value for testing
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
@ -51,6 +52,7 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
# Set minimal required env vars
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
@ -75,6 +77,7 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch):
|
||||
# Set environment variables using monkeypatch
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
@ -124,6 +127,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch):
|
||||
# Set environment variables using monkeypatch
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
@ -140,6 +144,7 @@ def test_inner_api_config_exist(monkeypatch: pytest.MonkeyPatch):
|
||||
def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that DB_EXTRAS options are properly merged with default timezone setting"""
|
||||
# Set environment variables
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
@ -199,6 +204,7 @@ def test_celery_broker_url_with_special_chars_password(
|
||||
# Set up basic required environment variables (following existing pattern)
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_TYPE", "postgresql")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
|
||||
@ -0,0 +1,285 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from flask.views import MethodView
|
||||
|
||||
# kombu references MethodView as a global when importing celery/kombu pools.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_app_module():
|
||||
module_name = "controllers.console.app.app"
|
||||
if module_name in sys.modules:
|
||||
return sys.modules[module_name]
|
||||
|
||||
root = Path(__file__).resolve().parents[5]
|
||||
module_path = root / "controllers" / "console" / "app" / "app.py"
|
||||
|
||||
class _StubNamespace:
|
||||
def __init__(self):
|
||||
self.models: dict[str, Any] = {}
|
||||
self.payload = None
|
||||
|
||||
def schema_model(self, name, schema):
|
||||
self.models[name] = schema
|
||||
|
||||
def _decorator(self, obj):
|
||||
return obj
|
||||
|
||||
def doc(self, *args, **kwargs):
|
||||
return self._decorator
|
||||
|
||||
def expect(self, *args, **kwargs):
|
||||
return self._decorator
|
||||
|
||||
def response(self, *args, **kwargs):
|
||||
return self._decorator
|
||||
|
||||
def route(self, *args, **kwargs):
|
||||
def decorator(obj):
|
||||
return obj
|
||||
|
||||
return decorator
|
||||
|
||||
stub_namespace = _StubNamespace()
|
||||
|
||||
original_console = sys.modules.get("controllers.console")
|
||||
original_app_pkg = sys.modules.get("controllers.console.app")
|
||||
stubbed_modules: list[tuple[str, ModuleType | None]] = []
|
||||
|
||||
console_module = ModuleType("controllers.console")
|
||||
console_module.__path__ = [str(root / "controllers" / "console")]
|
||||
console_module.console_ns = stub_namespace
|
||||
console_module.api = None
|
||||
console_module.bp = None
|
||||
sys.modules["controllers.console"] = console_module
|
||||
|
||||
app_package = ModuleType("controllers.console.app")
|
||||
app_package.__path__ = [str(root / "controllers" / "console" / "app")]
|
||||
sys.modules["controllers.console.app"] = app_package
|
||||
console_module.app = app_package
|
||||
|
||||
def _stub_module(name: str, attrs: dict[str, Any]):
|
||||
original = sys.modules.get(name)
|
||||
module = ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
sys.modules[name] = module
|
||||
stubbed_modules.append((name, original))
|
||||
|
||||
class _OpsTraceManager:
|
||||
@staticmethod
|
||||
def get_app_tracing_config(app_id: str) -> dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def update_app_tracing_config(app_id: str, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
_stub_module(
|
||||
"core.ops.ops_trace_manager",
|
||||
{
|
||||
"OpsTraceManager": _OpsTraceManager,
|
||||
"TraceQueueManager": object,
|
||||
"TraceTask": object,
|
||||
},
|
||||
)
|
||||
|
||||
spec = util.spec_from_file_location(module_name, module_path)
|
||||
module = util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
try:
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
finally:
|
||||
for name, original in reversed(stubbed_modules):
|
||||
if original is not None:
|
||||
sys.modules[name] = original
|
||||
else:
|
||||
sys.modules.pop(name, None)
|
||||
if original_console is not None:
|
||||
sys.modules["controllers.console"] = original_console
|
||||
else:
|
||||
sys.modules.pop("controllers.console", None)
|
||||
if original_app_pkg is not None:
|
||||
sys.modules["controllers.console.app"] = original_app_pkg
|
||||
else:
|
||||
sys.modules.pop("controllers.console.app", None)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
_app_module = _load_app_module()
|
||||
AppDetailWithSite = _app_module.AppDetailWithSite
|
||||
AppPagination = _app_module.AppPagination
|
||||
AppPartial = _app_module.AppPartial
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_signed_url(monkeypatch):
|
||||
"""Ensure icon URL generation uses a deterministic helper for tests."""
|
||||
|
||||
def _fake_signed_url(key: str | None) -> str | None:
|
||||
if not key:
|
||||
return None
|
||||
return f"signed:{key}"
|
||||
|
||||
monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
|
||||
|
||||
|
||||
def _ts(hour: int = 12) -> datetime:
|
||||
return datetime(2024, 1, 1, hour, 0, 0)
|
||||
|
||||
|
||||
def _dummy_model_config():
|
||||
return SimpleNamespace(
|
||||
model_dict={"provider": "openai", "name": "gpt-4o"},
|
||||
pre_prompt="hello",
|
||||
created_by="config-author",
|
||||
created_at=_ts(9),
|
||||
updated_by="config-editor",
|
||||
updated_at=_ts(10),
|
||||
)
|
||||
|
||||
|
||||
def _dummy_workflow():
|
||||
return SimpleNamespace(
|
||||
id="wf-1",
|
||||
created_by="workflow-author",
|
||||
created_at=_ts(8),
|
||||
updated_by="workflow-editor",
|
||||
updated_at=_ts(9),
|
||||
)
|
||||
|
||||
|
||||
def test_app_partial_serialization_uses_aliases():
|
||||
created_at = _ts()
|
||||
app_obj = SimpleNamespace(
|
||||
id="app-1",
|
||||
name="My App",
|
||||
desc_or_prompt="Prompt snippet",
|
||||
mode_compatible_with_agent="chat",
|
||||
icon_type="image",
|
||||
icon="icon-key",
|
||||
icon_background="#fff",
|
||||
app_model_config=_dummy_model_config(),
|
||||
workflow=_dummy_workflow(),
|
||||
created_by="creator",
|
||||
created_at=created_at,
|
||||
updated_by="editor",
|
||||
updated_at=created_at,
|
||||
tags=[SimpleNamespace(id="tag-1", name="Utilities", type="app")],
|
||||
access_mode="private",
|
||||
create_user_name="Creator",
|
||||
author_name="Author",
|
||||
has_draft_trigger=True,
|
||||
)
|
||||
|
||||
serialized = AppPartial.model_validate(app_obj, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
assert serialized["description"] == "Prompt snippet"
|
||||
assert serialized["mode"] == "chat"
|
||||
assert serialized["icon_url"] == "signed:icon-key"
|
||||
assert serialized["created_at"] == int(created_at.timestamp())
|
||||
assert serialized["updated_at"] == int(created_at.timestamp())
|
||||
assert serialized["model_config"]["model"] == {"provider": "openai", "name": "gpt-4o"}
|
||||
assert serialized["workflow"]["id"] == "wf-1"
|
||||
assert serialized["tags"][0]["name"] == "Utilities"
|
||||
|
||||
|
||||
def test_app_detail_with_site_includes_nested_serialization():
|
||||
timestamp = _ts(14)
|
||||
site = SimpleNamespace(
|
||||
code="site-code",
|
||||
title="Public Site",
|
||||
icon_type="image",
|
||||
icon="site-icon",
|
||||
created_at=timestamp,
|
||||
updated_at=timestamp,
|
||||
)
|
||||
app_obj = SimpleNamespace(
|
||||
id="app-2",
|
||||
name="Detailed App",
|
||||
description="Desc",
|
||||
mode_compatible_with_agent="advanced-chat",
|
||||
icon_type="image",
|
||||
icon="detail-icon",
|
||||
icon_background="#123456",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
app_model_config={
|
||||
"opening_statement": "hi",
|
||||
"model": {"provider": "openai", "name": "gpt-4o"},
|
||||
"retriever_resource": {"enabled": True},
|
||||
},
|
||||
workflow=_dummy_workflow(),
|
||||
tracing={"enabled": True},
|
||||
use_icon_as_answer_icon=True,
|
||||
created_by="creator",
|
||||
created_at=timestamp,
|
||||
updated_by="editor",
|
||||
updated_at=timestamp,
|
||||
access_mode="public",
|
||||
tags=[SimpleNamespace(id="tag-2", name="Prod", type="app")],
|
||||
api_base_url="https://api.example.com/v1",
|
||||
max_active_requests=5,
|
||||
deleted_tools=[{"type": "api", "tool_name": "search", "provider_id": "prov"}],
|
||||
site=site,
|
||||
)
|
||||
|
||||
serialized = AppDetailWithSite.model_validate(app_obj, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
assert serialized["icon_url"] == "signed:detail-icon"
|
||||
assert serialized["model_config"]["retriever_resource"] == {"enabled": True}
|
||||
assert serialized["deleted_tools"][0]["tool_name"] == "search"
|
||||
assert serialized["site"]["icon_url"] == "signed:site-icon"
|
||||
assert serialized["site"]["created_at"] == int(timestamp.timestamp())
|
||||
|
||||
|
||||
def test_app_pagination_aliases_per_page_and_has_next():
|
||||
item_one = SimpleNamespace(
|
||||
id="app-10",
|
||||
name="Paginated One",
|
||||
desc_or_prompt="Summary",
|
||||
mode_compatible_with_agent="chat",
|
||||
icon_type="image",
|
||||
icon="first-icon",
|
||||
created_at=_ts(15),
|
||||
updated_at=_ts(15),
|
||||
)
|
||||
item_two = SimpleNamespace(
|
||||
id="app-11",
|
||||
name="Paginated Two",
|
||||
desc_or_prompt="Summary",
|
||||
mode_compatible_with_agent="agent-chat",
|
||||
icon_type="emoji",
|
||||
icon="🙂",
|
||||
created_at=_ts(16),
|
||||
updated_at=_ts(16),
|
||||
)
|
||||
pagination = SimpleNamespace(
|
||||
page=2,
|
||||
per_page=10,
|
||||
total=50,
|
||||
has_next=True,
|
||||
items=[item_one, item_two],
|
||||
)
|
||||
|
||||
serialized = AppPagination.model_validate(pagination, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
assert serialized["page"] == 2
|
||||
assert serialized["limit"] == 10
|
||||
assert serialized["has_more"] is True
|
||||
assert len(serialized["data"]) == 2
|
||||
assert serialized["data"][0]["icon_url"] == "signed:first-icon"
|
||||
assert serialized["data"][1]["icon_url"] is None
|
||||
@ -1,7 +1,9 @@
|
||||
import builtins
|
||||
import io
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask.views import MethodView
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.errors import (
|
||||
@ -14,6 +16,9 @@ from controllers.common.errors import (
|
||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestFileUploadSecurity:
|
||||
"""Test file upload security logic without complex framework setup"""
|
||||
@ -128,7 +133,7 @@ class TestFileUploadSecurity:
|
||||
# Test passes if no exception is raised
|
||||
|
||||
# Test 4: Service error handling
|
||||
@patch("services.file_service.FileService.upload_file")
|
||||
@patch("controllers.console.files.FileService.upload_file")
|
||||
def test_should_handle_file_too_large_error(self, mock_upload):
|
||||
"""Test that service FileTooLargeError is properly converted"""
|
||||
mock_upload.side_effect = ServiceFileTooLargeError("File too large")
|
||||
@ -140,7 +145,7 @@ class TestFileUploadSecurity:
|
||||
with pytest.raises(FileTooLargeError):
|
||||
raise FileTooLargeError(e.description)
|
||||
|
||||
@patch("services.file_service.FileService.upload_file")
|
||||
@patch("controllers.console.files.FileService.upload_file")
|
||||
def test_should_handle_unsupported_file_type_error(self, mock_upload):
|
||||
"""Test that service UnsupportedFileTypeError is properly converted"""
|
||||
mock_upload.side_effect = ServiceUnsupportedFileTypeError()
|
||||
|
||||
3
api/tests/unit_tests/core/agent/__init__.py
Normal file
3
api/tests/unit_tests/core/agent/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
Mark agent test modules as a package to avoid import name collisions.
|
||||
"""
|
||||
324
api/tests/unit_tests/core/agent/patterns/test_base.py
Normal file
324
api/tests/unit_tests/core/agent/patterns/test_base.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""Tests for AgentPattern base class."""
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentLog, ExecutionContext
|
||||
from core.agent.patterns.base import AgentPattern
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class ConcreteAgentPattern(AgentPattern):
|
||||
"""Concrete implementation of AgentPattern for testing."""
|
||||
|
||||
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
|
||||
"""Minimal implementation for testing."""
|
||||
yield from []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_pattern(mock_model_instance, mock_context):
|
||||
"""Create a concrete agent pattern for testing."""
|
||||
return ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
|
||||
class TestAccumulateUsage:
|
||||
"""Tests for _accumulate_usage method."""
|
||||
|
||||
def test_accumulate_usage_to_empty_dict(self, agent_pattern):
|
||||
"""Test accumulating usage to an empty dict creates a copy."""
|
||||
total_usage: dict = {"usage": None}
|
||||
delta_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
agent_pattern._accumulate_usage(total_usage, delta_usage)
|
||||
|
||||
assert total_usage["usage"] is not None
|
||||
assert total_usage["usage"].total_tokens == 150
|
||||
assert total_usage["usage"].prompt_tokens == 100
|
||||
assert total_usage["usage"].completion_tokens == 50
|
||||
# Verify it's a copy, not a reference
|
||||
assert total_usage["usage"] is not delta_usage
|
||||
|
||||
def test_accumulate_usage_adds_to_existing(self, agent_pattern):
|
||||
"""Test accumulating usage adds to existing values."""
|
||||
initial_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
total_usage: dict = {"usage": initial_usage}
|
||||
|
||||
delta_usage = LLMUsage(
|
||||
prompt_tokens=200,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.2"),
|
||||
completion_tokens=100,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.2"),
|
||||
total_tokens=300,
|
||||
total_price=Decimal("0.4"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
agent_pattern._accumulate_usage(total_usage, delta_usage)
|
||||
|
||||
assert total_usage["usage"].total_tokens == 450 # 150 + 300
|
||||
assert total_usage["usage"].prompt_tokens == 300 # 100 + 200
|
||||
assert total_usage["usage"].completion_tokens == 150 # 50 + 100
|
||||
|
||||
def test_accumulate_usage_multiple_rounds(self, agent_pattern):
|
||||
"""Test accumulating usage across multiple rounds."""
|
||||
total_usage: dict = {"usage": None}
|
||||
|
||||
# Round 1: 100 tokens
|
||||
round1_usage = LLMUsage(
|
||||
prompt_tokens=70,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.07"),
|
||||
completion_tokens=30,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.06"),
|
||||
total_tokens=100,
|
||||
total_price=Decimal("0.13"),
|
||||
currency="USD",
|
||||
latency=0.3,
|
||||
)
|
||||
agent_pattern._accumulate_usage(total_usage, round1_usage)
|
||||
assert total_usage["usage"].total_tokens == 100
|
||||
|
||||
# Round 2: 150 tokens
|
||||
round2_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.4,
|
||||
)
|
||||
agent_pattern._accumulate_usage(total_usage, round2_usage)
|
||||
assert total_usage["usage"].total_tokens == 250 # 100 + 150
|
||||
|
||||
# Round 3: 200 tokens
|
||||
round3_usage = LLMUsage(
|
||||
prompt_tokens=130,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.13"),
|
||||
completion_tokens=70,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.14"),
|
||||
total_tokens=200,
|
||||
total_price=Decimal("0.27"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
agent_pattern._accumulate_usage(total_usage, round3_usage)
|
||||
assert total_usage["usage"].total_tokens == 450 # 100 + 150 + 200
|
||||
|
||||
|
||||
class TestCreateLog:
|
||||
"""Tests for _create_log method."""
|
||||
|
||||
def test_create_log_with_label_and_status(self, agent_pattern):
|
||||
"""Test creating a log with label and status."""
|
||||
log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"key": "value"},
|
||||
)
|
||||
|
||||
assert log.label == "ROUND 1"
|
||||
assert log.log_type == AgentLog.LogType.ROUND
|
||||
assert log.status == AgentLog.LogStatus.START
|
||||
assert log.data == {"key": "value"}
|
||||
assert log.parent_id is None
|
||||
|
||||
def test_create_log_with_parent_id(self, agent_pattern):
|
||||
"""Test creating a log with parent_id."""
|
||||
parent_log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
child_log = agent_pattern._create_log(
|
||||
label="CALL tool",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=parent_log.id,
|
||||
)
|
||||
|
||||
assert child_log.parent_id == parent_log.id
|
||||
assert child_log.log_type == AgentLog.LogType.TOOL_CALL
|
||||
|
||||
|
||||
class TestFinishLog:
|
||||
"""Tests for _finish_log method."""
|
||||
|
||||
def test_finish_log_updates_status(self, agent_pattern):
|
||||
"""Test that finish_log updates status to SUCCESS."""
|
||||
log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
finished_log = agent_pattern._finish_log(log, data={"result": "done"})
|
||||
|
||||
assert finished_log.status == AgentLog.LogStatus.SUCCESS
|
||||
assert finished_log.data == {"result": "done"}
|
||||
|
||||
def test_finish_log_adds_usage_metadata(self, agent_pattern):
|
||||
"""Test that finish_log adds usage to metadata."""
|
||||
log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
finished_log = agent_pattern._finish_log(log, usage=usage)
|
||||
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_TOKENS] == 150
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_PRICE] == Decimal("0.2")
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.CURRENCY] == "USD"
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.LLM_USAGE] == usage
|
||||
|
||||
|
||||
class TestFindToolByName:
|
||||
"""Tests for _find_tool_by_name method."""
|
||||
|
||||
def test_find_existing_tool(self, mock_model_instance, mock_context):
|
||||
"""Test finding an existing tool by name."""
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
found_tool = pattern._find_tool_by_name("test_tool")
|
||||
assert found_tool == mock_tool
|
||||
|
||||
def test_find_nonexistent_tool_returns_none(self, mock_model_instance, mock_context):
|
||||
"""Test that finding a nonexistent tool returns None."""
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
found_tool = pattern._find_tool_by_name("nonexistent_tool")
|
||||
assert found_tool is None
|
||||
|
||||
|
||||
class TestMaxIterationsCapping:
|
||||
"""Tests for max_iterations capping."""
|
||||
|
||||
def test_max_iterations_capped_at_99(self, mock_model_instance, mock_context):
|
||||
"""Test that max_iterations is capped at 99."""
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
max_iterations=150,
|
||||
)
|
||||
|
||||
assert pattern.max_iterations == 99
|
||||
|
||||
def test_max_iterations_not_capped_when_under_99(self, mock_model_instance, mock_context):
|
||||
"""Test that max_iterations is not capped when under 99."""
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
max_iterations=50,
|
||||
)
|
||||
|
||||
assert pattern.max_iterations == 50
|
||||
332
api/tests/unit_tests/core/agent/patterns/test_function_call.py
Normal file
332
api/tests/unit_tests/core/agent/patterns/test_function_call.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""Tests for FunctionCallStrategy."""
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentLog, ExecutionContext
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool."""
|
||||
tool = MagicMock()
|
||||
tool.entity.identity.name = "test_tool"
|
||||
tool.to_prompt_message_tool.return_value = PromptMessageTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"param1": {"type": "string", "description": "A parameter"}},
|
||||
"required": ["param1"],
|
||||
},
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
class TestFunctionCallStrategyInit:
|
||||
"""Tests for FunctionCallStrategy initialization."""
|
||||
|
||||
def test_initialization(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test basic initialization."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
assert strategy.model_instance == mock_model_instance
|
||||
assert strategy.context == mock_context
|
||||
assert strategy.max_iterations == 10
|
||||
assert len(strategy.tools) == 1
|
||||
|
||||
def test_initialization_with_tool_invoke_hook(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test initialization with tool_invoke_hook."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
mock_hook = MagicMock()
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
tool_invoke_hook=mock_hook,
|
||||
)
|
||||
|
||||
assert strategy.tool_invoke_hook == mock_hook
|
||||
|
||||
|
||||
class TestConvertToolsToPromptFormat:
|
||||
"""Tests for _convert_tools_to_prompt_format method."""
|
||||
|
||||
def test_convert_tools_returns_prompt_message_tools(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that _convert_tools_to_prompt_format returns PromptMessageTool list."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
tools = strategy._convert_tools_to_prompt_format()
|
||||
|
||||
assert len(tools) == 1
|
||||
assert isinstance(tools[0], PromptMessageTool)
|
||||
assert tools[0].name == "test_tool"
|
||||
|
||||
def test_convert_tools_empty_when_no_tools(self, mock_model_instance, mock_context):
|
||||
"""Test that _convert_tools_to_prompt_format returns empty list when no tools."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
tools = strategy._convert_tools_to_prompt_format()
|
||||
|
||||
assert tools == []
|
||||
|
||||
|
||||
class TestAgentLogGeneration:
|
||||
"""Tests for AgentLog generation during run."""
|
||||
|
||||
def test_round_log_structure(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that round logs have correct structure."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
# Create a round log
|
||||
round_log = strategy._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"inputs": {"query": "test"}},
|
||||
)
|
||||
|
||||
assert round_log.label == "ROUND 1"
|
||||
assert round_log.log_type == AgentLog.LogType.ROUND
|
||||
assert round_log.status == AgentLog.LogStatus.START
|
||||
assert "inputs" in round_log.data
|
||||
|
||||
def test_tool_call_log_structure(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that tool call logs have correct structure."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
# Create a parent round log
|
||||
round_log = strategy._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
# Create a tool call log
|
||||
tool_log = strategy._create_log(
|
||||
label="CALL test_tool",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"tool_name": "test_tool", "tool_args": {"param1": "value1"}},
|
||||
parent_id=round_log.id,
|
||||
)
|
||||
|
||||
assert tool_log.label == "CALL test_tool"
|
||||
assert tool_log.log_type == AgentLog.LogType.TOOL_CALL
|
||||
assert tool_log.parent_id == round_log.id
|
||||
assert tool_log.data["tool_name"] == "test_tool"
|
||||
|
||||
|
||||
class TestToolInvocation:
|
||||
"""Tests for tool invocation."""
|
||||
|
||||
def test_invoke_tool_with_hook(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that tool invocation uses hook when provided."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
mock_hook = MagicMock()
|
||||
mock_meta = ToolInvokeMeta(
|
||||
time_cost=0.5,
|
||||
error=None,
|
||||
tool_config={"tool_provider_type": "test", "tool_provider": "test_id"},
|
||||
)
|
||||
mock_hook.return_value = ("Tool result", ["file-1"], mock_meta)
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
tool_invoke_hook=mock_hook,
|
||||
)
|
||||
|
||||
result, files, meta = strategy._invoke_tool(mock_tool, {"param1": "value"}, "test_tool")
|
||||
|
||||
mock_hook.assert_called_once()
|
||||
assert result == "Tool result"
|
||||
assert files == [] # Hook returns file IDs, but _invoke_tool returns empty File list
|
||||
assert meta == mock_meta
|
||||
|
||||
def test_invoke_tool_without_hook_attribute_set(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that tool_invoke_hook is None when not provided."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
tool_invoke_hook=None,
|
||||
)
|
||||
|
||||
# Verify that tool_invoke_hook is None
|
||||
assert strategy.tool_invoke_hook is None
|
||||
|
||||
|
||||
class TestUsageTracking:
|
||||
"""Tests for usage tracking across rounds."""
|
||||
|
||||
def test_round_usage_is_separate_from_total(self, mock_model_instance, mock_context):
|
||||
"""Test that round usage is tracked separately from total."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
# Simulate two rounds of usage
|
||||
total_usage: dict = {"usage": None}
|
||||
round1_usage: dict = {"usage": None}
|
||||
round2_usage: dict = {"usage": None}
|
||||
|
||||
# Round 1
|
||||
usage1 = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
strategy._accumulate_usage(round1_usage, usage1)
|
||||
strategy._accumulate_usage(total_usage, usage1)
|
||||
|
||||
# Round 2
|
||||
usage2 = LLMUsage(
|
||||
prompt_tokens=200,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.2"),
|
||||
completion_tokens=100,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.2"),
|
||||
total_tokens=300,
|
||||
total_price=Decimal("0.4"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
strategy._accumulate_usage(round2_usage, usage2)
|
||||
strategy._accumulate_usage(total_usage, usage2)
|
||||
|
||||
# Verify round usage is separate
|
||||
assert round1_usage["usage"].total_tokens == 150
|
||||
assert round2_usage["usage"].total_tokens == 300
|
||||
# Verify total is accumulated
|
||||
assert total_usage["usage"].total_tokens == 450
|
||||
|
||||
|
||||
class TestPromptMessageHandling:
|
||||
"""Tests for prompt message handling."""
|
||||
|
||||
def test_messages_include_system_and_user(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that messages include system and user prompts."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Hello"),
|
||||
]
|
||||
|
||||
# Just verify the messages can be processed
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], SystemPromptMessage)
|
||||
assert isinstance(messages[1], UserPromptMessage)
|
||||
|
||||
def test_assistant_message_with_tool_calls(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that assistant messages can contain tool calls."""
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id="call_123",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments='{"param1": "value1"}',
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content="I'll help you with that.",
|
||||
tool_calls=[tool_call],
|
||||
)
|
||||
|
||||
assert len(assistant_message.tool_calls) == 1
|
||||
assert assistant_message.tool_calls[0].function.name == "test_tool"
|
||||
224
api/tests/unit_tests/core/agent/patterns/test_react.py
Normal file
224
api/tests/unit_tests/core/agent/patterns/test_react.py
Normal file
@ -0,0 +1,224 @@
|
||||
"""Tests for ReActStrategy."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import ExecutionContext
|
||||
from core.agent.patterns.react import ReActStrategy
|
||||
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool."""
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
|
||||
tool = MagicMock()
|
||||
tool.entity.identity.name = "test_tool"
|
||||
tool.entity.identity.provider = "test_provider"
|
||||
|
||||
# Use real PromptMessageTool for proper serialization
|
||||
prompt_tool = PromptMessageTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
)
|
||||
tool.to_prompt_message_tool.return_value = prompt_tool
|
||||
|
||||
return tool
|
||||
|
||||
|
||||
class TestReActStrategyInit:
|
||||
"""Tests for ReActStrategy initialization."""
|
||||
|
||||
def test_init_with_instruction(self, mock_model_instance, mock_context):
|
||||
"""Test that instruction is stored correctly."""
|
||||
instruction = "You are a helpful assistant."
|
||||
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
assert strategy.instruction == instruction
|
||||
|
||||
def test_init_with_empty_instruction(self, mock_model_instance, mock_context):
|
||||
"""Test that empty instruction is handled correctly."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
assert strategy.instruction == ""
|
||||
|
||||
|
||||
class TestBuildPromptWithReactFormat:
|
||||
"""Tests for _build_prompt_with_react_format method."""
|
||||
|
||||
def test_replace_tools_placeholder(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that {{tools}} placeholder is replaced."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
system_content = "You have access to: {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
UserPromptMessage(content="Hello"),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
# The tools placeholder should be replaced with JSON
|
||||
assert "{{tools}}" not in result[0].content
|
||||
assert "test_tool" in result[0].content
|
||||
|
||||
def test_replace_tool_names_placeholder(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that {{tool_names}} placeholder is replaced."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
system_content = "Valid actions: {{tool_names}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
assert "{{tool_names}}" not in result[0].content
|
||||
assert '"test_tool"' in result[0].content
|
||||
|
||||
def test_replace_instruction_placeholder(self, mock_model_instance, mock_context):
|
||||
"""Test that {{instruction}} placeholder is replaced."""
|
||||
instruction = "You are a helpful coding assistant."
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
system_content = "{{instruction}}\n\nYou have access to: {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True, instruction)
|
||||
|
||||
assert "{{instruction}}" not in result[0].content
|
||||
assert instruction in result[0].content
|
||||
|
||||
def test_no_tools_available_message(self, mock_model_instance, mock_context):
|
||||
"""Test that 'No tools available' is shown when include_tools is False."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
system_content = "You have access to: {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], False)
|
||||
|
||||
assert "No tools available" in result[0].content
|
||||
|
||||
def test_scratchpad_appended_as_assistant_message(self, mock_model_instance, mock_context):
|
||||
"""Test that agent scratchpad is appended as AssistantPromptMessage."""
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="System prompt"),
|
||||
UserPromptMessage(content="User query"),
|
||||
]
|
||||
|
||||
scratchpad = [
|
||||
AgentScratchpadUnit(
|
||||
thought="I need to search for information",
|
||||
action_str='{"action": "search", "action_input": "query"}',
|
||||
observation="Search results here",
|
||||
)
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, scratchpad, True)
|
||||
|
||||
# The last message should be an AssistantPromptMessage with scratchpad content
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[-1], AssistantPromptMessage)
|
||||
assert "I need to search for information" in result[-1].content
|
||||
assert "Search results here" in result[-1].content
|
||||
|
||||
def test_empty_scratchpad_no_extra_message(self, mock_model_instance, mock_context):
|
||||
"""Test that empty scratchpad doesn't add extra message."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="System prompt"),
|
||||
UserPromptMessage(content="User query"),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
# Should only have the original 2 messages
|
||||
assert len(result) == 2
|
||||
|
||||
def test_original_messages_not_modified(self, mock_model_instance, mock_context):
|
||||
"""Test that original messages list is not modified."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
original_content = "Original system prompt {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=original_content),
|
||||
]
|
||||
|
||||
strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
# Original message should not be modified
|
||||
assert messages[0].content == original_content
|
||||
@ -0,0 +1,203 @@
|
||||
"""Tests for StrategyFactory."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
from core.agent.patterns.react import ReActStrategy
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
class TestStrategyFactory:
|
||||
"""Tests for StrategyFactory.create_strategy method."""
|
||||
|
||||
def test_create_function_call_strategy_with_tool_call_feature(self, mock_model_instance, mock_context):
|
||||
"""Test that FunctionCallStrategy is created when model supports TOOL_CALL."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_create_function_call_strategy_with_multi_tool_call_feature(self, mock_model_instance, mock_context):
|
||||
"""Test that FunctionCallStrategy is created when model supports MULTI_TOOL_CALL."""
|
||||
model_features = [ModelFeature.MULTI_TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_create_function_call_strategy_with_stream_tool_call_feature(self, mock_model_instance, mock_context):
|
||||
"""Test that FunctionCallStrategy is created when model supports STREAM_TOOL_CALL."""
|
||||
model_features = [ModelFeature.STREAM_TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_create_react_strategy_without_tool_call_features(self, mock_model_instance, mock_context):
|
||||
"""Test that ReActStrategy is created when model doesn't support tool calling."""
|
||||
model_features = [ModelFeature.VISION] # Only vision, no tool calling
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_create_react_strategy_with_empty_features(self, mock_model_instance, mock_context):
|
||||
"""Test that ReActStrategy is created when model has no features."""
|
||||
model_features: list[ModelFeature] = []
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_explicit_function_calling_strategy_with_support(self, mock_model_instance, mock_context):
|
||||
"""Test explicit FUNCTION_CALLING strategy selection with model support."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_explicit_function_calling_strategy_without_support_falls_back_to_react(
|
||||
self, mock_model_instance, mock_context
|
||||
):
|
||||
"""Test that explicit FUNCTION_CALLING falls back to ReAct when not supported."""
|
||||
model_features: list[ModelFeature] = [] # No tool calling support
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
)
|
||||
|
||||
# Should fall back to ReAct since FC is not supported
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_explicit_chain_of_thought_strategy(self, mock_model_instance, mock_context):
|
||||
"""Test explicit CHAIN_OF_THOUGHT strategy selection."""
|
||||
model_features = [ModelFeature.TOOL_CALL] # Even with tool call support
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_react_strategy_with_instruction(self, mock_model_instance, mock_context):
|
||||
"""Test that ReActStrategy receives instruction parameter."""
|
||||
model_features: list[ModelFeature] = []
|
||||
instruction = "You are a helpful assistant."
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
assert strategy.instruction == instruction
|
||||
|
||||
def test_max_iterations_passed_to_strategy(self, mock_model_instance, mock_context):
|
||||
"""Test that max_iterations is passed to the strategy."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
max_iterations = 5
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
assert strategy.max_iterations == max_iterations
|
||||
|
||||
def test_tool_invoke_hook_passed_to_strategy(self, mock_model_instance, mock_context):
|
||||
"""Test that tool_invoke_hook is passed to the strategy."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
mock_hook = MagicMock()
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
tool_invoke_hook=mock_hook,
|
||||
)
|
||||
|
||||
assert strategy.tool_invoke_hook == mock_hook
|
||||
388
api/tests/unit_tests/core/agent/test_agent_app_runner.py
Normal file
388
api/tests/unit_tests/core/agent/test_agent_app_runner.py
Normal file
@ -0,0 +1,388 @@
|
||||
"""Tests for AgentAppRunner."""
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult
|
||||
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
"""Tests for _organize_prompt_messages method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
# We'll patch the class to avoid complex initialization
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
|
||||
# Set up required attributes
|
||||
runner.config = MagicMock(spec=AgentEntity)
|
||||
runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
runner.config.prompt = None
|
||||
|
||||
runner.app_config = MagicMock()
|
||||
runner.app_config.prompt_template = MagicMock()
|
||||
runner.app_config.prompt_template.simple_prompt_template = "You are a helpful assistant."
|
||||
|
||||
runner.history_prompt_messages = []
|
||||
runner.query = "Hello"
|
||||
runner._current_thoughts = []
|
||||
runner.files = []
|
||||
runner.model_config = MagicMock()
|
||||
runner.memory = None
|
||||
runner.application_generate_entity = MagicMock()
|
||||
runner.application_generate_entity.file_upload_config = None
|
||||
|
||||
return runner
|
||||
|
||||
def test_function_calling_uses_simple_prompt(self, mock_runner):
|
||||
"""Test that function calling strategy uses simple_prompt_template."""
|
||||
mock_runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
with patch.object(mock_runner, "_init_system_message") as mock_init:
|
||||
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
|
||||
with patch.object(mock_runner, "_organize_user_query") as mock_query:
|
||||
mock_query.return_value = [UserPromptMessage(content="Hello")]
|
||||
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
|
||||
mock_transform.return_value.get_prompt.return_value = [
|
||||
SystemPromptMessage(content="You are a helpful assistant.")
|
||||
]
|
||||
|
||||
result = mock_runner._organize_prompt_messages()
|
||||
|
||||
# Verify _init_system_message was called with simple_prompt_template
|
||||
mock_init.assert_called_once()
|
||||
call_args = mock_init.call_args[0]
|
||||
assert call_args[0] == "You are a helpful assistant."
|
||||
|
||||
def test_chain_of_thought_uses_agent_prompt(self, mock_runner):
|
||||
"""Test that chain of thought strategy uses agent prompt template."""
|
||||
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
mock_runner.config.prompt = AgentPromptEntity(
|
||||
first_prompt="ReAct prompt template with {{tools}}",
|
||||
next_iteration="Continue...",
|
||||
)
|
||||
|
||||
with patch.object(mock_runner, "_init_system_message") as mock_init:
|
||||
mock_init.return_value = [SystemPromptMessage(content="ReAct prompt template with {{tools}}")]
|
||||
with patch.object(mock_runner, "_organize_user_query") as mock_query:
|
||||
mock_query.return_value = [UserPromptMessage(content="Hello")]
|
||||
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
|
||||
mock_transform.return_value.get_prompt.return_value = [
|
||||
SystemPromptMessage(content="ReAct prompt template with {{tools}}")
|
||||
]
|
||||
|
||||
result = mock_runner._organize_prompt_messages()
|
||||
|
||||
# Verify _init_system_message was called with agent prompt
|
||||
mock_init.assert_called_once()
|
||||
call_args = mock_init.call_args[0]
|
||||
assert call_args[0] == "ReAct prompt template with {{tools}}"
|
||||
|
||||
def test_chain_of_thought_without_prompt_falls_back(self, mock_runner):
|
||||
"""Test that chain of thought without prompt falls back to simple_prompt_template."""
|
||||
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
mock_runner.config.prompt = None
|
||||
|
||||
with patch.object(mock_runner, "_init_system_message") as mock_init:
|
||||
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
|
||||
with patch.object(mock_runner, "_organize_user_query") as mock_query:
|
||||
mock_query.return_value = [UserPromptMessage(content="Hello")]
|
||||
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
|
||||
mock_transform.return_value.get_prompt.return_value = [
|
||||
SystemPromptMessage(content="You are a helpful assistant.")
|
||||
]
|
||||
|
||||
result = mock_runner._organize_prompt_messages()
|
||||
|
||||
# Verify _init_system_message was called with simple_prompt_template
|
||||
mock_init.assert_called_once()
|
||||
call_args = mock_init.call_args[0]
|
||||
assert call_args[0] == "You are a helpful assistant."
|
||||
|
||||
|
||||
class TestInitSystemMessage:
|
||||
"""Tests for _init_system_message method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
return runner
|
||||
|
||||
def test_empty_messages_with_template(self, mock_runner):
|
||||
"""Test that system message is created when messages are empty."""
|
||||
result = mock_runner._init_system_message("System template", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], SystemPromptMessage)
|
||||
assert result[0].content == "System template"
|
||||
|
||||
def test_empty_messages_without_template(self, mock_runner):
|
||||
"""Test that empty list is returned when no template and no messages."""
|
||||
result = mock_runner._init_system_message("", [])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_existing_system_message_not_duplicated(self, mock_runner):
|
||||
"""Test that system message is not duplicated if already present."""
|
||||
existing_messages = [
|
||||
SystemPromptMessage(content="Existing system"),
|
||||
UserPromptMessage(content="User message"),
|
||||
]
|
||||
|
||||
result = mock_runner._init_system_message("New template", existing_messages)
|
||||
|
||||
# Should not insert new system message
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "Existing system"
|
||||
|
||||
def test_system_message_inserted_when_missing(self, mock_runner):
|
||||
"""Test that system message is inserted when first message is not system."""
|
||||
existing_messages = [
|
||||
UserPromptMessage(content="User message"),
|
||||
]
|
||||
|
||||
result = mock_runner._init_system_message("System template", existing_messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], SystemPromptMessage)
|
||||
assert result[0].content == "System template"
|
||||
|
||||
|
||||
class TestClearUserPromptImageMessages:
|
||||
"""Tests for _clear_user_prompt_image_messages method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
return runner
|
||||
|
||||
def test_text_content_unchanged(self, mock_runner):
|
||||
"""Test that text content is unchanged."""
|
||||
messages = [
|
||||
UserPromptMessage(content="Plain text message"),
|
||||
]
|
||||
|
||||
result = mock_runner._clear_user_prompt_image_messages(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Plain text message"
|
||||
|
||||
def test_original_messages_not_modified(self, mock_runner):
|
||||
"""Test that original messages are not modified (deep copy)."""
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
messages = [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="Text part"),
|
||||
ImagePromptMessageContent(
|
||||
data="http://example.com/image.jpg",
|
||||
format="url",
|
||||
mime_type="image/jpeg",
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
result = mock_runner._clear_user_prompt_image_messages(messages)
|
||||
|
||||
# Original should still have list content
|
||||
assert isinstance(messages[0].content, list)
|
||||
# Result should have string content
|
||||
assert isinstance(result[0].content, str)
|
||||
|
||||
|
||||
class TestToolInvokeHook:
|
||||
"""Tests for _create_tool_invoke_hook method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
|
||||
runner.user_id = "test-user"
|
||||
runner.tenant_id = "test-tenant"
|
||||
runner.application_generate_entity = MagicMock()
|
||||
runner.application_generate_entity.trace_manager = None
|
||||
runner.application_generate_entity.invoke_from = "api"
|
||||
runner.application_generate_entity.app_config = MagicMock()
|
||||
runner.application_generate_entity.app_config.app_id = "test-app"
|
||||
runner.agent_callback = MagicMock()
|
||||
runner.conversation = MagicMock()
|
||||
runner.conversation.id = "test-conversation"
|
||||
runner.queue_manager = MagicMock()
|
||||
runner._current_message_file_ids = []
|
||||
|
||||
return runner
|
||||
|
||||
def test_hook_calls_agent_invoke(self, mock_runner):
|
||||
"""Test that the hook calls ToolEngine.agent_invoke."""
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = "test-message"
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool_meta = ToolInvokeMeta(
|
||||
time_cost=0.5,
|
||||
error=None,
|
||||
tool_config={
|
||||
"tool_provider_type": "test_provider",
|
||||
"tool_provider": "test_id",
|
||||
},
|
||||
)
|
||||
|
||||
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
|
||||
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
|
||||
|
||||
hook = mock_runner._create_tool_invoke_hook(mock_message)
|
||||
result_content, result_files, result_meta = hook(mock_tool, {"arg": "value"}, "test_tool")
|
||||
|
||||
# Verify ToolEngine.agent_invoke was called
|
||||
mock_engine.agent_invoke.assert_called_once()
|
||||
|
||||
# Verify return values
|
||||
assert result_content == "Tool result"
|
||||
assert result_files == ["file-1", "file-2"]
|
||||
assert result_meta == mock_tool_meta
|
||||
|
||||
def test_hook_publishes_file_events(self, mock_runner):
|
||||
"""Test that the hook publishes QueueMessageFileEvent for files."""
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = "test-message"
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool_meta = ToolInvokeMeta(
|
||||
time_cost=0.5,
|
||||
error=None,
|
||||
tool_config={},
|
||||
)
|
||||
|
||||
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
|
||||
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
|
||||
|
||||
hook = mock_runner._create_tool_invoke_hook(mock_message)
|
||||
hook(mock_tool, {}, "test_tool")
|
||||
|
||||
# Verify file events were published
|
||||
assert mock_runner.queue_manager.publish.call_count == 2
|
||||
assert mock_runner._current_message_file_ids == ["file-1", "file-2"]
|
||||
|
||||
|
||||
class TestAgentLogProcessing:
|
||||
"""Tests for AgentLog processing in run method."""
|
||||
|
||||
def test_agent_log_status_enum(self):
|
||||
"""Test AgentLog status enum values."""
|
||||
assert AgentLog.LogStatus.START == "start"
|
||||
assert AgentLog.LogStatus.SUCCESS == "success"
|
||||
assert AgentLog.LogStatus.ERROR == "error"
|
||||
|
||||
def test_agent_log_metadata_enum(self):
|
||||
"""Test AgentLog metadata enum values."""
|
||||
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
|
||||
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
|
||||
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
|
||||
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
|
||||
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
|
||||
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
|
||||
|
||||
def test_agent_result_structure(self):
|
||||
"""Test AgentResult structure."""
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
result = AgentResult(
|
||||
text="Final answer",
|
||||
files=[],
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
assert result.text == "Final answer"
|
||||
assert result.files == []
|
||||
assert result.usage == usage
|
||||
assert result.finish_reason == "stop"
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
"""Tests for _organize_user_query method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
runner.files = []
|
||||
runner.application_generate_entity = MagicMock()
|
||||
runner.application_generate_entity.file_upload_config = None
|
||||
return runner
|
||||
|
||||
def test_simple_query_without_files(self, mock_runner):
|
||||
"""Test organizing a simple query without files."""
|
||||
result = mock_runner._organize_user_query("Hello world", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
assert result[0].content == "Hello world"
|
||||
|
||||
def test_query_with_files(self, mock_runner):
|
||||
"""Test organizing a query with files."""
|
||||
from core.file.models import File
|
||||
|
||||
mock_file = MagicMock(spec=File)
|
||||
mock_runner.files = [mock_file]
|
||||
|
||||
with patch("core.agent.agent_app_runner.file_manager") as mock_fm:
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_fm.to_prompt_message_content.return_value = ImagePromptMessageContent(
|
||||
data="http://example.com/image.jpg",
|
||||
format="url",
|
||||
mime_type="image/jpeg",
|
||||
)
|
||||
|
||||
result = mock_runner._organize_user_query("Describe this image", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
assert isinstance(result[0].content, list)
|
||||
assert len(result[0].content) == 2 # Image + Text
|
||||
191
api/tests/unit_tests/core/agent/test_entities.py
Normal file
191
api/tests/unit_tests/core/agent/test_entities.py
Normal file
@ -0,0 +1,191 @@
|
||||
"""Tests for agent entities."""
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentScratchpadUnit, ExecutionContext
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Tests for ExecutionContext entity."""
|
||||
|
||||
def test_create_with_all_fields(self):
|
||||
"""Test creating ExecutionContext with all fields."""
|
||||
context = ExecutionContext(
|
||||
user_id="user-123",
|
||||
app_id="app-456",
|
||||
conversation_id="conv-789",
|
||||
message_id="msg-012",
|
||||
tenant_id="tenant-345",
|
||||
)
|
||||
|
||||
assert context.user_id == "user-123"
|
||||
assert context.app_id == "app-456"
|
||||
assert context.conversation_id == "conv-789"
|
||||
assert context.message_id == "msg-012"
|
||||
assert context.tenant_id == "tenant-345"
|
||||
|
||||
def test_create_minimal(self):
|
||||
"""Test creating minimal ExecutionContext."""
|
||||
context = ExecutionContext.create_minimal(user_id="user-123")
|
||||
|
||||
assert context.user_id == "user-123"
|
||||
assert context.app_id is None
|
||||
assert context.conversation_id is None
|
||||
assert context.message_id is None
|
||||
assert context.tenant_id is None
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting ExecutionContext to dictionary."""
|
||||
context = ExecutionContext(
|
||||
user_id="user-123",
|
||||
app_id="app-456",
|
||||
conversation_id="conv-789",
|
||||
message_id="msg-012",
|
||||
tenant_id="tenant-345",
|
||||
)
|
||||
|
||||
result = context.to_dict()
|
||||
|
||||
assert result == {
|
||||
"user_id": "user-123",
|
||||
"app_id": "app-456",
|
||||
"conversation_id": "conv-789",
|
||||
"message_id": "msg-012",
|
||||
"tenant_id": "tenant-345",
|
||||
}
|
||||
|
||||
def test_with_updates(self):
|
||||
"""Test creating new context with updates."""
|
||||
original = ExecutionContext(
|
||||
user_id="user-123",
|
||||
app_id="app-456",
|
||||
)
|
||||
|
||||
updated = original.with_updates(message_id="msg-789")
|
||||
|
||||
# Original should be unchanged
|
||||
assert original.message_id is None
|
||||
# Updated should have new value
|
||||
assert updated.message_id == "msg-789"
|
||||
assert updated.user_id == "user-123"
|
||||
assert updated.app_id == "app-456"
|
||||
|
||||
|
||||
class TestAgentLog:
|
||||
"""Tests for AgentLog entity."""
|
||||
|
||||
def test_create_log_with_required_fields(self):
|
||||
"""Test creating AgentLog with required fields."""
|
||||
log = AgentLog(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"key": "value"},
|
||||
)
|
||||
|
||||
assert log.label == "ROUND 1"
|
||||
assert log.log_type == AgentLog.LogType.ROUND
|
||||
assert log.status == AgentLog.LogStatus.START
|
||||
assert log.data == {"key": "value"}
|
||||
assert log.id is not None # Auto-generated
|
||||
assert log.parent_id is None
|
||||
assert log.error is None
|
||||
|
||||
def test_log_type_enum(self):
|
||||
"""Test LogType enum values."""
|
||||
assert AgentLog.LogType.ROUND == "round"
|
||||
assert AgentLog.LogType.THOUGHT == "thought"
|
||||
assert AgentLog.LogType.TOOL_CALL == "tool_call"
|
||||
|
||||
def test_log_status_enum(self):
|
||||
"""Test LogStatus enum values."""
|
||||
assert AgentLog.LogStatus.START == "start"
|
||||
assert AgentLog.LogStatus.SUCCESS == "success"
|
||||
assert AgentLog.LogStatus.ERROR == "error"
|
||||
|
||||
def test_log_metadata_enum(self):
|
||||
"""Test LogMetadata enum values."""
|
||||
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
|
||||
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
|
||||
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
|
||||
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
|
||||
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
|
||||
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
|
||||
|
||||
|
||||
class TestAgentScratchpadUnit:
|
||||
"""Tests for AgentScratchpadUnit entity."""
|
||||
|
||||
def test_is_final_with_final_answer_action(self):
|
||||
"""Test is_final returns True for Final Answer action."""
|
||||
unit = AgentScratchpadUnit(
|
||||
thought="I know the answer",
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name="Final Answer",
|
||||
action_input="The answer is 42",
|
||||
),
|
||||
)
|
||||
|
||||
assert unit.is_final() is True
|
||||
|
||||
def test_is_final_with_tool_action(self):
|
||||
"""Test is_final returns False for tool action."""
|
||||
unit = AgentScratchpadUnit(
|
||||
thought="I need to search",
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name="search",
|
||||
action_input={"query": "test"},
|
||||
),
|
||||
)
|
||||
|
||||
assert unit.is_final() is False
|
||||
|
||||
def test_is_final_with_no_action(self):
|
||||
"""Test is_final returns True when no action."""
|
||||
unit = AgentScratchpadUnit(
|
||||
thought="Just thinking",
|
||||
)
|
||||
|
||||
assert unit.is_final() is True
|
||||
|
||||
def test_action_to_dict(self):
|
||||
"""Test Action.to_dict method."""
|
||||
action = AgentScratchpadUnit.Action(
|
||||
action_name="search",
|
||||
action_input={"query": "test"},
|
||||
)
|
||||
|
||||
result = action.to_dict()
|
||||
|
||||
assert result == {
|
||||
"action": "search",
|
||||
"action_input": {"query": "test"},
|
||||
}
|
||||
|
||||
|
||||
class TestAgentEntity:
|
||||
"""Tests for AgentEntity."""
|
||||
|
||||
def test_strategy_enum(self):
|
||||
"""Test Strategy enum values."""
|
||||
assert AgentEntity.Strategy.CHAIN_OF_THOUGHT == "chain-of-thought"
|
||||
assert AgentEntity.Strategy.FUNCTION_CALLING == "function-calling"
|
||||
|
||||
def test_create_with_prompt(self):
|
||||
"""Test creating AgentEntity with prompt."""
|
||||
prompt = AgentPromptEntity(
|
||||
first_prompt="You are a helpful assistant.",
|
||||
next_iteration="Continue thinking...",
|
||||
)
|
||||
|
||||
entity = AgentEntity(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
|
||||
prompt=prompt,
|
||||
max_iteration=5,
|
||||
)
|
||||
|
||||
assert entity.provider == "openai"
|
||||
assert entity.model == "gpt-4"
|
||||
assert entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
assert entity.prompt == prompt
|
||||
assert entity.max_iteration == 5
|
||||
@ -0,0 +1,390 @@
|
||||
"""
|
||||
Tests for AdvancedChatAppGenerateTaskPipeline._handle_node_succeeded_event method,
|
||||
specifically testing the ANSWER node message_replace logic.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueNodeSucceededEvent
|
||||
from core.workflow.enums import NodeType
|
||||
from models import EndUser
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestAnswerNodeMessageReplace:
|
||||
"""Test cases for ANSWER node message_replace event logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock(spec=AdvancedChatAppGenerateEntity)
|
||||
entity.task_id = "test-task-id"
|
||||
entity.app_id = "test-app-id"
|
||||
entity.workflow_run_id = "test-workflow-run-id"
|
||||
# minimal app_config used by pipeline internals
|
||||
entity.app_config = SimpleNamespace(
|
||||
tenant_id="test-tenant-id",
|
||||
app_id="test-app-id",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
app_model_config_dict={},
|
||||
additional_features=None,
|
||||
sensitive_word_avoidance=None,
|
||||
)
|
||||
entity.query = "test query"
|
||||
entity.files = []
|
||||
entity.extras = {}
|
||||
entity.trace_manager = None
|
||||
entity.inputs = {}
|
||||
entity.invoke_from = "debugger"
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow(self):
|
||||
"""Create a mock workflow."""
|
||||
workflow = Mock()
|
||||
workflow.id = "test-workflow-id"
|
||||
workflow.features_dict = {}
|
||||
return workflow
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = Mock()
|
||||
manager.listen.return_value = []
|
||||
manager.graph_runtime_state = None
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation(self):
|
||||
"""Create a mock conversation."""
|
||||
conversation = Mock()
|
||||
conversation.id = "test-conversation-id"
|
||||
conversation.mode = "advanced_chat"
|
||||
return conversation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message."""
|
||||
message = Mock()
|
||||
message.id = "test-message-id"
|
||||
message.query = "test query"
|
||||
message.created_at = Mock()
|
||||
message.created_at.timestamp.return_value = 1234567890
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Create a mock end user."""
|
||||
user = MagicMock(spec=EndUser)
|
||||
user.id = "test-user-id"
|
||||
user.session_id = "test-session-id"
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def mock_draft_var_saver_factory(self):
|
||||
"""Create a mock draft variable saver factory."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(
|
||||
self,
|
||||
mock_application_generate_entity,
|
||||
mock_workflow,
|
||||
mock_queue_manager,
|
||||
mock_conversation,
|
||||
mock_message,
|
||||
mock_user,
|
||||
mock_draft_var_saver_factory,
|
||||
):
|
||||
"""Create an AdvancedChatAppGenerateTaskPipeline instance with mocked dependencies."""
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
|
||||
with patch("core.app.apps.advanced_chat.generate_task_pipeline.db"):
|
||||
pipeline = AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=mock_application_generate_entity,
|
||||
workflow=mock_workflow,
|
||||
queue_manager=mock_queue_manager,
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
user=mock_user,
|
||||
stream=True,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=mock_draft_var_saver_factory,
|
||||
)
|
||||
# Initialize workflow run id to avoid validation errors
|
||||
pipeline._workflow_run_id = "test-workflow-run-id"
|
||||
# Mock the message cycle manager methods we need to track
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response = Mock()
|
||||
return pipeline
|
||||
|
||||
def test_answer_node_with_different_output_sends_message_replace(self, pipeline, mock_application_generate_entity):
|
||||
"""
|
||||
Test that when an ANSWER node's final output differs from accumulated answer,
|
||||
a message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "initial answer"
|
||||
|
||||
# Create ANSWER node succeeded event with different final output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "updated final answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter to avoid extra processing
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
responses = list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert
|
||||
assert pipeline._task_state.answer == "updated final answer"
|
||||
# Verify message_replace was called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
|
||||
answer="updated final answer", reason="variable_update"
|
||||
)
|
||||
|
||||
def test_answer_node_with_same_output_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's final output is the same as accumulated answer,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "same answer"
|
||||
|
||||
# Create ANSWER node succeeded event with same output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "same answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "same answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_none_output_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's output is None or missing 'answer' key,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event with None output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": None},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_empty_outputs_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node has empty outputs dict,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event with empty outputs
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_no_answer_key_in_outputs(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's outputs don't contain 'answer' key,
|
||||
no message_replace event is sent.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event without 'answer' key in outputs
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"other_key": "some value"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_non_answer_node_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that non-ANSWER nodes (e.g., LLM, END) don't trigger message_replace events.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Test with LLM node
|
||||
llm_event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-llm-execution-id",
|
||||
node_id="test-llm-node",
|
||||
node_type=NodeType.LLM,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "different answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(llm_event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_end_node_does_not_send_message_replace(self, pipeline):
|
||||
"""
|
||||
Test that END nodes don't trigger message_replace events even with 'answer' output.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create END node succeeded event with answer output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-end-execution-id",
|
||||
node_id="test-end-node",
|
||||
node_type=NodeType.END,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": "different answer"},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should remain unchanged
|
||||
assert pipeline._task_state.answer == "existing answer"
|
||||
# Verify message_replace was NOT called
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_not_called()
|
||||
|
||||
def test_answer_node_with_numeric_output_converts_to_string(self, pipeline):
|
||||
"""
|
||||
Test that when an ANSWER node's final output is numeric,
|
||||
it gets converted to string properly.
|
||||
"""
|
||||
# Arrange: Set initial accumulated answer
|
||||
pipeline._task_state.answer = "text answer"
|
||||
|
||||
# Create ANSWER node succeeded event with numeric output
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={"answer": 12345},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: answer should be converted to string
|
||||
assert pipeline._task_state.answer == "12345"
|
||||
# Verify message_replace was called with string
|
||||
pipeline._message_cycle_manager.message_replace_to_stream_response.assert_called_once_with(
|
||||
answer="12345", reason="variable_update"
|
||||
)
|
||||
|
||||
def test_answer_node_files_are_recorded(self, pipeline):
|
||||
"""
|
||||
Test that ANSWER nodes properly record files from outputs.
|
||||
"""
|
||||
# Arrange
|
||||
pipeline._task_state.answer = "existing answer"
|
||||
|
||||
# Create ANSWER node succeeded event with files
|
||||
event = QueueNodeSucceededEvent(
|
||||
node_execution_id="test-node-execution-id",
|
||||
node_id="test-answer-node",
|
||||
node_type=NodeType.ANSWER,
|
||||
start_at=datetime.now(),
|
||||
outputs={
|
||||
"answer": "same answer",
|
||||
"files": [
|
||||
{"type": "image", "transfer_method": "remote_url", "remote_url": "http://example.com/img.png"}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
# Mock the workflow response converter
|
||||
pipeline._workflow_response_converter.fetch_files_from_node_outputs = Mock(return_value=event.outputs["files"])
|
||||
pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = Mock(return_value=None)
|
||||
pipeline._save_output_for_event = Mock()
|
||||
|
||||
# Act
|
||||
list(pipeline._handle_node_succeeded_event(event))
|
||||
|
||||
# Assert: files should be recorded
|
||||
assert len(pipeline._recorded_files) == 1
|
||||
assert pipeline._recorded_files[0] == event.outputs["files"][0]
|
||||
@ -431,10 +431,10 @@ class TestWorkflowResponseConverterServiceApiTruncation:
|
||||
description="Explore calls should have truncation enabled",
|
||||
),
|
||||
TestCase(
|
||||
name="published_truncation_enabled",
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
name="published_pipeline_truncation_enabled",
|
||||
invoke_from=InvokeFrom.PUBLISHED_PIPELINE,
|
||||
expected_truncation_enabled=True,
|
||||
description="Published app calls should have truncation enabled",
|
||||
description="Published pipeline calls should have truncation enabled",
|
||||
),
|
||||
],
|
||||
ids=lambda x: x.name,
|
||||
|
||||
@ -0,0 +1,47 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
|
||||
class DummyQueueManager:
|
||||
def __init__(self) -> None:
|
||||
self.published = []
|
||||
|
||||
def publish(self, event, publish_from: PublishFrom) -> None:
|
||||
self.published.append((event, publish_from))
|
||||
|
||||
|
||||
def test_skip_empty_final_chunk() -> None:
|
||||
queue_manager = DummyQueueManager()
|
||||
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app")
|
||||
|
||||
empty_final_event = NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["node", "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry=MagicMock(), event=empty_final_event)
|
||||
assert queue_manager.published == []
|
||||
|
||||
normal_event = NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["node", "text"],
|
||||
chunk="hi",
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry=MagicMock(), event=normal_event)
|
||||
|
||||
assert len(queue_manager.published) == 1
|
||||
published_event, publish_from = queue_manager.published[0]
|
||||
assert publish_from == PublishFrom.APPLICATION_MANAGER
|
||||
assert published_event.text == "hi"
|
||||
@ -0,0 +1,144 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock
|
||||
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.variables import StringVariable
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
class MockReadOnlyVariablePool:
|
||||
def __init__(self, variables: dict[tuple[str, str], Segment] | None = None) -> None:
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, selector: Sequence[str]) -> Segment | None:
|
||||
if len(selector) < 2:
|
||||
return None
|
||||
return self._variables.get((selector[0], selector[1]))
|
||||
|
||||
def get_all_by_node(self, node_id: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == node_id}
|
||||
|
||||
def get_by_prefix(self, prefix: str) -> dict[str, object]:
|
||||
return {key: value for (nid, key), value in self._variables.items() if nid == prefix}
|
||||
|
||||
|
||||
def _build_graph_runtime_state(
|
||||
variable_pool: MockReadOnlyVariablePool,
|
||||
conversation_id: str | None = None,
|
||||
) -> ReadOnlyGraphRuntimeState:
|
||||
graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState)
|
||||
graph_runtime_state.variable_pool = variable_pool
|
||||
graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view()
|
||||
return graph_runtime_state
|
||||
|
||||
|
||||
def _build_node_run_succeeded_event(
|
||||
*,
|
||||
node_type: NodeType,
|
||||
outputs: dict[str, object] | None = None,
|
||||
process_data: dict[str, object] | None = None,
|
||||
) -> NodeRunSucceededEvent:
|
||||
return NodeRunSucceededEvent(
|
||||
id="node-exec-id",
|
||||
node_id="assigner",
|
||||
node_type=node_type,
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs or {},
|
||||
process_data=process_data or {},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_persists_conversation_variables_from_assigner_output():
|
||||
conversation_id = "conv-123"
|
||||
variable = StringVariable(
|
||||
id="var-1",
|
||||
name="name",
|
||||
value="updated",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
|
||||
)
|
||||
process_data = common_helpers.set_updated_variables(
|
||||
{}, [common_helpers.variable_to_processed_data(variable.selector, variable)]
|
||||
)
|
||||
|
||||
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
|
||||
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable)
|
||||
updater.flush.assert_called_once()
|
||||
|
||||
|
||||
def test_skips_when_outputs_missing():
|
||||
conversation_id = "conv-456"
|
||||
variable = StringVariable(
|
||||
id="var-2",
|
||||
name="name",
|
||||
value="updated",
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "name"],
|
||||
)
|
||||
|
||||
variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable})
|
||||
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_not_called()
|
||||
updater.flush.assert_not_called()
|
||||
|
||||
|
||||
def test_skips_non_assigner_nodes():
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.LLM)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_not_called()
|
||||
updater.flush.assert_not_called()
|
||||
|
||||
|
||||
def test_skips_non_conversation_variables():
|
||||
conversation_id = "conv-789"
|
||||
non_conversation_variable = StringVariable(
|
||||
id="var-3",
|
||||
name="name",
|
||||
value="updated",
|
||||
selector=["environment", "name"],
|
||||
)
|
||||
process_data = common_helpers.set_updated_variables(
|
||||
{}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)]
|
||||
)
|
||||
|
||||
variable_pool = MockReadOnlyVariablePool()
|
||||
|
||||
updater = Mock()
|
||||
layer = ConversationVariablePersistenceLayer(updater)
|
||||
layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel))
|
||||
|
||||
event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data)
|
||||
layer.on_event(event)
|
||||
|
||||
updater.update.assert_not_called()
|
||||
updater.flush.assert_not_called()
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from time import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -67,8 +68,10 @@ class MockReadOnlyVariablePool:
|
||||
def __init__(self, variables: dict[tuple[str, str], object] | None = None):
|
||||
self._variables = variables or {}
|
||||
|
||||
def get(self, node_id: str, variable_key: str) -> Segment | None:
|
||||
value = self._variables.get((node_id, variable_key))
|
||||
def get(self, selector: Sequence[str]) -> Segment | None:
|
||||
if len(selector) < 2:
|
||||
return None
|
||||
value = self._variables.get((selector[0], selector[1]))
|
||||
if value is None:
|
||||
return None
|
||||
mock_segment = Mock(spec=Segment)
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
"""Primarily used for testing merged cell scenarios"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from types import SimpleNamespace
|
||||
|
||||
from docx import Document
|
||||
from docx.oxml import OxmlElement
|
||||
from docx.oxml.ns import qn
|
||||
|
||||
import core.rag.extractor.word_extractor as we
|
||||
from core.rag.extractor.word_extractor import WordExtractor
|
||||
@ -165,3 +169,110 @@ def test_extract_images_from_docx_uses_internal_files_url():
|
||||
dify_config.FILES_URL = original_files_url
|
||||
if original_internal_files_url is not None:
|
||||
dify_config.INTERNAL_FILES_URL = original_internal_files_url
|
||||
|
||||
|
||||
def test_extract_hyperlinks(monkeypatch):
|
||||
# Mock db and storage to avoid issues during image extraction (even if no images are present)
|
||||
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
|
||||
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None))
|
||||
monkeypatch.setattr(we, "db", db_stub)
|
||||
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
|
||||
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
|
||||
|
||||
doc = Document()
|
||||
p = doc.add_paragraph("Visit ")
|
||||
|
||||
# Adding a hyperlink manually
|
||||
r_id = "rId99"
|
||||
hyperlink = OxmlElement("w:hyperlink")
|
||||
hyperlink.set(qn("r:id"), r_id)
|
||||
|
||||
new_run = OxmlElement("w:r")
|
||||
t = OxmlElement("w:t")
|
||||
t.text = "Dify"
|
||||
new_run.append(t)
|
||||
hyperlink.append(new_run)
|
||||
p._p.append(hyperlink)
|
||||
|
||||
# Add relationship to the part
|
||||
doc.part.rels.add_relationship(
|
||||
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink",
|
||||
"https://dify.ai",
|
||||
r_id,
|
||||
is_external=True,
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
|
||||
doc.save(tmp.name)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
extractor = WordExtractor(tmp_path, "tenant_id", "user_id")
|
||||
docs = extractor.extract()
|
||||
# Verify modern hyperlink extraction
|
||||
assert "Visit[Dify](https://dify.ai)" in docs[0].page_content
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def test_extract_legacy_hyperlinks(monkeypatch):
|
||||
# Mock db and storage
|
||||
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda k, d: None))
|
||||
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda o: None, commit=lambda: None))
|
||||
monkeypatch.setattr(we, "db", db_stub)
|
||||
monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False)
|
||||
monkeypatch.setattr(we.dify_config, "STORAGE_TYPE", "local", raising=False)
|
||||
|
||||
doc = Document()
|
||||
p = doc.add_paragraph()
|
||||
|
||||
# Construct a legacy HYPERLINK field:
|
||||
# 1. w:fldChar (begin)
|
||||
# 2. w:instrText (HYPERLINK "http://example.com")
|
||||
# 3. w:fldChar (separate)
|
||||
# 4. w:r (visible text "Example")
|
||||
# 5. w:fldChar (end)
|
||||
|
||||
run1 = OxmlElement("w:r")
|
||||
fldCharBegin = OxmlElement("w:fldChar")
|
||||
fldCharBegin.set(qn("w:fldCharType"), "begin")
|
||||
run1.append(fldCharBegin)
|
||||
p._p.append(run1)
|
||||
|
||||
run2 = OxmlElement("w:r")
|
||||
instrText = OxmlElement("w:instrText")
|
||||
instrText.text = ' HYPERLINK "http://example.com" '
|
||||
run2.append(instrText)
|
||||
p._p.append(run2)
|
||||
|
||||
run3 = OxmlElement("w:r")
|
||||
fldCharSep = OxmlElement("w:fldChar")
|
||||
fldCharSep.set(qn("w:fldCharType"), "separate")
|
||||
run3.append(fldCharSep)
|
||||
p._p.append(run3)
|
||||
|
||||
run4 = OxmlElement("w:r")
|
||||
t4 = OxmlElement("w:t")
|
||||
t4.text = "Example"
|
||||
run4.append(t4)
|
||||
p._p.append(run4)
|
||||
|
||||
run5 = OxmlElement("w:r")
|
||||
fldCharEnd = OxmlElement("w:fldChar")
|
||||
fldCharEnd.set(qn("w:fldCharType"), "end")
|
||||
run5.append(fldCharEnd)
|
||||
p._p.append(run5)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
|
||||
doc.save(tmp.name)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
extractor = WordExtractor(tmp_path, "tenant_id", "user_id")
|
||||
docs = extractor.extract()
|
||||
# Verify legacy hyperlink extraction
|
||||
assert "[Example](http://example.com)" in docs[0].page_content
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from unittest import mock
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
@ -36,6 +37,7 @@ def test_dispatcher_should_consume_remains_events_after_pause():
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=execution_coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
dispatcher._dispatcher_loop()
|
||||
assert event_queue.empty()
|
||||
@ -96,6 +98,7 @@ def _run_dispatcher_for_event(event) -> int:
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
@ -181,6 +184,7 @@ def test_dispatcher_drain_event_queue():
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
|
||||
@ -5,6 +5,8 @@ This module provides a flexible configuration system for customizing
|
||||
the behavior of mock nodes during testing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
@ -95,67 +97,67 @@ class MockConfigBuilder:
|
||||
def __init__(self) -> None:
|
||||
self._config = MockConfig()
|
||||
|
||||
def with_auto_mock(self, enabled: bool = True) -> "MockConfigBuilder":
|
||||
def with_auto_mock(self, enabled: bool = True) -> MockConfigBuilder:
|
||||
"""Enable or disable auto-mocking."""
|
||||
self._config.enable_auto_mock = enabled
|
||||
return self
|
||||
|
||||
def with_delays(self, enabled: bool = True) -> "MockConfigBuilder":
|
||||
def with_delays(self, enabled: bool = True) -> MockConfigBuilder:
|
||||
"""Enable or disable simulated execution delays."""
|
||||
self._config.simulate_delays = enabled
|
||||
return self
|
||||
|
||||
def with_llm_response(self, response: str) -> "MockConfigBuilder":
|
||||
def with_llm_response(self, response: str) -> MockConfigBuilder:
|
||||
"""Set default LLM response."""
|
||||
self._config.default_llm_response = response
|
||||
return self
|
||||
|
||||
def with_agent_response(self, response: str) -> "MockConfigBuilder":
|
||||
def with_agent_response(self, response: str) -> MockConfigBuilder:
|
||||
"""Set default agent response."""
|
||||
self._config.default_agent_response = response
|
||||
return self
|
||||
|
||||
def with_tool_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
||||
def with_tool_response(self, response: dict[str, Any]) -> MockConfigBuilder:
|
||||
"""Set default tool response."""
|
||||
self._config.default_tool_response = response
|
||||
return self
|
||||
|
||||
def with_retrieval_response(self, response: str) -> "MockConfigBuilder":
|
||||
def with_retrieval_response(self, response: str) -> MockConfigBuilder:
|
||||
"""Set default retrieval response."""
|
||||
self._config.default_retrieval_response = response
|
||||
return self
|
||||
|
||||
def with_http_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
||||
def with_http_response(self, response: dict[str, Any]) -> MockConfigBuilder:
|
||||
"""Set default HTTP response."""
|
||||
self._config.default_http_response = response
|
||||
return self
|
||||
|
||||
def with_template_transform_response(self, response: str) -> "MockConfigBuilder":
|
||||
def with_template_transform_response(self, response: str) -> MockConfigBuilder:
|
||||
"""Set default template transform response."""
|
||||
self._config.default_template_transform_response = response
|
||||
return self
|
||||
|
||||
def with_code_response(self, response: dict[str, Any]) -> "MockConfigBuilder":
|
||||
def with_code_response(self, response: dict[str, Any]) -> MockConfigBuilder:
|
||||
"""Set default code execution response."""
|
||||
self._config.default_code_response = response
|
||||
return self
|
||||
|
||||
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> "MockConfigBuilder":
|
||||
def with_node_output(self, node_id: str, outputs: dict[str, Any]) -> MockConfigBuilder:
|
||||
"""Set outputs for a specific node."""
|
||||
self._config.set_node_outputs(node_id, outputs)
|
||||
return self
|
||||
|
||||
def with_node_error(self, node_id: str, error: str) -> "MockConfigBuilder":
|
||||
def with_node_error(self, node_id: str, error: str) -> MockConfigBuilder:
|
||||
"""Set error for a specific node."""
|
||||
self._config.set_node_error(node_id, error)
|
||||
return self
|
||||
|
||||
def with_node_config(self, config: NodeMockConfig) -> "MockConfigBuilder":
|
||||
def with_node_config(self, config: NodeMockConfig) -> MockConfigBuilder:
|
||||
"""Add a node-specific configuration."""
|
||||
self._config.set_node_config(config.node_id, config)
|
||||
return self
|
||||
|
||||
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> "MockConfigBuilder":
|
||||
def with_default_config(self, node_type: NodeType, config: dict[str, Any]) -> MockConfigBuilder:
|
||||
"""Set default configuration for a node type."""
|
||||
self._config.set_default_config(node_type, config)
|
||||
return self
|
||||
|
||||
@ -0,0 +1,231 @@
|
||||
"""Tests for ResponseStreamCoordinator object field streaming."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities.tool_entities import ToolResultStatus
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph.graph import Graph
|
||||
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.template import Template, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
class TestResponseCoordinatorObjectStreaming:
|
||||
"""Test streaming of object-type variables with child fields."""
|
||||
|
||||
def test_object_field_streaming(self):
|
||||
"""Test that when selecting an object variable, all child field streams are forwarded."""
|
||||
# Create mock graph and variable pool
|
||||
graph = MagicMock(spec=Graph)
|
||||
variable_pool = MagicMock(spec=VariablePool)
|
||||
|
||||
# Mock nodes
|
||||
llm_node = MagicMock()
|
||||
llm_node.id = "llm_node"
|
||||
llm_node.node_type = NodeType.LLM
|
||||
llm_node.execution_type = MagicMock()
|
||||
llm_node.blocks_variable_output = MagicMock(return_value=False)
|
||||
|
||||
response_node = MagicMock()
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
response_node.execution_type = MagicMock()
|
||||
response_node.blocks_variable_output = MagicMock(return_value=False)
|
||||
|
||||
# Mock template for response node
|
||||
response_node.node_data = MagicMock(spec=BaseNodeData)
|
||||
response_node.node_data.answer = "{{#llm_node.generation#}}"
|
||||
|
||||
graph.nodes = {
|
||||
"llm_node": llm_node,
|
||||
"response_node": response_node,
|
||||
}
|
||||
graph.root_node = llm_node
|
||||
graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
|
||||
# Create coordinator
|
||||
coordinator = ResponseStreamCoordinator(variable_pool, graph)
|
||||
|
||||
# Track execution
|
||||
coordinator.track_node_execution("llm_node", "exec_123")
|
||||
coordinator.track_node_execution("response_node", "exec_456")
|
||||
|
||||
# Simulate streaming events for child fields of generation object
|
||||
# 1. Content stream
|
||||
content_event_1 = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["llm_node", "generation", "content"],
|
||||
chunk="Hello",
|
||||
is_final=False,
|
||||
chunk_type=ChunkType.TEXT,
|
||||
)
|
||||
content_event_2 = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["llm_node", "generation", "content"],
|
||||
chunk=" world",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TEXT,
|
||||
)
|
||||
|
||||
# 2. Tool call stream
|
||||
tool_call_event = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["llm_node", "generation", "tool_calls"],
|
||||
chunk='{"query": "test"}',
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TOOL_CALL,
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
)
|
||||
|
||||
# 3. Tool result stream
|
||||
tool_result_event = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["llm_node", "generation", "tool_results"],
|
||||
chunk="Found 10 results",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TOOL_RESULT,
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="search",
|
||||
output="Found 10 results",
|
||||
files=[],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
)
|
||||
|
||||
# Intercept these events
|
||||
coordinator.intercept_event(content_event_1)
|
||||
coordinator.intercept_event(tool_call_event)
|
||||
coordinator.intercept_event(tool_result_event)
|
||||
coordinator.intercept_event(content_event_2)
|
||||
|
||||
# Verify that all child streams are buffered
|
||||
assert ("llm_node", "generation", "content") in coordinator._stream_buffers
|
||||
assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers
|
||||
assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers
|
||||
|
||||
# Verify payloads are preserved in buffered events
|
||||
buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0]
|
||||
assert buffered_call.tool_call is not None
|
||||
assert buffered_call.tool_call.id == "call_123"
|
||||
buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0]
|
||||
assert buffered_result.tool_result is not None
|
||||
assert buffered_result.tool_result.status == "success"
|
||||
|
||||
# Verify we can find child streams
|
||||
child_streams = coordinator._find_child_streams(["llm_node", "generation"])
|
||||
assert len(child_streams) == 3
|
||||
assert ("llm_node", "generation", "content") in child_streams
|
||||
assert ("llm_node", "generation", "tool_calls") in child_streams
|
||||
assert ("llm_node", "generation", "tool_results") in child_streams
|
||||
|
||||
def test_find_child_streams(self):
|
||||
"""Test the _find_child_streams method."""
|
||||
graph = MagicMock(spec=Graph)
|
||||
variable_pool = MagicMock(spec=VariablePool)
|
||||
|
||||
coordinator = ResponseStreamCoordinator(variable_pool, graph)
|
||||
|
||||
# Add some mock streams
|
||||
coordinator._stream_buffers = {
|
||||
("node1", "generation", "content"): [],
|
||||
("node1", "generation", "tool_calls"): [],
|
||||
("node1", "generation", "thought"): [],
|
||||
("node1", "text"): [], # Not a child of generation
|
||||
("node2", "generation", "content"): [], # Different node
|
||||
}
|
||||
|
||||
# Find children of node1.generation
|
||||
children = coordinator._find_child_streams(["node1", "generation"])
|
||||
|
||||
assert len(children) == 3
|
||||
assert ("node1", "generation", "content") in children
|
||||
assert ("node1", "generation", "tool_calls") in children
|
||||
assert ("node1", "generation", "thought") in children
|
||||
assert ("node1", "text") not in children
|
||||
assert ("node2", "generation", "content") not in children
|
||||
|
||||
def test_find_child_streams_with_closed_streams(self):
|
||||
"""Test that _find_child_streams also considers closed streams."""
|
||||
graph = MagicMock(spec=Graph)
|
||||
variable_pool = MagicMock(spec=VariablePool)
|
||||
|
||||
coordinator = ResponseStreamCoordinator(variable_pool, graph)
|
||||
|
||||
# Add some streams - some buffered, some closed
|
||||
coordinator._stream_buffers = {
|
||||
("node1", "generation", "content"): [],
|
||||
}
|
||||
coordinator._closed_streams = {
|
||||
("node1", "generation", "tool_calls"),
|
||||
("node1", "generation", "thought"),
|
||||
}
|
||||
|
||||
# Should find all children regardless of whether they're in buffers or closed
|
||||
children = coordinator._find_child_streams(["node1", "generation"])
|
||||
|
||||
assert len(children) == 3
|
||||
assert ("node1", "generation", "content") in children
|
||||
assert ("node1", "generation", "tool_calls") in children
|
||||
assert ("node1", "generation", "thought") in children
|
||||
|
||||
def test_special_selector_rewrites_to_active_response_node(self):
|
||||
"""Ensure special selectors attribute streams to the active response node."""
|
||||
graph = MagicMock(spec=Graph)
|
||||
variable_pool = MagicMock(spec=VariablePool)
|
||||
|
||||
response_node = MagicMock()
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
graph.nodes = {"response_node": response_node}
|
||||
graph.root_node = response_node
|
||||
|
||||
coordinator = ResponseStreamCoordinator(variable_pool, graph)
|
||||
coordinator.track_node_execution("response_node", "exec_resp")
|
||||
|
||||
coordinator._active_session = ResponseSession(
|
||||
node_id="response_node",
|
||||
template=Template(segments=[VariableSegment(selector=["sys", "foo"])]),
|
||||
)
|
||||
|
||||
event = NodeRunStreamChunkEvent(
|
||||
id="stream_1",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
selector=["sys", "foo"],
|
||||
chunk="hi",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TEXT,
|
||||
)
|
||||
|
||||
coordinator._stream_buffers[("sys", "foo")] = [event]
|
||||
coordinator._stream_positions[("sys", "foo")] = 0
|
||||
coordinator._closed_streams.add(("sys", "foo"))
|
||||
|
||||
events, is_complete = coordinator._process_variable_segment(VariableSegment(selector=["sys", "foo"]))
|
||||
|
||||
assert is_complete
|
||||
assert len(events) == 1
|
||||
rewritten = events[0]
|
||||
assert rewritten.node_id == "response_node"
|
||||
assert rewritten.id == "exec_resp"
|
||||
@ -0,0 +1,539 @@
|
||||
"""
|
||||
Unit tests for stop_event functionality in GraphEngine.
|
||||
|
||||
Tests the unified stop_event management by GraphEngine and its propagation
|
||||
to WorkerPool, Worker, Dispatcher, and Nodes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class TestStopEventPropagation:
|
||||
"""Test suite for stop_event propagation through GraphEngine components."""
|
||||
|
||||
def test_graph_engine_creates_stop_event(self):
|
||||
"""Test that GraphEngine creates a stop_event on initialization."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Verify stop_event was created
|
||||
assert engine._stop_event is not None
|
||||
assert isinstance(engine._stop_event, threading.Event)
|
||||
|
||||
# Verify it was set in graph_runtime_state
|
||||
assert runtime_state.stop_event is not None
|
||||
assert runtime_state.stop_event is engine._stop_event
|
||||
|
||||
def test_stop_event_cleared_on_start(self):
|
||||
"""Test that stop_event is cleared when execution starts."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Set the stop_event before running
|
||||
engine._stop_event.set()
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
# Run the engine (should clear the stop_event)
|
||||
events = list(engine.run())
|
||||
|
||||
# After running, stop_event should be set again (by _stop_execution)
|
||||
# But during start it was cleared
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
||||
|
||||
def test_stop_event_set_on_stop(self):
|
||||
"""Test that stop_event is set when execution stops."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Initially not set
|
||||
assert not engine._stop_event.is_set()
|
||||
|
||||
# Run the engine
|
||||
list(engine.run())
|
||||
|
||||
# After execution completes, stop_event should be set
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
def test_stop_event_passed_to_worker_pool(self):
|
||||
"""Test that stop_event is passed to WorkerPool."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Verify WorkerPool has the stop_event
|
||||
assert engine._worker_pool._stop_event is not None
|
||||
assert engine._worker_pool._stop_event is engine._stop_event
|
||||
|
||||
def test_stop_event_passed_to_dispatcher(self):
|
||||
"""Test that stop_event is passed to Dispatcher."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Verify Dispatcher has the stop_event
|
||||
assert engine._dispatcher._stop_event is not None
|
||||
assert engine._dispatcher._stop_event is engine._stop_event
|
||||
|
||||
|
||||
class TestNodeStopCheck:
|
||||
"""Test suite for Node._should_stop() functionality."""
|
||||
|
||||
def test_node_should_stop_checks_runtime_state(self):
|
||||
"""Test that Node._should_stop() checks GraphRuntimeState.stop_event."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
# Initially stop_event is not set
|
||||
assert not answer_node._should_stop()
|
||||
|
||||
# Set the stop_event
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Now _should_stop should return True
|
||||
assert answer_node._should_stop()
|
||||
|
||||
def test_node_run_checks_stop_event_between_yields(self):
|
||||
"""Test that Node.run() checks stop_event between yielding events."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
# Create a simple node
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
# Set stop_event BEFORE running the node
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Run the node - should yield start event then detect stop
|
||||
# The node should check stop_event before processing
|
||||
assert answer_node._should_stop(), "stop_event should be set"
|
||||
|
||||
# Run and collect events
|
||||
events = list(answer_node.run())
|
||||
|
||||
# Since stop_event is set at the start, we should get:
|
||||
# 1. NodeRunStartedEvent (always yielded first)
|
||||
# 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast)
|
||||
assert len(events) >= 2
|
||||
assert isinstance(events[0], NodeRunStartedEvent)
|
||||
|
||||
# Note: AnswerNode is very simple and might complete before stop check
|
||||
# The important thing is that _should_stop() returns True when stop_event is set
|
||||
assert answer_node._should_stop()
|
||||
|
||||
|
||||
class TestStopEventIntegration:
|
||||
"""Integration tests for stop_event in workflow execution."""
|
||||
|
||||
def test_simple_workflow_respects_stop_event(self):
|
||||
"""Test that a simple workflow respects stop_event."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start"
|
||||
|
||||
# Create start and answer nodes
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.nodes["answer"] = answer_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Set stop_event before running
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Run the engine
|
||||
events = list(engine.run())
|
||||
|
||||
# Should get started event but not succeeded (due to stop)
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
# The workflow should still complete (start node runs quickly)
|
||||
# but answer node might be cancelled depending on timing
|
||||
|
||||
def test_stop_event_with_concurrent_nodes(self):
|
||||
"""Test stop_event behavior with multiple concurrent nodes."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
# Create multiple nodes
|
||||
for i in range(3):
|
||||
answer_node = AnswerNode(
|
||||
id=f"answer_{i}",
|
||||
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes[f"answer_{i}"] = answer_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# All nodes should share the same stop_event
|
||||
for node in mock_graph.nodes.values():
|
||||
assert node.graph_runtime_state.stop_event is runtime_state.stop_event
|
||||
assert node.graph_runtime_state.stop_event is engine._stop_event
|
||||
|
||||
|
||||
class TestStopEventTimeoutBehavior:
|
||||
"""Test stop_event behavior with join timeouts."""
|
||||
|
||||
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread")
|
||||
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
|
||||
"""Test that Dispatcher uses 2s timeout instead of 10s."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
dispatcher = engine._dispatcher
|
||||
dispatcher.start() # This will create and start the mocked thread
|
||||
|
||||
mock_thread_instance = mock_thread_cls.return_value
|
||||
mock_thread_instance.is_alive.return_value = True
|
||||
|
||||
dispatcher.stop()
|
||||
|
||||
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker")
|
||||
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
|
||||
"""Test that WorkerPool uses 2s timeout instead of 10s."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
worker_pool = engine._worker_pool
|
||||
worker_pool.start(initial_count=1) # Start with one worker
|
||||
|
||||
mock_worker_instance = mock_worker_cls.return_value
|
||||
mock_worker_instance.is_alive.return_value = True
|
||||
|
||||
worker_pool.stop()
|
||||
|
||||
mock_worker_instance.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
|
||||
class TestStopEventResumeBehavior:
|
||||
"""Test stop_event behavior during workflow resume."""
|
||||
|
||||
def test_stop_event_cleared_on_resume(self):
|
||||
"""Test that stop_event is cleared when resuming a paused workflow."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Simulate a previous execution that set stop_event
|
||||
engine._stop_event.set()
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
# Run the engine (should clear stop_event in _start_execution)
|
||||
events = list(engine.run())
|
||||
|
||||
# Execution should complete successfully
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
||||
|
||||
|
||||
class TestWorkerStopBehavior:
|
||||
"""Test Worker behavior with shared stop_event."""
|
||||
|
||||
def test_worker_uses_shared_stop_event(self):
|
||||
"""Test that Worker uses shared stop_event from GraphEngine."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
|
||||
# Get the worker pool and check workers
|
||||
worker_pool = engine._worker_pool
|
||||
|
||||
# Start the worker pool to create workers
|
||||
worker_pool.start()
|
||||
|
||||
# Check that at least one worker was created
|
||||
assert len(worker_pool._workers) > 0
|
||||
|
||||
# Verify workers use the shared stop_event
|
||||
for worker in worker_pool._workers:
|
||||
assert worker._stop_event is engine._stop_event
|
||||
|
||||
# Clean up
|
||||
worker_pool.stop()
|
||||
|
||||
def test_worker_stop_is_noop(self):
|
||||
"""Test that Worker.stop() is now a no-op."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
# Create a mock worker
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from core.workflow.graph_engine.worker import Worker
|
||||
|
||||
ready_queue = InMemoryReadyQueue()
|
||||
event_queue = MagicMock()
|
||||
|
||||
# Create a proper mock graph with real dict
|
||||
mock_graph = Mock(spec=Graph)
|
||||
mock_graph.nodes = {} # Use real dict
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=ready_queue,
|
||||
event_queue=event_queue,
|
||||
graph=mock_graph,
|
||||
layers=[],
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
# Calling stop() should do nothing (no-op)
|
||||
# and should NOT set the stop_event
|
||||
worker.stop()
|
||||
assert not stop_event.is_set()
|
||||
@ -0,0 +1,328 @@
|
||||
"""Tests for StreamChunkEvent and its subclasses."""
|
||||
|
||||
from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.node_events import (
|
||||
ChunkType,
|
||||
StreamChunkEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
|
||||
|
||||
class TestChunkType:
|
||||
"""Tests for ChunkType enum."""
|
||||
|
||||
def test_chunk_type_values(self):
|
||||
"""Test that ChunkType has expected values."""
|
||||
assert ChunkType.TEXT == "text"
|
||||
assert ChunkType.TOOL_CALL == "tool_call"
|
||||
assert ChunkType.TOOL_RESULT == "tool_result"
|
||||
assert ChunkType.THOUGHT == "thought"
|
||||
|
||||
def test_chunk_type_is_str_enum(self):
|
||||
"""Test that ChunkType values are strings."""
|
||||
for chunk_type in ChunkType:
|
||||
assert isinstance(chunk_type.value, str)
|
||||
|
||||
|
||||
class TestStreamChunkEvent:
|
||||
"""Tests for base StreamChunkEvent."""
|
||||
|
||||
def test_create_with_required_fields(self):
|
||||
"""Test creating StreamChunkEvent with required fields."""
|
||||
event = StreamChunkEvent(
|
||||
selector=["node1", "text"],
|
||||
chunk="Hello",
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "text"]
|
||||
assert event.chunk == "Hello"
|
||||
assert event.is_final is False
|
||||
assert event.chunk_type == ChunkType.TEXT
|
||||
|
||||
def test_create_with_all_fields(self):
|
||||
"""Test creating StreamChunkEvent with all fields."""
|
||||
event = StreamChunkEvent(
|
||||
selector=["node1", "output"],
|
||||
chunk="World",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TEXT,
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "output"]
|
||||
assert event.chunk == "World"
|
||||
assert event.is_final is True
|
||||
assert event.chunk_type == ChunkType.TEXT
|
||||
|
||||
def test_default_chunk_type_is_text(self):
|
||||
"""Test that default chunk_type is TEXT."""
|
||||
event = StreamChunkEvent(
|
||||
selector=["node1", "text"],
|
||||
chunk="test",
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.TEXT
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = StreamChunkEvent(
|
||||
selector=["node1", "text"],
|
||||
chunk="Hello",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["selector"] == ["node1", "text"]
|
||||
assert data["chunk"] == "Hello"
|
||||
assert data["is_final"] is True
|
||||
assert data["chunk_type"] == "text"
|
||||
|
||||
|
||||
class TestToolCallChunkEvent:
|
||||
"""Tests for ToolCallChunkEvent."""
|
||||
|
||||
def test_create_with_required_fields(self):
|
||||
"""Test creating ToolCallChunkEvent with required fields."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"city": "Beijing"}',
|
||||
tool_call=ToolCall(id="call_123", name="weather", arguments=None),
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "tool_calls"]
|
||||
assert event.chunk == '{"city": "Beijing"}'
|
||||
assert event.tool_call.id == "call_123"
|
||||
assert event.tool_call.name == "weather"
|
||||
assert event.chunk_type == ChunkType.TOOL_CALL
|
||||
|
||||
def test_chunk_type_is_tool_call(self):
|
||||
"""Test that chunk_type is always TOOL_CALL."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call=ToolCall(id="call_123", name="test_tool", arguments=None),
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.TOOL_CALL
|
||||
|
||||
def test_tool_arguments_field(self):
|
||||
"""Test tool_arguments field."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"param": "value"}',
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_call.arguments == '{"param": "value"}'
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"city": "Beijing"}',
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="weather",
|
||||
arguments='{"city": "Beijing"}',
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "tool_call"
|
||||
assert data["tool_call"]["id"] == "call_123"
|
||||
assert data["tool_call"]["name"] == "weather"
|
||||
assert data["tool_call"]["arguments"] == '{"city": "Beijing"}'
|
||||
assert data["is_final"] is True
|
||||
|
||||
|
||||
class TestToolResultChunkEvent:
|
||||
"""Tests for ToolResultChunkEvent."""
|
||||
|
||||
def test_create_with_required_fields(self):
|
||||
"""Test creating ToolResultChunkEvent with required fields."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="Weather: Sunny, 25°C",
|
||||
tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"),
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "tool_results"]
|
||||
assert event.chunk == "Weather: Sunny, 25°C"
|
||||
assert event.tool_result.id == "call_123"
|
||||
assert event.tool_result.name == "weather"
|
||||
assert event.chunk_type == ChunkType.TOOL_RESULT
|
||||
|
||||
def test_chunk_type_is_tool_result(self):
|
||||
"""Test that chunk_type is always TOOL_RESULT."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_result=ToolResult(id="call_123", name="test_tool"),
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.TOOL_RESULT
|
||||
|
||||
def test_tool_files_default_empty(self):
|
||||
"""Test that tool_files defaults to empty list."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_result=ToolResult(id="call_123", name="test_tool"),
|
||||
)
|
||||
|
||||
assert event.tool_result.files == []
|
||||
|
||||
def test_tool_files_with_values(self):
|
||||
"""Test tool_files with file IDs."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
files=["file_1", "file_2"],
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_result.files == ["file_1", "file_2"]
|
||||
|
||||
def test_tool_error_output(self):
|
||||
"""Test error output captured in tool_result."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="",
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
output="Tool execution failed",
|
||||
status=ToolResultStatus.ERROR,
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_result.output == "Tool execution failed"
|
||||
assert event.tool_result.status == ToolResultStatus.ERROR
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="Weather: Sunny",
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="weather",
|
||||
output="Weather: Sunny",
|
||||
files=["file_1"],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "tool_result"
|
||||
assert data["tool_result"]["id"] == "call_123"
|
||||
assert data["tool_result"]["name"] == "weather"
|
||||
assert data["tool_result"]["files"] == ["file_1"]
|
||||
assert data["is_final"] is True
|
||||
|
||||
|
||||
class TestThoughtChunkEvent:
|
||||
"""Tests for ThoughtChunkEvent."""
|
||||
|
||||
def test_create_with_required_fields(self):
|
||||
"""Test creating ThoughtChunkEvent with required fields."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="I need to query the weather...",
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "thought"]
|
||||
assert event.chunk == "I need to query the weather..."
|
||||
assert event.chunk_type == ChunkType.THOUGHT
|
||||
|
||||
def test_chunk_type_is_thought(self):
|
||||
"""Test that chunk_type is always THOUGHT."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="thinking...",
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.THOUGHT
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="I need to analyze this...",
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "thought"
|
||||
assert data["chunk"] == "I need to analyze this..."
|
||||
assert data["is_final"] is False
|
||||
|
||||
|
||||
class TestEventInheritance:
|
||||
"""Tests for event inheritance relationships."""
|
||||
|
||||
def test_tool_call_is_stream_chunk(self):
|
||||
"""Test that ToolCallChunkEvent is a StreamChunkEvent."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call=ToolCall(id="call_123", name="test", arguments=None),
|
||||
)
|
||||
|
||||
assert isinstance(event, StreamChunkEvent)
|
||||
|
||||
def test_tool_result_is_stream_chunk(self):
|
||||
"""Test that ToolResultChunkEvent is a StreamChunkEvent."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_result=ToolResult(id="call_123", name="test"),
|
||||
)
|
||||
|
||||
assert isinstance(event, StreamChunkEvent)
|
||||
|
||||
def test_thought_is_stream_chunk(self):
|
||||
"""Test that ThoughtChunkEvent is a StreamChunkEvent."""
|
||||
event = ThoughtChunkEvent(
|
||||
selector=["node1", "thought"],
|
||||
chunk="thinking...",
|
||||
)
|
||||
|
||||
assert isinstance(event, StreamChunkEvent)
|
||||
|
||||
def test_all_events_have_common_fields(self):
|
||||
"""Test that all events have common StreamChunkEvent fields."""
|
||||
events = [
|
||||
StreamChunkEvent(selector=["n", "t"], chunk="a"),
|
||||
ToolCallChunkEvent(
|
||||
selector=["n", "t"],
|
||||
chunk="b",
|
||||
tool_call=ToolCall(id="1", name="t", arguments=None),
|
||||
),
|
||||
ToolResultChunkEvent(
|
||||
selector=["n", "t"],
|
||||
chunk="c",
|
||||
tool_result=ToolResult(id="1", name="t"),
|
||||
),
|
||||
ThoughtChunkEvent(selector=["n", "t"], chunk="d"),
|
||||
]
|
||||
|
||||
for event in events:
|
||||
assert hasattr(event, "selector")
|
||||
assert hasattr(event, "chunk")
|
||||
assert hasattr(event, "is_final")
|
||||
assert hasattr(event, "chunk_type")
|
||||
@ -78,7 +78,7 @@ class TestFileSaverImpl:
|
||||
file_binary=_PNG_DATA,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
||||
mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True)
|
||||
|
||||
def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
|
||||
@ -0,0 +1,148 @@
|
||||
import types
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities import ToolCallResult
|
||||
from core.workflow.entities.tool_entities import ToolResultStatus
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
|
||||
|
||||
class _StubModelInstance:
|
||||
"""Minimal stub to satisfy _stream_llm_events signature."""
|
||||
|
||||
provider_model_bundle = None
|
||||
|
||||
|
||||
def _drain(generator: Generator[NodeEventBase, None, Any]):
|
||||
events: list = []
|
||||
try:
|
||||
while True:
|
||||
events.append(next(generator))
|
||||
except StopIteration as exc:
|
||||
return events, exc.value
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_deduct_llm_quota(monkeypatch):
|
||||
# Avoid touching real quota logic during unit tests
|
||||
monkeypatch.setattr("core.workflow.nodes.llm.node.llm_utils.deduct_llm_quota", lambda **_: None)
|
||||
|
||||
|
||||
def _make_llm_node(reasoning_format: str) -> LLMNode:
|
||||
node = LLMNode.__new__(LLMNode)
|
||||
object.__setattr__(node, "_node_data", types.SimpleNamespace(reasoning_format=reasoning_format, tools=[]))
|
||||
object.__setattr__(node, "tenant_id", "tenant")
|
||||
return node
|
||||
|
||||
|
||||
def test_stream_llm_events_extracts_reasoning_for_tagged():
|
||||
node = _make_llm_node(reasoning_format="tagged")
|
||||
tagged_text = "<think>Thought</think>Answer"
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
def generator():
|
||||
yield ModelInvokeCompletedEvent(
|
||||
text=tagged_text,
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
reasoning_content="",
|
||||
structured_output=None,
|
||||
)
|
||||
|
||||
events, returned = _drain(
|
||||
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
|
||||
)
|
||||
|
||||
assert events == []
|
||||
clean_text, reasoning_content, gen_reasoning, gen_clean, ret_usage, finish_reason, structured, gen_data = returned
|
||||
assert clean_text == tagged_text # original preserved for output
|
||||
assert reasoning_content == "" # tagged mode keeps reasoning separate
|
||||
assert gen_clean == "Answer" # stripped content for generation
|
||||
assert gen_reasoning == "Thought" # reasoning extracted from <think> tag
|
||||
assert ret_usage == usage
|
||||
assert finish_reason == "stop"
|
||||
assert structured is None
|
||||
assert gen_data is None
|
||||
|
||||
# generation building should include reasoning and sequence
|
||||
generation_content = gen_clean or clean_text
|
||||
sequence = [
|
||||
{"type": "reasoning", "index": 0},
|
||||
{"type": "content", "start": 0, "end": len(generation_content)},
|
||||
]
|
||||
assert sequence == [
|
||||
{"type": "reasoning", "index": 0},
|
||||
{"type": "content", "start": 0, "end": len("Answer")},
|
||||
]
|
||||
|
||||
|
||||
def test_stream_llm_events_no_reasoning_results_in_empty_sequence():
|
||||
node = _make_llm_node(reasoning_format="tagged")
|
||||
plain_text = "Hello world"
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
def generator():
|
||||
yield ModelInvokeCompletedEvent(
|
||||
text=plain_text,
|
||||
usage=usage,
|
||||
finish_reason=None,
|
||||
reasoning_content="",
|
||||
structured_output=None,
|
||||
)
|
||||
|
||||
events, returned = _drain(
|
||||
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
|
||||
)
|
||||
|
||||
assert events == []
|
||||
_, _, gen_reasoning, gen_clean, *_ = returned
|
||||
generation_content = gen_clean or plain_text
|
||||
assert gen_reasoning == ""
|
||||
assert generation_content == plain_text
|
||||
# Empty reasoning should imply empty sequence in generation construction
|
||||
sequence = []
|
||||
assert sequence == []
|
||||
|
||||
|
||||
def test_serialize_tool_call_strips_files_to_ids():
|
||||
file_cls = pytest.importorskip("core.file").File
|
||||
file_type = pytest.importorskip("core.file.enums").FileType
|
||||
transfer_method = pytest.importorskip("core.file.enums").FileTransferMethod
|
||||
|
||||
file_with_id = file_cls(
|
||||
id="f1",
|
||||
tenant_id="t",
|
||||
type=file_type.IMAGE,
|
||||
transfer_method=transfer_method.REMOTE_URL,
|
||||
remote_url="http://example.com/f1",
|
||||
storage_key="k1",
|
||||
)
|
||||
file_with_related = file_cls(
|
||||
id=None,
|
||||
tenant_id="t",
|
||||
type=file_type.IMAGE,
|
||||
transfer_method=transfer_method.REMOTE_URL,
|
||||
related_id="rel2",
|
||||
remote_url="http://example.com/f2",
|
||||
storage_key="k2",
|
||||
)
|
||||
tool_call = ToolCallResult(
|
||||
id="tc",
|
||||
name="do",
|
||||
arguments='{"a":1}',
|
||||
output="ok",
|
||||
files=[file_with_id, file_with_related],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
)
|
||||
|
||||
serialized = LLMNode._serialize_tool_call(tool_call)
|
||||
|
||||
assert serialized["files"] == ["f1", "rel2"]
|
||||
assert serialized["id"] == "tc"
|
||||
assert serialized["name"] == "do"
|
||||
assert serialized["arguments"] == '{"a":1}'
|
||||
assert serialized["output"] == "ok"
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Generator
|
||||
@ -21,7 +23,7 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_node(monkeypatch) -> "ToolNode":
|
||||
def tool_node(monkeypatch) -> ToolNode:
|
||||
module_name = "core.ops.ops_trace_manager"
|
||||
if module_name not in sys.modules:
|
||||
ops_stub = types.ModuleType(module_name)
|
||||
@ -85,7 +87,7 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]:
|
||||
return events, stop.value
|
||||
|
||||
|
||||
def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
|
||||
def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]:
|
||||
def _identity_transform(messages, *_args, **_kwargs):
|
||||
return messages
|
||||
|
||||
@ -103,7 +105,7 @@ def _run_transform(tool_node: "ToolNode", message: ToolInvokeMessage) -> tuple[l
|
||||
return _collect_events(generator)
|
||||
|
||||
|
||||
def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
|
||||
def test_link_messages_with_file_populate_files_output(tool_node: ToolNode):
|
||||
file_obj = File(
|
||||
tenant_id="tenant-id",
|
||||
type=FileType.DOCUMENT,
|
||||
@ -139,7 +141,7 @@ def test_link_messages_with_file_populate_files_output(tool_node: "ToolNode"):
|
||||
assert files_segment.value == [file_obj]
|
||||
|
||||
|
||||
def test_plain_link_messages_remain_links(tool_node: "ToolNode"):
|
||||
def test_plain_link_messages_remain_links(tool_node: ToolNode):
|
||||
message = ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=ToolInvokeMessage.TextMessage(text="https://dify.ai"),
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest import mock
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import ArrayStringVariable, StringVariable
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events.node import NodeRunSucceededEvent
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@ -86,9 +86,6 @@ def test_overwrite_string_variable():
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@ -104,20 +101,14 @@ def test_overwrite_string_variable():
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
list(node.run())
|
||||
expected_var = StringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=input_variable.value,
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
events = list(node.run())
|
||||
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
|
||||
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
|
||||
assert updated_variables is not None
|
||||
assert updated_variables[0].name == conversation_variable.name
|
||||
assert updated_variables[0].new_value == input_variable.value
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
@ -191,9 +182,6 @@ def test_append_variable_to_array():
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@ -209,22 +197,14 @@ def test_append_variable_to_array():
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
list(node.run())
|
||||
expected_value = list(conversation_variable.value)
|
||||
expected_value.append(input_variable.value)
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=expected_value,
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
events = list(node.run())
|
||||
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
|
||||
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
|
||||
assert updated_variables is not None
|
||||
assert updated_variables[0].name == conversation_variable.name
|
||||
assert updated_variables[0].new_value == ["the first value", "the second value"]
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
@ -287,9 +267,6 @@ def test_clear_array():
|
||||
)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
|
||||
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
|
||||
|
||||
node_config = {
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
@ -305,20 +282,14 @@ def test_clear_array():
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config=node_config,
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
list(node.run())
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
name=conversation_variable.name,
|
||||
description=conversation_variable.description,
|
||||
selector=conversation_variable.selector,
|
||||
value_type=conversation_variable.value_type,
|
||||
value=[],
|
||||
)
|
||||
mock_conv_var_updater.update.assert_called_once_with(conversation_id=conversation_id, variable=expected_var)
|
||||
mock_conv_var_updater.flush.assert_called_once()
|
||||
events = list(node.run())
|
||||
succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent))
|
||||
updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data)
|
||||
assert updated_variables is not None
|
||||
assert updated_variables[0].name == conversation_variable.name
|
||||
assert updated_variables[0].new_value == []
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
|
||||
@ -390,3 +390,42 @@ def test_remove_last_from_empty_array():
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == []
|
||||
|
||||
|
||||
def test_node_factory_creates_variable_assigner_node():
|
||||
graph_config = {
|
||||
"edges": [],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []},
|
||||
"id": "assigner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(conversation_id="conversation_id"),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
node = node_factory.create_node(graph_config["nodes"][0])
|
||||
|
||||
assert isinstance(node, VariableAssignerNode)
|
||||
|
||||
78
api/tests/unit_tests/fields/test_file_fields.py
Normal file
78
api/tests/unit_tests/fields/test_file_fields.py
Normal file
@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fields.file_fields import FileResponse, FileWithSignedUrl, RemoteFileInfo, UploadConfig
|
||||
|
||||
|
||||
def test_file_response_serializes_datetime() -> None:
|
||||
created_at = datetime(2024, 1, 1, 12, 0, 0)
|
||||
file_obj = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="example.txt",
|
||||
size=1024,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by="user-1",
|
||||
created_at=created_at,
|
||||
preview_url="https://preview",
|
||||
source_url="https://source",
|
||||
original_url="https://origin",
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
conversation_id="conv-1",
|
||||
file_key="key-1",
|
||||
)
|
||||
|
||||
serialized = FileResponse.model_validate(file_obj, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
assert serialized["id"] == "file-1"
|
||||
assert serialized["created_at"] == int(created_at.timestamp())
|
||||
assert serialized["preview_url"] == "https://preview"
|
||||
assert serialized["source_url"] == "https://source"
|
||||
assert serialized["original_url"] == "https://origin"
|
||||
assert serialized["user_id"] == "user-1"
|
||||
assert serialized["tenant_id"] == "tenant-1"
|
||||
assert serialized["conversation_id"] == "conv-1"
|
||||
assert serialized["file_key"] == "key-1"
|
||||
|
||||
|
||||
def test_file_with_signed_url_builds_payload() -> None:
|
||||
payload = FileWithSignedUrl(
|
||||
id="file-2",
|
||||
name="remote.pdf",
|
||||
size=2048,
|
||||
extension="pdf",
|
||||
url="https://signed",
|
||||
mime_type="application/pdf",
|
||||
created_by="user-2",
|
||||
created_at=datetime(2024, 1, 2, 0, 0, 0),
|
||||
)
|
||||
|
||||
dumped = payload.model_dump(mode="json")
|
||||
|
||||
assert dumped["url"] == "https://signed"
|
||||
assert dumped["created_at"] == int(datetime(2024, 1, 2, 0, 0, 0).timestamp())
|
||||
|
||||
|
||||
def test_remote_file_info_and_upload_config() -> None:
|
||||
info = RemoteFileInfo(file_type="text/plain", file_length=123)
|
||||
assert info.model_dump(mode="json") == {"file_type": "text/plain", "file_length": 123}
|
||||
|
||||
config = UploadConfig(
|
||||
file_size_limit=1,
|
||||
batch_count_limit=2,
|
||||
file_upload_limit=3,
|
||||
image_file_size_limit=4,
|
||||
video_file_size_limit=5,
|
||||
audio_file_size_limit=6,
|
||||
workflow_file_upload_limit=7,
|
||||
image_file_batch_limit=8,
|
||||
single_chunk_attachment_limit=9,
|
||||
attachment_image_file_size_limit=10,
|
||||
)
|
||||
|
||||
dumped = config.model_dump(mode="json")
|
||||
assert dumped["file_upload_limit"] == 3
|
||||
assert dumped["attachment_image_file_size_limit"] == 10
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from libs.helper import extract_tenant_id
|
||||
from libs.helper import escape_like_pattern, extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
@ -63,3 +63,51 @@ class TestExtractTenantId:
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"):
|
||||
extract_tenant_id(dict_user)
|
||||
|
||||
|
||||
class TestEscapeLikePattern:
|
||||
"""Test cases for the escape_like_pattern utility function."""
|
||||
|
||||
def test_escape_percent_character(self):
|
||||
"""Test escaping percent character."""
|
||||
result = escape_like_pattern("50% discount")
|
||||
assert result == "50\\% discount"
|
||||
|
||||
def test_escape_underscore_character(self):
|
||||
"""Test escaping underscore character."""
|
||||
result = escape_like_pattern("test_data")
|
||||
assert result == "test\\_data"
|
||||
|
||||
def test_escape_backslash_character(self):
|
||||
"""Test escaping backslash character."""
|
||||
result = escape_like_pattern("path\\to\\file")
|
||||
assert result == "path\\\\to\\\\file"
|
||||
|
||||
def test_escape_combined_special_characters(self):
|
||||
"""Test escaping multiple special characters together."""
|
||||
result = escape_like_pattern("file_50%\\path")
|
||||
assert result == "file\\_50\\%\\\\path"
|
||||
|
||||
def test_escape_empty_string(self):
|
||||
"""Test escaping empty string returns empty string."""
|
||||
result = escape_like_pattern("")
|
||||
assert result == ""
|
||||
|
||||
def test_escape_none_handling(self):
|
||||
"""Test escaping None returns None (falsy check handles it)."""
|
||||
# The function checks `if not pattern`, so None is falsy and returns as-is
|
||||
result = escape_like_pattern(None)
|
||||
assert result is None
|
||||
|
||||
def test_escape_normal_string_no_change(self):
|
||||
"""Test that normal strings without special characters are unchanged."""
|
||||
result = escape_like_pattern("normal text")
|
||||
assert result == "normal text"
|
||||
|
||||
def test_escape_order_matters(self):
|
||||
"""Test that backslash is escaped first to prevent double escaping."""
|
||||
# If we escape % first, then escape \, we might get wrong results
|
||||
# This test ensures the order is correct: \ first, then % and _
|
||||
result = escape_like_pattern("test\\%_value")
|
||||
# Should be: test\\\%\_value
|
||||
assert result == "test\\\\\\%\\_value"
|
||||
|
||||
@ -114,7 +114,7 @@ class TestAppModelValidation:
|
||||
def test_icon_type_validation(self):
|
||||
"""Test icon type enum values."""
|
||||
# Assert
|
||||
assert {t.value for t in IconType} == {"image", "emoji"}
|
||||
assert {t.value for t in IconType} == {"image", "emoji", "link"}
|
||||
|
||||
def test_app_desc_or_prompt_with_description(self):
|
||||
"""Test desc_or_prompt property when description exists."""
|
||||
|
||||
Reference in New Issue
Block a user