Merge remote-tracking branch 'origin/main' into docs/e2e-writing-guide

This commit is contained in:
yyh
2026-04-13 09:36:59 +08:00
11 changed files with 178 additions and 90 deletions

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import Any, TypedDict
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -86,7 +86,14 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
return value_type.exposed_type().value
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
class FullContentDict(TypedDict):
size_bytes: int | None
value_type: str
length: int | None
download_url: str
def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict | None:
"""Serialize full_content information for large variables."""
if not variable.is_truncated():
return None
@ -94,12 +101,13 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
variable_file = variable.variable_file
assert variable_file is not None
return {
result: FullContentDict = {
"size_bytes": variable_file.size,
"value_type": variable_file.value_type.exposed_type().value,
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
return result
def _ensure_variable_access(

View File

@ -2,7 +2,7 @@ import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any
from typing import Any, TypedDict
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from graphon.model_runtime.entities.message_entities import (
@ -29,6 +29,13 @@ from models.model import Message
logger = logging.getLogger(__name__)
class ActionDict(TypedDict):
"""Shape produced by AgentScratchpadUnit.Action.to_dict()."""
action: str
action_input: dict[str, Any] | str
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
@ -331,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return tool_invoke_response, tool_invoke_meta
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
def _convert_dict_to_action(self, action: ActionDict) -> AgentScratchpadUnit.Action:
"""
convert dict to action
"""

View File

@ -32,9 +32,9 @@ class Extensible:
name: str
tenant_id: str
config: dict | None = None
config: dict[str, Any] | None = None
def __init__(self, tenant_id: str, config: dict | None = None):
def __init__(self, tenant_id: str, config: dict[str, Any] | None = None):
self.tenant_id = tenant_id
self.config = config

View File

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any, TypedDict
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
@ -7,6 +10,16 @@ from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class ApiToolConfig(TypedDict, total=False):
"""Expected config shape for ApiExternalDataTool.
Not used directly in method signatures (base class accepts dict[str, Any]);
kept here to document the keys this tool reads from config.
"""
api_based_extension_id: str
class ApiExternalDataTool(ExternalDataTool):
"""
The api external data tool.
@ -16,7 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
"""the unique name of external data tool"""
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -37,7 +50,7 @@ class ApiExternalDataTool(ExternalDataTool):
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
def query(self, inputs: dict, query: str | None = None) -> str:
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
"""
Query the external data tool.

View File

@ -1,4 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any
from core.extension.extensible import Extensible, ExtensionModule
@ -15,14 +17,14 @@ class ExternalDataTool(Extensible, ABC):
variable: str
"""the tool variable name of app tool"""
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None):
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict[str, Any] | None = None):
super().__init__(tenant_id, config)
self.app_id = app_id
self.variable = variable
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -33,7 +35,7 @@ class ExternalDataTool(Extensible, ABC):
raise NotImplementedError
@abstractmethod
def query(self, inputs: dict, query: str | None = None) -> str:
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
"""
Query the external data tool.

View File

@ -3,7 +3,7 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any, TypedDict
from typing import Any, NotRequired, TypedDict
import orjson
@ -16,6 +16,19 @@ class IdentityDict(TypedDict, total=False):
user_type: str
class LogDict(TypedDict):
ts: str
severity: str
service: str
caller: str
message: str
trace_id: NotRequired[str]
span_id: NotRequired[str]
identity: NotRequired[IdentityDict]
attributes: NotRequired[dict[str, Any]]
stack_trace: NotRequired[str]
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
@ -55,9 +68,9 @@ class StructuredJSONFormatter(logging.Formatter):
return json.dumps(log_dict, default=str, ensure_ascii=False)
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
def _build_log_dict(self, record: logging.LogRecord) -> LogDict:
# Core fields
log_dict: dict[str, Any] = {
log_dict: LogDict = {
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
"service": self._service_name,

View File

@ -141,6 +141,12 @@ class RedisHealthParamsDict(TypedDict):
health_check_interval: int | None
class RedisClusterHealthParamsDict(TypedDict):
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
class RedisBaseParamsDict(TypedDict):
username: str | None
password: str | None
@ -211,7 +217,7 @@ def _get_connection_health_params() -> RedisHealthParamsDict:
)
def _get_cluster_connection_health_params() -> dict[str, Any]:
def _get_cluster_connection_health_params() -> RedisClusterHealthParamsDict:
"""Get retry and timeout parameters for Redis Cluster clients.
RedisCluster does not support ``health_check_interval`` as a constructor
@ -219,8 +225,13 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
are passed through.
"""
params: dict[str, Any] = dict(_get_connection_health_params())
return {k: v for k, v in params.items() if k != "health_check_interval"}
health_params = _get_connection_health_params()
result: RedisClusterHealthParamsDict = {
"retry": health_params["retry"],
"socket_timeout": health_params["socket_timeout"],
"socket_connect_timeout": health_params["socket_connect_timeout"],
}
return result
def _get_base_redis_params() -> RedisBaseParamsDict:

View File

@ -1,5 +1,6 @@
import json
from datetime import datetime
from typing import Any, TypedDict
from uuid import uuid4
import sqlalchemy as sa
@ -38,6 +39,17 @@ class DataSourceOauthBinding(TypeBase):
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
class DataSourceApiKeyAuthBindingDict(TypedDict):
id: str
tenant_id: str
category: str
provider: str
credentials: Any
created_at: float
updated_at: float
disabled: bool
class DataSourceApiKeyAuthBinding(TypeBase):
__tablename__ = "data_source_api_key_auth_bindings"
__table_args__ = (
@ -65,8 +77,8 @@ class DataSourceApiKeyAuthBinding(TypeBase):
)
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
def to_dict(self):
return {
def to_dict(self) -> DataSourceApiKeyAuthBindingDict:
result: DataSourceApiKeyAuthBindingDict = {
"id": self.id,
"tenant_id": self.tenant_id,
"category": self.category,
@ -76,3 +88,4 @@ class DataSourceApiKeyAuthBinding(TypeBase):
"updated_at": self.updated_at.timestamp(),
"disabled": self.disabled,
}
return result

View File

@ -7,7 +7,7 @@ from graphon.nodes import BuiltinNodeTypes
from graphon.variables.segments import StringSegment
from graphon.variables.types import SegmentType
from graphon.variables.variables import StringVariable
from sqlalchemy import delete
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
@ -38,21 +38,25 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._test_user_id = str(uuid.uuid4())
self._session: Session = db.session()
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
name="sys_var",
value=build_segment("sys_value"),
node_execution_id=self._node_exec_id,
)
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
name="conv_var",
value=build_segment("conv_value"),
)
node2_vars = [
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node2_id,
name="int_var",
value=build_segment(1),
@ -61,6 +65,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
),
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node2_id,
name="str_var",
value=build_segment("str_value"),
@ -70,6 +75,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
]
node1_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
@ -141,24 +147,27 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test_delete_node_variables(self):
srv = self._get_test_srv()
srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id)
node2_var_count = (
self._session.query(WorkflowDraftVariable)
node2_var_count = self._session.scalar(
select(func.count())
.select_from(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == self._test_app_id,
WorkflowDraftVariable.node_id == self._node2_id,
WorkflowDraftVariable.user_id == self._test_user_id,
)
.count()
)
assert node2_var_count == 0
def test_delete_variable(self):
srv = self._get_test_srv()
node_1_var = (
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
)
node_1_var = self._session.scalars(
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id)
).one()
srv.delete_variable(node_1_var)
exists = bool(
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
self._session.scalars(
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id)
).first()
)
assert exists is False
@ -248,9 +257,7 @@ class TestDraftVariableLoader(unittest.TestCase):
def tearDown(self):
with Session(bind=db.engine, expire_on_commit=False) as session:
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete(
synchronize_session=False
)
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id))
session.commit()
def test_variable_loader_with_empty_selector(self):
@ -431,9 +438,11 @@ class TestDraftVariableLoader(unittest.TestCase):
# Clean up
with Session(bind=db.engine) as session:
# Query and delete by ID to ensure they're tracked in this session
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
session.query(UploadFile).filter_by(id=upload_file.id).delete()
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id))
session.execute(
delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id)
)
session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id))
session.commit()
# Clean up storage
try:
@ -534,9 +543,11 @@ class TestDraftVariableLoader(unittest.TestCase):
# Clean up
with Session(bind=db.engine) as session:
# Query and delete by ID to ensure they're tracked in this session
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
session.query(UploadFile).filter_by(id=upload_file.id).delete()
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id))
session.execute(
delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id)
)
session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id))
session.commit()
# Clean up storage
try:

View File

@ -637,6 +637,40 @@ class TestConversationServiceSummarization:
assert conversation.name == new_name
assert conversation.updated_at == mock_time
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers):
"""
Test rename delegates to auto_generate_name when auto_generate is True.
When auto_generate is True, the service should call auto_generate_name
which uses an LLM to create a descriptive conversation title.
"""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
db_session_with_containers
)
conversation = ConversationServiceIntegrationTestDataFactory.create_conversation(
db_session_with_containers, app_model, user
)
ConversationServiceIntegrationTestDataFactory.create_message(
db_session_with_containers, app_model, conversation, user
)
generated_name = "Auto Generated Name"
mock_llm_generator.return_value = generated_name
# Act
result = ConversationService.rename(
app_model=app_model,
conversation_id=conversation.id,
user=user,
name=None,
auto_generate=True,
)
# Assert
assert result == conversation
assert conversation.name == generated_name
class TestConversationServiceMessageAnnotation:
"""
@ -1066,3 +1100,32 @@ class TestConversationServiceExport:
not_deleted = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation.id))
assert not_deleted is not None
mock_delete_task.delay.assert_not_called()
@patch("services.conversation_service.delete_conversation_related_data")
def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers):
"""
Test that delete propagates exceptions and does not trigger the cleanup task.
When a DB error occurs during deletion, the service must rollback the
transaction and re-raise the exception without scheduling async cleanup.
"""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
db_session_with_containers
)
conversation = ConversationServiceIntegrationTestDataFactory.create_conversation(
db_session_with_containers, app_model, user
)
conversation_id = conversation.id
# Act — force an error during the delete to exercise the rollback path
with patch("services.conversation_service.db.session.delete", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
ConversationService.delete(app_model=app_model, conversation_id=conversation_id, user=user)
# Assert — async cleanup must NOT have been scheduled
mock_delete_task.delay.assert_not_called()
# Conversation is still present because the deletion was never committed
still_there = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation_id))
assert still_there is not None

