Merge branch 'main' into jzh

This commit is contained in:
JzoNg
2026-04-13 10:29:55 +08:00
25 changed files with 696 additions and 420 deletions

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import Any, TypedDict
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -86,7 +86,14 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
return value_type.exposed_type().value
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
class FullContentDict(TypedDict):
size_bytes: int | None
value_type: str
length: int | None
download_url: str
def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict | None:
"""Serialize full_content information for large variables."""
if not variable.is_truncated():
return None
@ -94,12 +101,13 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
variable_file = variable.variable_file
assert variable_file is not None
return {
result: FullContentDict = {
"size_bytes": variable_file.size,
"value_type": variable_file.value_type.exposed_type().value,
"length": variable_file.length,
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
}
return result
def _ensure_variable_access(

View File

@ -2,7 +2,7 @@ import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any
from typing import Any, TypedDict
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from graphon.model_runtime.entities.message_entities import (
@ -29,6 +29,13 @@ from models.model import Message
logger = logging.getLogger(__name__)
class ActionDict(TypedDict):
"""Shape produced by AgentScratchpadUnit.Action.to_dict()."""
action: str
action_input: dict[str, Any] | str
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
@ -331,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return tool_invoke_response, tool_invoke_meta
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
def _convert_dict_to_action(self, action: ActionDict) -> AgentScratchpadUnit.Action:
"""
convert dict to action
"""

View File

@ -32,9 +32,9 @@ class Extensible:
name: str
tenant_id: str
config: dict | None = None
config: dict[str, Any] | None = None
def __init__(self, tenant_id: str, config: dict | None = None):
def __init__(self, tenant_id: str, config: dict[str, Any] | None = None):
self.tenant_id = tenant_id
self.config = config

View File

@ -1,3 +1,6 @@
from collections.abc import Mapping
from typing import Any, TypedDict
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
@ -7,6 +10,16 @@ from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class ApiToolConfig(TypedDict, total=False):
"""Expected config shape for ApiExternalDataTool.
Not used directly in method signatures (base class accepts dict[str, Any]);
kept here to document the keys this tool reads from config.
"""
api_based_extension_id: str
class ApiExternalDataTool(ExternalDataTool):
"""
The api external data tool.
@ -16,7 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
"""the unique name of external data tool"""
@classmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -37,7 +50,7 @@ class ApiExternalDataTool(ExternalDataTool):
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
def query(self, inputs: dict, query: str | None = None) -> str:
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
"""
Query the external data tool.

View File

@ -1,4 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Any
from core.extension.extensible import Extensible, ExtensionModule
@ -15,14 +17,14 @@ class ExternalDataTool(Extensible, ABC):
variable: str
"""the tool variable name of app tool"""
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None):
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict[str, Any] | None = None):
super().__init__(tenant_id, config)
self.app_id = app_id
self.variable = variable
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict):
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
"""
Validate the incoming form config data.
@ -33,7 +35,7 @@ class ExternalDataTool(Extensible, ABC):
raise NotImplementedError
@abstractmethod
def query(self, inputs: dict, query: str | None = None) -> str:
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
"""
Query the external data tool.

View File

@ -3,7 +3,7 @@
import logging
import traceback
from datetime import UTC, datetime
from typing import Any, TypedDict
from typing import Any, NotRequired, TypedDict
import orjson
@ -16,6 +16,19 @@ class IdentityDict(TypedDict, total=False):
user_type: str
class LogDict(TypedDict):
ts: str
severity: str
service: str
caller: str
message: str
trace_id: NotRequired[str]
span_id: NotRequired[str]
identity: NotRequired[IdentityDict]
attributes: NotRequired[dict[str, Any]]
stack_trace: NotRequired[str]
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
@ -55,9 +68,9 @@ class StructuredJSONFormatter(logging.Formatter):
return json.dumps(log_dict, default=str, ensure_ascii=False)
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
def _build_log_dict(self, record: logging.LogRecord) -> LogDict:
# Core fields
log_dict: dict[str, Any] = {
log_dict: LogDict = {
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
"service": self._service_name,

View File

@ -141,6 +141,12 @@ class RedisHealthParamsDict(TypedDict):
health_check_interval: int | None
class RedisClusterHealthParamsDict(TypedDict):
retry: Retry
socket_timeout: float | None
socket_connect_timeout: float | None
class RedisBaseParamsDict(TypedDict):
username: str | None
password: str | None
@ -211,7 +217,7 @@ def _get_connection_health_params() -> RedisHealthParamsDict:
)
def _get_cluster_connection_health_params() -> dict[str, Any]:
def _get_cluster_connection_health_params() -> RedisClusterHealthParamsDict:
"""Get retry and timeout parameters for Redis Cluster clients.
RedisCluster does not support ``health_check_interval`` as a constructor
@ -219,8 +225,13 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
are passed through.
"""
params: dict[str, Any] = dict(_get_connection_health_params())
return {k: v for k, v in params.items() if k != "health_check_interval"}
health_params = _get_connection_health_params()
result: RedisClusterHealthParamsDict = {
"retry": health_params["retry"],
"socket_timeout": health_params["socket_timeout"],
"socket_connect_timeout": health_params["socket_connect_timeout"],
}
return result
def _get_base_redis_params() -> RedisBaseParamsDict:

View File

@ -353,13 +353,17 @@ class Dataset(Base):
if self.provider != "external":
return None
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
select(ExternalKnowledgeBindings).where(
ExternalKnowledgeBindings.dataset_id == self.id,
ExternalKnowledgeBindings.tenant_id == self.tenant_id,
)
)
if not external_knowledge_binding:
return None
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis).where(
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id,
ExternalKnowledgeApis.tenant_id == self.tenant_id,
)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:

