mirror of
https://github.com/langgenius/dify.git
synced 2026-04-19 10:17:26 +08:00
Merge remote-tracking branch 'origin/main' into docs/e2e-writing-guide
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
@ -86,7 +86,14 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
return value_type.exposed_type().value
|
||||
|
||||
|
||||
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
class FullContentDict(TypedDict):
|
||||
size_bytes: int | None
|
||||
value_type: str
|
||||
length: int | None
|
||||
download_url: str
|
||||
|
||||
|
||||
def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict | None:
|
||||
"""Serialize full_content information for large variables."""
|
||||
if not variable.is_truncated():
|
||||
return None
|
||||
@ -94,12 +101,13 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
variable_file = variable.variable_file
|
||||
assert variable_file is not None
|
||||
|
||||
return {
|
||||
result: FullContentDict = {
|
||||
"size_bytes": variable_file.size,
|
||||
"value_type": variable_file.value_type.exposed_type().value,
|
||||
"length": variable_file.length,
|
||||
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_variable_access(
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
@ -29,6 +29,13 @@ from models.model import Message
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionDict(TypedDict):
|
||||
"""Shape produced by AgentScratchpadUnit.Action.to_dict()."""
|
||||
|
||||
action: str
|
||||
action_input: dict[str, Any] | str
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
@ -331,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
|
||||
def _convert_dict_to_action(self, action: ActionDict) -> AgentScratchpadUnit.Action:
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
|
||||
@ -32,9 +32,9 @@ class Extensible:
|
||||
|
||||
name: str
|
||||
tenant_id: str
|
||||
config: dict | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
def __init__(self, tenant_id: str, config: dict | None = None):
|
||||
def __init__(self, tenant_id: str, config: dict[str, Any] | None = None):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
@ -7,6 +10,16 @@ from extensions.ext_database import db
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
|
||||
|
||||
class ApiToolConfig(TypedDict, total=False):
|
||||
"""Expected config shape for ApiExternalDataTool.
|
||||
|
||||
Not used directly in method signatures (base class accepts dict[str, Any]);
|
||||
kept here to document the keys this tool reads from config.
|
||||
"""
|
||||
|
||||
api_based_extension_id: str
|
||||
|
||||
|
||||
class ApiExternalDataTool(ExternalDataTool):
|
||||
"""
|
||||
The api external data tool.
|
||||
@ -16,7 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
"""the unique name of external data tool"""
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -37,7 +50,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
||||
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.extension.extensible import Extensible, ExtensionModule
|
||||
|
||||
@ -15,14 +17,14 @@ class ExternalDataTool(Extensible, ABC):
|
||||
variable: str
|
||||
"""the tool variable name of app tool"""
|
||||
|
||||
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None):
|
||||
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict[str, Any] | None = None):
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
self.variable = variable
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -33,7 +35,7 @@ class ExternalDataTool(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
import orjson
|
||||
|
||||
@ -16,6 +16,19 @@ class IdentityDict(TypedDict, total=False):
|
||||
user_type: str
|
||||
|
||||
|
||||
class LogDict(TypedDict):
|
||||
ts: str
|
||||
severity: str
|
||||
service: str
|
||||
caller: str
|
||||
message: str
|
||||
trace_id: NotRequired[str]
|
||||
span_id: NotRequired[str]
|
||||
identity: NotRequired[IdentityDict]
|
||||
attributes: NotRequired[dict[str, Any]]
|
||||
stack_trace: NotRequired[str]
|
||||
|
||||
|
||||
class StructuredJSONFormatter(logging.Formatter):
|
||||
"""
|
||||
JSON log formatter following the specified schema:
|
||||
@ -55,9 +68,9 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
|
||||
return json.dumps(log_dict, default=str, ensure_ascii=False)
|
||||
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> LogDict:
|
||||
# Core fields
|
||||
log_dict: dict[str, Any] = {
|
||||
log_dict: LogDict = {
|
||||
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
|
||||
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
|
||||
"service": self._service_name,
|
||||
|
||||
@ -141,6 +141,12 @@ class RedisHealthParamsDict(TypedDict):
|
||||
health_check_interval: int | None
|
||||
|
||||
|
||||
class RedisClusterHealthParamsDict(TypedDict):
|
||||
retry: Retry
|
||||
socket_timeout: float | None
|
||||
socket_connect_timeout: float | None
|
||||
|
||||
|
||||
class RedisBaseParamsDict(TypedDict):
|
||||
username: str | None
|
||||
password: str | None
|
||||
@ -211,7 +217,7 @@ def _get_connection_health_params() -> RedisHealthParamsDict:
|
||||
)
|
||||
|
||||
|
||||
def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
def _get_cluster_connection_health_params() -> RedisClusterHealthParamsDict:
|
||||
"""Get retry and timeout parameters for Redis Cluster clients.
|
||||
|
||||
RedisCluster does not support ``health_check_interval`` as a constructor
|
||||
@ -219,8 +225,13 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
|
||||
are passed through.
|
||||
"""
|
||||
params: dict[str, Any] = dict(_get_connection_health_params())
|
||||
return {k: v for k, v in params.items() if k != "health_check_interval"}
|
||||
health_params = _get_connection_health_params()
|
||||
result: RedisClusterHealthParamsDict = {
|
||||
"retry": health_params["retry"],
|
||||
"socket_timeout": health_params["socket_timeout"],
|
||||
"socket_connect_timeout": health_params["socket_connect_timeout"],
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _get_base_redis_params() -> RedisBaseParamsDict:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, TypedDict
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -38,6 +39,17 @@ class DataSourceOauthBinding(TypeBase):
|
||||
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBindingDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
category: str
|
||||
provider: str
|
||||
credentials: Any
|
||||
created_at: float
|
||||
updated_at: float
|
||||
disabled: bool
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
__tablename__ = "data_source_api_key_auth_bindings"
|
||||
__table_args__ = (
|
||||
@ -65,8 +77,8 @@ class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
)
|
||||
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
def to_dict(self) -> DataSourceApiKeyAuthBindingDict:
|
||||
result: DataSourceApiKeyAuthBindingDict = {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"category": self.category,
|
||||
@ -76,3 +88,4 @@ class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
"updated_at": self.updated_at.timestamp(),
|
||||
"disabled": self.disabled,
|
||||
}
|
||||
return result
|
||||
|
||||
@ -7,7 +7,7 @@ from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.variables.segments import StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
from graphon.variables.variables import StringVariable
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
@ -38,21 +38,25 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._test_user_id = str(uuid.uuid4())
|
||||
self._session: Session = db.session()
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
name="conv_var",
|
||||
value=build_segment("conv_value"),
|
||||
)
|
||||
node2_vars = [
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id=self._node2_id,
|
||||
name="int_var",
|
||||
value=build_segment(1),
|
||||
@ -61,6 +65,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
),
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id=self._node2_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
@ -70,6 +75,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
]
|
||||
node1_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id=self._node1_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
@ -141,24 +147,27 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
def test_delete_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id)
|
||||
node2_var_count = (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
node2_var_count = self._session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == self._test_app_id,
|
||||
WorkflowDraftVariable.node_id == self._node2_id,
|
||||
WorkflowDraftVariable.user_id == self._test_user_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
assert node2_var_count == 0
|
||||
|
||||
def test_delete_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
node_1_var = (
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
|
||||
)
|
||||
node_1_var = self._session.scalars(
|
||||
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id)
|
||||
).one()
|
||||
srv.delete_variable(node_1_var)
|
||||
exists = bool(
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
|
||||
self._session.scalars(
|
||||
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id)
|
||||
).first()
|
||||
)
|
||||
assert exists is False
|
||||
|
||||
@ -248,9 +257,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id))
|
||||
session.commit()
|
||||
|
||||
def test_variable_loader_with_empty_selector(self):
|
||||
@ -431,9 +438,11 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id))
|
||||
session.execute(
|
||||
delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id)
|
||||
)
|
||||
session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id))
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
@ -534,9 +543,11 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id))
|
||||
session.execute(
|
||||
delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id)
|
||||
)
|
||||
session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id))
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
|
||||
@ -637,6 +637,40 @@ class TestConversationServiceSummarization:
|
||||
assert conversation.name == new_name
|
||||
assert conversation.updated_at == mock_time
|
||||
|
||||
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
|
||||
def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers):
|
||||
"""
|
||||
Test rename delegates to auto_generate_name when auto_generate is True.
|
||||
|
||||
When auto_generate is True, the service should call auto_generate_name
|
||||
which uses an LLM to create a descriptive conversation title.
|
||||
"""
|
||||
# Arrange
|
||||
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
|
||||
db_session_with_containers
|
||||
)
|
||||
conversation = ConversationServiceIntegrationTestDataFactory.create_conversation(
|
||||
db_session_with_containers, app_model, user
|
||||
)
|
||||
ConversationServiceIntegrationTestDataFactory.create_message(
|
||||
db_session_with_containers, app_model, conversation, user
|
||||
)
|
||||
generated_name = "Auto Generated Name"
|
||||
mock_llm_generator.return_value = generated_name
|
||||
|
||||
# Act
|
||||
result = ConversationService.rename(
|
||||
app_model=app_model,
|
||||
conversation_id=conversation.id,
|
||||
user=user,
|
||||
name=None,
|
||||
auto_generate=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == conversation
|
||||
assert conversation.name == generated_name
|
||||
|
||||
|
||||
class TestConversationServiceMessageAnnotation:
|
||||
"""
|
||||
@ -1066,3 +1100,32 @@ class TestConversationServiceExport:
|
||||
not_deleted = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation.id))
|
||||
assert not_deleted is not None
|
||||
mock_delete_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.conversation_service.delete_conversation_related_data")
|
||||
def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers):
|
||||
"""
|
||||
Test that delete propagates exceptions and does not trigger the cleanup task.
|
||||
|
||||
When a DB error occurs during deletion, the service must rollback the
|
||||
transaction and re-raise the exception without scheduling async cleanup.
|
||||
"""
|
||||
# Arrange
|
||||
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
|
||||
db_session_with_containers
|
||||
)
|
||||
conversation = ConversationServiceIntegrationTestDataFactory.create_conversation(
|
||||
db_session_with_containers, app_model, user
|
||||
)
|
||||
conversation_id = conversation.id
|
||||
|
||||
# Act — force an error during the delete to exercise the rollback path
|
||||
with patch("services.conversation_service.db.session.delete", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
ConversationService.delete(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
|
||||
# Assert — async cleanup must NOT have been scheduled
|
||||
mock_delete_task.delay.assert_not_called()
|
||||
|
||||
# Conversation is still present because the deletion was never committed
|
||||
still_there = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation_id))
|
||||
assert still_there is not None
|
||||
|
||||
@ -435,36 +435,6 @@ class TestConversationServiceRename:
|
||||
assert conversation.name == "New Name"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.conversation_service.db.session")
|
||||
@patch("services.conversation_service.ConversationService.get_conversation")
|
||||
@patch("services.conversation_service.ConversationService.auto_generate_name")
|
||||
def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session):
|
||||
"""
|
||||
Test renaming conversation with auto-generation.
|
||||
|
||||
Should call auto_generate_name when auto_generate is True.
|
||||
"""
|
||||
# Arrange
|
||||
app_model = ConversationServiceTestDataFactory.create_app_mock()
|
||||
user = ConversationServiceTestDataFactory.create_account_mock()
|
||||
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
|
||||
|
||||
mock_get_conversation.return_value = conversation
|
||||
mock_auto_generate.return_value = conversation
|
||||
|
||||
# Act
|
||||
result = ConversationService.rename(
|
||||
app_model=app_model,
|
||||
conversation_id="conv-123",
|
||||
user=user,
|
||||
name=None,
|
||||
auto_generate=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == conversation
|
||||
mock_auto_generate.assert_called_once_with(app_model, conversation)
|
||||
|
||||
|
||||
class TestConversationServiceAutoGenerateName:
|
||||
"""Test conversation auto-name generation operations."""
|
||||
@ -576,29 +546,6 @@ class TestConversationServiceDelete:
|
||||
mock_db_session.commit.assert_called_once()
|
||||
mock_delete_task.delay.assert_called_once_with(conversation.id)
|
||||
|
||||
@patch("services.conversation_service.db.session")
|
||||
@patch("services.conversation_service.ConversationService.get_conversation")
|
||||
def test_delete_handles_exception_and_rollback(self, mock_get_conversation, mock_db_session):
|
||||
"""
|
||||
Test deletion handles exceptions and rolls back transaction.
|
||||
|
||||
Should rollback database changes when deletion fails.
|
||||
"""
|
||||
# Arrange
|
||||
app_model = ConversationServiceTestDataFactory.create_app_mock()
|
||||
user = ConversationServiceTestDataFactory.create_account_mock()
|
||||
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
|
||||
|
||||
mock_get_conversation.return_value = conversation
|
||||
mock_db_session.delete.side_effect = Exception("Database Error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Database Error"):
|
||||
ConversationService.delete(app_model, "conv-123", user)
|
||||
|
||||
# Assert rollback was called
|
||||
mock_db_session.rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestConversationServiceConversationalVariable:
|
||||
"""Test conversational variable operations."""
|
||||
|
||||
Reference in New Issue
Block a user