View File

@ -435,36 +435,6 @@ class TestConversationServiceRename:
assert conversation.name == "New Name"
mock_db_session.commit.assert_called_once()
@patch("services.conversation_service.db.session")
@patch("services.conversation_service.ConversationService.get_conversation")
@patch("services.conversation_service.ConversationService.auto_generate_name")
def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session):
"""
Test renaming conversation with auto-generation.
Should call auto_generate_name when auto_generate is True.
"""
# Arrange
app_model = ConversationServiceTestDataFactory.create_app_mock()
user = ConversationServiceTestDataFactory.create_account_mock()
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
mock_get_conversation.return_value = conversation
mock_auto_generate.return_value = conversation
# Act
result = ConversationService.rename(
app_model=app_model,
conversation_id="conv-123",
user=user,
name=None,
auto_generate=True,
)
# Assert
assert result == conversation
mock_auto_generate.assert_called_once_with(app_model, conversation)
class TestConversationServiceAutoGenerateName:
"""Test conversation auto-name generation operations."""
@ -576,29 +546,6 @@ class TestConversationServiceDelete:
mock_db_session.commit.assert_called_once()
mock_delete_task.delay.assert_called_once_with(conversation.id)
@patch("services.conversation_service.db.session")
@patch("services.conversation_service.ConversationService.get_conversation")
def test_delete_handles_exception_and_rollback(self, mock_get_conversation, mock_db_session):
"""
Test deletion handles exceptions and rolls back transaction.
Should rollback database changes when deletion fails.
"""
# Arrange
app_model = ConversationServiceTestDataFactory.create_app_mock()
user = ConversationServiceTestDataFactory.create_account_mock()
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
mock_get_conversation.return_value = conversation
mock_db_session.delete.side_effect = Exception("Database Error")
# Act & Assert
with pytest.raises(Exception, match="Database Error"):
ConversationService.delete(app_model, "conv-123", user)
# Assert rollback was called
mock_db_session.rollback.assert_called_once()
class TestConversationServiceConversationalVariable:
"""Test conversational variable operations."""