mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 03:07:39 +08:00
Merge branch 'main' into jzh
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
@ -86,7 +86,14 @@ def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
|
||||
return value_type.exposed_type().value
|
||||
|
||||
|
||||
def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
class FullContentDict(TypedDict):
|
||||
size_bytes: int | None
|
||||
value_type: str
|
||||
length: int | None
|
||||
download_url: str
|
||||
|
||||
|
||||
def _serialize_full_content(variable: WorkflowDraftVariable) -> FullContentDict | None:
|
||||
"""Serialize full_content information for large variables."""
|
||||
if not variable.is_truncated():
|
||||
return None
|
||||
@ -94,12 +101,13 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None:
|
||||
variable_file = variable.variable_file
|
||||
assert variable_file is not None
|
||||
|
||||
return {
|
||||
result: FullContentDict = {
|
||||
"size_bytes": variable_file.size,
|
||||
"value_type": variable_file.value_type.exposed_type().value,
|
||||
"length": variable_file.length,
|
||||
"download_url": file_helpers.get_signed_file_url(variable_file.upload_file_id, as_attachment=True),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_variable_access(
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
@ -29,6 +29,13 @@ from models.model import Message
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionDict(TypedDict):
|
||||
"""Shape produced by AgentScratchpadUnit.Action.to_dict()."""
|
||||
|
||||
action: str
|
||||
action_input: dict[str, Any] | str
|
||||
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
@ -331,7 +338,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
return tool_invoke_response, tool_invoke_meta
|
||||
|
||||
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
|
||||
def _convert_dict_to_action(self, action: ActionDict) -> AgentScratchpadUnit.Action:
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
|
||||
@ -32,9 +32,9 @@ class Extensible:
|
||||
|
||||
name: str
|
||||
tenant_id: str
|
||||
config: dict | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
def __init__(self, tenant_id: str, config: dict | None = None):
|
||||
def __init__(self, tenant_id: str, config: dict[str, Any] | None = None):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
@ -7,6 +10,16 @@ from extensions.ext_database import db
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
|
||||
|
||||
class ApiToolConfig(TypedDict, total=False):
|
||||
"""Expected config shape for ApiExternalDataTool.
|
||||
|
||||
Not used directly in method signatures (base class accepts dict[str, Any]);
|
||||
kept here to document the keys this tool reads from config.
|
||||
"""
|
||||
|
||||
api_based_extension_id: str
|
||||
|
||||
|
||||
class ApiExternalDataTool(ExternalDataTool):
|
||||
"""
|
||||
The api external data tool.
|
||||
@ -16,7 +29,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
"""the unique name of external data tool"""
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -37,7 +50,7 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
||||
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.extension.extensible import Extensible, ExtensionModule
|
||||
|
||||
@ -15,14 +17,14 @@ class ExternalDataTool(Extensible, ABC):
|
||||
variable: str
|
||||
"""the tool variable name of app tool"""
|
||||
|
||||
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict | None = None):
|
||||
def __init__(self, tenant_id: str, app_id: str, variable: str, config: dict[str, Any] | None = None):
|
||||
super().__init__(tenant_id, config)
|
||||
self.app_id = app_id
|
||||
self.variable = variable
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@ -33,7 +35,7 @@ class ExternalDataTool(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def query(self, inputs: dict, query: str | None = None) -> str:
|
||||
def query(self, inputs: Mapping[str, Any], query: str | None = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
import orjson
|
||||
|
||||
@ -16,6 +16,19 @@ class IdentityDict(TypedDict, total=False):
|
||||
user_type: str
|
||||
|
||||
|
||||
class LogDict(TypedDict):
|
||||
ts: str
|
||||
severity: str
|
||||
service: str
|
||||
caller: str
|
||||
message: str
|
||||
trace_id: NotRequired[str]
|
||||
span_id: NotRequired[str]
|
||||
identity: NotRequired[IdentityDict]
|
||||
attributes: NotRequired[dict[str, Any]]
|
||||
stack_trace: NotRequired[str]
|
||||
|
||||
|
||||
class StructuredJSONFormatter(logging.Formatter):
|
||||
"""
|
||||
JSON log formatter following the specified schema:
|
||||
@ -55,9 +68,9 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
|
||||
return json.dumps(log_dict, default=str, ensure_ascii=False)
|
||||
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
||||
def _build_log_dict(self, record: logging.LogRecord) -> LogDict:
|
||||
# Core fields
|
||||
log_dict: dict[str, Any] = {
|
||||
log_dict: LogDict = {
|
||||
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
|
||||
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
|
||||
"service": self._service_name,
|
||||
|
||||
@ -141,6 +141,12 @@ class RedisHealthParamsDict(TypedDict):
|
||||
health_check_interval: int | None
|
||||
|
||||
|
||||
class RedisClusterHealthParamsDict(TypedDict):
|
||||
retry: Retry
|
||||
socket_timeout: float | None
|
||||
socket_connect_timeout: float | None
|
||||
|
||||
|
||||
class RedisBaseParamsDict(TypedDict):
|
||||
username: str | None
|
||||
password: str | None
|
||||
@ -211,7 +217,7 @@ def _get_connection_health_params() -> RedisHealthParamsDict:
|
||||
)
|
||||
|
||||
|
||||
def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
def _get_cluster_connection_health_params() -> RedisClusterHealthParamsDict:
|
||||
"""Get retry and timeout parameters for Redis Cluster clients.
|
||||
|
||||
RedisCluster does not support ``health_check_interval`` as a constructor
|
||||
@ -219,8 +225,13 @@ def _get_cluster_connection_health_params() -> dict[str, Any]:
|
||||
here. Only ``retry``, ``socket_timeout``, and ``socket_connect_timeout``
|
||||
are passed through.
|
||||
"""
|
||||
params: dict[str, Any] = dict(_get_connection_health_params())
|
||||
return {k: v for k, v in params.items() if k != "health_check_interval"}
|
||||
health_params = _get_connection_health_params()
|
||||
result: RedisClusterHealthParamsDict = {
|
||||
"retry": health_params["retry"],
|
||||
"socket_timeout": health_params["socket_timeout"],
|
||||
"socket_connect_timeout": health_params["socket_connect_timeout"],
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def _get_base_redis_params() -> RedisBaseParamsDict:
|
||||
|
||||
@ -353,13 +353,17 @@ class Dataset(Base):
|
||||
if self.provider != "external":
|
||||
return None
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
|
||||
select(ExternalKnowledgeBindings).where(
|
||||
ExternalKnowledgeBindings.dataset_id == self.id,
|
||||
ExternalKnowledgeBindings.tenant_id == self.tenant_id,
|
||||
)
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
return None
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis).where(
|
||||
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id
|
||||
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id,
|
||||
ExternalKnowledgeApis.tenant_id == self.tenant_id,
|
||||
)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, TypedDict
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@ -38,6 +39,17 @@ class DataSourceOauthBinding(TypeBase):
|
||||
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBindingDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
category: str
|
||||
provider: str
|
||||
credentials: Any
|
||||
created_at: float
|
||||
updated_at: float
|
||||
disabled: bool
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
__tablename__ = "data_source_api_key_auth_bindings"
|
||||
__table_args__ = (
|
||||
@ -65,8 +77,8 @@ class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
)
|
||||
disabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"), default=False)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
def to_dict(self) -> DataSourceApiKeyAuthBindingDict:
|
||||
result: DataSourceApiKeyAuthBindingDict = {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"category": self.category,
|
||||
@ -76,3 +88,4 @@ class DataSourceApiKeyAuthBinding(TypeBase):
|
||||
"updated_at": self.updated_at.timestamp(),
|
||||
"disabled": self.disabled,
|
||||
}
|
||||
return result
|
||||
|
||||
@ -8,38 +8,38 @@ dependencies = [
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.3",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.83",
|
||||
"boto3==1.42.88",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.6.2",
|
||||
"charset-normalizer>=3.4.4",
|
||||
"flask~=3.1.2",
|
||||
"flask-compress>=1.17,<1.25",
|
||||
"flask-cors~=6.0.0",
|
||||
"flask~=3.1.3",
|
||||
"flask-compress>=1.24,<1.25",
|
||||
"flask-cors~=6.0.2",
|
||||
"flask-login~=0.6.3",
|
||||
"flask-migrate~=4.1.0",
|
||||
"flask-orjson~=2.0.0",
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=25.9.1",
|
||||
"gmpy2~=2.3.0",
|
||||
"google-api-core>=2.19.1",
|
||||
"google-api-python-client==2.193.0",
|
||||
"google-auth>=2.47.0",
|
||||
"google-api-core>=2.30.3",
|
||||
"google-api-python-client==2.194.0",
|
||||
"google-auth>=2.49.2",
|
||||
"google-auth-httplib2==0.3.1",
|
||||
"google-cloud-aiplatform>=1.123.0",
|
||||
"googleapis-common-protos>=1.65.0",
|
||||
"google-cloud-aiplatform>=1.147.0",
|
||||
"googleapis-common-protos>=1.74.0",
|
||||
"graphon>=0.1.2",
|
||||
"gunicorn~=25.3.0",
|
||||
"httpx[socks]~=0.28.0",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.55.1",
|
||||
"langfuse>=3.0.0,<5.0.0",
|
||||
"langsmith~=0.7.16",
|
||||
"langfuse>=4.2.0,<5.0.0",
|
||||
"langsmith~=0.7.30",
|
||||
"markdown~=3.10.2",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"mlflow-skinny>=3.11.1",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"opik~=1.11.2",
|
||||
"litellm==1.83.0", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.40.0",
|
||||
"opentelemetry-distro==0.61b0",
|
||||
@ -53,14 +53,14 @@ dependencies = [
|
||||
"opentelemetry-instrumentation-httpx==0.61b0",
|
||||
"opentelemetry-instrumentation-redis==0.61b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
|
||||
"opentelemetry-propagator-b3==1.40.0",
|
||||
"opentelemetry-propagator-b3==1.41.0",
|
||||
"opentelemetry-proto==1.40.0",
|
||||
"opentelemetry-sdk==1.40.0",
|
||||
"opentelemetry-semantic-conventions==0.61b0",
|
||||
"opentelemetry-util-http==0.61b0",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.1",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
"psycopg2-binary~=2.9.11",
|
||||
"pycryptodome==3.23.0",
|
||||
"pydantic~=2.12.5",
|
||||
"pydantic-settings~=2.13.1",
|
||||
@ -73,7 +73,7 @@ dependencies = [
|
||||
"redis[hiredis]~=7.4.0",
|
||||
"resend~=2.26.0",
|
||||
"sentry-sdk[flask]~=2.55.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"sqlalchemy~=2.0.49",
|
||||
"starlette==1.0.0",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
@ -86,7 +86,7 @@ dependencies = [
|
||||
"flask-restx~=1.3.2",
|
||||
"packaging~=23.2",
|
||||
"croniter>=6.0.0",
|
||||
"weaviate-client==4.20.4",
|
||||
"weaviate-client==4.20.5",
|
||||
"apscheduler>=3.11.0",
|
||||
"weave>=0.52.16",
|
||||
"fastopenapi[flask]>=0.7.0",
|
||||
@ -111,11 +111,11 @@ package = false
|
||||
dev = [
|
||||
"coverage~=7.13.4",
|
||||
"dotenv-linter~=0.7.0",
|
||||
"faker~=40.12.0",
|
||||
"faker~=40.13.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"basedpyright~=1.39.0",
|
||||
"ruff~=0.15.5",
|
||||
"pytest~=9.0.2",
|
||||
"ruff~=0.15.10",
|
||||
"pytest~=9.0.3",
|
||||
"pytest-benchmark~=5.2.3",
|
||||
"pytest-cov~=7.1.0",
|
||||
"pytest-env~=1.6.0",
|
||||
@ -130,8 +130,8 @@ dev = [
|
||||
"types-docutils~=0.22.3",
|
||||
"types-flask-cors~=6.0.0",
|
||||
"types-flask-migrate~=4.1.0",
|
||||
"types-gevent~=25.9.0",
|
||||
"types-greenlet~=3.3.0",
|
||||
"types-gevent~=26.4.0",
|
||||
"types-greenlet~=3.4.0",
|
||||
"types-html5lib~=1.1.11",
|
||||
"types-markdown~=3.10.2",
|
||||
"types-oauthlib~=3.3.0",
|
||||
@ -149,20 +149,20 @@ dev = [
|
||||
"types-pyyaml~=6.0.12",
|
||||
"types-regex~=2026.4.4",
|
||||
"types-shapely~=2.1.0",
|
||||
"types-simplejson>=3.20.0",
|
||||
"types-six>=1.17.0",
|
||||
"types-tensorflow>=2.18.0",
|
||||
"types-tqdm>=4.67.0",
|
||||
"types-simplejson>=3.20.0.20260408",
|
||||
"types-six>=1.17.0.20260408",
|
||||
"types-tensorflow>=2.18.0.20260408",
|
||||
"types-tqdm>=4.67.3.20260408",
|
||||
"types-ujson>=5.10.0",
|
||||
"boto3-stubs>=1.38.20",
|
||||
"types-jmespath>=1.0.2.20240106",
|
||||
"hypothesis>=6.131.15",
|
||||
"boto3-stubs>=1.42.88",
|
||||
"types-jmespath>=1.1.0.20260408",
|
||||
"hypothesis>=6.151.12",
|
||||
"types_pyOpenSSL>=24.1.0",
|
||||
"types_cffi>=1.17.0",
|
||||
"types_setuptools>=80.9.0",
|
||||
"types_cffi>=2.0.0.20260408",
|
||||
"types_setuptools>=82.0.0.20260408",
|
||||
"pandas-stubs~=3.0.0",
|
||||
"scipy-stubs>=1.15.3.0",
|
||||
"types-python-http-client>=3.3.7.20240910",
|
||||
"types-python-http-client>=3.3.7.20260408",
|
||||
"import-linter>=2.3",
|
||||
"types-redis>=4.6.0.20241004",
|
||||
"celery-types>=0.23.0",
|
||||
@ -180,10 +180,10 @@ dev = [
|
||||
############################################################
|
||||
storage = [
|
||||
"azure-storage-blob==12.28.0",
|
||||
"bce-python-sdk~=0.9.23",
|
||||
"bce-python-sdk~=0.9.69",
|
||||
"cos-python-sdk-v5==1.9.41",
|
||||
"esdk-obs-python==3.26.2",
|
||||
"google-cloud-storage>=3.0.0",
|
||||
"google-cloud-storage>=3.10.1",
|
||||
"opendal~=0.46.0",
|
||||
"oss2==2.19.1",
|
||||
"supabase~=2.18.1",
|
||||
@ -209,19 +209,19 @@ vdb = [
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==3.1.0",
|
||||
"oracledb==3.4.2",
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.1",
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.2",
|
||||
"pgvector==0.4.2",
|
||||
"pymilvus~=2.6.10",
|
||||
"pymilvus~=2.6.12",
|
||||
"pymochow==2.4.0",
|
||||
"pyobvector~=0.2.17",
|
||||
"qdrant-client==1.9.0",
|
||||
"intersystems-irispython>=5.1.0",
|
||||
"tablestore==6.4.3",
|
||||
"tablestore==6.4.4",
|
||||
"tcvectordb~=2.1.0",
|
||||
"tidb-vector==0.0.15",
|
||||
"upstash-vector==0.8.0",
|
||||
"volcengine-compat~=1.0.0",
|
||||
"weaviate-client==4.20.4",
|
||||
"weaviate-client==4.20.5",
|
||||
"xinference-client~=2.4.0",
|
||||
"mo-vector~=0.1.13",
|
||||
"mysql-connector-python>=9.3.0",
|
||||
|
||||
@ -528,6 +528,8 @@ class DatasetService:
|
||||
raise ValueError("External knowledge id is required.")
|
||||
if not external_knowledge_api_id:
|
||||
raise ValueError("External knowledge api id is required.")
|
||||
# Ensure the referenced external API template exists and belongs to the dataset tenant.
|
||||
ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id, dataset.tenant_id)
|
||||
# Update metadata fields
|
||||
dataset.updated_by = user.id if user else None
|
||||
dataset.updated_at = naive_utc_now()
|
||||
|
||||
@ -317,7 +317,10 @@ class ExternalDatasetService:
|
||||
|
||||
external_knowledge_api = db.session.scalar(
|
||||
select(ExternalKnowledgeApis)
|
||||
.where(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
|
||||
.where(
|
||||
ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id,
|
||||
ExternalKnowledgeApis.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
if external_knowledge_api is None or external_knowledge_api.settings is None:
|
||||
|
||||
@ -3,8 +3,9 @@ import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any, ClassVar, NotRequired, TypedDict
|
||||
|
||||
from graphon.enums import NodeType
|
||||
from graphon.file import File
|
||||
@ -725,8 +726,27 @@ def _batch_upsert_draft_variable(
|
||||
session.execute(stmt)
|
||||
|
||||
|
||||
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
|
||||
d: dict[str, Any] = {
|
||||
class _InsertionDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
user_id: str | None
|
||||
last_edited_at: datetime | None
|
||||
node_id: str
|
||||
name: str
|
||||
selector: str
|
||||
value_type: SegmentType
|
||||
value: str
|
||||
node_execution_id: str | None
|
||||
file_id: str | None
|
||||
visible: NotRequired[bool]
|
||||
editable: NotRequired[bool]
|
||||
created_at: NotRequired[datetime]
|
||||
updated_at: NotRequired[datetime]
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
def _model_to_insertion_dict(model: WorkflowDraftVariable) -> _InsertionDict:
|
||||
d: _InsertionDict = {
|
||||
"id": model.id,
|
||||
"app_id": model.app_id,
|
||||
"user_id": model.user_id,
|
||||
|
||||
@ -8,6 +8,7 @@ from collections.abc import Generator
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app_factory import create_app
|
||||
@ -83,15 +84,15 @@ def setup_account(request) -> Generator[Account, None, None]:
|
||||
|
||||
with _CACHED_APP.test_request_context():
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
account = session.query(Account).filter_by(email=email).one()
|
||||
account = session.scalars(select(Account).filter_by(email=email)).one()
|
||||
|
||||
yield account
|
||||
|
||||
with _CACHED_APP.test_request_context():
|
||||
db.session.query(DifySetup).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.execute(delete(DifySetup))
|
||||
db.session.execute(delete(TenantAccountJoin))
|
||||
db.session.execute(delete(Account))
|
||||
db.session.execute(delete(Tenant))
|
||||
db.session.commit()
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
@ -61,7 +61,11 @@ class TestPluginPermissionLifecycle:
|
||||
assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count()
|
||||
count = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(TenantPluginPermission)
|
||||
.where(TenantPluginPermission.tenant_id == tenant)
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import math
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
@ -210,7 +210,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids)))
|
||||
assert remaining == len(all_ids)
|
||||
|
||||
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
|
||||
@ -231,7 +231,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == len(all_ids)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(all_ids)))
|
||||
assert remaining == 0
|
||||
|
||||
def test_start_from_filters_correctly(self, seed_messages):
|
||||
@ -254,7 +254,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
all_ids = list(msg_ids.values())
|
||||
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
|
||||
remaining_ids = set(session.scalars(select(Message.id).where(Message.id.in_(all_ids))).all())
|
||||
|
||||
assert msg_ids["old"] not in remaining_ids
|
||||
assert msg_ids["very_old"] in remaining_ids
|
||||
@ -282,7 +282,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["batches"] >= expected_batches
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
|
||||
remaining = session.scalar(select(func.count()).select_from(Message).where(Message.id.in_(msg_ids)))
|
||||
assert remaining == 0
|
||||
|
||||
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
|
||||
@ -319,9 +319,17 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
assert session.query(Message).where(Message.id == msg_id).count() == 0
|
||||
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
|
||||
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
|
||||
assert session.scalar(select(func.count()).select_from(Message).where(Message.id == msg_id)) == 0
|
||||
assert (
|
||||
session.scalar(select(func.count()).select_from(MessageFeedback).where(MessageFeedback.id == fb_id))
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
session.scalar(
|
||||
select(func.count()).select_from(MessageAnnotation).where(MessageAnnotation.id == ann_id)
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
def test_factory_from_time_range_validation(self):
|
||||
with pytest.raises(ValueError, match="start_from"):
|
||||
|
||||
@ -7,7 +7,7 @@ from graphon.nodes import BuiltinNodeTypes
|
||||
from graphon.variables.segments import StringSegment
|
||||
from graphon.variables.types import SegmentType
|
||||
from graphon.variables.variables import StringVariable
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
@ -38,21 +38,25 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._test_app_id = str(uuid.uuid4())
|
||||
self._test_user_id = str(uuid.uuid4())
|
||||
self._session: Session = db.session()
|
||||
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
name="sys_var",
|
||||
value=build_segment("sys_value"),
|
||||
node_execution_id=self._node_exec_id,
|
||||
)
|
||||
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
name="conv_var",
|
||||
value=build_segment("conv_value"),
|
||||
)
|
||||
node2_vars = [
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id=self._node2_id,
|
||||
name="int_var",
|
||||
value=build_segment(1),
|
||||
@ -61,6 +65,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
),
|
||||
WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id=self._node2_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
@ -70,6 +75,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
]
|
||||
node1_var = WorkflowDraftVariable.new_node_variable(
|
||||
app_id=self._test_app_id,
|
||||
user_id=self._test_user_id,
|
||||
node_id=self._node1_id,
|
||||
name="str_var",
|
||||
value=build_segment("str_value"),
|
||||
@ -141,24 +147,27 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||
def test_delete_node_variables(self):
|
||||
srv = self._get_test_srv()
|
||||
srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id)
|
||||
node2_var_count = (
|
||||
self._session.query(WorkflowDraftVariable)
|
||||
node2_var_count = self._session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariable)
|
||||
.where(
|
||||
WorkflowDraftVariable.app_id == self._test_app_id,
|
||||
WorkflowDraftVariable.node_id == self._node2_id,
|
||||
WorkflowDraftVariable.user_id == self._test_user_id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
assert node2_var_count == 0
|
||||
|
||||
def test_delete_variable(self):
|
||||
srv = self._get_test_srv()
|
||||
node_1_var = (
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
|
||||
)
|
||||
node_1_var = self._session.scalars(
|
||||
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id)
|
||||
).one()
|
||||
srv.delete_variable(node_1_var)
|
||||
exists = bool(
|
||||
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
|
||||
self._session.scalars(
|
||||
select(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id)
|
||||
).first()
|
||||
)
|
||||
assert exists is False
|
||||
|
||||
@ -248,9 +257,7 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == self._test_app_id))
|
||||
session.commit()
|
||||
|
||||
def test_variable_loader_with_empty_selector(self):
|
||||
@ -431,9 +438,11 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id))
|
||||
session.execute(
|
||||
delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id)
|
||||
)
|
||||
session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id))
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
@ -534,9 +543,11 @@ class TestDraftVariableLoader(unittest.TestCase):
|
||||
# Clean up
|
||||
with Session(bind=db.engine) as session:
|
||||
# Query and delete by ID to ensure they're tracked in this session
|
||||
session.query(WorkflowDraftVariable).filter_by(id=offloaded_var.id).delete()
|
||||
session.query(WorkflowDraftVariableFile).filter_by(id=variable_file.id).delete()
|
||||
session.query(UploadFile).filter_by(id=upload_file.id).delete()
|
||||
session.execute(delete(WorkflowDraftVariable).where(WorkflowDraftVariable.id == offloaded_var.id))
|
||||
session.execute(
|
||||
delete(WorkflowDraftVariableFile).where(WorkflowDraftVariableFile.id == variable_file.id)
|
||||
)
|
||||
session.execute(delete(UploadFile).where(UploadFile.id == upload_file.id))
|
||||
session.commit()
|
||||
# Clean up storage
|
||||
try:
|
||||
|
||||
@ -3,7 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from graphon.variables.segments import StringSegment
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from extensions.storage.storage_type import StorageType
|
||||
@ -108,8 +108,12 @@ class TestDeleteDraftVariablesIntegration:
|
||||
app2_id = data["app2"].id
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
app1_vars_before = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id)
|
||||
)
|
||||
app2_vars_before = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app2_id)
|
||||
)
|
||||
assert app1_vars_before == 5
|
||||
assert app2_vars_before == 5
|
||||
|
||||
@ -117,8 +121,12 @@ class TestDeleteDraftVariablesIntegration:
|
||||
assert deleted_count == 5
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
|
||||
app1_vars_after = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id)
|
||||
)
|
||||
app2_vars_after = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app2_id)
|
||||
)
|
||||
assert app1_vars_after == 0
|
||||
assert app2_vars_after == 5
|
||||
|
||||
@ -130,7 +138,9 @@ class TestDeleteDraftVariablesIntegration:
|
||||
assert deleted_count == 5
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
remaining_vars = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id)
|
||||
)
|
||||
assert remaining_vars == 0
|
||||
|
||||
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
|
||||
@ -143,14 +153,18 @@ class TestDeleteDraftVariablesIntegration:
|
||||
app1_id = data["app1"].id
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
vars_before = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id)
|
||||
)
|
||||
assert vars_before == 5
|
||||
|
||||
deleted_count = _delete_draft_variables(app1_id)
|
||||
assert deleted_count == 5
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
|
||||
vars_after = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app1_id)
|
||||
)
|
||||
assert vars_after == 0
|
||||
|
||||
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
|
||||
@ -175,7 +189,9 @@ class TestDeleteDraftVariablesIntegration:
|
||||
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
|
||||
assert deleted_count == 25
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
|
||||
remaining = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app.id)
|
||||
)
|
||||
assert remaining == 0
|
||||
finally:
|
||||
with session_factory.create_session() as session:
|
||||
@ -307,13 +323,17 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
mock_storage.delete.return_value = None
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
draft_vars_before = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
|
||||
)
|
||||
var_files_before = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
)
|
||||
upload_files_before = session.scalar(
|
||||
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
)
|
||||
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_before == 3
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
@ -322,16 +342,20 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
assert deleted_count == 3
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
draft_vars_after = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
|
||||
)
|
||||
assert draft_vars_after == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
var_files_after = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
upload_files_after = session.scalar(
|
||||
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
)
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
@ -352,16 +376,20 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
||||
assert deleted_count == 3
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
draft_vars_after = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
|
||||
)
|
||||
assert draft_vars_after == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
var_files_after = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||
.count()
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
upload_files_after = session.scalar(
|
||||
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
)
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
@ -579,7 +607,9 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
|
||||
# Verify all data was deleted (proves transaction was committed)
|
||||
with session_factory.create_session() as session:
|
||||
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
remaining_count = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
|
||||
)
|
||||
|
||||
assert deleted_count == 10
|
||||
assert remaining_count == 0
|
||||
@ -592,7 +622,9 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
|
||||
# Verify initial state
|
||||
with session_factory.create_session() as session:
|
||||
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
initial_count = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
|
||||
)
|
||||
assert initial_count == 10
|
||||
|
||||
# Perform deletion with small batch size to force multiple commits
|
||||
@ -602,13 +634,17 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
|
||||
# Verify all data is deleted in a new session (proves commits worked)
|
||||
with session_factory.create_session() as session:
|
||||
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
final_count = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
|
||||
)
|
||||
assert final_count == 0
|
||||
|
||||
# Verify specific IDs are deleted
|
||||
with session_factory.create_session() as session:
|
||||
remaining_vars = (
|
||||
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
|
||||
remaining_vars = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariable)
|
||||
.where(WorkflowDraftVariable.id.in_(variable_ids))
|
||||
)
|
||||
assert remaining_vars == 0
|
||||
|
||||
@ -626,7 +662,9 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
app_id = data["app"].id
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
initial_count = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
|
||||
)
|
||||
assert initial_count == 10
|
||||
|
||||
# Delete all in a single batch
|
||||
@ -635,7 +673,9 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
|
||||
# Verify data is persisted
|
||||
with session_factory.create_session() as session:
|
||||
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
final_count = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
|
||||
)
|
||||
assert final_count == 0
|
||||
|
||||
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
|
||||
@ -659,13 +699,17 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
|
||||
# Verify initial state
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_before = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||
.count()
|
||||
draft_vars_before = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
|
||||
)
|
||||
var_files_before = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||
)
|
||||
upload_files_before = session.scalar(
|
||||
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
)
|
||||
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_before == 3
|
||||
assert var_files_before == 2
|
||||
assert upload_files_before == 2
|
||||
@ -676,13 +720,17 @@ class TestDeleteDraftVariablesSessionCommit:
|
||||
|
||||
# Verify all data is persisted (deleted) in new session
|
||||
with session_factory.create_session() as session:
|
||||
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||
var_files_after = (
|
||||
session.query(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||
.count()
|
||||
draft_vars_after = session.scalar(
|
||||
select(func.count()).select_from(WorkflowDraftVariable).filter_by(app_id=app_id)
|
||||
)
|
||||
var_files_after = session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowDraftVariableFile)
|
||||
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||
)
|
||||
upload_files_after = session.scalar(
|
||||
select(func.count()).select_from(UploadFile).where(UploadFile.id.in_(upload_file_ids))
|
||||
)
|
||||
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||
assert draft_vars_after == 0
|
||||
assert var_files_after == 0
|
||||
assert upload_files_after == 0
|
||||
|
||||
@ -637,6 +637,40 @@ class TestConversationServiceSummarization:
|
||||
assert conversation.name == new_name
|
||||
assert conversation.updated_at == mock_time
|
||||
|
||||
@patch("services.conversation_service.LLMGenerator.generate_conversation_name")
|
||||
def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers):
|
||||
"""
|
||||
Test rename delegates to auto_generate_name when auto_generate is True.
|
||||
|
||||
When auto_generate is True, the service should call auto_generate_name
|
||||
which uses an LLM to create a descriptive conversation title.
|
||||
"""
|
||||
# Arrange
|
||||
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
|
||||
db_session_with_containers
|
||||
)
|
||||
conversation = ConversationServiceIntegrationTestDataFactory.create_conversation(
|
||||
db_session_with_containers, app_model, user
|
||||
)
|
||||
ConversationServiceIntegrationTestDataFactory.create_message(
|
||||
db_session_with_containers, app_model, conversation, user
|
||||
)
|
||||
generated_name = "Auto Generated Name"
|
||||
mock_llm_generator.return_value = generated_name
|
||||
|
||||
# Act
|
||||
result = ConversationService.rename(
|
||||
app_model=app_model,
|
||||
conversation_id=conversation.id,
|
||||
user=user,
|
||||
name=None,
|
||||
auto_generate=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == conversation
|
||||
assert conversation.name == generated_name
|
||||
|
||||
|
||||
class TestConversationServiceMessageAnnotation:
|
||||
"""
|
||||
@ -1066,3 +1100,32 @@ class TestConversationServiceExport:
|
||||
not_deleted = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation.id))
|
||||
assert not_deleted is not None
|
||||
mock_delete_task.delay.assert_not_called()
|
||||
|
||||
@patch("services.conversation_service.delete_conversation_related_data")
|
||||
def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers):
|
||||
"""
|
||||
Test that delete propagates exceptions and does not trigger the cleanup task.
|
||||
|
||||
When a DB error occurs during deletion, the service must rollback the
|
||||
transaction and re-raise the exception without scheduling async cleanup.
|
||||
"""
|
||||
# Arrange
|
||||
app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account(
|
||||
db_session_with_containers
|
||||
)
|
||||
conversation = ConversationServiceIntegrationTestDataFactory.create_conversation(
|
||||
db_session_with_containers, app_model, user
|
||||
)
|
||||
conversation_id = conversation.id
|
||||
|
||||
# Act — force an error during the delete to exercise the rollback path
|
||||
with patch("services.conversation_service.db.session.delete", side_effect=Exception("DB error")):
|
||||
with pytest.raises(Exception, match="DB error"):
|
||||
ConversationService.delete(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
|
||||
# Assert — async cleanup must NOT have been scheduled
|
||||
mock_delete_task.delay.assert_not_called()
|
||||
|
||||
# Conversation is still present because the deletion was never committed
|
||||
still_there = db_session_with_containers.scalar(select(Conversation).where(Conversation.id == conversation_id))
|
||||
assert still_there is not None
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
@ -7,7 +8,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||
from models.dataset import Dataset, ExternalKnowledgeApis, ExternalKnowledgeBindings
|
||||
from models.enums import DataSourceType
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
@ -103,6 +104,34 @@ class DatasetUpdateTestDataFactory:
|
||||
db_session_with_containers.commit()
|
||||
return binding
|
||||
|
||||
@staticmethod
|
||||
def create_external_knowledge_api(
|
||||
db_session_with_containers: Session,
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
api_id: str | None = None,
|
||||
name: str = "test-api",
|
||||
) -> ExternalKnowledgeApis:
|
||||
"""Create a real external knowledge API template for tenant-scoped update validation."""
|
||||
external_api = ExternalKnowledgeApis(
|
||||
tenant_id=tenant_id,
|
||||
created_by=created_by,
|
||||
updated_by=created_by,
|
||||
name=name,
|
||||
description="test description",
|
||||
settings=json.dumps(
|
||||
{
|
||||
"endpoint": "https://example.com",
|
||||
"api_key": "test-api-key",
|
||||
}
|
||||
),
|
||||
)
|
||||
if api_id is not None:
|
||||
external_api.id = api_id
|
||||
db_session_with_containers.add(external_api)
|
||||
db_session_with_containers.commit()
|
||||
return external_api
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateDataset:
|
||||
"""
|
||||
@ -138,6 +167,11 @@ class TestDatasetServiceUpdateDataset:
|
||||
)
|
||||
binding_id = binding.id
|
||||
db_session_with_containers.expunge(binding)
|
||||
external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
@ -145,7 +179,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
"external_retrieval_model": "new_model",
|
||||
"permission": "only_me",
|
||||
"external_knowledge_id": "new_knowledge_id",
|
||||
"external_knowledge_api_id": str(uuid4()),
|
||||
"external_knowledge_api_id": external_api.id,
|
||||
}
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
@ -218,11 +252,16 @@ class TestDatasetServiceUpdateDataset:
|
||||
created_by=user.id,
|
||||
provider="external",
|
||||
)
|
||||
external_api = DatasetUpdateTestDataFactory.create_external_knowledge_api(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"external_knowledge_id": "knowledge_id",
|
||||
"external_knowledge_api_id": str(uuid4()),
|
||||
"external_knowledge_api_id": external_api.id,
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
|
||||
@ -12,7 +12,7 @@ This test suite covers:
|
||||
import json
|
||||
import pickle
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexTechniqueType
|
||||
@ -25,6 +25,7 @@ from models.dataset import (
|
||||
Document,
|
||||
DocumentSegment,
|
||||
Embedding,
|
||||
ExternalKnowledgeBindings,
|
||||
)
|
||||
from models.enums import (
|
||||
DataSourceType,
|
||||
@ -180,6 +181,24 @@ class TestDatasetModelValidation:
|
||||
assert result["top_k"] == 2
|
||||
assert result["score_threshold"] == 0.0
|
||||
|
||||
def test_dataset_external_knowledge_info_returns_none_for_cross_tenant_template(self):
|
||||
"""Test external datasets fail closed when the bound template is outside the tenant."""
|
||||
dataset = Dataset(
|
||||
tenant_id=str(uuid4()),
|
||||
name="External Dataset",
|
||||
data_source_type=DataSourceType.UPLOAD_FILE,
|
||||
created_by=str(uuid4()),
|
||||
provider="external",
|
||||
)
|
||||
binding = Mock(spec=ExternalKnowledgeBindings)
|
||||
binding.external_knowledge_id = "knowledge-1"
|
||||
binding.external_knowledge_api_id = str(uuid4())
|
||||
|
||||
with patch("models.dataset.db") as mock_db:
|
||||
mock_db.session.scalar.side_effect = [binding, None]
|
||||
|
||||
assert dataset.external_knowledge_info is None
|
||||
|
||||
def test_dataset_retrieval_model_dict_property(self):
|
||||
"""Test retrieval_model_dict property with default values."""
|
||||
# Arrange
|
||||
|
||||
@ -435,36 +435,6 @@ class TestConversationServiceRename:
|
||||
assert conversation.name == "New Name"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.conversation_service.db.session")
|
||||
@patch("services.conversation_service.ConversationService.get_conversation")
|
||||
@patch("services.conversation_service.ConversationService.auto_generate_name")
|
||||
def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session):
|
||||
"""
|
||||
Test renaming conversation with auto-generation.
|
||||
|
||||
Should call auto_generate_name when auto_generate is True.
|
||||
"""
|
||||
# Arrange
|
||||
app_model = ConversationServiceTestDataFactory.create_app_mock()
|
||||
user = ConversationServiceTestDataFactory.create_account_mock()
|
||||
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
|
||||
|
||||
mock_get_conversation.return_value = conversation
|
||||
mock_auto_generate.return_value = conversation
|
||||
|
||||
# Act
|
||||
result = ConversationService.rename(
|
||||
app_model=app_model,
|
||||
conversation_id="conv-123",
|
||||
user=user,
|
||||
name=None,
|
||||
auto_generate=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == conversation
|
||||
mock_auto_generate.assert_called_once_with(app_model, conversation)
|
||||
|
||||
|
||||
class TestConversationServiceAutoGenerateName:
|
||||
"""Test conversation auto-name generation operations."""
|
||||
@ -576,29 +546,6 @@ class TestConversationServiceDelete:
|
||||
mock_db_session.commit.assert_called_once()
|
||||
mock_delete_task.delay.assert_called_once_with(conversation.id)
|
||||
|
||||
@patch("services.conversation_service.db.session")
|
||||
@patch("services.conversation_service.ConversationService.get_conversation")
|
||||
def test_delete_handles_exception_and_rollback(self, mock_get_conversation, mock_db_session):
|
||||
"""
|
||||
Test deletion handles exceptions and rolls back transaction.
|
||||
|
||||
Should rollback database changes when deletion fails.
|
||||
"""
|
||||
# Arrange
|
||||
app_model = ConversationServiceTestDataFactory.create_app_mock()
|
||||
user = ConversationServiceTestDataFactory.create_account_mock()
|
||||
conversation = ConversationServiceTestDataFactory.create_conversation_mock()
|
||||
|
||||
mock_get_conversation.return_value = conversation
|
||||
mock_db_session.delete.side_effect = Exception("Database Error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Database Error"):
|
||||
ConversationService.delete(app_model, "conv-123", user)
|
||||
|
||||
# Assert rollback was called
|
||||
mock_db_session.rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestConversationServiceConversationalVariable:
|
||||
"""Test conversational variable operations."""
|
||||
|
||||
@ -532,6 +532,9 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
|
||||
with (
|
||||
patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding,
|
||||
patch(
|
||||
"services.dataset_service.ExternalDatasetService.get_external_knowledge_api", return_value=object()
|
||||
) as get_external_knowledge_api,
|
||||
patch("services.dataset_service.naive_utc_now", return_value=now),
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
):
|
||||
@ -557,6 +560,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
assert dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM
|
||||
assert dataset.updated_by == "user-1"
|
||||
assert dataset.updated_at is now
|
||||
get_external_knowledge_api.assert_called_once_with("api-1", dataset.tenant_id)
|
||||
update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1")
|
||||
mock_db.session.add.assert_called_once_with(dataset)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
@ -574,6 +578,31 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
with pytest.raises(ValueError, match=message):
|
||||
DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1"))
|
||||
|
||||
def test_update_external_dataset_rejects_cross_tenant_external_api_id(self):
|
||||
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"services.dataset_service.ExternalDatasetService.get_external_knowledge_api",
|
||||
side_effect=ValueError("api template not found"),
|
||||
) as get_external_knowledge_api,
|
||||
patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding,
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
):
|
||||
with pytest.raises(ValueError, match="api template not found"):
|
||||
DatasetService._update_external_dataset(
|
||||
dataset,
|
||||
{
|
||||
"external_knowledge_id": "knowledge-1",
|
||||
"external_knowledge_api_id": "foreign-api",
|
||||
},
|
||||
SimpleNamespace(id="user-1"),
|
||||
)
|
||||
|
||||
get_external_knowledge_api.assert_called_once_with("foreign-api", dataset.tenant_id)
|
||||
update_binding.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
def test_update_external_knowledge_binding_updates_changed_binding_values(self):
|
||||
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
|
||||
session = MagicMock()
|
||||
|
||||
@ -1560,6 +1560,17 @@ class TestExternalDatasetServiceFetchRetrieval:
|
||||
with pytest.raises(ValueError, match="external knowledge binding not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {})
|
||||
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_cross_tenant_api_template_error(self, mock_db, factory):
|
||||
"""Test error when a binding points to an API template outside the dataset tenant."""
|
||||
# Arrange
|
||||
binding = factory.create_external_knowledge_binding_mock()
|
||||
mock_db.session.scalar.side_effect = [binding, None]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="external api template not found"):
|
||||
ExternalDatasetService.fetch_external_knowledge_retrieval("tenant-123", "dataset-123", "query", {})
|
||||
|
||||
@patch("services.external_knowledge_service.ExternalDatasetService.process_external_api")
|
||||
@patch("services.external_knowledge_service.db")
|
||||
def test_fetch_external_knowledge_retrieval_empty_results(self, mock_db, mock_process, factory):
|
||||
|
||||
446
api/uv.lock
generated
446
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user