View File

@ -1,5 +1,6 @@
import json
from datetime import datetime
from typing import Any, TypedDict
from uuid import uuid4
import sqlalchemy as sa
@ -38,6 +39,17 @@ class DataSourceOauthBinding(TypeBase):
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
class DataSourceApiKeyAuthBindingDict(TypedDict):
id: str
tenant_id: str
category: str
provider: str
credentials: Any
created_at: float
updated_at: float
disabled: bool
class DataSourceApiKeyAuthBinding(TypeBase):
__tablename__ = "data_source_api_key_auth_bindings"
__table_args__ = (
@ -65,8 +77,8 @@ class DataSourceApiKeyAuthBinding(TypeBase):
)
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
def to_dict(self):
return {
def to_dict(self) -> DataSourceApiKeyAuthBindingDict:
result: DataSourceApiKeyAuthBindingDict = {
"id": self.id,
"tenant_id": self.tenant_id,
"category": self.category,
@ -76,3 +88,4 @@ class DataSourceApiKeyAuthBinding(TypeBase):
"updated_at": self.updated_at.timestamp(),
"disabled": self.disabled,
}
return result

View File

@ -8,38 +8,38 @@ dependencies = [
"arize-phoenix-otel~=0.15.0",
"azure-identity==1.25.3",
"beautifulsoup4==4.14.3",
"boto3==1.42.83",
"boto3==1.42.88",
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.6.2",
"charset-normalizer>=3.4.4",
"flask~=3.1.2",
"flask-compress>=1.17,<1.25",
"flask-cors~=6.0.0",
"flask~=3.1.3",
"flask-compress>=1.24,<1.25",
"flask-cors~=6.0.2",
"flask-login~=0.6.3",
"flask-migrate~=4.1.0",
"flask-orjson~=2.0.0",
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gmpy2~=2.3.0",
"google-api-core>=2.19.1",
"google-api-python-client==2.193.0",
"google-auth>=2.47.0",
"google-api-core>=2.30.3",
"google-api-python-client==2.194.0",
"google-auth>=2.49.2",
"google-auth-httplib2==0.3.1",
"google-cloud-aiplatform>=1.123.0",
"googleapis-common-protos>=1.65.0",
"google-cloud-aiplatform>=1.147.0",
"googleapis-common-protos>=1.74.0",
"graphon>=0.1.2",
"gunicorn~=25.3.0",
"httpx[socks]~=0.28.0",
"jieba==0.42.1",
"json-repair>=0.55.1",
"langfuse>=3.0.0,<5.0.0",
"langsmith~=0.7.16",
"langfuse>=4.2.0,<5.0.0",
"langsmith~=0.7.30",
"markdown~=3.10.2",
"mlflow-skinny>=3.0.0",
"mlflow-skinny>=3.11.1",
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.10.37",
"opik~=1.11.2",
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.40.0",
"opentelemetry-distro==0.61b0",
@ -53,14 +53,14 @@ dependencies = [
"opentelemetry-instrumentation-httpx==0.61b0",
"opentelemetry-instrumentation-redis==0.61b0",
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
"opentelemetry-propagator-b3==1.40.0",
"opentelemetry-propagator-b3==1.41.0",
"opentelemetry-proto==1.40.0",
"opentelemetry-sdk==1.40.0",
"opentelemetry-semantic-conventions==0.61b0",
"opentelemetry-util-http==0.61b0",
"pandas[excel,output-formatting,performance]~=3.0.1",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",
"psycopg2-binary~=2.9.11",
"pycryptodome==3.23.0",
"pydantic~=2.12.5",
"pydantic-settings~=2.13.1",
@ -73,7 +73,7 @@ dependencies = [
"redis[hiredis]~=7.4.0",
"resend~=2.26.0",
"sentry-sdk[flask]~=2.55.0",
"sqlalchemy~=2.0.29",
"sqlalchemy~=2.0.49",
"starlette==1.0.0",
"tiktoken~=0.12.0",
"transformers~=5.3.0",
@ -86,7 +86,7 @@ dependencies = [
"flask-restx~=1.3.2",
"packaging~=23.2",
"croniter>=6.0.0",
"weaviate-client==4.20.4",
"weaviate-client==4.20.5",
"apscheduler>=3.11.0",
"weave>=0.52.16",
"fastopenapi[flask]>=0.7.0",
@ -111,11 +111,11 @@ package = false
dev = [
"coverage~=7.13.4",
"dotenv-linter~=0.7.0",
"faker~=40.12.0",
"faker~=40.13.0",
"lxml-stubs~=0.5.1",
"basedpyright~=1.39.0",
"ruff~=0.15.5",
"pytest~=9.0.2",
"ruff~=0.15.10",
"pytest~=9.0.3",
"pytest-benchmark~=5.2.3",
"pytest-cov~=7.1.0",
"pytest-env~=1.6.0",
@ -130,8 +130,8 @@ dev = [
"types-docutils~=0.22.3",
"types-flask-cors~=6.0.0",
"types-flask-migrate~=4.1.0",
"types-gevent~=25.9.0",
"types-greenlet~=3.3.0",
"types-gevent~=26.4.0",
"types-greenlet~=3.4.0",
"types-html5lib~=1.1.11",
"types-markdown~=3.10.2",
"types-oauthlib~=3.3.0",
@ -149,20 +149,20 @@ dev = [
"types-pyyaml~=6.0.12",
"types-regex~=2026.4.4",
"types-shapely~=2.1.0",
"types-simplejson>=3.20.0",
"types-six>=1.17.0",
"types-tensorflow>=2.18.0",
"types-tqdm>=4.67.0",
"types-simplejson>=3.20.0.20260408",
"types-six>=1.17.0.20260408",
"types-tensorflow>=2.18.0.20260408",
"types-tqdm>=4.67.3.20260408",
"types-ujson>=5.10.0",
"boto3-stubs>=1.38.20",
"types-jmespath>=1.0.2.20240106",
"hypothesis>=6.131.15",
"boto3-stubs>=1.42.88",
"types-jmespath>=1.1.0.20260408",
"hypothesis>=6.151.12",
"types_pyOpenSSL>=24.1.0",
"types_cffi>=1.17.0",
"types_setuptools>=80.9.0",
"types_cffi>=2.0.0.20260408",
"types_setuptools>=82.0.0.20260408",
"pandas-stubs~=3.0.0",
"scipy-stubs>=1.15.3.0",
"types-python-http-client>=3.3.7.20240910",
"types-python-http-client>=3.3.7.20260408",
"import-linter>=2.3",
"types-redis>=4.6.0.20241004",
"celery-types>=0.23.0",
@ -180,10 +180,10 @@ dev = [
############################################################
storage = [
"azure-storage-blob==12.28.0",
"bce-python-sdk~=0.9.23",
"bce-python-sdk~=0.9.69",
"cos-python-sdk-v5==1.9.41",
"esdk-obs-python==3.26.2",
"google-cloud-storage>=3.0.0",
"google-cloud-storage>=3.10.1",
"opendal~=0.46.0",
"oss2==2.19.1",
"supabase~=2.18.1",
@ -209,19 +209,19 @@ vdb = [
"elasticsearch==8.14.0",
"opensearch-py==3.1.0",
"oracledb==3.4.2",
"pgvecto-rs[sqlalchemy]~=0.2.1",
"pgvecto-rs[sqlalchemy]~=0.2.2",
"pgvector==0.4.2",
"pymilvus~=2.6.10",
"pymilvus~=2.6.12",
"pymochow==2.4.0",
"pyobvector~=0.2.17",
"qdrant-client==1.9.0",
"intersystems-irispython>=5.1.0",
"tablestore==6.4.3",
"tablestore==6.4.4",
"tcvectordb~=2.1.0",
"tidb-vector==0.0.15",
"upstash-vector==0.8.0",
"volcengine-compat~=1.0.0",
"weaviate-client==4.20.4",
"weaviate-client==4.20.5",
"xinference-client~=2.4.0",
"mo-vector~=0.1.13",
"mysql-connector-python>=9.3.0",

View File

@ -528,6 +528,8 @@ class DatasetService:
raise ValueError("External knowledge id is required.")
if not external_knowledge_api_id:
raise ValueError("External knowledge api id is required.")
# Ensure the referenced external API template exists and belongs to the dataset tenant.
ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id, dataset.tenant_id)
# Update metadata fields
dataset.updated_by = user.id if user else None
dataset.updated_at = naive_utc_now()

View File

@ -317,7 +317,10 @@ class ExternalDatasetService:
external_knowledge_api = db.session.scalar(
select(ExternalKnowledgeApis)
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
.where(
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id,
ExternalKnowledgeApis.tenant_id == tenant_id,
)
.limit(1)
)
if external_knowledge_api is None or external_knowledge_api.settings is None:

View File

@ -3,8 +3,9 @@ import json
import logging
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import StrEnum
from typing import Any, ClassVar
from typing import Any, ClassVar, NotRequired, TypedDict
from graphon.enums import NodeType
from graphon.file import File
@ -725,8 +726,27 @@ def _batch_upsert_draft_variable(
session.execute(stmt)
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
d: dict[str, Any] = {
class _InsertionDict(TypedDict):
id: str
app_id: str
user_id: str | None
last_edited_at: datetime | None
node_id: str
name: str
selector: str
value_type: SegmentType
value: str
node_execution_id: str | None
file_id: str | None
visible: NotRequired[bool]
editable: NotRequired[bool]
created_at: NotRequired[datetime]
updated_at: NotRequired[datetime]
description: NotRequired[str]
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> _InsertionDict:
d: _InsertionDict = {
"id": model.id,
"app_id": model.app_id,
"user_id": model.user_id,

View File

@ -8,6 +8,7 @@ from collections.abc import Generator
import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app_factory import create_app
@ -83,15 +84,15 @@ def setup_account(request) -> Generator[Account, None, None]:
with _CACHED_APP.test_request_context():
with Session(bind=db.engine, expire_on_commit=False) as session:
account = session.query(Account).filter_by(email=email).one()
account = session.scalars(select(Account).filter_by(email=email)).one()
yield account
with _CACHED_APP.test_request_context():
db.session.query(DifySetup).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Account).delete()
db.session.query(Tenant).delete()
db.session.execute(delete(DifySetup))
db.session.execute(delete(TenantAccountJoin))
db.session.execute(delete(Account))
db.session.execute(delete(Tenant))
db.session.commit()

View File

@ -1,5 +1,5 @@
import pytest
from sqlalchemy import delete
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from models import Tenant
@ -61,7 +61,11 @@ class TestPluginPermissionLifecycle:
assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
with session_factory.create_session() as session:
count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count()
count = session.scalar(
select(func.count())
.select_from(TenantPluginPermission)
.where(TenantPluginPermission.tenant_id == tenant)
)
assert count == 1

View File

@ -3,7 +3,7 @@ import math
import uuid
import pytest
from sqlalchemy import delete
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from models import Tenant
@ -210,7 +210,7 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 0
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids)))
assert remaining == len(all_ids)
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
@ -231,7 +231,7 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == len(all_ids)
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids)))
assert remaining == 0
def test_start_from_filters_correctly(self, seed_messages):
@ -254,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
with session_factory.create_session() as session:
all_ids = list(msg_ids.values())
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
remaining_ids = set(session.scalars(select(Message.id).where(Message.id.in_(all_ids))).all())
assert msg_ids["old"] not in remaining_ids
assert msg_ids["very_old"] in remaining_ids
@ -282,7 +282,7 @@ class TestMessagesCleanServiceIntegration:
assert stats["batches"] >= expected_batches
with session_factory.create_session() as session:
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(msg_ids)))
assert remaining == 0
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
@ -319,9 +319,17 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 1
with session_factory.create_session() as session:
assert session.query(Message).where(Message.id == msg_id).count() == 0
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
assert session.scalar(select(func.count()).select_from(Message).where(Message.id == msg_id)) == 0
assert (
session.scalar(select(func.count()).select_from(MessageFeedback).where(MessageFeedback.id == fb_id))
== 0
)
assert (
session.scalar(
select(func.count()).select_from(MessageAnnotation).where(MessageAnnotation.id == ann_id)
)
== 0
)
def test_factory_from_time_range_validation(self):
with pytest.raises(ValueError, match="start_from"):

View File

@ -7,7 +7,7 @@ from graphon.nodes import BuiltinNodeTypes
from graphon.variables.segments import StringSegment
from graphon.variables.types import SegmentType
from graphon.variables.variables import StringVariable
from sqlalchemy import delete
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
@ -38,21 +38,25 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._test_user_id = str(uuid.uuid4())
self._session: Session = db.session()
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
name="sys_var",
value=build_segment("sys_value"),
node_execution_id=self._node_exec_id,
)
conv_var = WorkflowDraftVariable.new_conversation_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
name="conv_var",
value=build_segment("conv_value"),
)
node2_vars = [
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node2_id,
name="int_var",
value=build_segment(1),
@ -61,6 +65,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
),
WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node2_id,
name="str_var",
value=build_segment("str_value"),
@ -70,6 +75,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
]
node1_var = WorkflowDraftVariable.new_node_variable(
app_id=self._test_app_id,
user_id=self._test_user_id,
node_id=self._node1_id,
name="str_var",
value=build_segment("str_value"),
@ -141,24 +147,27 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
def test_delete_node_variables(self):
srv = self._get_test_srv()
srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id)
node2_var_count = (
self._session.query(WorkflowDraftVariable)
node2_var_count = self._session.scalar(
select(func.count())
.select_from(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id == self._test_app_id,
WorkflowDraftVariable.node_id == self._node2_id,
WorkflowDraftVariable.user_id == self._test_user_id,
)
.count()
)
assert node2_var_count == 0
def test_delete_variable(self):
srv = self._get_test_srv()
node_1_var = (
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
)
node_1_var = self._session.scalars(
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id)
).one()
srv.delete_variable(node_1_var)
exists = bool(
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
self._session.scalars(
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id)
).first()
)
assert exists is False
@ -248,9 +257,7 @@ class TestDraftVariableLoader(unittest.TestCase):
def tearDown(self):
with Session(bind=db.engine, expire_on_commit=False) as session:
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete(
synchronize_session=False
)
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id))
session.commit()
def test_variable_loader_with_empty_selector(self):
@ -431,9 +438,11 @@ class TestDraftVariableLoader(unittest.TestCase):
# Clean up
with Session(bind=db.engine) as session:
# Query and delete by ID to ensure they're tracked in this session
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
session.query(UploadFile).filter_by(id=upload_file.id).delete()
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id))
session.execute(
delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id)
)
session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id))
session.commit()
# Clean up storage
try:
@ -534,9 +543,11 @@ class TestDraftVariableLoader(unittest.TestCase):
# Clean up
with Session(bind=db.engine) as session:
# Query and delete by ID to ensure they're tracked in this session
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
session.query(UploadFile).filter_by(id=upload_file.id).delete()
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id))
session.execute(
delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id)
)
session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id))
session.commit()
# Clean up storage
try:

View File

@ -3,7 +3,7 @@ from unittest.mock import patch
import pytest
from graphon.variables.segments import StringSegment
from sqlalchemy import delete
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from extensions.storage.storage_type import StorageType
@ -108,8 +108,12 @@ class TestDeleteDraftVariablesIntegration:
app2_id = data["app2"].id
with session_factory.create_session() as session:
app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
app1_vars_before = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id)
)
app2_vars_before = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app2_id)
)
assert app1_vars_before == 5
assert app2_vars_before == 5
@ -117,8 +121,12 @@ class TestDeleteDraftVariablesIntegration:
assert deleted_count == 5
with session_factory.create_session() as session:
app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
app1_vars_after = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id)
)
app2_vars_after = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app2_id)
)
assert app1_vars_after == 0
assert app2_vars_after == 5
@ -130,7 +138,9 @@ class TestDeleteDraftVariablesIntegration:
assert deleted_count == 5
with session_factory.create_session() as session:
remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
remaining_vars = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id)
)
assert remaining_vars == 0
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
@ -143,14 +153,18 @@ class TestDeleteDraftVariablesIntegration:
app1_id = data["app1"].id
with session_factory.create_session() as session:
vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
vars_before = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id)
)
assert vars_before == 5
deleted_count = _delete_draft_variables(app1_id)
assert deleted_count == 5
with session_factory.create_session() as session:
vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
vars_after = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id)
)
assert vars_after == 0
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
@ -175,7 +189,9 @@ class TestDeleteDraftVariablesIntegration:
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
assert deleted_count == 25
with session_factory.create_session() as session:
remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
remaining = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app.id)
)
assert remaining == 0
finally:
with session_factory.create_session() as session:
@ -307,13 +323,17 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
mock_storage.delete.return_value = None
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
draft_vars_before = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
)
var_files_before = session.scalar(
select(func.count())
.select_from(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
)
upload_files_before = session.scalar(
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
@ -322,16 +342,20 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert deleted_count == 3
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
draft_vars_after = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
)
assert draft_vars_after == 0
with session_factory.create_session() as session:
var_files_after = (
session.query(WorkflowDraftVariableFile)
var_files_after = session.scalar(
select(func.count())
.select_from(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
upload_files_after = session.scalar(
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
)
assert var_files_after == 0
assert upload_files_after == 0
@ -352,16 +376,20 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert deleted_count == 3
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
draft_vars_after = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
)
assert draft_vars_after == 0
with session_factory.create_session() as session:
var_files_after = (
session.query(WorkflowDraftVariableFile)
var_files_after = session.scalar(
select(func.count())
.select_from(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
upload_files_after = session.scalar(
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
)
assert var_files_after == 0
assert upload_files_after == 0
@ -579,7 +607,9 @@ class TestDeleteDraftVariablesSessionCommit:
# Verify all data was deleted (proves transaction was committed)
with session_factory.create_session() as session:
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
remaining_count = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
)
assert deleted_count == 10
assert remaining_count == 0
@ -592,7 +622,9 @@ class TestDeleteDraftVariablesSessionCommit:
# Verify initial state
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
initial_count = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
)
assert initial_count == 10
# Perform deletion with small batch size to force multiple commits
@ -602,13 +634,17 @@ class TestDeleteDraftVariablesSessionCommit:
# Verify all data is deleted in a new session (proves commits worked)
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
final_count = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
)
assert final_count == 0
# Verify specific IDs are deleted
with session_factory.create_session() as session:
remaining_vars = (
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
remaining_vars = session.scalar(
select(func.count())
.select_from(WorkflowDraftVariable)
.where(WorkflowDraftVariable.id.in_(variable_ids))
)
assert remaining_vars == 0
@ -626,7 +662,9 @@ class TestDeleteDraftVariablesSessionCommit:
app_id = data["app"].id
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
initial_count = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
)
assert initial_count == 10
# Delete all in a single batch
@ -635,7 +673,9 @@ class TestDeleteDraftVariablesSessionCommit:
# Verify data is persisted
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
final_count = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
)
assert final_count == 0
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
@ -659,13 +699,17 @@ class TestDeleteDraftVariablesSessionCommit:
# Verify initial state
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
draft_vars_before = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
)
var_files_before = session.scalar(
select(func.count())
.select_from(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
)
upload_files_before = session.scalar(
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
@ -676,13 +720,17 @@ class TestDeleteDraftVariablesSessionCommit:
# Verify all data is persisted (deleted) in new session
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
draft_vars_after = session.scalar(
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
)
var_files_after = session.scalar(
select(func.count())
.select_from(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
)
upload_files_after = session.scalar(
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_after == 0
assert var_files_after == 0
assert upload_files_after == 0

View File

@ -637,6 +637,40 @@ class TestConversationServiceSummarization:
assert conversation.name == new_name
assert conversation.updated_at == mock_time
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers):
"""
Test rename delegates to auto_generate_name when auto_generate is True.
When auto_generate is True, the service should call auto_generate_name
which uses an LLM to create a descriptive conversation title.
"""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
db_session_with_containers
)
conversation = ConversationServiceIntegrationTestDataFactory.create_conversation(
db_session_with_containers, app_model, user
)
ConversationServiceIntegrationTestDataFactory.create_message(
db_session_with_containers, app_model, conversation, user
)
generated_name = "Auto Generated Name"
mock_llm_generator.return_value = generated_name
# Act
result = ConversationService.rename(
app_model=app_model,
conversation_id=conversation.id,
user=user,
name=None,
auto_generate=True,
)
# Assert
assert result == conversation
assert conversation.name == generated_name
class TestConversationServiceMessageAnnotation:
"""
@ -1066,3 +1100,32 @@ class TestConversationServiceExport:
not_deleted = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation.id))
assert not_deleted is not None
mock_delete_task.delay.assert_not_called()
@patch("services.conversation_service.delete_conversation_related_data")
def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers):
"""
Test that delete propagates exceptions and does not trigger the cleanup task.
When a DB error occurs during deletion, the service must rollback the
transaction and re-raise the exception without scheduling async cleanup.
"""
# Arrange
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
db_session_with_containers
)
conversation = ConversationServiceIntegrationTestDataFactory.create_conversation(
db_session_with_containers, app_model, user
)
conversation_id = conversation.id
# Act — force an error during the delete to exercise the rollback path
with patch("services.conversation_service.db.session.delete", side_effect=Exception("DB error")):
with pytest.raises(Exception, match="DB error"):
ConversationService.delete(app_model=app_model, conversation_id=conversation_id, user=user)
# Assert — async cleanup must NOT have been scheduled
mock_delete_task.delay.assert_not_called()
# Conversation is still present because the deletion was never committed
still_there = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation_id))
assert still_there is not None

View File

@ -1,3 +1,4 @@
import json
from unittest.mock import Mock, patch
from uuid import uuid4
@ -7,7 +8,7 @@ from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexTechniqueType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, ExternalKnowledgeBindings
from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings
from models.enums import DataSourceType
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
@ -103,6 +104,34 @@ class DatasetUpdateTestDataFactory:
db_session_with_containers.commit()
return binding
@staticmethod
def create_external_knowledge_api(
db_session_with_containers: Session,
tenant_id: str,
created_by: str,
api_id: str | None = None,
name: str = "test-api",
) -> ExternalKnowledgeApis:
"""Create a real external knowledge API template for tenant-scoped update validation."""
external_api = ExternalKnowledgeApis(
tenant_id=tenant_id,
created_by=created_by,
updated_by=created_by,
name=name,
description="test description",
settings=json.dumps(
{
"endpoint": "https://example.com",
"api_key": "test-api-key",
}
),
)
if api_id is not None:
external_api.id = api_id
db_session_with_containers.add(external_api)
db_session_with_containers.commit()
return external_api
class TestDatasetServiceUpdateDataset:
"""
@ -138,6 +167,11 @@ class TestDatasetServiceUpdateDataset:
)
binding_id = binding.id
db_session_with_containers.expunge(binding)
external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
)
update_data = {
"name": "new_name",
@ -145,7 +179,7 @@ class TestDatasetServiceUpdateDataset:
"external_retrieval_model": "new_model",
"permission": "only_me",
"external_knowledge_id": "new_knowledge_id",
"external_knowledge_api_id": str(uuid4()),
"external_knowledge_api_id": external_api.id,
}
result = DatasetService.update_dataset(dataset.id, update_data, user)
@ -218,11 +252,16 @@ class TestDatasetServiceUpdateDataset:
created_by=user.id,
provider="external",
)
external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
)
update_data = {
"name": "new_name",
"external_knowledge_id": "knowledge_id",
"external_knowledge_api_id": str(uuid4()),
"external_knowledge_api_id": external_api.id,
}
with pytest.raises(ValueError) as context:

View File

@ -12,7 +12,7 @@ This test suite covers:
import json
import pickle
from datetime import UTC, datetime
from unittest.mock import patch
from unittest.mock import Mock, patch
from uuid import uuid4
from core.rag.index_processor.constant.index_type import IndexTechniqueType
@ -25,6 +25,7 @@ from models.dataset import (
Document,
DocumentSegment,
Embedding,
ExternalKnowledgeBindings,
)
from models.enums import (
DataSourceType,
@ -180,6 +181,24 @@ class TestDatasetModelValidation:
assert result["top_k"] == 2
assert result["score_threshold"] == 0.0
def test_dataset_external_knowledge_info_returns_none_for_cross_tenant_template(self):
"""Test external datasets fail closed when the bound template is outside the tenant."""
dataset = Dataset(
tenant_id=str(uuid4()),
name="External Dataset",
data_source_type=DataSourceType.UPLOAD_FILE,
created_by=str(uuid4()),
provider="external",
)
binding = Mock(spec=ExternalKnowledgeBindings)
binding.external_knowledge_id = "knowledge-1"
binding.external_knowledge_api_id = str(uuid4())
with patch("models.dataset.db") as mock_db:
mock_db.session.scalar.side_effect = [binding, None]
assert dataset.external_knowledge_info is None
def test_dataset_retrieval_model_dict_property(self):
"""Test retrieval_model_dict property with default values."""
# Arrange

View File

@ -435,36 +435,6 @@ class TestConversationServiceRename:
assert conversation.name == "New Name"
mock_db_session.commit.assert_called_once()
@patch("services.conversation_service.db.session")
@patch("services.conversation_service.ConversationService.get_conversation")
@patch("services.conversation_service.ConversationService.auto_generate_name")
def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session):
"""
Test renaming conversation with auto-generation.
Should call auto_generate_name when auto_generate is True.
"""
# Arrange
app_model = ConversationServiceTestDataFactory.create_app_mock()
user = ConversationServiceTestDataFactory.create_account_mock()
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
mock_get_conversation.return_value = conversation
mock_auto_generate.return_value = conversation
# Act
result = ConversationService.rename(
app_model=app_model,
conversation_id="conv-123",
user=user,
name=None,
auto_generate=True,
)
# Assert
assert result == conversation
mock_auto_generate.assert_called_once_with(app_model, conversation)
class TestConversationServiceAutoGenerateName:
"""Test conversation auto-name generation operations."""
@ -576,29 +546,6 @@ class TestConversationServiceDelete:
mock_db_session.commit.assert_called_once()
mock_delete_task.delay.assert_called_once_with(conversation.id)
@patch("services.conversation_service.db.session")
@patch("services.conversation_service.ConversationService.get_conversation")
def test_delete_handles_exception_and_rollback(self, mock_get_conversation, mock_db_session):
"""
Test deletion handles exceptions and rolls back transaction.
Should rollback database changes when deletion fails.
"""
# Arrange
app_model = ConversationServiceTestDataFactory.create_app_mock()
user = ConversationServiceTestDataFactory.create_account_mock()
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
mock_get_conversation.return_value = conversation
mock_db_session.delete.side_effect = Exception("Database Error")
# Act & Assert
with pytest.raises(Exception, match="Database Error"):
ConversationService.delete(app_model, "conv-123", user)
# Assert rollback was called
mock_db_session.rollback.assert_called_once()
class TestConversationServiceConversationalVariable:
"""Test conversational variable operations."""

