mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
test: update unit tests for system message handling and workflow collaboration serices
This commit is contained in:
@ -134,8 +134,8 @@ class TestInitSystemMessage:
|
|||||||
|
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
def test_existing_system_message_not_duplicated(self, mock_runner):
|
def test_existing_system_message_replaced_with_template(self, mock_runner):
|
||||||
"""Test that system message is not duplicated if already present."""
|
"""Test that existing system message is replaced with the new template."""
|
||||||
existing_messages = [
|
existing_messages = [
|
||||||
SystemPromptMessage(content="Existing system"),
|
SystemPromptMessage(content="Existing system"),
|
||||||
UserPromptMessage(content="User message"),
|
UserPromptMessage(content="User message"),
|
||||||
@ -143,9 +143,8 @@ class TestInitSystemMessage:
|
|||||||
|
|
||||||
result = mock_runner._init_system_message("New template", existing_messages)
|
result = mock_runner._init_system_message("New template", existing_messages)
|
||||||
|
|
||||||
# Should not insert new system message
|
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
assert result[0].content == "Existing system"
|
assert result[0].content == "New template"
|
||||||
|
|
||||||
def test_system_message_inserted_when_missing(self, mock_runner):
|
def test_system_message_inserted_when_missing(self, mock_runner):
|
||||||
"""Test that system message is inserted when first message is not system."""
|
"""Test that system message is inserted when first message is not system."""
|
||||||
|
|||||||
@ -105,9 +105,12 @@ def test_generate_appends_pause_layer_and_forwards_state(mocker):
|
|||||||
|
|
||||||
graph_runtime_state = MagicMock()
|
graph_runtime_state = MagicMock()
|
||||||
|
|
||||||
|
workflow_mock = MagicMock()
|
||||||
|
workflow_mock.get_feature.return_value.enabled = False
|
||||||
|
|
||||||
result = generator._generate(
|
result = generator._generate(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
workflow=MagicMock(),
|
workflow=workflow_mock,
|
||||||
user=MagicMock(),
|
user=MagicMock(),
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
invoke_from="service-api",
|
invoke_from="service-api",
|
||||||
@ -143,8 +146,15 @@ def test_resume_path_runs_worker_with_runtime_state(mocker):
|
|||||||
fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock())
|
fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock())
|
||||||
mocker.patch("core.app.apps.workflow.app_generator.db", fake_db)
|
mocker.patch("core.app.apps.workflow.app_generator.db", fake_db)
|
||||||
|
|
||||||
|
sandbox_feature = SimpleNamespace(enabled=False)
|
||||||
workflow = SimpleNamespace(
|
workflow = SimpleNamespace(
|
||||||
id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1"
|
id="workflow",
|
||||||
|
tenant_id="tenant",
|
||||||
|
app_id="app",
|
||||||
|
graph_dict={},
|
||||||
|
type="workflow",
|
||||||
|
version="1",
|
||||||
|
get_feature=lambda _feature: sandbox_feature,
|
||||||
)
|
)
|
||||||
end_user = SimpleNamespace(session_id="end-user-session")
|
end_user = SimpleNamespace(session_id="end-user-session")
|
||||||
app_record = SimpleNamespace(id="app")
|
app_record = SimpleNamespace(id="app")
|
||||||
|
|||||||
@ -38,6 +38,7 @@ class TestSegmentTypeIsArrayType:
|
|||||||
SegmentType.NONE,
|
SegmentType.NONE,
|
||||||
SegmentType.GROUP,
|
SegmentType.GROUP,
|
||||||
SegmentType.BOOLEAN,
|
SegmentType.BOOLEAN,
|
||||||
|
SegmentType.ARRAY_PROMPT_MESSAGE,
|
||||||
]
|
]
|
||||||
|
|
||||||
for seg_type in expected_array_types:
|
for seg_type in expected_array_types:
|
||||||
|
|||||||
@ -581,11 +581,11 @@ class TestSegmentTypeIsValid:
|
|||||||
test_value = None
|
test_value = None
|
||||||
elif segment_type == SegmentType.GROUP:
|
elif segment_type == SegmentType.GROUP:
|
||||||
test_value = SegmentGroup(value=[StringSegment(value="test")])
|
test_value = SegmentGroup(value=[StringSegment(value="test")])
|
||||||
|
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE:
|
||||||
|
continue # Internal type, not validated via is_valid
|
||||||
elif segment_type.is_array_type():
|
elif segment_type.is_array_type():
|
||||||
test_value = [] # Empty array is valid for all array types
|
test_value = [] # Empty array is valid for all array types
|
||||||
else:
|
else:
|
||||||
# If we get here, there's a segment type we don't know how to test
|
|
||||||
# This should prompt us to add validation logic
|
|
||||||
pytest.fail(f"Unknown segment type {segment_type} needs validation logic and test case")
|
pytest.fail(f"Unknown segment type {segment_type} needs validation logic and test case")
|
||||||
|
|
||||||
# This should NOT raise AssertionError
|
# This should NOT raise AssertionError
|
||||||
@ -788,6 +788,7 @@ class TestSegmentTypeValidationIntegration:
|
|||||||
unhandled_types = {
|
unhandled_types = {
|
||||||
SegmentType.INTEGER, # Handled by NUMBER validation logic
|
SegmentType.INTEGER, # Handled by NUMBER validation logic
|
||||||
SegmentType.FLOAT, # Handled by NUMBER validation logic
|
SegmentType.FLOAT, # Handled by NUMBER validation logic
|
||||||
|
SegmentType.ARRAY_PROMPT_MESSAGE, # Internal type, not user-facing
|
||||||
}
|
}
|
||||||
|
|
||||||
# Verify all types are accounted for
|
# Verify all types are accounted for
|
||||||
|
|||||||
@ -180,7 +180,8 @@ class TestBuildContext:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
context = build_context(messages, "The weather in Beijing is sunny, 25°C.", generation_data)
|
accumulated_response = "Let me check the weather.The weather in Beijing is sunny, 25°C."
|
||||||
|
context = build_context(messages, accumulated_response, generation_data)
|
||||||
|
|
||||||
# Should have: user message + assistant with tool_call + tool result + final assistant
|
# Should have: user message + assistant with tool_call + tool result + final assistant
|
||||||
assert len(context) == 4
|
assert len(context) == 4
|
||||||
@ -263,7 +264,8 @@ class TestBuildContext:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
context = build_context(messages, "Beijing is sunny at 25°C, Shanghai is cloudy at 22°C.", generation_data)
|
accumulated_response = "I'll check both cities.Beijing is sunny at 25°C, Shanghai is cloudy at 22°C."
|
||||||
|
context = build_context(messages, accumulated_response, generation_data)
|
||||||
|
|
||||||
# Should have: user + assistant with 2 tool_calls + 2 tool results + final assistant
|
# Should have: user + assistant with 2 tool_calls + 2 tool results + final assistant
|
||||||
assert len(context) == 5
|
assert len(context) == 5
|
||||||
|
|||||||
@ -45,6 +45,8 @@ class TestWorkflowCollaborationRepository:
|
|||||||
"avatar": None,
|
"avatar": None,
|
||||||
"sid": "sid-1",
|
"sid": "sid-1",
|
||||||
"connected_at": 2,
|
"connected_at": 2,
|
||||||
|
"graph_active": False,
|
||||||
|
"active_skill_file_id": None,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,13 @@ class TestWorkflowCollaborationService:
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def service(self) -> tuple[WorkflowCollaborationService, Mock, Mock]:
|
def service(self) -> tuple[WorkflowCollaborationService, Mock, Mock]:
|
||||||
repository = Mock(spec=WorkflowCollaborationRepository)
|
repository = Mock(spec=WorkflowCollaborationRepository)
|
||||||
|
repository.get_current_leader.return_value = None
|
||||||
|
repository.get_session_sids.return_value = []
|
||||||
|
repository.get_active_skill_file_id.return_value = None
|
||||||
|
repository.get_active_skill_session_sids.return_value = []
|
||||||
|
repository.is_graph_active.return_value = False
|
||||||
|
repository.get_skill_leader.return_value = None
|
||||||
|
repository.list_sessions.return_value = []
|
||||||
socketio = Mock()
|
socketio = Mock()
|
||||||
return WorkflowCollaborationService(repository, socketio), repository, socketio
|
return WorkflowCollaborationService(repository, socketio), repository, socketio
|
||||||
|
|
||||||
@ -124,6 +131,7 @@ class TestWorkflowCollaborationService:
|
|||||||
# Arrange
|
# Arrange
|
||||||
collaboration_service, repository, _socketio = service
|
collaboration_service, repository, _socketio = service
|
||||||
repository.get_current_leader.return_value = "sid-1"
|
repository.get_current_leader.return_value = "sid-1"
|
||||||
|
repository.is_graph_active.return_value = True
|
||||||
|
|
||||||
with patch.object(collaboration_service, "is_session_active", return_value=True):
|
with patch.object(collaboration_service, "is_session_active", return_value=True):
|
||||||
# Act
|
# Act
|
||||||
@ -265,6 +273,7 @@ class TestWorkflowCollaborationService:
|
|||||||
# Arrange
|
# Arrange
|
||||||
collaboration_service, repository, _socketio = service
|
collaboration_service, repository, _socketio = service
|
||||||
repository.get_current_leader.return_value = "sid-1"
|
repository.get_current_leader.return_value = "sid-1"
|
||||||
|
repository.is_graph_active.return_value = True
|
||||||
|
|
||||||
with patch.object(collaboration_service, "is_session_active", return_value=True):
|
with patch.object(collaboration_service, "is_session_active", return_value=True):
|
||||||
# Act
|
# Act
|
||||||
|
|||||||
@ -17,6 +17,10 @@ def mock_session(monkeypatch: pytest.MonkeyPatch) -> Mock:
|
|||||||
mock_db.engine = Mock()
|
mock_db.engine = Mock()
|
||||||
monkeypatch.setattr(service_module, "Session", Mock(return_value=context_manager))
|
monkeypatch.setattr(service_module, "Session", Mock(return_value=context_manager))
|
||||||
monkeypatch.setattr(service_module, "db", mock_db)
|
monkeypatch.setattr(service_module, "db", mock_db)
|
||||||
|
monkeypatch.setattr(service_module, "send_workflow_comment_mention_email_task", Mock())
|
||||||
|
scalars_default = Mock()
|
||||||
|
scalars_default.all.return_value = []
|
||||||
|
session.scalars.return_value = scalars_default
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class TestSystemOAuthEncrypter:
|
|||||||
|
|
||||||
def test_init_with_none_secret_key(self):
|
def test_init_with_none_secret_key(self):
|
||||||
"""Test initialization with None secret key falls back to config"""
|
"""Test initialization with None secret key falls back to config"""
|
||||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
|
||||||
mock_config.SECRET_KEY = "config_secret"
|
mock_config.SECRET_KEY = "config_secret"
|
||||||
encrypter = SystemEncrypter(secret_key=None)
|
encrypter = SystemEncrypter(secret_key=None)
|
||||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||||
@ -43,7 +43,7 @@ class TestSystemOAuthEncrypter:
|
|||||||
|
|
||||||
def test_init_without_secret_key_uses_config(self):
|
def test_init_without_secret_key_uses_config(self):
|
||||||
"""Test initialization without secret key uses config"""
|
"""Test initialization without secret key uses config"""
|
||||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
|
||||||
mock_config.SECRET_KEY = "default_secret"
|
mock_config.SECRET_KEY = "default_secret"
|
||||||
encrypter = SystemEncrypter()
|
encrypter = SystemEncrypter()
|
||||||
expected_key = hashlib.sha256(b"default_secret").digest()
|
expected_key = hashlib.sha256(b"default_secret").digest()
|
||||||
@ -302,7 +302,7 @@ class TestSystemOAuthEncrypter:
|
|||||||
decrypted2 = encrypter2.decrypt_params(encrypted2)
|
decrypted2 = encrypter2.decrypt_params(encrypted2)
|
||||||
assert decrypted1 == decrypted2 == oauth_params
|
assert decrypted1 == decrypted2 == oauth_params
|
||||||
|
|
||||||
@patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
|
@patch("core.tools.utils.system_encryption.get_random_bytes")
|
||||||
def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes):
|
def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes):
|
||||||
"""Test encryption when crypto operation fails"""
|
"""Test encryption when crypto operation fails"""
|
||||||
mock_get_random_bytes.side_effect = Exception("Crypto error")
|
mock_get_random_bytes.side_effect = Exception("Crypto error")
|
||||||
@ -315,7 +315,7 @@ class TestSystemOAuthEncrypter:
|
|||||||
|
|
||||||
assert "Encryption failed" in str(exc_info.value)
|
assert "Encryption failed" in str(exc_info.value)
|
||||||
|
|
||||||
@patch("core.tools.utils.system_oauth_encryption.TypeAdapter")
|
@patch("core.tools.utils.system_encryption.TypeAdapter")
|
||||||
def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter):
|
def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter):
|
||||||
"""Test encryption when JSON serialization fails"""
|
"""Test encryption when JSON serialization fails"""
|
||||||
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
|
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
|
||||||
@ -370,7 +370,7 @@ class TestFactoryFunctions:
|
|||||||
|
|
||||||
def test_create_system_oauth_encrypter_without_secret(self):
|
def test_create_system_oauth_encrypter_without_secret(self):
|
||||||
"""Test factory function without secret key"""
|
"""Test factory function without secret key"""
|
||||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
|
||||||
mock_config.SECRET_KEY = "config_secret"
|
mock_config.SECRET_KEY = "config_secret"
|
||||||
encrypter = create_system_encrypter()
|
encrypter = create_system_encrypter()
|
||||||
|
|
||||||
@ -380,7 +380,7 @@ class TestFactoryFunctions:
|
|||||||
|
|
||||||
def test_create_system_oauth_encrypter_with_none_secret(self):
|
def test_create_system_oauth_encrypter_with_none_secret(self):
|
||||||
"""Test factory function with None secret key"""
|
"""Test factory function with None secret key"""
|
||||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
|
||||||
mock_config.SECRET_KEY = "config_secret"
|
mock_config.SECRET_KEY = "config_secret"
|
||||||
encrypter = create_system_encrypter(None)
|
encrypter = create_system_encrypter(None)
|
||||||
|
|
||||||
@ -412,7 +412,7 @@ class TestGlobalEncrypterInstance:
|
|||||||
|
|
||||||
core.tools.utils.system_encryption._encrypter = None
|
core.tools.utils.system_encryption._encrypter = None
|
||||||
|
|
||||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
with patch("core.tools.utils.system_encryption.dify_config") as mock_config:
|
||||||
mock_config.SECRET_KEY = "global_secret"
|
mock_config.SECRET_KEY = "global_secret"
|
||||||
encrypter = get_system_encrypter()
|
encrypter = get_system_encrypter()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user