View File

@ -532,6 +532,9 @@ class TestDatasetServiceCreationAndUpdate:
with (
patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding,
patch(
"services.dataset_service.ExternalDatasetService.get_external_knowledge_api", return_value=object()
) as get_external_knowledge_api,
patch("services.dataset_service.naive_utc_now", return_value=now),
patch("services.dataset_service.db") as mock_db,
):
@ -557,6 +560,7 @@ class TestDatasetServiceCreationAndUpdate:
assert dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM
assert dataset.updated_by == "user-1"
assert dataset.updated_at is now
get_external_knowledge_api.assert_called_once_with("api-1", dataset.tenant_id)
update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1")
mock_db.session.add.assert_called_once_with(dataset)
mock_db.session.commit.assert_called_once()
@ -574,6 +578,31 @@ class TestDatasetServiceCreationAndUpdate:
with pytest.raises(ValueError, match=message):
DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1"))
def test_update_external_dataset_rejects_cross_tenant_external_api_id(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
with (
patch(
"services.dataset_service.ExternalDatasetService.get_external_knowledge_api",
side_effect=ValueError("api template not found"),
) as get_external_knowledge_api,
patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding,
patch("services.dataset_service.db") as mock_db,
):
with pytest.raises(ValueError, match="api template not found"):
DatasetService._update_external_dataset(
dataset,
{
"external_knowledge_id": "knowledge-1",
"external_knowledge_api_id": "foreign-api",
},
SimpleNamespace(id="user-1"),
)
get_external_knowledge_api.assert_called_once_with("foreign-api", dataset.tenant_id)
update_binding.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_update_external_knowledge_binding_updates_changed_binding_values(self):
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
session = MagicMock()

View File

@ -1560,6 +1560,17 @@ class TestExternalDatasetServiceFetchRetrieval:
with pytest.raises(ValueError, match="external knowledge binding not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {})
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_cross_tenant_api_template_error(self, mock_db, factory):
"""Test error when a binding points to an API template outside the dataset tenant."""
# Arrange
binding = factory.create_external_knowledge_binding_mock()
mock_db.session.scalar.side_effect = [binding, None]
# Act & Assert
with pytest.raises(ValueError, match="external api template not found"):
ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {})
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
@patch("services.external_knowledge_service.db")
def test_fetch_external_knowledge_retrieval_empty_results(self, mock_db, mock_process, factory):

446
api/uv.lock generated

File diff suppressed because it is too large Load Diff